Spaces:
Running
Running
Amir Mahla
commited on
Commit
·
02be0d4
1
Parent(s):
1d83133
FIX race condition
Browse files
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -31,6 +31,10 @@ from starlette.websockets import WebSocketState
|
|
| 31 |
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
class AgentStopException(Exception):
|
| 36 |
"""Exception for agent stop"""
|
|
@@ -119,26 +123,27 @@ class AgentService:
|
|
| 119 |
"""Create a new ID and sandbox"""
|
| 120 |
# Prevent sandbox creation for the first 30 seconds after app start
|
| 121 |
# This prevents spawning sandboxes for all users already connected when app restarts
|
| 122 |
-
elapsed_time = time.time() - self._start_time
|
| 123 |
-
if elapsed_time < 30:
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
| 136 |
async with self._lock:
|
| 137 |
uuid = str(uuid4())
|
| 138 |
while uuid in self.active_tasks:
|
| 139 |
uuid = str(uuid4())
|
| 140 |
self.task_websockets[uuid] = websocket
|
| 141 |
-
|
|
|
|
| 142 |
return uuid
|
| 143 |
|
| 144 |
async def process_user_task(
|
|
@@ -150,32 +155,32 @@ class AgentService:
|
|
| 150 |
trace.steps = []
|
| 151 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 152 |
|
| 153 |
-
trace_id_to_release = None
|
| 154 |
async with self._lock:
|
| 155 |
if self.task_websockets[trace_id] != websocket:
|
| 156 |
# Release sandbox before raising exception to prevent leak
|
| 157 |
# Do this outside the lock to avoid deadlock
|
| 158 |
-
trace_id_to_release = trace_id
|
| 159 |
# Remove from task_websockets since we're rejecting this
|
| 160 |
if trace_id in self.task_websockets:
|
| 161 |
del self.task_websockets[trace_id]
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
try:
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
except Exception as e:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
raise WebSocketException("WebSocket mismatch")
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
active_task = ActiveTask(
|
| 180 |
message_id=trace_id,
|
| 181 |
instruction=trace.instruction,
|
|
@@ -320,10 +325,18 @@ class AgentService:
|
|
| 320 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 321 |
self.last_screenshot[message_id] = (image, step_filename)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
self.active_tasks[message_id].traceMetadata.completed = True
|
| 329 |
|
|
|
|
| 31 |
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
+
# Timeout constants to prevent stuck threads
|
| 35 |
+
AGENT_RUN_TIMEOUT = 1000 # 10 minutes - maximum time for agent.run() to complete
|
| 36 |
+
SANDBOX_KILL_TIMEOUT = 30 # 30 seconds - maximum time for sandbox.kill() to complete
|
| 37 |
+
|
| 38 |
|
| 39 |
class AgentStopException(Exception):
|
| 40 |
"""Exception for agent stop"""
|
|
|
|
| 123 |
"""Create a new ID and sandbox"""
|
| 124 |
# Prevent sandbox creation for the first 30 seconds after app start
|
| 125 |
# This prevents spawning sandboxes for all users already connected when app restarts
|
| 126 |
+
# elapsed_time = time.time() - self._start_time
|
| 127 |
+
# if elapsed_time < 30:
|
| 128 |
+
# logger.info(
|
| 129 |
+
# f"Skipping sandbox creation (app started {elapsed_time:.1f}s ago, "
|
| 130 |
+
# f"waiting for 30s grace period)"
|
| 131 |
+
# )
|
| 132 |
+
# # Still create UUID and register websocket, but don't acquire sandbox
|
| 133 |
+
# async with self._lock:
|
| 134 |
+
# uuid = str(uuid4())
|
| 135 |
+
# while uuid in self.active_tasks:
|
| 136 |
+
# uuid = str(uuid4())
|
| 137 |
+
# self.task_websockets[uuid] = websocket
|
| 138 |
+
# return uuid
|
| 139 |
|
| 140 |
async with self._lock:
|
| 141 |
uuid = str(uuid4())
|
| 142 |
while uuid in self.active_tasks:
|
| 143 |
uuid = str(uuid4())
|
| 144 |
self.task_websockets[uuid] = websocket
|
| 145 |
+
logger.info(f"Created UUID {uuid} and registered websocket")
|
| 146 |
+
# await self.sandbox_service.acquire_sandbox(uuid)
|
| 147 |
return uuid
|
| 148 |
|
| 149 |
async def process_user_task(
|
|
|
|
| 155 |
trace.steps = []
|
| 156 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 157 |
|
| 158 |
+
# trace_id_to_release = None
|
| 159 |
async with self._lock:
|
| 160 |
if self.task_websockets[trace_id] != websocket:
|
| 161 |
# Release sandbox before raising exception to prevent leak
|
| 162 |
# Do this outside the lock to avoid deadlock
|
| 163 |
+
# trace_id_to_release = trace_id
|
| 164 |
# Remove from task_websockets since we're rejecting this
|
| 165 |
if trace_id in self.task_websockets:
|
| 166 |
del self.task_websockets[trace_id]
|
| 167 |
|
| 168 |
+
# # Release sandbox outside of lock if there was a mismatch
|
| 169 |
+
# if trace_id_to_release:
|
| 170 |
+
# try:
|
| 171 |
+
# await self.sandbox_service.release_sandbox(trace_id_to_release)
|
| 172 |
+
# logger.info(
|
| 173 |
+
# f"Released sandbox for {trace_id_to_release} due to WebSocket mismatch"
|
| 174 |
+
# )
|
| 175 |
+
# except Exception as e:
|
| 176 |
+
# logger.error(
|
| 177 |
+
# f"Error releasing sandbox for {trace_id_to_release}: {e}",
|
| 178 |
+
# exc_info=True,
|
| 179 |
+
# )
|
| 180 |
+
# raise WebSocketException("WebSocket mismatch")
|
| 181 |
+
|
| 182 |
+
# # Continue with normal processing if no mismatch
|
| 183 |
+
# async with self._lock:
|
| 184 |
active_task = ActiveTask(
|
| 185 |
message_id=trace_id,
|
| 186 |
instruction=trace.instruction,
|
|
|
|
| 325 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 326 |
self.last_screenshot[message_id] = (image, step_filename)
|
| 327 |
|
| 328 |
+
try:
|
| 329 |
+
await asyncio.wait_for(
|
| 330 |
+
asyncio.to_thread(agent.run, user_content),
|
| 331 |
+
timeout=AGENT_RUN_TIMEOUT,
|
| 332 |
+
)
|
| 333 |
+
except asyncio.TimeoutError:
|
| 334 |
+
logger.error(
|
| 335 |
+
f"Agent run timed out after {AGENT_RUN_TIMEOUT} seconds for {message_id}"
|
| 336 |
+
)
|
| 337 |
+
raise Exception(
|
| 338 |
+
f"Agent run timed out after {AGENT_RUN_TIMEOUT} seconds"
|
| 339 |
+
)
|
| 340 |
|
| 341 |
self.active_tasks[message_id].traceMetadata.completed = True
|
| 342 |
|
cua2-core/src/cua2_core/services/sandbox_service.py
CHANGED
|
@@ -9,6 +9,10 @@ from pydantic import BaseModel
|
|
| 9 |
|
| 10 |
SANDBOX_TIMEOUT = 500
|
| 11 |
SANDBOX_READY_TIMEOUT = 200 # Seconds before a sandbox expires
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
WIDTH = 1280
|
| 13 |
HEIGHT = 960
|
| 14 |
|
|
@@ -140,7 +144,10 @@ class SandboxService:
|
|
| 140 |
"""Background task to create a sandbox"""
|
| 141 |
desktop = None
|
| 142 |
try:
|
| 143 |
-
desktop = await asyncio.
|
|
|
|
|
|
|
|
|
|
| 144 |
print(
|
| 145 |
f"Sandbox created for session {session_hash}, ID: {desktop.sandbox_id}"
|
| 146 |
)
|
|
@@ -174,6 +181,16 @@ class SandboxService:
|
|
| 174 |
self.sandboxes[session_hash] = SandboxEntry(desktop)
|
| 175 |
print(f"Sandbox {session_hash} is now ready")
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
except Exception as e:
|
| 178 |
error_msg = str(e)
|
| 179 |
import traceback
|
|
@@ -208,7 +225,14 @@ class SandboxService:
|
|
| 208 |
async def _kill_sandbox_safe(self, sandbox: Sandbox, session_hash: str):
|
| 209 |
"""Safely kill a sandbox with error handling"""
|
| 210 |
try:
|
| 211 |
-
await asyncio.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
except Exception as e:
|
| 213 |
print(f"Error killing sandbox for session {session_hash}: {str(e)}")
|
| 214 |
|
|
|
|
| 9 |
|
| 10 |
SANDBOX_TIMEOUT = 500
|
| 11 |
SANDBOX_READY_TIMEOUT = 200 # Seconds before a sandbox expires
|
| 12 |
+
SANDBOX_CREATION_THREAD_TIMEOUT = (
|
| 13 |
+
300 # Timeout for sandbox creation thread to prevent hanging
|
| 14 |
+
)
|
| 15 |
+
SANDBOX_KILL_TIMEOUT = 30 # Timeout for sandbox.kill() to prevent hanging
|
| 16 |
WIDTH = 1280
|
| 17 |
HEIGHT = 960
|
| 18 |
|
|
|
|
| 144 |
"""Background task to create a sandbox"""
|
| 145 |
desktop = None
|
| 146 |
try:
|
| 147 |
+
desktop = await asyncio.wait_for(
|
| 148 |
+
asyncio.to_thread(self._create_and_setup_sandbox),
|
| 149 |
+
timeout=SANDBOX_CREATION_THREAD_TIMEOUT,
|
| 150 |
+
)
|
| 151 |
print(
|
| 152 |
f"Sandbox created for session {session_hash}, ID: {desktop.sandbox_id}"
|
| 153 |
)
|
|
|
|
| 181 |
self.sandboxes[session_hash] = SandboxEntry(desktop)
|
| 182 |
print(f"Sandbox {session_hash} is now ready")
|
| 183 |
|
| 184 |
+
except asyncio.TimeoutError:
|
| 185 |
+
error_msg = f"Sandbox creation timed out after {SANDBOX_CREATION_THREAD_TIMEOUT} seconds"
|
| 186 |
+
print(f"Error creating sandbox for session {session_hash}: {error_msg}")
|
| 187 |
+
|
| 188 |
+
async with self.lock:
|
| 189 |
+
self.pending.discard(session_hash)
|
| 190 |
+
# Store error so agent service can retrieve it
|
| 191 |
+
self.creation_errors[session_hash] = error_msg
|
| 192 |
+
if desktop:
|
| 193 |
+
asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
|
| 194 |
except Exception as e:
|
| 195 |
error_msg = str(e)
|
| 196 |
import traceback
|
|
|
|
| 225 |
async def _kill_sandbox_safe(self, sandbox: Sandbox, session_hash: str):
|
| 226 |
"""Safely kill a sandbox with error handling"""
|
| 227 |
try:
|
| 228 |
+
await asyncio.wait_for(
|
| 229 |
+
asyncio.to_thread(sandbox.kill),
|
| 230 |
+
timeout=SANDBOX_KILL_TIMEOUT,
|
| 231 |
+
)
|
| 232 |
+
except asyncio.TimeoutError:
|
| 233 |
+
print(
|
| 234 |
+
f"Sandbox kill timed out after {SANDBOX_KILL_TIMEOUT} seconds for session {session_hash}"
|
| 235 |
+
)
|
| 236 |
except Exception as e:
|
| 237 |
print(f"Error killing sandbox for session {session_hash}: {str(e)}")
|
| 238 |
|