Amir Mahla commited on
Commit
02be0d4
·
1 Parent(s): 1d83133

FIX race condition

Browse files
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -31,6 +31,10 @@ from starlette.websockets import WebSocketState
31
 
32
  logger = logging.getLogger(__name__)
33
 
 
 
 
 
34
 
35
  class AgentStopException(Exception):
36
  """Exception for agent stop"""
@@ -119,26 +123,27 @@ class AgentService:
119
  """Create a new ID and sandbox"""
120
  # Prevent sandbox creation for the first 30 seconds after app start
121
  # This prevents spawning sandboxes for all users already connected when app restarts
122
- elapsed_time = time.time() - self._start_time
123
- if elapsed_time < 30:
124
- logger.info(
125
- f"Skipping sandbox creation (app started {elapsed_time:.1f}s ago, "
126
- f"waiting for 30s grace period)"
127
- )
128
- # Still create UUID and register websocket, but don't acquire sandbox
129
- async with self._lock:
130
- uuid = str(uuid4())
131
- while uuid in self.active_tasks:
132
- uuid = str(uuid4())
133
- self.task_websockets[uuid] = websocket
134
- return uuid
135
 
136
  async with self._lock:
137
  uuid = str(uuid4())
138
  while uuid in self.active_tasks:
139
  uuid = str(uuid4())
140
  self.task_websockets[uuid] = websocket
141
- await self.sandbox_service.acquire_sandbox(uuid)
 
142
  return uuid
143
 
