A-Mahla commited on
Commit
9493fef
·
unverified ·
1 Parent(s): 65f0789

ADD agent stop feature (#11)

Browse files

* NEW agent

* Handle stop agent feature

* CHG agent max step

* FIX model

cua2-core/src/cua2_core/models/models.py CHANGED
@@ -82,7 +82,7 @@ class AgentAction(FunctionCall):
82
  return f"Open: {url}"
83
 
84
  elif action_type == "launch":
85
- url = args.get("url") or args.get("arg_0")
86
  return f"Open: {url}"
87
 
88
  elif action_type == "final_answer":
@@ -172,6 +172,7 @@ class AgentCompleteEvent(BaseModel):
172
 
173
  type: Literal["agent_complete"] = "agent_complete"
174
  traceMetadata: AgentTraceMetadata
 
175
 
176
 
177
  class AgentErrorEvent(BaseModel):
@@ -222,6 +223,13 @@ class UserTaskMessage(BaseModel):
222
  agent_trace: AgentTrace | None = None
223
 
224
 
 
 
 
 
 
 
 
225
  ##################### Agent Service ########################
226
 
227
 
@@ -256,6 +264,7 @@ class ActiveTask(BaseModel):
256
  f,
257
  indent=2,
258
  )
 
259
 
260
  def update_step(self, step: AgentStep):
261
  """Update step"""
@@ -276,6 +285,27 @@ class ActiveTask(BaseModel):
276
  indent=2,
277
  )
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  #################### API Routes Models ########################
281
 
 
82
  return f"Open: {url}"
83
 
84
  elif action_type == "launch":
85
+ url = args.get("app") or args.get("arg_0")
86
  return f"Open: {url}"
87
 
88
  elif action_type == "final_answer":
 
172
 
173
  type: Literal["agent_complete"] = "agent_complete"
174
  traceMetadata: AgentTraceMetadata
175
+ final_state: Literal["success", "stopped", "max_steps_reached", "error"]
176
 
177
 
178
  class AgentErrorEvent(BaseModel):
 
223
  agent_trace: AgentTrace | None = None
224
 
225
 
226
+ class StopTask(BaseModel):
227
+ """Stop task message"""
228
+
229
+ event_type: Literal["stop_task"]
230
+ traceId: str
231
+
232
+
233
  ##################### Agent Service ########################
234
 
235
 
 
264
  f,
265
  indent=2,
266
  )
267
+ return self
268
 
269
  def update_step(self, step: AgentStep):
270
  """Update step"""
 
285
  indent=2,
286
  )
287
 
288
+ def update_trace_metadata(
289
+ self,
290
+ step_input_tokens_used: int | None = None,
291
+ step_output_tokens_used: int | None = None,
292
+ step_duration: float | None = None,
293
+ step_numberOfSteps: int | None = None,
294
+ completed: bool | None = None,
295
+ ):
296
+ """Update trace metadata"""
297
+ with self._file_lock:
298
+ if step_input_tokens_used is not None:
299
+ self.traceMetadata.inputTokensUsed += step_input_tokens_used
300
+ if step_output_tokens_used is not None:
301
+ self.traceMetadata.outputTokensUsed += step_output_tokens_used
302
+ if step_duration is not None:
303
+ self.traceMetadata.duration += step_duration
304
+ if step_numberOfSteps is not None:
305
+ self.traceMetadata.numberOfSteps += step_numberOfSteps
306
+ if completed is not None:
307
+ self.traceMetadata.completed = completed
308
+
309
 
310
  #################### API Routes Models ########################
311
 
cua2-core/src/cua2_core/routes/websocket.py CHANGED
@@ -59,6 +59,16 @@ async def websocket_endpoint(websocket: WebSocket):
59
  else:
60
  print("No trace data in message")
61
 
 
 
 
 
 
 
 
 
 
 
62
  except json.JSONDecodeError as e:
63
  print(f"JSON decode error: {e}")
64
  from cua2_core.models.models import AgentErrorEvent
 
59
  else:
60
  print("No trace data in message")
61
 
62
+ elif message_data.get("type") == "stop_task":
63
+ # Extract and parse the trace
64
+ trace_id = message_data.get("trace_id")
65
+ if trace_id:
66
+ # Stop the task
67
+ await agent_service.stop_task(trace_id)
68
+ print(f"Stopped task: {trace_id}")
69
+ else:
70
+ print("No trace ID in message")
71
+
72
  except json.JSONDecodeError as e:
73
  print(f"JSON decode error: {e}")
74
  from cua2_core.models.models import AgentErrorEvent
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -27,6 +27,12 @@ from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
27
  logger = logging.getLogger(__name__)
28
 
29
 
 
 
 
 
 
 
30
  class AgentService:
31
  """Service for handling agent tasks and processing"""
32
 
@@ -74,6 +80,7 @@ class AgentService:
74
  agent = None
75
  novnc_active = False
76
  websocket_exception = False
 
77
 
78
  try:
79
  # Get the websocket for this task
@@ -129,9 +136,14 @@ class AgentService:
129
 
130
  self.active_tasks[message_id].traceMetadata.completed = True
