Spaces:
Running
Running
A-Mahla
commited on
mute logs (#13)
Browse files* NEW agent
* Handle stop agent feature
* CHG agent max step
* FIX model
* ADD files
* ADD error status (#14)
* CHG error status
* CHG error status
* CHG error status
cua2-core/src/cua2_core/app.py
CHANGED
|
@@ -21,11 +21,13 @@ async def lifespan(app: FastAPI):
|
|
| 21 |
if not os.getenv("HF_TOKEN"):
|
| 22 |
raise ValueError("HF_TOKEN is not set")
|
| 23 |
|
|
|
|
|
|
|
| 24 |
websocket_manager = WebSocketManager()
|
| 25 |
|
| 26 |
sandbox_service = SandboxService()
|
| 27 |
|
| 28 |
-
agent_service = AgentService(websocket_manager, sandbox_service)
|
| 29 |
|
| 30 |
# Store services in app state for access in routes
|
| 31 |
app.state.websocket_manager = websocket_manager
|
|
|
|
| 21 |
if not os.getenv("HF_TOKEN"):
|
| 22 |
raise ValueError("HF_TOKEN is not set")
|
| 23 |
|
| 24 |
+
num_workers = int(os.getenv("NUM_WORKERS", "1"))
|
| 25 |
+
|
| 26 |
websocket_manager = WebSocketManager()
|
| 27 |
|
| 28 |
sandbox_service = SandboxService()
|
| 29 |
|
| 30 |
+
agent_service = AgentService(websocket_manager, sandbox_service, num_workers)
|
| 31 |
|
| 32 |
# Store services in app state for access in routes
|
| 33 |
app.state.websocket_manager = websocket_manager
|
cua2-core/src/cua2_core/models/models.py
CHANGED
|
@@ -126,6 +126,10 @@ class AgentTraceMetadata(BaseModel):
|
|
| 126 |
numberOfSteps: int = 0
|
| 127 |
maxSteps: int = 0
|
| 128 |
completed: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
class AgentTrace(BaseModel):
|
|
@@ -157,6 +161,7 @@ class AgentStartEvent(BaseModel):
|
|
| 157 |
|
| 158 |
type: Literal["agent_start"] = "agent_start"
|
| 159 |
agentTrace: AgentTrace
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
class AgentProgressEvent(BaseModel):
|
|
@@ -172,7 +177,9 @@ class AgentCompleteEvent(BaseModel):
|
|
| 172 |
|
| 173 |
type: Literal["agent_complete"] = "agent_complete"
|
| 174 |
traceMetadata: AgentTraceMetadata
|
| 175 |
-
final_state: Literal[
|
|
|
|
|
|
|
| 176 |
|
| 177 |
|
| 178 |
class AgentErrorEvent(BaseModel):
|
|
@@ -292,6 +299,10 @@ class ActiveTask(BaseModel):
|
|
| 292 |
step_duration: float | None = None,
|
| 293 |
step_numberOfSteps: int | None = None,
|
| 294 |
completed: bool | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
):
|
| 296 |
"""Update trace metadata"""
|
| 297 |
with self._file_lock:
|
|
@@ -305,6 +316,8 @@ class ActiveTask(BaseModel):
|
|
| 305 |
self.traceMetadata.numberOfSteps += step_numberOfSteps
|
| 306 |
if completed is not None:
|
| 307 |
self.traceMetadata.completed = completed
|
|
|
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
#################### API Routes Models ########################
|
|
|
|
| 126 |
numberOfSteps: int = 0
|
| 127 |
maxSteps: int = 0
|
| 128 |
completed: bool = False
|
| 129 |
+
final_state: (
|
| 130 |
+
Literal["success", "stopped", "max_steps_reached", "error", "sandbox_timeout"]
|
| 131 |
+
| None
|
| 132 |
+
) = None
|
| 133 |
|
| 134 |
|
| 135 |
class AgentTrace(BaseModel):
|
|
|
|
| 161 |
|
| 162 |
type: Literal["agent_start"] = "agent_start"
|
| 163 |
agentTrace: AgentTrace
|
| 164 |
+
status: Literal["max_sandboxes_reached", "success"] = "success"
|
| 165 |
|
| 166 |
|
| 167 |
class AgentProgressEvent(BaseModel):
|
|
|
|
| 177 |
|
| 178 |
type: Literal["agent_complete"] = "agent_complete"
|
| 179 |
traceMetadata: AgentTraceMetadata
|
| 180 |
+
final_state: Literal[
|
| 181 |
+
"success", "stopped", "max_steps_reached", "error", "sandbox_timeout"
|
| 182 |
+
]
|
| 183 |
|
| 184 |
|
| 185 |
class AgentErrorEvent(BaseModel):
|
|
|
|
| 299 |
step_duration: float | None = None,
|
| 300 |
step_numberOfSteps: int | None = None,
|
| 301 |
completed: bool | None = None,
|
| 302 |
+
final_state: Literal[
|
| 303 |
+
"success", "stopped", "max_steps_reached", "error", "sandbox_timeout"
|
| 304 |
+
]
|
| 305 |
+
| None = None,
|
| 306 |
):
|
| 307 |
"""Update trace metadata"""
|
| 308 |
with self._file_lock:
|
|
|
|
| 316 |
self.traceMetadata.numberOfSteps += step_numberOfSteps
|
| 317 |
if completed is not None:
|
| 318 |
self.traceMetadata.completed = completed
|
| 319 |
+
if final_state is not None:
|
| 320 |
+
self.traceMetadata.final_state = final_state
|
| 321 |
|
| 322 |
|
| 323 |
#################### API Routes Models ########################
|
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -19,7 +19,7 @@ from cua2_core.services.agent_utils.function_parser import parse_function_call
|
|
| 19 |
from cua2_core.services.agent_utils.get_model import get_model
|
| 20 |
from cua2_core.services.sandbox_service import SandboxService
|
| 21 |
from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
|
| 22 |
-
from e2b_desktop import Sandbox
|
| 23 |
from fastapi import WebSocket
|
| 24 |
from PIL import Image
|
| 25 |
from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
|
@@ -37,33 +37,50 @@ class AgentService:
|
|
| 37 |
"""Service for handling agent tasks and processing"""
|
| 38 |
|
| 39 |
def __init__(
|
| 40 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 41 |
):
|
| 42 |
self.active_tasks: dict[str, ActiveTask] = {}
|
| 43 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 44 |
self.task_websockets: dict[str, WebSocket] = {}
|
| 45 |
self.sandbox_service: SandboxService = sandbox_service
|
| 46 |
-
self.last_screenshot: dict[str, AgentImage] = {}
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
async def process_user_task(
|
|
|
|
|
|
|
| 49 |
"""Process a user task and return the trace ID"""
|
| 50 |
|
| 51 |
trace_id = trace.id
|
| 52 |
trace.steps = []
|
| 53 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 69 |
|
|
@@ -87,7 +104,9 @@ class AgentService:
|
|
| 87 |
websocket = self.task_websockets.get(message_id)
|
| 88 |
|
| 89 |
await self.websocket_manager.send_agent_start(
|
| 90 |
-
active_task=self.active_tasks[message_id],
|
|
|
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
model = get_model(self.active_tasks[message_id].model_id)
|
|
@@ -145,6 +164,9 @@ class AgentService:
|
|
| 145 |
except WebSocketException:
|
| 146 |
websocket_exception = True
|
| 147 |
|
|
|
|
|
|
|
|
|
|
| 148 |
except (Exception, KeyboardInterrupt):
|
| 149 |
import traceback
|
| 150 |
|
|
@@ -170,17 +192,23 @@ class AgentService:
|
|
| 170 |
|
| 171 |
novnc_active = False
|
| 172 |
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
| 174 |
if message_id in self.active_tasks:
|
| 175 |
self.active_tasks[message_id].store_model()
|
| 176 |
-
del self.active_tasks[message_id]
|
| 177 |
|
| 178 |
-
# Clean up
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
|
| 185 |
# Release sandbox back to the pool
|
| 186 |
if sandbox:
|
|
@@ -211,7 +239,7 @@ class AgentService:
|
|
| 211 |
time.sleep(3)
|
| 212 |
|
| 213 |
image = self.last_screenshot[message_id]
|
| 214 |
-
|
| 215 |
|
| 216 |
for previous_memory_step in (
|
| 217 |
agent.memory.steps
|
|
@@ -262,8 +290,6 @@ class AgentService:
|
|
| 262 |
else:
|
| 263 |
image_base64 = None
|
| 264 |
|
| 265 |
-
logger.info(memory_step)
|
| 266 |
-
|
| 267 |
step = AgentStep(
|
| 268 |
traceId=message_id,
|
| 269 |
stepId=str(memory_step.step_number),
|
|
|
|
| 19 |
from cua2_core.services.agent_utils.get_model import get_model
|
| 20 |
from cua2_core.services.sandbox_service import SandboxService
|
| 21 |
from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
|
| 22 |
+
from e2b_desktop import Sandbox, TimeoutException
|
| 23 |
from fastapi import WebSocket
|
| 24 |
from PIL import Image
|
| 25 |
from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
|
|
|
| 37 |
"""Service for handling agent tasks and processing"""
|
| 38 |
|
| 39 |
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
websocket_manager: WebSocketManager,
|
| 42 |
+
sandbox_service: SandboxService,
|
| 43 |
+
num_workers: int,
|
| 44 |
):
|
| 45 |
self.active_tasks: dict[str, ActiveTask] = {}
|
| 46 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 47 |
self.task_websockets: dict[str, WebSocket] = {}
|
| 48 |
self.sandbox_service: SandboxService = sandbox_service
|
| 49 |
+
self.last_screenshot: dict[str, AgentImage | None] = {}
|
| 50 |
+
self._lock = asyncio.Lock()
|
| 51 |
+
self.max_sandboxes = int(600 / num_workers)
|
| 52 |
|
| 53 |
+
async def process_user_task(
|
| 54 |
+
self, trace: AgentTrace, websocket: WebSocket
|
| 55 |
+
) -> str | None:
|
| 56 |
"""Process a user task and return the trace ID"""
|
| 57 |
|
| 58 |
trace_id = trace.id
|
| 59 |
trace.steps = []
|
| 60 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 61 |
|
| 62 |
+
async with self._lock:
|
| 63 |
+
active_task = ActiveTask(
|
| 64 |
+
message_id=trace_id,
|
| 65 |
+
instruction=trace.instruction,
|
| 66 |
+
model_id=trace.modelId,
|
| 67 |
+
timestamp=trace.timestamp,
|
| 68 |
+
steps=trace.steps,
|
| 69 |
+
traceMetadata=trace.traceMetadata,
|
| 70 |
+
)
|
| 71 |
|
| 72 |
+
if len(self.active_tasks) >= self.max_sandboxes:
|
| 73 |
+
await self.websocket_manager.send_agent_start(
|
| 74 |
+
active_task=active_task,
|
| 75 |
+
status="max_sandboxes_reached",
|
| 76 |
+
websocket=websocket,
|
| 77 |
+
)
|
| 78 |
+
return trace_id
|
| 79 |
+
|
| 80 |
+
# Store the task and websocket for this task
|
| 81 |
+
self.active_tasks[trace_id] = active_task
|
| 82 |
+
self.task_websockets[trace_id] = websocket
|
| 83 |
+
self.last_screenshot[trace_id] = None
|
| 84 |
|
| 85 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 86 |
|
|
|
|
| 104 |
websocket = self.task_websockets.get(message_id)
|
| 105 |
|
| 106 |
await self.websocket_manager.send_agent_start(
|
| 107 |
+
active_task=self.active_tasks[message_id],
|
| 108 |
+
websocket=websocket,
|
| 109 |
+
status="success",
|
| 110 |
)
|
| 111 |
|
| 112 |
model = get_model(self.active_tasks[message_id].model_id)
|
|
|
|
| 164 |
except WebSocketException:
|
| 165 |
websocket_exception = True
|
| 166 |
|
| 167 |
+
except TimeoutException:
|
| 168 |
+
final_state = "sandbox_timeout"
|
| 169 |
+
|
| 170 |
except (Exception, KeyboardInterrupt):
|
| 171 |
import traceback
|
| 172 |
|
|
|
|
| 192 |
|
| 193 |
novnc_active = False
|
| 194 |
|
| 195 |
+
self.active_tasks[message_id].update_trace_metadata(
|
| 196 |
+
final_state=final_state,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
if message_id in self.active_tasks:
|
| 200 |
self.active_tasks[message_id].store_model()
|
|
|
|
| 201 |
|
| 202 |
+
# Clean up
|
| 203 |
+
async with self._lock:
|
| 204 |
+
if message_id in self.active_tasks:
|
| 205 |
+
del self.active_tasks[message_id]
|
| 206 |
+
|
| 207 |
+
if message_id in self.task_websockets:
|
| 208 |
+
del self.task_websockets[message_id]
|
| 209 |
|
| 210 |
+
if message_id in self.last_screenshot:
|
| 211 |
+
del self.last_screenshot[message_id]
|
| 212 |
|
| 213 |
# Release sandbox back to the pool
|
| 214 |
if sandbox:
|
|
|
|
| 239 |
time.sleep(3)
|
| 240 |
|
| 241 |
image = self.last_screenshot[message_id]
|
| 242 |
+
assert image is not None
|
| 243 |
|
| 244 |
for previous_memory_step in (
|
| 245 |
agent.memory.steps
|
|
|
|
| 290 |
else:
|
| 291 |
image_base64 = None
|
| 292 |
|
|
|
|
|
|
|
| 293 |
step = AgentStep(
|
| 294 |
traceId=message_id,
|
| 295 |
stepId=str(memory_step.step_number),
|
cua2-core/src/cua2_core/websocket/websocket_manager.py
CHANGED
|
@@ -62,7 +62,12 @@ class WebSocketManager:
|
|
| 62 |
self.disconnect(websocket)
|
| 63 |
raise WebSocketException()
|
| 64 |
|
| 65 |
-
async def send_agent_start(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
"""Send agent start event"""
|
| 67 |
event = AgentStartEvent(
|
| 68 |
agentTrace=AgentTrace(
|
|
@@ -74,6 +79,7 @@ class WebSocketManager:
|
|
| 74 |
traceMetadata=active_task.traceMetadata,
|
| 75 |
isRunning=True,
|
| 76 |
),
|
|
|
|
| 77 |
)
|
| 78 |
await self.send_message(event, websocket)
|
| 79 |
|
|
@@ -94,7 +100,9 @@ class WebSocketManager:
|
|
| 94 |
self,
|
| 95 |
metadata: AgentTraceMetadata,
|
| 96 |
websocket: WebSocket,
|
| 97 |
-
final_state: Literal[
|
|
|
|
|
|
|
| 98 |
):
|
| 99 |
"""Send agent complete event"""
|
| 100 |
event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
|
|
|
|
| 62 |
self.disconnect(websocket)
|
| 63 |
raise WebSocketException()
|
| 64 |
|
| 65 |
+
async def send_agent_start(
|
| 66 |
+
self,
|
| 67 |
+
active_task: ActiveTask,
|
| 68 |
+
websocket: WebSocket,
|
| 69 |
+
status: Literal["max_sandboxes_reached", "success"],
|
| 70 |
+
):
|
| 71 |
"""Send agent start event"""
|
| 72 |
event = AgentStartEvent(
|
| 73 |
agentTrace=AgentTrace(
|
|
|
|
| 79 |
traceMetadata=active_task.traceMetadata,
|
| 80 |
isRunning=True,
|
| 81 |
),
|
| 82 |
+
status=status,
|
| 83 |
)
|
| 84 |
await self.send_message(event, websocket)
|
| 85 |
|
|
|
|
| 100 |
self,
|
| 101 |
metadata: AgentTraceMetadata,
|
| 102 |
websocket: WebSocket,
|
| 103 |
+
final_state: Literal[
|
| 104 |
+
"success", "stopped", "max_steps_reached", "error", "timeout"
|
| 105 |
+
],
|
| 106 |
):
|
| 107 |
"""Send agent complete event"""
|
| 108 |
event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
|
cua2-front/src/types/agent.ts
CHANGED
|
@@ -35,6 +35,7 @@ export interface AgentTraceMetadata {
|
|
| 35 |
numberOfSteps: number;
|
| 36 |
maxSteps: number;
|
| 37 |
completed: boolean;
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
export interface FinalStep {
|
|
@@ -48,6 +49,7 @@ export interface FinalStep {
|
|
| 48 |
interface AgentStartEvent {
|
| 49 |
type: 'agent_start';
|
| 50 |
agentTrace: AgentTrace;
|
|
|
|
| 51 |
}
|
| 52 |
|
| 53 |
interface AgentProgressEvent {
|
|
@@ -59,7 +61,7 @@ interface AgentProgressEvent {
|
|
| 59 |
interface AgentCompleteEvent {
|
| 60 |
type: 'agent_complete';
|
| 61 |
traceMetadata: AgentTraceMetadata;
|
| 62 |
-
final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error';
|
| 63 |
}
|
| 64 |
|
| 65 |
interface AgentErrorEvent {
|
|
|
|
| 35 |
numberOfSteps: number;
|
| 36 |
maxSteps: number;
|
| 37 |
completed: boolean;
|
| 38 |
+
final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout' | null;
|
| 39 |
}
|
| 40 |
|
| 41 |
export interface FinalStep {
|
|
|
|
| 49 |
interface AgentStartEvent {
|
| 50 |
type: 'agent_start';
|
| 51 |
agentTrace: AgentTrace;
|
| 52 |
+
status: 'max_sandboxes_reached' | 'success';
|
| 53 |
}
|
| 54 |
|
| 55 |
interface AgentProgressEvent {
|
|
|
|
| 61 |
interface AgentCompleteEvent {
|
| 62 |
type: 'agent_complete';
|
| 63 |
traceMetadata: AgentTraceMetadata;
|
| 64 |
+
final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout';
|
| 65 |
}
|
| 66 |
|
| 67 |
interface AgentErrorEvent {
|
entrypoint.sh
CHANGED
|
@@ -19,9 +19,10 @@ echo "nginx started successfully"
|
|
| 19 |
cd $HOME/app/cua2-core
|
| 20 |
|
| 21 |
# Set default number of workers if not specified
|
| 22 |
-
|
| 23 |
|
| 24 |
echo "Starting backend with $WORKERS worker(s)..."
|
| 25 |
|
| 26 |
# Use uv to run the application
|
| 27 |
-
|
|
|
|
|
|
| 19 |
cd $HOME/app/cua2-core
|
| 20 |
|
| 21 |
# Set default number of workers if not specified
|
| 22 |
+
NUM_WORKERS=${NUM_WORKERS:-1}
|
| 23 |
|
| 24 |
echo "Starting backend with $WORKERS worker(s)..."
|
| 25 |
|
| 26 |
# Use uv to run the application
|
| 27 |
+
echo "uv run uvicorn cua2_core.main:app --host 0.0.0.0 --port 8000 --workers $NUM_WORKERS --log-level error > /dev/null"
|
| 28 |
+
exec uv run uvicorn cua2_core.main:app --host 0.0.0.0 --port 8000 --workers $NUM_WORKERS --log-level error > /dev/null
|