A-Mahla commited on
Commit
f5d0df5
·
unverified ·
1 Parent(s): 9493fef

mute logs (#13)

Browse files

* NEW agent

* Handle stop agent feature

* CHG agent max step

* FIX model

* ADD files

* ADD error status (#14)

* CHG error status

* CHG error status

* CHG error status

cua2-core/src/cua2_core/app.py CHANGED
@@ -21,11 +21,13 @@ async def lifespan(app: FastAPI):
21
  if not os.getenv("HF_TOKEN"):
22
  raise ValueError("HF_TOKEN is not set")
23
 
 
 
24
  websocket_manager = WebSocketManager()
25
 
26
  sandbox_service = SandboxService()
27
 
28
- agent_service = AgentService(websocket_manager, sandbox_service)
29
 
30
  # Store services in app state for access in routes
31
  app.state.websocket_manager = websocket_manager
 
21
  if not os.getenv("HF_TOKEN"):
22
  raise ValueError("HF_TOKEN is not set")
23
 
24
+ num_workers = int(os.getenv("NUM_WORKERS", "1"))
25
+
26
  websocket_manager = WebSocketManager()
27
 
28
  sandbox_service = SandboxService()
29
 
30
+ agent_service = AgentService(websocket_manager, sandbox_service, num_workers)
31
 
32
  # Store services in app state for access in routes
33
  app.state.websocket_manager = websocket_manager
cua2-core/src/cua2_core/models/models.py CHANGED
@@ -126,6 +126,10 @@ class AgentTraceMetadata(BaseModel):
126
  numberOfSteps: int = 0
127
  maxSteps: int = 0
128
  completed: bool = False
 
 
 
 
129
 
130
 
131
  class AgentTrace(BaseModel):
@@ -157,6 +161,7 @@ class AgentStartEvent(BaseModel):
157
 
158
  type: Literal["agent_start"] = "agent_start"
159
  agentTrace: AgentTrace
 
160
 
161
 
162
  class AgentProgressEvent(BaseModel):
@@ -172,7 +177,9 @@ class AgentCompleteEvent(BaseModel):
172
 
173
  type: Literal["agent_complete"] = "agent_complete"
174
  traceMetadata: AgentTraceMetadata
175
- final_state: Literal["success", "stopped", "max_steps_reached", "error"]
 
 
176
 
177
 
178
  class AgentErrorEvent(BaseModel):
@@ -292,6 +299,10 @@ class ActiveTask(BaseModel):
292
  step_duration: float | None = None,
293
  step_numberOfSteps: int | None = None,
294
  completed: bool | None = None,
 
 
 
 
295
  ):
296
  """Update trace metadata"""
297
  with self._file_lock:
@@ -305,6 +316,8 @@ class ActiveTask(BaseModel):
305
  self.traceMetadata.numberOfSteps += step_numberOfSteps
306
  if completed is not None:
307
  self.traceMetadata.completed = completed
 
 
308
 
309
 
310
  #################### API Routes Models ########################
 
126
  numberOfSteps: int = 0
127
  maxSteps: int = 0
128
  completed: bool = False
129
+ final_state: (
130
+ Literal["success", "stopped", "max_steps_reached", "error", "sandbox_timeout"]
131
+ | None
132
+ ) = None
133
 
134
 
135
  class AgentTrace(BaseModel):
 
161
 
162
  type: Literal["agent_start"] = "agent_start"
163
  agentTrace: AgentTrace
164
+ status: Literal["max_sandboxes_reached", "success"] = "success"
165
 
166
 
167
  class AgentProgressEvent(BaseModel):
 
177
 
178
  type: Literal["agent_complete"] = "agent_complete"
179
  traceMetadata: AgentTraceMetadata
180
+ final_state: Literal[
181
+ "success", "stopped", "max_steps_reached", "error", "sandbox_timeout"
182
+ ]
183
 
184
 
185
  class AgentErrorEvent(BaseModel):
 
299
  step_duration: float | None = None,
300
  step_numberOfSteps: int | None = None,
301
  completed: bool | None = None,
302
+ final_state: Literal[
303
+ "success", "stopped", "max_steps_reached", "error", "sandbox_timeout"
304
+ ]
305
+ | None = None,
306
  ):