131
 
 
 
 
 
 
 
132
  except WebSocketException:
133
  websocket_exception = True
134
- pass
135
 
136
  except (Exception, KeyboardInterrupt):
137
  import traceback
@@ -139,6 +151,7 @@ class AgentService:
139
  logger.error(
140
  f"Error processing task: {traceback.format_exc()}", exc_info=True
141
  )
 
142
  await self.websocket_manager.send_agent_error(
143
  error="Error processing task", websocket=websocket
144
  )
@@ -149,6 +162,7 @@ class AgentService:
149
  await self.websocket_manager.send_agent_complete(
150
  metadata=self.active_tasks[message_id].traceMetadata,
151
  websocket=websocket,
 
152
  )
153
 
154
  if novnc_active:
@@ -191,6 +205,9 @@ class AgentService:
191
  def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
192
  assert memory_step.step_number is not None
193
 
 
 
 
194
  time.sleep(3)
195
 
196
  image = self.last_screenshot[message_id]
@@ -215,9 +232,7 @@ class AgentService:
215
  if memory_step.model_output_message
216
  else None
217
  )
218
- if model_output is None and isinstance(
219
- memory_step.error, AgentMaxStepsError
220
- ):
221
  model_output = memory_step.action_output
222
 
223
  thought = (
@@ -229,11 +244,14 @@ class AgentService:
229
  )
230
  else None
231
  )
232
- action_sequence = (
233
- model_output.split("```")[1]
234
- if model_output and memory_step.error is None
235
- else None
236
- )
 
 
 
237
  if memory_step.observations_images:
238
  image = memory_step.observations_images[0]
239
  buffered = BytesIO()
@@ -244,6 +262,8 @@ class AgentService:
244
  else:
245
  image_base64 = None
246
 
 
 
247
  step = AgentStep(
248
  traceId=message_id,
249
  stepId=str(memory_step.step_number),
@@ -260,18 +280,14 @@ class AgentService:
260
  outputTokensUsed=memory_step.token_usage.output_tokens,
261
  step_evaluation="neutral",
262
  )
263
- self.active_tasks[
264
- message_id
265
- ].traceMetadata.inputTokensUsed += memory_step.token_usage.input_tokens
266
- self.active_tasks[
267
- message_id
268
- ].traceMetadata.outputTokensUsed += memory_step.token_usage.output_tokens
269
- self.active_tasks[message_id].traceMetadata.numberOfSteps += 1
270
- self.active_tasks[
271
- message_id
272
- ].traceMetadata.duration += memory_step.timing.duration
273
-
274
- # Add step to active task
275
  self.active_tasks[message_id].update_step(step)
276
 
277
  websocket = self.task_websockets.get(message_id)
@@ -285,6 +301,9 @@ class AgentService:
285
  )
286
  future.result()
287
 
 
 
 
288
  step_filename = f"{message_id}-{memory_step.step_number}"
289
  screenshot_bytes = agent.desktop.screenshot()
290
  image = Image.open(BytesIO(screenshot_bytes))
@@ -367,3 +386,10 @@ class AgentService:
367
  raise ValueError(f"Step {step_id} not found in trace")
368
  except (ValueError, KeyError, TypeError) as e:
369
  raise ValueError(f"Error processing step update: {e}")
 
 
 
 
 
 
 
 
27
  logger = logging.getLogger(__name__)
28
 
29
 
30
+ class AgentStopException(Exception):
31
+ """Exception for agent stop"""
32
+
33
+ pass
34
+
35
+
36
  class AgentService:
37
  """Service for handling agent tasks and processing"""
38
 
 
80
  agent = None
81
  novnc_active = False
82
  websocket_exception = False
83
+ final_state = "success"
84
 
85
  try:
86
  # Get the websocket for this task
 
136
 
137
  self.active_tasks[message_id].traceMetadata.completed = True
138
 
139
+ except AgentStopException as e:
140
+ if str(e) == "Max steps reached":
141
+ final_state = "max_steps_reached"
142
+ elif str(e) == "Task not completed":
143
+ final_state = "stopped"
144
+
145
  except WebSocketException:
146
  websocket_exception = True
 
147
 
148
  except (Exception, KeyboardInterrupt):
149
  import traceback
 
151
  logger.error(
152
  f"Error processing task: {traceback.format_exc()}", exc_info=True
153
  )
154
+ final_state = "error"
155
  await self.websocket_manager.send_agent_error(
156
  error="Error processing task", websocket=websocket
157
  )
 
162
  await self.websocket_manager.send_agent_complete(
163
  metadata=self.active_tasks[message_id].traceMetadata,
164
  websocket=websocket,
165
+ final_state=final_state,
166
  )
167
 
168
  if novnc_active:
 
205
  def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
206
  assert memory_step.step_number is not None
207
 
208
+ if memory_step.step_number > agent.max_steps:
209
+ raise AgentStopException("Max steps reached")
210
+
211
  time.sleep(3)
212
 
213
  image = self.last_screenshot[message_id]
 
232
  if memory_step.model_output_message
233
  else None
234
  )
235
+ if isinstance(memory_step.error, AgentMaxStepsError):
 
 
236
  model_output = memory_step.action_output