144
  async def process_user_task(
@@ -150,32 +155,32 @@ class AgentService:
150
  trace.steps = []
151
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
152
 
153
- trace_id_to_release = None
154
  async with self._lock:
155
  if self.task_websockets[trace_id] != websocket:
156
  # Release sandbox before raising exception to prevent leak
157
  # Do this outside the lock to avoid deadlock
158
- trace_id_to_release = trace_id
159
  # Remove from task_websockets since we're rejecting this
160
  if trace_id in self.task_websockets:
161
  del self.task_websockets[trace_id]
162
 
163
- # Release sandbox outside of lock if there was a mismatch
164
- if trace_id_to_release:
165
- try:
166
- await self.sandbox_service.release_sandbox(trace_id_to_release)
167
- logger.info(
168
- f"Released sandbox for {trace_id_to_release} due to WebSocket mismatch"
169
- )
170
- except Exception as e:
171
- logger.error(
172
- f"Error releasing sandbox for {trace_id_to_release}: {e}",
173
- exc_info=True,
174
- )
175
- raise WebSocketException("WebSocket mismatch")
176
-
177
- # Continue with normal processing if no mismatch
178
- async with self._lock:
179
  active_task = ActiveTask(
180
  message_id=trace_id,
181
  instruction=trace.instruction,
@@ -320,10 +325,18 @@ class AgentService:
320
  image = Image.open(BytesIO(screenshot_bytes))
321
  self.last_screenshot[message_id] = (image, step_filename)
322
 
323
- await asyncio.to_thread(
324
- agent.run,
325
- user_content,
326
- )
 
 
 
 
 
 
 
 
327
 
328
  self.active_tasks[message_id].traceMetadata.completed = True
329
 
 
31
 
32
  logger = logging.getLogger(__name__)
33
 
34
+ # Timeout constants to prevent stuck threads
35
+ AGENT_RUN_TIMEOUT = 1000 # 10 minutes - maximum time for agent.run() to complete
36
+ SANDBOX_KILL_TIMEOUT = 30 # 30 seconds - maximum time for sandbox.kill() to complete
37
+
38
 
39
  class AgentStopException(Exception):
40
  """Exception for agent stop"""
 
123
  """Create a new ID and sandbox"""
124
  # Prevent sandbox creation for the first 30 seconds after app start
125
  # This prevents spawning sandboxes for all users already connected when app restarts
126
+ # elapsed_time = time.time() - self._start_time
127
+ # if elapsed_time < 30:
128
+ # logger.info(
129
+ # f"Skipping sandbox creation (app started {elapsed_time:.1f}s ago, "
130
+ # f"waiting for 30s grace period)"
131
+ # )
132
+ # # Still create UUID and register websocket, but don't acquire sandbox
133
+ # async with self._lock:
134
+ # uuid = str(uuid4())
135
+ # while uuid in self.active_tasks:
136
+ # uuid = str(uuid4())
137
+ # self.task_websockets[uuid] = websocket
138
+ # return uuid
139
 
140
  async with self._lock:
141
  uuid = str(uuid4())
142
  while uuid in self.active_tasks:
143
  uuid = str(uuid4())
144
  self.task_websockets[uuid] = websocket
145
+ logger.info(f"Created UUID {uuid} and registered websocket")
146
+ # await self.sandbox_service.acquire_sandbox(uuid)
147
  return uuid
148
 
149
  async def process_user_task(
 
155
  trace.steps = []
156
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
157
 
158
+ # trace_id_to_release = None
159
  async with self._lock:
160
  if self.task_websockets[trace_id] != websocket:
161
  # Release sandbox before raising exception to prevent leak
162
  # Do this outside the lock to avoid deadlock
163
+ # trace_id_to_release = trace_id
164
  # Remove from task_websockets since we're rejecting this
165
  if trace_id in self.task_websockets:
166
  del self.task_websockets[trace_id]
167
 
168
+ # # Release sandbox outside of lock if there was a mismatch
169
+ # if trace_id_to_release:
170
+ # try:
171
+ # await self.sandbox_service.release_sandbox(trace_id_to_release)
172
+ # logger.info(
173
+ # f"Released sandbox for {trace_id_to_release} due to WebSocket mismatch"
174
+ # )
175
+ # except Exception as e:
176
+ # logger.error(
177
+ # f"Error releasing sandbox for {trace_id_to_release}: {e}",
178
+ # exc_info=True,
179
+ # )
180
+ # raise WebSocketException("WebSocket mismatch")
181
+
182
+ # # Continue with normal processing if no mismatch
183
+ # async with self._lock:
184
  active_task = ActiveTask(
185
  message_id=trace_id,
186
  instruction=trace.instruction,
 
325
  image = Image.open(BytesIO(screenshot_bytes))
326
  self.last_screenshot[message_id] = (image, step_filename)
327
 
328
+ try:
329
+ await asyncio.wait_for(
330
+ asyncio.to_thread(agent.run, user_content),
331
+ timeout=AGENT_RUN_TIMEOUT,
332
+ )
333
+ except asyncio.TimeoutError:
334
+ logger.error(
335
+ f"Agent run timed out after {AGENT_RUN_TIMEOUT} seconds for {message_id}"
336
+ )
337
+ raise Exception(
338
+ f"Agent run timed out after {AGENT_RUN_TIMEOUT} seconds"
339
+ )
340
 
341
  self.active_tasks[message_id].traceMetadata.completed = True
342
 
cua2-core/src/cua2_core/services/sandbox_service.py CHANGED
@@ -9,6 +9,10 @@ from pydantic import BaseModel
9
 
10
  SANDBOX_TIMEOUT = 500
11
  SANDBOX_READY_TIMEOUT = 200 # Seconds before a sandbox expires
 
 
 
 
12
  WIDTH = 1280
13
  HEIGHT = 960
14
 
@@ -140,7 +144,10 @@ class SandboxService:
140
  """Background task to create a sandbox"""
141
  desktop = None
142
  try:
143
- desktop = await asyncio.to_thread(self._create_and_setup_sandbox)
 
 
 
144
  print(
145
  f"Sandbox created for session {session_hash}, ID: {desktop.sandbox_id}"
146
  )
@@ -174,6 +181,16 @@ class SandboxService:
174
  self.sandboxes[session_hash] = SandboxEntry(desktop)
175
  print(f"Sandbox {session_hash} is now ready")
176
 
 
 
 
 
 
 
 
 
 
 
177
  except Exception as e:
178
  error_msg = str(e)
179
  import traceback
@@ -208,7 +225,14 @@ class SandboxService:
208
  async def _kill_sandbox_safe(self, sandbox: Sandbox, session_hash: str):
209
  """Safely kill a sandbox with error handling"""
210
  try:
211
- await asyncio.to_thread(sandbox.kill)
 
 
 
 
 
 
 
212
  except Exception as e:
213
  print(f"Error killing sandbox for session {session_hash}: {str(e)}")
214
 
 
9
 
10
  SANDBOX_TIMEOUT = 500
11
  SANDBOX_READY_TIMEOUT = 200 # Seconds before a sandbox expires
12
+ SANDBOX_CREATION_THREAD_TIMEOUT = (
13
+ 300 # Timeout for sandbox creation thread to prevent hanging
14
+ )
15
+ SANDBOX_KILL_TIMEOUT = 30 # Timeout for sandbox.kill() to prevent hanging
16
  WIDTH = 1280
17
  HEIGHT = 960
18
 
 
144
  """Background task to create a sandbox"""
145
  desktop = None
146
  try:
147
+ desktop = await asyncio.wait_for(
148
+ asyncio.to_thread(self._create_and_setup_sandbox),
149
+ timeout=SANDBOX_CREATION_THREAD_TIMEOUT,
150
+ )
151
  print(
152
  f"Sandbox created for session {session_hash}, ID: {desktop.sandbox_id}"
153
  )
 
181
  self.sandboxes[session_hash] = SandboxEntry(desktop)
182
  print(f"Sandbox {session_hash} is now ready")
183
 
184
+ except asyncio.TimeoutError:
185
+ error_msg = f"Sandbox creation timed out after {SANDBOX_CREATION_THREAD_TIMEOUT} seconds"
186
+ print(f"Error creating sandbox for session {session_hash}: {error_msg}")
187
+
188
+ async with self.lock:
189
+ self.pending.discard(session_hash)
190
+ # Store error so agent service can retrieve it
191
+ self.creation_errors[session_hash] = error_msg
192
+ if desktop:
193
+ asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
194
  except Exception as e:
195
  error_msg = str(e)
196
  import traceback
 
225
  async def _kill_sandbox_safe(self, sandbox: Sandbox, session_hash: str):
226
  """Safely kill a sandbox with error handling"""
227
  try:
228
+ await asyncio.wait_for(
229
+ asyncio.to_thread(sandbox.kill),
230
+ timeout=SANDBOX_KILL_TIMEOUT,
231
+ )
232
+ except asyncio.TimeoutError:
233
+ print(
234
+ f"Sandbox kill timed out after {SANDBOX_KILL_TIMEOUT} seconds for session {session_hash}"
235
+ )
236
  except Exception as e:
237
  print(f"Error killing sandbox for session {session_hash}: {str(e)}")
238