307
  """Update trace metadata"""
308
  with self._file_lock:
 
316
  self.traceMetadata.numberOfSteps += step_numberOfSteps
317
  if completed is not None:
318
  self.traceMetadata.completed = completed
319
+ if final_state is not None:
320
+ self.traceMetadata.final_state = final_state
321
 
322
 
323
  #################### API Routes Models ########################
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -19,7 +19,7 @@ from cua2_core.services.agent_utils.function_parser import parse_function_call
19
  from cua2_core.services.agent_utils.get_model import get_model
20
  from cua2_core.services.sandbox_service import SandboxService
21
  from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
22
- from e2b_desktop import Sandbox
23
  from fastapi import WebSocket
24
  from PIL import Image
25
  from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
@@ -37,33 +37,50 @@ class AgentService:
37
  """Service for handling agent tasks and processing"""
38
 
39
  def __init__(
40
- self, websocket_manager: WebSocketManager, sandbox_service: SandboxService
 
 
 
41
  ):
42
  self.active_tasks: dict[str, ActiveTask] = {}
43
  self.websocket_manager: WebSocketManager = websocket_manager
44
  self.task_websockets: dict[str, WebSocket] = {}
45
  self.sandbox_service: SandboxService = sandbox_service
46
- self.last_screenshot: dict[str, AgentImage] = {}
 
 
47
 
48
- async def process_user_task(self, trace: AgentTrace, websocket: WebSocket) -> str:
 
 
49
  """Process a user task and return the trace ID"""
50
 
51
  trace_id = trace.id
52
  trace.steps = []
53
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
54
 
55
- # Store the task
56
- self.active_tasks[trace_id] = ActiveTask(
57
- message_id=trace_id,
58
- instruction=trace.instruction,
59
- model_id=trace.modelId,
60
- timestamp=trace.timestamp,
61
- steps=trace.steps,
62
- traceMetadata=trace.traceMetadata,
63
- )
64
 
65
- # Store the websocket for this task
66
- self.task_websockets[trace_id] = websocket
 
 
 
 
 
 
 
 
 
 
67
 
68
  asyncio.create_task(self._agent_processing(trace_id))
69
 
@@ -87,7 +104,9 @@ class AgentService:
87
  websocket = self.task_websockets.get(message_id)
88
 
89
  await self.websocket_manager.send_agent_start(
90
- active_task=self.active_tasks[message_id], websocket=websocket
 
 
91
  )
92
 
93
  model = get_model(self.active_tasks[message_id].model_id)
@@ -145,6 +164,9 @@ class AgentService:
145
  except WebSocketException:
146
  websocket_exception = True
147
 
 
 
 
148
  except (Exception, KeyboardInterrupt):
149
  import traceback
150
 
@@ -170,17 +192,23 @@ class AgentService:
170
 
171
  novnc_active = False
172
 
173
- # Clean up
 
 
 
174
  if message_id in self.active_tasks:
175
  self.active_tasks[message_id].store_model()
176
- del self.active_tasks[message_id]
177
 
178
- # Clean up websocket reference
179
- if message_id in self.task_websockets:
180
- del self.task_websockets[message_id]
 
 
 
 
181
 
182
- if message_id in self.last_screenshot:
183
- del self.last_screenshot[message_id]
184
 
185
  # Release sandbox back to the pool
186
  if sandbox:
@@ -211,7 +239,7 @@ class AgentService:
211
  time.sleep(3)
212
 
213
  image = self.last_screenshot[message_id]
214
- # agent.last_marked_screenshot = AgentImage(screenshot_path)
215
 