237
 
238
  thought = (
 
244
  )
245
  else None
246
  )
247
+
248
+ if model_output is not None:
249
+ action_sequence = model_output.split("```")[1]
250
+ else:
251
+ action_sequence = (
252
+ """The task failed due to an error""" # TODO: To Handle in front
253
+ )
254
+
255
  if memory_step.observations_images:
256
  image = memory_step.observations_images[0]
257
  buffered = BytesIO()
 
262
  else:
263
  image_base64 = None
264
 
265
+ logger.info(memory_step)
266
+
267
  step = AgentStep(
268
  traceId=message_id,
269
  stepId=str(memory_step.step_number),
 
280
  outputTokensUsed=memory_step.token_usage.output_tokens,
281
  step_evaluation="neutral",
282
  )
283
+
284
+ self.active_tasks[message_id].update_trace_metadata(
285
+ step_input_tokens_used=memory_step.token_usage.input_tokens,
286
+ step_output_tokens_used=memory_step.token_usage.output_tokens,
287
+ step_duration=memory_step.timing.duration,
288
+ step_numberOfSteps=1,
289
+ )
290
+
 
 
 
 
291
  self.active_tasks[message_id].update_step(step)
292
 
293
  websocket = self.task_websockets.get(message_id)
 
301
  )
302
  future.result()
303
 
304
+ if self.active_tasks[message_id].traceMetadata.completed:
305
+ raise AgentStopException("Task not completed")
306
+
307
  step_filename = f"{message_id}-{memory_step.step_number}"
308
  screenshot_bytes = agent.desktop.screenshot()
309
  image = Image.open(BytesIO(screenshot_bytes))
 
386
  raise ValueError(f"Step {step_id} not found in trace")
387
  except (ValueError, KeyError, TypeError) as e:
388
  raise ValueError(f"Error processing step update: {e}")
389
+
390
+ async def stop_task(self, trace_id: str):
391
+ """Stop a task"""
392
+ if trace_id in self.active_tasks:
393
+ self.active_tasks[trace_id].update_trace_metadata(
394
+ completed=True,
395
+ )
cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py CHANGED
@@ -20,7 +20,7 @@ class E2BVisionAgent(CodeAgent):
20
  model: Model,
21
  data_dir: str,
22
  desktop: Sandbox,
23
- max_steps: int = 200,
24
  verbosity_level: LogLevel = 2,
25
  planning_interval: int | None = None,
26
  use_v1_prompt: bool = False,
 
20
  model: Model,
21
  data_dir: str,
22
  desktop: Sandbox,
23
+ max_steps: int = 30,
24
  verbosity_level: LogLevel = 2,
25
  planning_interval: int | None = None,
26
  use_v1_prompt: bool = False,
cua2-core/src/cua2_core/websocket/websocket_manager.py CHANGED
@@ -1,6 +1,6 @@
1
  import asyncio
2
  import json
3
- from typing import Dict, Set
4
 
5
  from cua2_core.models.models import (
6
  ActiveTask,
@@ -91,10 +91,13 @@ class WebSocketManager:
91
  await self.send_message(event, websocket)
92
 
93
  async def send_agent_complete(
94
- self, metadata: AgentTraceMetadata, websocket: WebSocket
 
 
 
95
  ):
96
  """Send agent complete event"""
97
- event = AgentCompleteEvent(traceMetadata=metadata)
98
  await self.send_message(event, websocket)
99
 
100
  async def send_agent_error(self, error: str, websocket: WebSocket):
 
1
  import asyncio
2
  import json
3
+ from typing import Dict, Literal, Set
4
 
5
  from cua2_core.models.models import (
6
  ActiveTask,
 
91
  await self.send_message(event, websocket)
92
 
93
  async def send_agent_complete(
94
+ self,
95
+ metadata: AgentTraceMetadata,
96
+ websocket: WebSocket,
97
+ final_state: Literal["success", "stopped", "max_steps_reached", "error"],
98
  ):
99
  """Send agent complete event"""
100
+ event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
101
  await self.send_message(event, websocket)
102
 
103
  async def send_agent_error(self, error: str, websocket: WebSocket):
cua2-front/src/types/agent.ts CHANGED
@@ -59,6 +59,7 @@ interface AgentProgressEvent {
59
  interface AgentCompleteEvent {
60
  type: 'agent_complete';
61
  traceMetadata: AgentTraceMetadata;
 
62
  }
63
 
64
  interface AgentErrorEvent {
@@ -95,3 +96,9 @@ export interface UserTaskMessage {
95
  type: 'user_task';
96
  trace: AgentTrace;
97
  }
 
 
 
 
 
 
 
59
  interface AgentCompleteEvent {
60
  type: 'agent_complete';
61
  traceMetadata: AgentTraceMetadata;
62
+ final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error';
63
  }
64
 
65
  interface AgentErrorEvent {
 
96
  type: 'user_task';
97
  trace: AgentTrace;
98
  }
99
+
100
+
101
+ export interface StopTaskMessage {
102
+ type: 'stop_task';
103
+ traceId: string;
104
+ }