Spaces:
Running
Running
A-Mahla
commited on
ADD agent stop feature (#11)
Browse files* NEW agent
* Handle stop agent feature
* CHG agent max step
* FIX model
- cua2-core/src/cua2_core/models/models.py +31 -1
- cua2-core/src/cua2_core/routes/websocket.py +10 -0
- cua2-core/src/cua2_core/services/agent_service.py +47 -21
- cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py +1 -1
- cua2-core/src/cua2_core/websocket/websocket_manager.py +6 -3
- cua2-front/src/types/agent.ts +7 -0
cua2-core/src/cua2_core/models/models.py
CHANGED
|
@@ -82,7 +82,7 @@ class AgentAction(FunctionCall):
|
|
| 82 |
return f"Open: {url}"
|
| 83 |
|
| 84 |
elif action_type == "launch":
|
| 85 |
-
url = args.get("
|
| 86 |
return f"Open: {url}"
|
| 87 |
|
| 88 |
elif action_type == "final_answer":
|
|
@@ -172,6 +172,7 @@ class AgentCompleteEvent(BaseModel):
|
|
| 172 |
|
| 173 |
type: Literal["agent_complete"] = "agent_complete"
|
| 174 |
traceMetadata: AgentTraceMetadata
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
class AgentErrorEvent(BaseModel):
|
|
@@ -222,6 +223,13 @@ class UserTaskMessage(BaseModel):
|
|
| 222 |
agent_trace: AgentTrace | None = None
|
| 223 |
|
| 224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
##################### Agent Service ########################
|
| 226 |
|
| 227 |
|
|
@@ -256,6 +264,7 @@ class ActiveTask(BaseModel):
|
|
| 256 |
f,
|
| 257 |
indent=2,
|
| 258 |
)
|
|
|
|
| 259 |
|
| 260 |
def update_step(self, step: AgentStep):
|
| 261 |
"""Update step"""
|
|
@@ -276,6 +285,27 @@ class ActiveTask(BaseModel):
|
|
| 276 |
indent=2,
|
| 277 |
)
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
#################### API Routes Models ########################
|
| 281 |
|
|
|
|
| 82 |
return f"Open: {url}"
|
| 83 |
|
| 84 |
elif action_type == "launch":
|
| 85 |
+
url = args.get("app") or args.get("arg_0")
|
| 86 |
return f"Open: {url}"
|
| 87 |
|
| 88 |
elif action_type == "final_answer":
|
|
|
|
| 172 |
|
| 173 |
type: Literal["agent_complete"] = "agent_complete"
|
| 174 |
traceMetadata: AgentTraceMetadata
|
| 175 |
+
final_state: Literal["success", "stopped", "max_steps_reached", "error"]
|
| 176 |
|
| 177 |
|
| 178 |
class AgentErrorEvent(BaseModel):
|
|
|
|
| 223 |
agent_trace: AgentTrace | None = None
|
| 224 |
|
| 225 |
|
| 226 |
+
class StopTask(BaseModel):
|
| 227 |
+
"""Stop task message"""
|
| 228 |
+
|
| 229 |
+
event_type: Literal["stop_task"]
|
| 230 |
+
traceId: str
|
| 231 |
+
|
| 232 |
+
|
| 233 |
##################### Agent Service ########################
|
| 234 |
|
| 235 |
|
|
|
|
| 264 |
f,
|
| 265 |
indent=2,
|
| 266 |
)
|
| 267 |
+
return self
|
| 268 |
|
| 269 |
def update_step(self, step: AgentStep):
|
| 270 |
"""Update step"""
|
|
|
|
| 285 |
indent=2,
|
| 286 |
)
|
| 287 |
|
| 288 |
+
def update_trace_metadata(
|
| 289 |
+
self,
|
| 290 |
+
step_input_tokens_used: int | None = None,
|
| 291 |
+
step_output_tokens_used: int | None = None,
|
| 292 |
+
step_duration: float | None = None,
|
| 293 |
+
step_numberOfSteps: int | None = None,
|
| 294 |
+
completed: bool | None = None,
|
| 295 |
+
):
|
| 296 |
+
"""Update trace metadata"""
|
| 297 |
+
with self._file_lock:
|
| 298 |
+
if step_input_tokens_used is not None:
|
| 299 |
+
self.traceMetadata.inputTokensUsed += step_input_tokens_used
|
| 300 |
+
if step_output_tokens_used is not None:
|
| 301 |
+
self.traceMetadata.outputTokensUsed += step_output_tokens_used
|
| 302 |
+
if step_duration is not None:
|
| 303 |
+
self.traceMetadata.duration += step_duration
|
| 304 |
+
if step_numberOfSteps is not None:
|
| 305 |
+
self.traceMetadata.numberOfSteps += step_numberOfSteps
|
| 306 |
+
if completed is not None:
|
| 307 |
+
self.traceMetadata.completed = completed
|
| 308 |
+
|
| 309 |
|
| 310 |
#################### API Routes Models ########################
|
| 311 |
|
cua2-core/src/cua2_core/routes/websocket.py
CHANGED
|
@@ -59,6 +59,16 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 59 |
else:
|
| 60 |
print("No trace data in message")
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
except json.JSONDecodeError as e:
|
| 63 |
print(f"JSON decode error: {e}")
|
| 64 |
from cua2_core.models.models import AgentErrorEvent
|
|
|
|
| 59 |
else:
|
| 60 |
print("No trace data in message")
|
| 61 |
|
| 62 |
+
elif message_data.get("type") == "stop_task":
|
| 63 |
+
# Extract and parse the trace
|
| 64 |
+
trace_id = message_data.get("trace_id")
|
| 65 |
+
if trace_id:
|
| 66 |
+
# Stop the task
|
| 67 |
+
await agent_service.stop_task(trace_id)
|
| 68 |
+
print(f"Stopped task: {trace_id}")
|
| 69 |
+
else:
|
| 70 |
+
print("No trace ID in message")
|
| 71 |
+
|
| 72 |
except json.JSONDecodeError as e:
|
| 73 |
print(f"JSON decode error: {e}")
|
| 74 |
from cua2_core.models.models import AgentErrorEvent
|
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -27,6 +27,12 @@ from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
class AgentService:
|
| 31 |
"""Service for handling agent tasks and processing"""
|
| 32 |
|
|
@@ -74,6 +80,7 @@ class AgentService:
|
|
| 74 |
agent = None
|
| 75 |
novnc_active = False
|
| 76 |
websocket_exception = False
|
|
|
|
| 77 |
|
| 78 |
try:
|
| 79 |
# Get the websocket for this task
|
|
@@ -129,9 +136,14 @@ class AgentService:
|
|
| 129 |
|
| 130 |
self.active_tasks[message_id].traceMetadata.completed = True
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
except WebSocketException:
|
| 133 |
websocket_exception = True
|
| 134 |
-
pass
|
| 135 |
|
| 136 |
except (Exception, KeyboardInterrupt):
|
| 137 |
import traceback
|
|
@@ -139,6 +151,7 @@ class AgentService:
|
|
| 139 |
logger.error(
|
| 140 |
f"Error processing task: {traceback.format_exc()}", exc_info=True
|
| 141 |
)
|
|
|
|
| 142 |
await self.websocket_manager.send_agent_error(
|
| 143 |
error="Error processing task", websocket=websocket
|
| 144 |
)
|
|
@@ -149,6 +162,7 @@ class AgentService:
|
|
| 149 |
await self.websocket_manager.send_agent_complete(
|
| 150 |
metadata=self.active_tasks[message_id].traceMetadata,
|
| 151 |
websocket=websocket,
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
if novnc_active:
|
|
@@ -191,6 +205,9 @@ class AgentService:
|
|
| 191 |
def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
|
| 192 |
assert memory_step.step_number is not None
|
| 193 |
|
|
|
|
|
|
|
|
|
|
| 194 |
time.sleep(3)
|
| 195 |
|
| 196 |
image = self.last_screenshot[message_id]
|
|
@@ -215,9 +232,7 @@ class AgentService:
|
|
| 215 |
if memory_step.model_output_message
|
| 216 |
else None
|
| 217 |
)
|
| 218 |
-
if
|
| 219 |
-
memory_step.error, AgentMaxStepsError
|
| 220 |
-
):
|
| 221 |
model_output = memory_step.action_output
|
| 222 |
|
| 223 |
thought = (
|
|
@@ -229,11 +244,14 @@ class AgentService:
|
|
| 229 |
)
|
| 230 |
else None
|
| 231 |
)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
| 237 |
if memory_step.observations_images:
|
| 238 |
image = memory_step.observations_images[0]
|
| 239 |
buffered = BytesIO()
|
|
@@ -244,6 +262,8 @@ class AgentService:
|
|
| 244 |
else:
|
| 245 |
image_base64 = None
|
| 246 |
|
|
|
|
|
|
|
| 247 |
step = AgentStep(
|
| 248 |
traceId=message_id,
|
| 249 |
stepId=str(memory_step.step_number),
|
|
@@ -260,18 +280,14 @@ class AgentService:
|
|
| 260 |
outputTokensUsed=memory_step.token_usage.output_tokens,
|
| 261 |
step_evaluation="neutral",
|
| 262 |
)
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
message_id
|
| 272 |
-
].traceMetadata.duration += memory_step.timing.duration
|
| 273 |
-
|
| 274 |
-
# Add step to active task
|
| 275 |
self.active_tasks[message_id].update_step(step)
|
| 276 |
|
| 277 |
websocket = self.task_websockets.get(message_id)
|
|
@@ -285,6 +301,9 @@ class AgentService:
|
|
| 285 |
)
|
| 286 |
future.result()
|
| 287 |
|
|
|
|
|
|
|
|
|
|
| 288 |
step_filename = f"{message_id}-{memory_step.step_number}"
|
| 289 |
screenshot_bytes = agent.desktop.screenshot()
|
| 290 |
image = Image.open(BytesIO(screenshot_bytes))
|
|
@@ -367,3 +386,10 @@ class AgentService:
|
|
| 367 |
raise ValueError(f"Step {step_id} not found in trace")
|
| 368 |
except (ValueError, KeyError, TypeError) as e:
|
| 369 |
raise ValueError(f"Error processing step update: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
| 29 |
|
| 30 |
+
class AgentStopException(Exception):
|
| 31 |
+
"""Exception for agent stop"""
|
| 32 |
+
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class AgentService:
|
| 37 |
"""Service for handling agent tasks and processing"""
|
| 38 |
|
|
|
|
| 80 |
agent = None
|
| 81 |
novnc_active = False
|
| 82 |
websocket_exception = False
|
| 83 |
+
final_state = "success"
|
| 84 |
|
| 85 |
try:
|
| 86 |
# Get the websocket for this task
|
|
|
|
| 136 |
|
| 137 |
self.active_tasks[message_id].traceMetadata.completed = True
|
| 138 |
|
| 139 |
+
except AgentStopException as e:
|
| 140 |
+
if str(e) == "Max steps reached":
|
| 141 |
+
final_state = "max_steps_reached"
|
| 142 |
+
elif str(e) == "Task not completed":
|
| 143 |
+
final_state = "stopped"
|
| 144 |
+
|
| 145 |
except WebSocketException:
|
| 146 |
websocket_exception = True
|
|
|
|
| 147 |
|
| 148 |
except (Exception, KeyboardInterrupt):
|
| 149 |
import traceback
|
|
|
|
| 151 |
logger.error(
|
| 152 |
f"Error processing task: {traceback.format_exc()}", exc_info=True
|
| 153 |
)
|
| 154 |
+
final_state = "error"
|
| 155 |
await self.websocket_manager.send_agent_error(
|
| 156 |
error="Error processing task", websocket=websocket
|
| 157 |
)
|
|
|
|
| 162 |
await self.websocket_manager.send_agent_complete(
|
| 163 |
metadata=self.active_tasks[message_id].traceMetadata,
|
| 164 |
websocket=websocket,
|
| 165 |
+
final_state=final_state,
|
| 166 |
)
|
| 167 |
|
| 168 |
if novnc_active:
|
|
|
|
| 205 |
def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
|
| 206 |
assert memory_step.step_number is not None
|
| 207 |
|
| 208 |
+
if memory_step.step_number > agent.max_steps:
|
| 209 |
+
raise AgentStopException("Max steps reached")
|
| 210 |
+
|
| 211 |
time.sleep(3)
|
| 212 |
|
| 213 |
image = self.last_screenshot[message_id]
|
|
|
|
| 232 |
if memory_step.model_output_message
|
| 233 |
else None
|
| 234 |
)
|
| 235 |
+
if isinstance(memory_step.error, AgentMaxStepsError):
|
|
|
|
|
|
|
| 236 |
model_output = memory_step.action_output
|
| 237 |
|
| 238 |
thought = (
|
|
|
|
| 244 |
)
|
| 245 |
else None
|
| 246 |
)
|
| 247 |
+
|
| 248 |
+
if model_output is not None:
|
| 249 |
+
action_sequence = model_output.split("```")[1]
|
| 250 |
+
else:
|
| 251 |
+
action_sequence = (
|
| 252 |
+
"""The task failed due to an error""" # TODO: To Handle in front
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
if memory_step.observations_images:
|
| 256 |
image = memory_step.observations_images[0]
|
| 257 |
buffered = BytesIO()
|
|
|
|
| 262 |
else:
|
| 263 |
image_base64 = None
|
| 264 |
|
| 265 |
+
logger.info(memory_step)
|
| 266 |
+
|
| 267 |
step = AgentStep(
|
| 268 |
traceId=message_id,
|
| 269 |
stepId=str(memory_step.step_number),
|
|
|
|
| 280 |
outputTokensUsed=memory_step.token_usage.output_tokens,
|
| 281 |
step_evaluation="neutral",
|
| 282 |
)
|
| 283 |
+
|
| 284 |
+
self.active_tasks[message_id].update_trace_metadata(
|
| 285 |
+
step_input_tokens_used=memory_step.token_usage.input_tokens,
|
| 286 |
+
step_output_tokens_used=memory_step.token_usage.output_tokens,
|
| 287 |
+
step_duration=memory_step.timing.duration,
|
| 288 |
+
step_numberOfSteps=1,
|
| 289 |
+
)
|
| 290 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
self.active_tasks[message_id].update_step(step)
|
| 292 |
|
| 293 |
websocket = self.task_websockets.get(message_id)
|
|
|
|
| 301 |
)
|
| 302 |
future.result()
|
| 303 |
|
| 304 |
+
if self.active_tasks[message_id].traceMetadata.completed:
|
| 305 |
+
raise AgentStopException("Task not completed")
|
| 306 |
+
|
| 307 |
step_filename = f"{message_id}-{memory_step.step_number}"
|
| 308 |
screenshot_bytes = agent.desktop.screenshot()
|
| 309 |
image = Image.open(BytesIO(screenshot_bytes))
|
|
|
|
| 386 |
raise ValueError(f"Step {step_id} not found in trace")
|
| 387 |
except (ValueError, KeyError, TypeError) as e:
|
| 388 |
raise ValueError(f"Error processing step update: {e}")
|
| 389 |
+
|
| 390 |
+
async def stop_task(self, trace_id: str):
|
| 391 |
+
"""Stop a task"""
|
| 392 |
+
if trace_id in self.active_tasks:
|
| 393 |
+
self.active_tasks[trace_id].update_trace_metadata(
|
| 394 |
+
completed=True,
|
| 395 |
+
)
|
cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py
CHANGED
|
@@ -20,7 +20,7 @@ class E2BVisionAgent(CodeAgent):
|
|
| 20 |
model: Model,
|
| 21 |
data_dir: str,
|
| 22 |
desktop: Sandbox,
|
| 23 |
-
max_steps: int =
|
| 24 |
verbosity_level: LogLevel = 2,
|
| 25 |
planning_interval: int | None = None,
|
| 26 |
use_v1_prompt: bool = False,
|
|
|
|
| 20 |
model: Model,
|
| 21 |
data_dir: str,
|
| 22 |
desktop: Sandbox,
|
| 23 |
+
max_steps: int = 30,
|
| 24 |
verbosity_level: LogLevel = 2,
|
| 25 |
planning_interval: int | None = None,
|
| 26 |
use_v1_prompt: bool = False,
|
cua2-core/src/cua2_core/websocket/websocket_manager.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
-
from typing import Dict, Set
|
| 4 |
|
| 5 |
from cua2_core.models.models import (
|
| 6 |
ActiveTask,
|
|
@@ -91,10 +91,13 @@ class WebSocketManager:
|
|
| 91 |
await self.send_message(event, websocket)
|
| 92 |
|
| 93 |
async def send_agent_complete(
|
| 94 |
-
self,
|
|
|
|
|
|
|
|
|
|
| 95 |
):
|
| 96 |
"""Send agent complete event"""
|
| 97 |
-
event = AgentCompleteEvent(traceMetadata=metadata)
|
| 98 |
await self.send_message(event, websocket)
|
| 99 |
|
| 100 |
async def send_agent_error(self, error: str, websocket: WebSocket):
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import json
|
| 3 |
+
from typing import Dict, Literal, Set
|
| 4 |
|
| 5 |
from cua2_core.models.models import (
|
| 6 |
ActiveTask,
|
|
|
|
| 91 |
await self.send_message(event, websocket)
|
| 92 |
|
| 93 |
async def send_agent_complete(
|
| 94 |
+
self,
|
| 95 |
+
metadata: AgentTraceMetadata,
|
| 96 |
+
websocket: WebSocket,
|
| 97 |
+
final_state: Literal["success", "stopped", "max_steps_reached", "error"],
|
| 98 |
):
|
| 99 |
"""Send agent complete event"""
|
| 100 |
+
event = AgentCompleteEvent(traceMetadata=metadata, final_state=final_state)
|
| 101 |
await self.send_message(event, websocket)
|
| 102 |
|
| 103 |
async def send_agent_error(self, error: str, websocket: WebSocket):
|
cua2-front/src/types/agent.ts
CHANGED
|
@@ -59,6 +59,7 @@ interface AgentProgressEvent {
|
|
| 59 |
interface AgentCompleteEvent {
|
| 60 |
type: 'agent_complete';
|
| 61 |
traceMetadata: AgentTraceMetadata;
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
interface AgentErrorEvent {
|
|
@@ -95,3 +96,9 @@ export interface UserTaskMessage {
|
|
| 95 |
type: 'user_task';
|
| 96 |
trace: AgentTrace;
|
| 97 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
interface AgentCompleteEvent {
|
| 60 |
type: 'agent_complete';
|
| 61 |
traceMetadata: AgentTraceMetadata;
|
| 62 |
+
final_state: 'success' | 'stopped' | 'max_steps_reached' | 'error';
|
| 63 |
}
|
| 64 |
|
| 65 |
interface AgentErrorEvent {
|
|
|
|
| 96 |
type: 'user_task';
|
| 97 |
trace: AgentTrace;
|
| 98 |
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
export interface StopTaskMessage {
|
| 102 |
+
type: 'stop_task';
|
| 103 |
+
traceId: string;
|
| 104 |
+
}
|