216
  for previous_memory_step in (
217
  agent.memory.steps
@@ -262,8 +290,6 @@ class AgentService:
262
  else:
263
  image_base64 = None
264
 
265
- logger.info(memory_step)
266
-
267
  step = AgentStep(
268
  traceId=message_id,
269
  stepId=str(memory_step.step_number),
 
19
  from cua2_core.services.agent_utils.get_model import get_model
20
  from cua2_core.services.sandbox_service import SandboxService
21
  from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
22
+ from e2b_desktop import Sandbox, TimeoutException
23
  from fastapi import WebSocket
24
  from PIL import Image
25
  from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
 
37
  """Service for handling agent tasks and processing"""
38
 
39
  def __init__(
40
+ self,
41
+ websocket_manager: WebSocketManager,
42
+ sandbox_service: SandboxService,
43
+ num_workers: int,
44
  ):
45
  self.active_tasks: dict[str, ActiveTask] = {}
46
  self.websocket_manager: WebSocketManager = websocket_manager
47
  self.task_websockets: dict[str, WebSocket] = {}
48
  self.sandbox_service: SandboxService = sandbox_service
49
+ self.last_screenshot: dict[str, AgentImage | None] = {}
50
+ self._lock = asyncio.Lock()
51
+ self.max_sandboxes = int(600 / num_workers)
52
 
53
+ async def process_user_task(
54
+ self, trace: AgentTrace, websocket: WebSocket
55
+ ) -> str | None:
56
  """Process a user task and return the trace ID"""
57
 
58
  trace_id = trace.id
59
  trace.steps = []
60
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
61
 
62
+ async with self._lock:
63
+ active_task = ActiveTask(
64
+ message_id=trace_id,
65
+ instruction=trace.instruction,
66
+ model_id=trace.modelId,
67
+ timestamp=trace.timestamp,
68
+ steps=trace.steps,
69
+ traceMetadata=trace.traceMetadata,
70
+ )
71
 
72
+ if len(self.active_tasks) >= self.max_sandboxes:
73
+ await self.websocket_manager.send_agent_start(
74
+ active_task=active_task,
75
+ status="max_sandboxes_reached",
76
+ websocket=websocket,
77
+ )
78
+ return trace_id
79
+
80
+ # Store the task and websocket for this task
81
+ self.active_tasks[trace_id] = active_task
82
+ self.task_websockets[trace_id] = websocket
83
+ self.last_screenshot[trace_id] = None
84
 
85
  asyncio.create_task(self._agent_processing(trace_id))
86
 
 
104
  websocket = self.task_websockets.get(message_id)
105
 
106
  await self.websocket_manager.send_agent_start(
107
+ active_task=self.active_tasks[message_id],
108
+ websocket=websocket,
109
+ status="success",
110
  )
111
 
112
  model = get_model(self.active_tasks[message_id].model_id)
 
164
  except WebSocketException:
165
  websocket_exception = True
166
 
167
+ except TimeoutException:
168
+ final_state = "sandbox_timeout"
169
+
170
  except (Exception, KeyboardInterrupt):
171
  import traceback
172
 
 
192
 
193
  novnc_active = False
194
 
195
+ self.active_tasks[message_id].update_trace_metadata(
196
+ final_state=final_state,
197
+ )
198
+
199
  if message_id in self.active_tasks:
200
  self.active_tasks[message_id].store_model()
 
201
 
202
+ # Clean up
203
+ async with self._lock:
204
+ if message_id in self.active_tasks:
205
+ del self.active_tasks[message_id]
206
+
207
+ if message_id in self.task_websockets:
208
+ del self.task_websockets[message_id]
209
 
210
+ if message_id in self.last_screenshot:
211
+ del self.last_screenshot[message_id]
212
 
213
  # Release sandbox back to the pool
214
  if sandbox:
 
239
  time.sleep(3)
240
 
241
  image = self.last_screenshot[message_id]
242
+ assert image is not None
243
 
244
  for previous_memory_step in (
245
  agent.memory.steps
 
290
  else:
291
  image_base64 = None
292
 
 
 
293
  step = AgentStep(
294
  traceId=message_id,
295
  stepId=str(memory_step.step_number),
cua2-core/src/cua2_core/websocket/websocket_manager.py CHANGED
@@ -62,7 +62,12 @@ class WebSocketManager:
62
  self.disconnect(websocket)
63
  raise WebSocketException()
64
 
65
- async def send_agent_start(self, active_task: ActiveTask, websocket: WebSocket):
 
 
 
 
 
66
  """Send agent start event"""
67
  event = AgentStartEvent(
68
  agentTrace=AgentTrace(
@@ -74,6 +79,7 @@ class WebSocketManager:
74
  traceMetadata=active_task.traceMetadata,
75
  isRunning=True,
76
  ),
 
77
  )
78
  await self.send_message(event, websocket)
79
 
@@ -94,7 +100,9 @@ class WebSocketManager:
94
  self,
95
  metadata: AgentTraceMetadata,
96
  websocket: WebSocket,
97
- final_state: Literal["success", "stopped", "max_steps_reached", "error"],
 
 
98
  ):
99
  """Send agent complete event"""
100
  event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
 
62
  self.disconnect(websocket)
63
  raise WebSocketException()
64
 
65
+ async def send_agent_start(
66
+ self,
67
+ active_task: ActiveTask,
68
+ websocket: WebSocket,
69
+ status: Literal["max_sandboxes_reached", "success"],
70
+ ):
71
  """Send agent start event"""
72
  event = AgentStartEvent(
73
  agentTrace=AgentTrace(
 
79
  traceMetadata=active_task.traceMetadata,
80
  isRunning=True,
81
  ),
82
+ status=status,
83
  )
84
  await self.send_message(event, websocket)
85
 
 
100
  self,
101
  metadata: AgentTraceMetadata,
102
  websocket: WebSocket,
103
+ final_state: Literal[
104
+ "success", "stopped", "max_steps_reached", "error", "timeout"
105
+ ],
106
  ):
107
  """Send agent complete event"""
108
  event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
cua2-front/src/types/agent.ts CHANGED
@@ -35,6 +35,7 @@ export interface AgentTraceMetadata {
35
  numberOfSteps: number;
36
  maxSteps: number;
37
  completed: boolean;
 
38
  }
39
 
40
  export interface FinalStep {
@@ -48,6 +49,7 @@ export interface FinalStep {
48
  interface AgentStartEvent {
49
  type: 'agent_start';
50
  agentTrace: AgentTrace;
 
51
  }
52
 
53
  interface AgentProgressEvent {
@@ -59,7 +61,7 @@ interface AgentProgressEvent {
59
  interface AgentCompleteEvent {
60
  type: 'agent_complete';
61
  traceMetadata: AgentTraceMetadata;
62
- final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error';
63
  }
64
 
65
  interface AgentErrorEvent {
 
35
  numberOfSteps: number;
36
  maxSteps: number;
37
  completed: boolean;
38
+ final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout' | null;
39
  }
40
 
41
  export interface FinalStep {
 
49
  interface AgentStartEvent {
50
  type: 'agent_start';
51
  agentTrace: AgentTrace;
52
+ status: 'max_sandboxes_reached' | 'success';
53
  }
54
 
55
  interface AgentProgressEvent {
 
61
  interface AgentCompleteEvent {
62
  type: 'agent_complete';
63
  traceMetadata: AgentTraceMetadata;
64
+ final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout';
65
  }
66
 
67
  interface AgentErrorEvent {
entrypoint.sh CHANGED
@@ -19,9 +19,10 @@ echo "nginx started successfully"
19
  cd $HOME/app/cua2-core
20
 
21
  # Set default number of workers if not specified
22
- WORKERS=${WORKERS:-1}
23
 
24
  echo "Starting backend with $WORKERS worker(s)..."
25
 
26
  # Use uv to run the application
27
- exec uv run uvicorn cua2_core.main:app --host 0.0.0.0 --port 8000 --workers $WORKERS --log-level info
 
 
19
  cd $HOME/app/cua2-core
20
 
21
  # Set default number of workers if not specified
22
+ NUM_WORKERS=${NUM_WORKERS:-1}
23
 
24
  echo "Starting backend with $WORKERS worker(s)..."
25
 
26
  # Use uv to run the application
27
+ echo "uv run uvicorn cua2_core.main:app --host 0.0.0.0 --port 8000 --workers $NUM_WORKERS --log-level error > /dev/null"
28
+ exec uv run uvicorn cua2_core.main:app --host 0.0.0.0 --port 8000 --workers $NUM_WORKERS --log-level error > /dev/null