A-Mahla commited on
Commit
e0d4a07
·
1 Parent(s): 97e46c6

ADD Backend V1 (#2)

Browse files

* ADD backend logic

* FIX pre-commit

* FIX pre-commit

* FIX pre-commit

* FIX Agent loop

* ADD pytest

* ADD pytest

* CHG github workflow

.github/workflows/pre-commit.yml CHANGED
@@ -31,4 +31,4 @@ jobs:
31
 
32
  - name: Run pre-commit
33
  run: |
34
- uv run pre-commit run --all-files --show-diff-on-failure
 
31
 
32
  - name: Run pre-commit
33
  run: |
34
+ make pre-commit
Makefile CHANGED
@@ -23,6 +23,14 @@ dev-frontend:
23
 
24
  pre-commit:
25
  uv run pre-commit run --all-files --show-diff-on-failure
 
 
 
 
 
 
 
 
26
 
27
  clean:
28
  find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
 
23
 
24
  pre-commit:
25
  uv run pre-commit run --all-files --show-diff-on-failure
26
+ make test
27
+
28
+ # Run tests
29
+ test:
30
+ cd cua2-core && uv run pytest tests/ -v
31
+
32
+ test-coverage:
33
+ cd cua2-core && uv run pytest tests/ -v --cov=cua2_core --cov-report=html --cov-report=term
34
 
35
  clean:
36
  find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
cua2-core/pyproject.toml CHANGED
@@ -33,9 +33,9 @@ dependencies = [
33
  "httpx>=0.27.1",
34
  "asyncio-mqtt==0.16.1",
35
  "aiofiles==23.2.1",
36
- "smolagents[openai,litellm]==1.15.0",
37
- "openai==1.91.0",
38
- "litellm[proxy]==1.63.14",
39
  ]
40
 
41
  [project.optional-dependencies]
 
33
  "httpx>=0.27.1",
34
  "asyncio-mqtt==0.16.1",
35
  "aiofiles==23.2.1",
36
+ "smolagents[openai,litellm]==1.22.0",
37
+ "openai==2.6.1",
38
+ "e2b-desktop==2.1.0",
39
  ]
40
 
41
  [project.optional-dependencies]
cua2-core/pytest.ini ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py
4
+ python_classes = Test*
5
+ python_functions = test_*
6
+ addopts =
7
+ -v
8
+ --strict-markers
9
+ --tb=short
10
+ --disable-warnings
11
+ markers =
12
+ unit: Unit tests
13
+ integration: Integration tests
cua2-core/src/cua2_core/app.py CHANGED
@@ -1,6 +1,7 @@
1
  from contextlib import asynccontextmanager
2
 
3
  from cua2_core.services.agent_service import AgentService
 
4
  from cua2_core.websocket.websocket_manager import WebSocketManager
5
  from dotenv import load_dotenv
6
  from fastapi import FastAPI
@@ -16,23 +17,23 @@ async def lifespan(app: FastAPI):
16
  # Startup: Initialize services
17
  print("Initializing services...")
18
 
19
- # Initialize WebSocket manager
20
  websocket_manager = WebSocketManager()
21
 
22
- # Initialize agent service with websocket manager dependency
23
- agent_service = AgentService(websocket_manager)
 
24
 
25
  # Store services in app state for access in routes
26
  app.state.websocket_manager = websocket_manager
 
27
  app.state.agent_service = agent_service
28
 
29
  print("Services initialized successfully")
30
 
31
  yield
32
 
33
- # Shutdown: Clean up resources
34
  print("Shutting down services...")
35
- # Add any cleanup logic here if needed
36
  print("Services shut down successfully")
37
 
38
 
 
1
  from contextlib import asynccontextmanager
2
 
3
  from cua2_core.services.agent_service import AgentService
4
+ from cua2_core.services.sandbox_service import SandboxService
5
  from cua2_core.websocket.websocket_manager import WebSocketManager
6
  from dotenv import load_dotenv
7
  from fastapi import FastAPI
 
17
  # Startup: Initialize services
18
  print("Initializing services...")
19
 
 
20
  websocket_manager = WebSocketManager()
21
 
22
+ sandbox_service = SandboxService()
23
+
24
+ agent_service = AgentService(websocket_manager, sandbox_service)
25
 
26
  # Store services in app state for access in routes
27
  app.state.websocket_manager = websocket_manager
28
+ app.state.sandbox_service = sandbox_service
29
  app.state.agent_service = agent_service
30
 
31
  print("Services initialized successfully")
32
 
33
  yield
34
 
 
35
  print("Shutting down services...")
36
+ await sandbox_service.cleanup_sandboxes()
37
  print("Services shut down successfully")
38
 
39
 
cua2-core/src/cua2_core/models/models.py CHANGED
@@ -1,70 +1,91 @@
1
  import json
2
  import os
 
3
  from datetime import datetime
4
  from typing import Annotated, Literal, Optional
5
 
6
- from pydantic import BaseModel, Field, field_serializer, model_validator
 
7
  from typing_extensions import TypeAlias
8
 
9
  #################### Backend -> Frontend ########################
10
 
11
 
12
- class AgentAction(BaseModel):
13
  """Agent action structure"""
14
 
15
- actionType: Literal[
16
- "click",
17
- "write",
18
- "press",
19
- "scroll",
20
- "wait",
21
- "open",
22
- "launch_app",
23
- "refresh",
24
- "go_back",
25
- ]
26
- actionArguments: dict
27
 
28
  def to_string(self) -> str:
29
  """Convert action to a human-readable string"""
30
- action_type = self.actionType
31
- args = self.actionArguments
32
 
33
  if action_type == "click":
34
- x = args.get("x", "?")
35
- y = args.get("y", "?")
36
  return f"Click at coordinates ({x}, {y})"
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  elif action_type == "write":
39
- text = args.get("text", "")
40
  return f"Type text: '{text}'"
41
 
42
  elif action_type == "press":
43
- key = args.get("key", "")
44
  return f"Press key: {key}"
45
 
 
 
 
 
 
 
 
 
 
 
46
  elif action_type == "scroll":
47
- direction = args.get("direction", "down")
48
- amount = args.get("amount", 2)
 
 
49
  return f"Scroll {direction} by {amount}"
50
 
51
  elif action_type == "wait":
52
- seconds = args.get("seconds", 0)
53
  return f"Wait for {seconds} seconds"
54
 
55
  elif action_type == "open":
56
- file_or_url = args.get("file_or_url", "")
57
- return f"Open: {file_or_url}"
58
-
59
- elif action_type == "launch_app":
60
- app_name = args.get("app_name", "")
61
- return f"Launch app: {app_name}"
62
 
63
- elif action_type == "refresh":
64
- return "Refresh the current page"
 
65
 
66
- elif action_type == "go_back":
67
- return "Go back one page"
68
 
69
 
70
  class AgentStep(BaseModel):
@@ -85,10 +106,10 @@ class AgentStep(BaseModel):
85
  def serialize_actions(self, actions: list[AgentAction], _info):
86
  """Convert actions to list of strings when dumping (controlled by context)"""
87
 
88
- if _info.context and _info.context.get("actions_as_json", False):
89
  return [action.model_dump(mode="json") for action in actions]
90
 
91
- return [action.to_string() for action in actions]
92
 
93
 
94
  class AgentTraceMetadata(BaseModel):
@@ -100,6 +121,7 @@ class AgentTraceMetadata(BaseModel):
100
  duration: float = 0.0 # in seconds
101
  numberOfSteps: int = 0
102
  maxSteps: int = 0
 
103
 
104
 
105
  class AgentTrace(BaseModel):
@@ -204,29 +226,54 @@ class ActiveTask(BaseModel):
204
 
205
  message_id: str
206
  instruction: str
207
- modelId: str
208
  timestamp: datetime = datetime.now()
209
  steps: list[AgentStep] = []
210
  traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
 
211
 
212
  @property
213
  def trace_path(self):
214
  """Trace path"""
215
- return f"data/trace-{self.message_id}-{self.modelId}"
216
 
217
  @model_validator(mode="after")
218
  def store_model(self):
219
  """Validate model ID"""
220
- self.traceMetadata.traceId = self.message_id
221
- os.makedirs(self.trace_path, exist_ok=True)
222
- with open(f"{self.trace_path}/tasks.json", "w") as f:
223
- json.dump(
224
- self.model_dump(mode="json", context={"actions_as_json": True}),
225
- f,
226
- indent=2,
227
- )
228
-
229
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
  class HealthResponse(BaseModel):
@@ -249,3 +296,22 @@ class ActiveTasksResponse(BaseModel):
249
 
250
  active_tasks: dict[str, ActiveTask]
251
  total_connections: int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
+ import threading
4
  from datetime import datetime
5
  from typing import Annotated, Literal, Optional
6
 
7
+ from cua2_core.services.agent_utils.function_parser import FunctionCall
8
+ from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
9
  from typing_extensions import TypeAlias
10
 
11
  #################### Backend -> Frontend ########################
12
 
13
 
14
+ class AgentAction(FunctionCall):
15
  """Agent action structure"""
16
 
17
+ @classmethod
18
+ def from_function_calls(
19
+ cls, function_calls: list[FunctionCall]
20
+ ) -> list["AgentAction"]:
21
+ list_of_actions = [cls(**action.model_dump()) for action in function_calls]
22
+ for action in list_of_actions:
23
+ action.description = action.to_string()
24
+ return list_of_actions
 
 
 
 
25
 
26
  def to_string(self) -> str:
27
  """Convert action to a human-readable string"""
28
+ action_type = self.function_name
29
+ args = self.parameters
30
 
31
  if action_type == "click":
32
+ x = args.get("x") or args.get("arg_0")
33
+ y = args.get("y") or args.get("arg_1")
34
  return f"Click at coordinates ({x}, {y})"
35
 
36
+ if action_type == "right_click":
37
+ x = args.get("x") or args.get("arg_0")
38
+ y = args.get("y") or args.get("arg_1")
39
+ return f"Right click at coordinates ({x}, {y})"
40
+
41
+ if action_type == "double_click":
42
+ x = args.get("x") or args.get("arg_0")
43
+ y = args.get("y") or args.get("arg_1")
44
+ return f"Right click at coordinates ({x}, {y})"
45
+
46
+ if action_type == "move_mouse":
47
+ x = args.get("x") or args.get("arg_0")
48
+ y = args.get("y") or args.get("arg_1")
49
+ return f"Move mouse to coordinates ({x}, {y})"
50
+
51
  elif action_type == "write":
52
+ text = args.get("text") or args.get("arg_0")
53
  return f"Type text: '{text}'"
54
 
55
  elif action_type == "press":
56
+ key = args.get("key") or args.get("arg_0")
57
  return f"Press key: {key}"
58
 
59
+ elif action_type == "go_back":
60
+ return "Go back one page"
61
+
62
+ elif action_type == "drag":
63
+ x1 = args.get("x1") or args.get("arg_0")
64
+ y1 = args.get("y1") or args.get("arg_1")
65
+ x2 = args.get("x2") or args.get("arg_2")
66
+ y2 = args.get("y2") or args.get("arg_3")
67
+ return f"Drag from ({x1}, {y1}) to ({x2}, {y2})"
68
+
69
  elif action_type == "scroll":
70
+ x = args.get("x") or args.get("arg_0")
71
+ y = args.get("y") or args.get("arg_1")
72
+ direction = args.get("direction") or args.get("arg_2")
73
+ amount = args.get("amount") or args.get("arg_3") or 2
74
  return f"Scroll {direction} by {amount}"
75
 
76
  elif action_type == "wait":
77
+ seconds = args.get("seconds") or args.get("arg_0")
78
  return f"Wait for {seconds} seconds"
79
 
80
  elif action_type == "open":
81
+ url = args.get("url") or args.get("arg_0")
82
+ return f"Open: {url}"
 
 
 
 
83
 
84
+ elif action_type == "final_answer":
85
+ answer = args.get("answer") or args.get("arg_0")
86
+ return f"Final answer: {answer}"
87
 
88
+ return "Unknown action"
 
89
 
90
 
91
  class AgentStep(BaseModel):
 
106
  def serialize_actions(self, actions: list[AgentAction], _info):
107
  """Convert actions to list of strings when dumping (controlled by context)"""
108
 
109
+ if _info.context and _info.context.get("actions_as_json", True):
110
  return [action.model_dump(mode="json") for action in actions]
111
 
112
+ return [action.description for action in actions]
113
 
114
 
115
  class AgentTraceMetadata(BaseModel):
 
121
  duration: float = 0.0 # in seconds
122
  numberOfSteps: int = 0
123
  maxSteps: int = 0
124
+ completed: bool = False
125
 
126
 
127
  class AgentTrace(BaseModel):
 
226
 
227
  message_id: str
228
  instruction: str
229
+ model_id: str
230
  timestamp: datetime = datetime.now()
231
  steps: list[AgentStep] = []
232
  traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
233
+ _file_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
234
 
235
  @property
236
  def trace_path(self):
237
  """Trace path"""
238
+ return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
239
 
240
  @model_validator(mode="after")
241
  def store_model(self):
242
  """Validate model ID"""
243
+ with self._file_lock:
244
+ os.makedirs(self.trace_path, exist_ok=True)
245
+ with open(f"{self.trace_path}/tasks.json", "w") as f:
246
+ json.dump(
247
+ self.model_dump(
248
+ mode="json",
249
+ exclude={"_file_locks"},
250
+ context={"actions_as_json": True},
251
+ ),
252
+ f,
253
+ indent=2,
254
+ )
255
+
256
+ def update_step(self, step: AgentStep):
257
+ """Update step"""
258
+ with self._file_lock:
259
+ if int(step.stepId) <= len(self.steps):
260
+ self.steps[int(step.stepId) - 1] = step
261
+ else:
262
+ self.steps.append(step)
263
+ self.traceMetadata.numberOfSteps = len(self.steps)
264
+ with open(f"{self.trace_path}/tasks.json", "w") as f:
265
+ json.dump(
266
+ self.model_dump(
267
+ mode="json",
268
+ exclude={"_file_locks"},
269
+ context={"actions_as_json": True},
270
+ ),
271
+ f,
272
+ indent=2,
273
+ )
274
+
275
+
276
+ #################### API Routes Models ########################
277
 
278
 
279
  class HealthResponse(BaseModel):
 
296
 
297
  active_tasks: dict[str, ActiveTask]
298
  total_connections: int
299
+
300
+
301
+ class UpdateStepRequest(BaseModel):
302
+ """Request model for updating a step"""
303
+
304
+ step_evaluation: Literal["like", "dislike", "neutral"]
305
+
306
+
307
+ class UpdateStepResponse(BaseModel):
308
+ """Response model for step update"""
309
+
310
+ success: bool
311
+ message: str
312
+
313
+
314
+ class AvailableModelsResponse(BaseModel):
315
+ """Response for available models"""
316
+
317
+ models: list[str]
cua2-core/src/cua2_core/routes/routes.py CHANGED
@@ -2,11 +2,13 @@ from datetime import datetime
2
 
3
  # Get services from app state
4
  from cua2_core.models.models import (
5
- ActiveTasksResponse,
6
  HealthResponse,
7
- TaskStatusResponse,
 
8
  )
9
  from cua2_core.services.agent_service import AgentService
 
10
  from cua2_core.websocket.websocket_manager import WebSocketManager
11
  from fastapi import APIRouter, Depends, HTTPException, Request
12
 
@@ -36,24 +38,31 @@ async def health_check(
36
  )
37
 
38
 
39
- @router.get("/tasks", response_model=ActiveTasksResponse)
40
- async def get_active_tasks(
41
- agent_service: AgentService = Depends(get_agent_service),
42
- websocket_manager: WebSocketManager = Depends(get_websocket_manager),
43
- ):
44
- """Get currently active tasks"""
45
- return ActiveTasksResponse(
46
- active_tasks=agent_service.get_active_tasks(),
47
- total_connections=websocket_manager.get_connection_count(),
48
- )
49
 
50
 
51
- @router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
52
- async def get_task_status(
53
- task_id: str, agent_service: AgentService = Depends(get_agent_service)
 
 
 
54
  ):
55
- """Get status of a specific task"""
56
- task_status = agent_service.get_task_status(task_id)
57
- if task_status is None:
58
- raise HTTPException(status_code=404, detail="Task not found")
59
- return TaskStatusResponse(task_id=task_id, status=task_status)
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # Get services from app state
4
  from cua2_core.models.models import (
5
+ AvailableModelsResponse,
6
  HealthResponse,
7
+ UpdateStepRequest,
8
+ UpdateStepResponse,
9
  )
10
  from cua2_core.services.agent_service import AgentService
11
+ from cua2_core.services.agent_utils.get_model import AVAILABLE_MODELS
12
  from cua2_core.websocket.websocket_manager import WebSocketManager
13
  from fastapi import APIRouter, Depends, HTTPException, Request
14
 
 
38
  )
39
 
40
 
41
+ @router.get("/models", response_model=AvailableModelsResponse)
42
+ async def get_available_models():
43
+ """Get list of all available model IDs"""
44
+ return AvailableModelsResponse(models=AVAILABLE_MODELS)
 
 
 
 
 
 
45
 
46
 
47
+ @router.patch("/traces/{trace_id}/steps/{step_id}", response_model=UpdateStepResponse)
48
+ async def update_trace_step(
49
+ trace_id: str,
50
+ step_id: str,
51
+ request: UpdateStepRequest,
52
+ agent_service: AgentService = Depends(get_agent_service),
53
  ):
54
+ """Update a specific step in a trace (e.g., update step evaluation)"""
55
+ try:
56
+ agent_service.update_trace_step(
57
+ trace_id=trace_id,
58
+ step_id=step_id,
59
+ step_evaluation=request.step_evaluation,
60
+ )
61
+ return UpdateStepResponse(
62
+ success=True,
63
+ message="Step updated successfully",
64
+ )
65
+ except ValueError as e:
66
+ raise HTTPException(status_code=400, detail=str(e))
67
+ except FileNotFoundError as e:
68
+ raise HTTPException(status_code=404, detail=str(e))
cua2-core/src/cua2_core/routes/websocket.py CHANGED
@@ -3,6 +3,8 @@ import json
3
  # Get services from app state
4
  from cua2_core.app import app
5
  from cua2_core.models.models import AgentTrace, HeartbeatEvent
 
 
6
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
7
 
8
  # Create router
@@ -13,15 +15,15 @@ router = APIRouter()
13
  async def websocket_endpoint(websocket: WebSocket):
14
  """WebSocket endpoint for real-time communication"""
15
 
16
- websocket_manager = app.state.websocket_manager
17
- agent_service = app.state.agent_service
18
 
19
  await websocket_manager.connect(websocket)
20
 
21
  try:
22
  # Send welcome heartbeat
23
  welcome_message = HeartbeatEvent(type="heartbeat")
24
- await websocket_manager.send_personal_message(welcome_message, websocket)
25
 
26
  # Keep the connection alive and wait for messages
27
  while True:
@@ -50,7 +52,9 @@ async def websocket_endpoint(websocket: WebSocket):
50
  trace = AgentTrace(**trace_data)
51
 
52
  # Process the user task with the trace
53
- trace_id = await agent_service.process_user_task(trace)
 
 
54
  print(f"Started processing trace: {trace_id}")
55
  else:
56
  print("No trace data in message")
@@ -62,9 +66,7 @@ async def websocket_endpoint(websocket: WebSocket):
62
  error_response = AgentErrorEvent(
63
  type="agent_error", error="Invalid JSON format"
64
  )
65
- await websocket_manager.send_personal_message(
66
- error_response, websocket
67
- )
68
 
69
  except Exception as e:
70
  print(f"Error processing message: {e}")
@@ -76,9 +78,7 @@ async def websocket_endpoint(websocket: WebSocket):
76
  error_response = AgentErrorEvent(
77
  type="agent_error", error=f"Error processing message: {str(e)}"
78
  )
79
- await websocket_manager.send_personal_message(
80
- error_response, websocket
81
- )
82
 
83
  except Exception as e:
84
  print(f"Error receiving WebSocket message: {e}")
 
3
  # Get services from app state
4
  from cua2_core.app import app
5
  from cua2_core.models.models import AgentTrace, HeartbeatEvent
6
+ from cua2_core.services.agent_service import AgentService
7
+ from cua2_core.websocket.websocket_manager import WebSocketManager
8
  from fastapi import APIRouter, WebSocket, WebSocketDisconnect
9
 
10
  # Create router
 
15
  async def websocket_endpoint(websocket: WebSocket):
16
  """WebSocket endpoint for real-time communication"""
17
 
18
+ websocket_manager: WebSocketManager = app.state.websocket_manager
19
+ agent_service: AgentService = app.state.agent_service
20
 
21
  await websocket_manager.connect(websocket)
22
 
23
  try:
24
  # Send welcome heartbeat
25
  welcome_message = HeartbeatEvent(type="heartbeat")
26
+ await websocket_manager.send_message(welcome_message, websocket)
27
 
28
  # Keep the connection alive and wait for messages
29
  while True:
 
52
  trace = AgentTrace(**trace_data)
53
 
54
  # Process the user task with the trace
55
+ trace_id = await agent_service.process_user_task(
56
+ trace, websocket
57
+ )
58
  print(f"Started processing trace: {trace_id}")
59
  else:
60
  print("No trace data in message")
 
66
  error_response = AgentErrorEvent(
67
  type="agent_error", error="Invalid JSON format"
68
  )
69
+ await websocket_manager.send_message(error_response, websocket)
 
 
70
 
71
  except Exception as e:
72
  print(f"Error processing message: {e}")
 
78
  error_response = AgentErrorEvent(
79
  type="agent_error", error=f"Error processing message: {str(e)}"
80
  )
81
+ await websocket_manager.send_message(error_response, websocket)
 
 
82
 
83
  except Exception as e:
84
  print(f"Error receiving WebSocket message: {e}")
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -1,39 +1,45 @@
1
  import asyncio
2
  import base64
3
  import json
4
- from pathlib import Path
5
- from typing import Optional
 
 
 
6
 
7
  from cua2_core.models.models import (
8
  ActiveTask,
9
  AgentAction,
10
- AgentCompleteEvent,
11
- AgentErrorEvent,
12
- AgentProgressEvent,
13
- AgentStartEvent,
14
  AgentStep,
15
  AgentTrace,
16
  AgentTraceMetadata,
17
- VncUrlSetEvent,
18
- VncUrlUnsetEvent,
19
  )
20
- from cua2_core.websocket.websocket_manager import WebSocketManager
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  class AgentService:
24
  """Service for handling agent tasks and processing"""
25
 
26
- def __init__(self, websocket_manager):
 
 
27
  self.active_tasks: dict[str, ActiveTask] = {}
28
  self.websocket_manager: WebSocketManager = websocket_manager
29
- self.simulation_data_path = (
30
- Path(__file__).parent / "simulation_metadata" / "simulated_trace.json"
31
- )
32
- self.simulation_images_path = (
33
- Path(__file__).parent / "simulation_metadata" / "images"
34
- )
35
 
36
- async def process_user_task(self, trace: AgentTrace) -> str:
37
  """Process a user task and return the trace ID"""
38
 
39
  trace_id = trace.id
@@ -44,123 +50,326 @@ class AgentService:
44
  self.active_tasks[trace_id] = ActiveTask(
45
  message_id=trace_id,
46
  instruction=trace.instruction,
47
- modelId=trace.modelId,
48
  timestamp=trace.timestamp,
49
  steps=trace.steps,
50
  traceMetadata=trace.traceMetadata,
51
  )
52
 
53
- # Start the agent processing in the background
54
- asyncio.create_task(self._simulate_agent_processing(trace))
 
 
55
 
56
  return trace_id
57
 
58
- async def _simulate_agent_processing(self, trace: AgentTrace):
59
- """Simulate agent processing using simulated_trace.json data"""
60
- trace_id = trace.id
 
 
 
 
 
 
 
 
61
 
62
  try:
63
- # Load simulation data
64
- with open(self.simulation_data_path, "r") as f:
65
- simulation_data = json.load(f)
 
 
 
66
 
67
- # Send agent start event with the initial trace
68
- start_event = AgentStartEvent(type="agent_start", agentTrace=trace)
69
- await self.websocket_manager.broadcast(start_event)
70
 
71
- # mock VNC URL
72
- vnc_url = "https://www.youtube.com/embed/VCutEsRSJ5A?si=PT0ETJ7zIJ9ywhGW"
73
- vnc_set_event = VncUrlSetEvent(type="vnc_url_set", vncUrl=vnc_url)
74
- await self.websocket_manager.broadcast(vnc_set_event)
75
 
76
- trace_metadata = AgentTraceMetadata(traceId=trace_id, maxSteps=20)
 
77
 
78
- # Process each step from the simulation data
79
- for step_data in simulation_data["steps"]:
80
- # Wait before sending the next step to simulate processing time
81
- await asyncio.sleep(step_data["duration"])
 
 
82
 
83
- # Load and encode the image
84
- image_path = (
85
- self.simulation_images_path / step_data["image"].split("/")[-1]
86
- )
87
- with open(image_path, "rb") as img_file:
88
- image_bytes = img_file.read()
89
- image_base64 = f"data:image/png;base64,{base64.b64encode(image_bytes).decode('utf-8')}"
90
-
91
- # Convert actions to AgentAction objects
92
- actions = [
93
- AgentAction(
94
- actionType=action["actionType"],
95
- actionArguments=action["actionArguments"],
96
- )
97
- for action in step_data["actions"]
98
- ]
99
 
100
- # Create agent step
101
- agent_step = AgentStep(
102
- traceId=trace_id,
103
- stepId=step_data["stepId"],
104
- image=image_base64,
105
- thought=step_data["thought"],
106
- actions=actions,
107
- error="",
108
- duration=step_data["duration"],
109
- inputTokensUsed=step_data["inputTokensUsed"],
110
- outputTokensUsed=step_data["outputTokensUsed"],
111
- step_evaluation=step_data["step_evaluation"],
112
  )
 
 
 
 
113
 
114
- trace_metadata.numberOfSteps += 1
115
- trace_metadata.duration += step_data["duration"]
116
- trace_metadata.inputTokensUsed += step_data["inputTokensUsed"]
117
- trace_metadata.outputTokensUsed += step_data["outputTokensUsed"]
 
118
 
119
- # Send progress event
120
- progress_event = AgentProgressEvent(
121
- type="agent_progress",
122
- agentStep=agent_step,
123
- traceMetadata=trace_metadata,
124
- )
125
- await self.websocket_manager.broadcast(progress_event)
 
126
 
127
- # Update active task
128
- self.active_tasks[trace_id].steps.append(agent_step)
 
129
 
130
- # Unset VNC URL before completion
131
- vnc_unset_event = VncUrlUnsetEvent(type="vnc_url_unset")
132
- await self.websocket_manager.broadcast(vnc_unset_event)
133
 
 
 
 
 
 
 
 
 
134
  # Send completion event
135
- complete_event = AgentCompleteEvent(
136
- type="agent_complete", traceMetadata=trace_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  )
138
- await self.websocket_manager.broadcast(complete_event)
 
 
 
 
 
 
 
 
139
 
140
- # Update active task with final metadata
141
- self.active_tasks[trace_id].traceMetadata = trace_metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- # Clean up after a delay
144
- await asyncio.sleep(1)
145
- if trace_id in self.active_tasks:
146
- del self.active_tasks[trace_id]
147
 
148
- except Exception as e:
149
- print(f"Error in agent simulation: {str(e)}")
150
- # Send error event
151
- error_event = AgentErrorEvent(
152
- type="agent_error", error=f"Error processing task: {str(e)}"
 
 
 
153
  )
154
- await self.websocket_manager.broadcast(error_event)
155
 
156
- # Clean up
157
- if trace_id in self.active_tasks:
158
- del self.active_tasks[trace_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- def get_active_tasks(self) -> dict:
161
- """Get currently active tasks"""
162
- return self.active_tasks.copy()
163
 
164
- def get_task_status(self, message_id: str) -> Optional[dict]:
165
- """Get status of a specific task"""
166
- return self.active_tasks.get(message_id)
 
 
 
 
 
1
  import asyncio
2
  import base64
3
  import json
4
+ import logging
5
+ import os
6
+ import time
7
+ from io import BytesIO
8
+ from typing import Callable, Literal
9
 
10
  from cua2_core.models.models import (
11
  ActiveTask,
12
  AgentAction,
 
 
 
 
13
  AgentStep,
14
  AgentTrace,
15
  AgentTraceMetadata,
 
 
16
  )
17
+ from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
18
+ from cua2_core.services.agent_utils.function_parser import parse_function_call
19
+ from cua2_core.services.agent_utils.get_model import get_model
20
+ from cua2_core.services.sandbox_service import SandboxService
21
+ from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
22
+ from e2b_desktop import Sandbox
23
+ from fastapi import WebSocket
24
+ from PIL import Image
25
+ from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
26
+
27
+ logger = logging.getLogger(__name__)
28
 
29
 
30
  class AgentService:
31
  """Service for handling agent tasks and processing"""
32
 
33
+ def __init__(
34
+ self, websocket_manager: WebSocketManager, sandbox_service: SandboxService
35
+ ):
36
  self.active_tasks: dict[str, ActiveTask] = {}
37
  self.websocket_manager: WebSocketManager = websocket_manager
38
+ self.task_websockets: dict[str, WebSocket] = {}
39
+ self.sandbox_service: SandboxService = sandbox_service
40
+ self.last_screenshot: dict[str, AgentImage] = {}
 
 
 
41
 
42
+ async def process_user_task(self, trace: AgentTrace, websocket: WebSocket) -> str:
43
  """Process a user task and return the trace ID"""
44
 
45
  trace_id = trace.id
 
50
  self.active_tasks[trace_id] = ActiveTask(
51
  message_id=trace_id,
52
  instruction=trace.instruction,
53
+ model_id=trace.modelId,
54
  timestamp=trace.timestamp,
55
  steps=trace.steps,
56
  traceMetadata=trace.traceMetadata,
57
  )
58
 
59
+ # Store the websocket for this task
60
+ self.task_websockets[trace_id] = websocket
61
+
62
+ asyncio.create_task(self._agent_processing(trace_id))
63
 
64
  return trace_id
65
 
66
+ async def _agent_runner(
67
+ self,
68
+ message_id: str,
69
+ step_callback: Callable[[ActionStep, E2BVisionAgent], None],
70
+ ):
71
+ """Run the task with the appropriate agent"""
72
+
73
+ sandbox: Sandbox | None = None
74
+ agent = None
75
+ novnc_active = False
76
+ websocket_exception = False
77
 
78
  try:
79
+ # Get the websocket for this task
80
+ websocket = self.task_websockets.get(message_id)
81
+
82
+ await self.websocket_manager.send_agent_start(
83
+ active_task=self.active_tasks[message_id], websocket=websocket
84
+ )
85
 
86
+ model = get_model(self.active_tasks[message_id].model_id)
 
 
87
 
88
+ # Acquire a sandbox from the pool
89
+ sandbox = await self.sandbox_service.acquire_sandbox(message_id)
90
+ if sandbox is None:
91
+ raise Exception("No sandbox available: pool limit reached")
92
 
93
+ data_dir = self.active_tasks[message_id].trace_path
94
+ user_content = self.active_tasks[message_id].instruction
95
 
96
+ agent = E2BVisionAgent(
97
+ model=model,
98
+ data_dir=data_dir,
99
+ desktop=sandbox,
100
+ step_callbacks=[step_callback],
101
+ )
102
 
103
+ self.active_tasks[message_id].traceMetadata.maxSteps = agent.max_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ await self.websocket_manager.send_vnc_url_set(
106
+ vnc_url=sandbox.stream.get_url(
107
+ auto_connect=True,
108
+ view_only=True,
109
+ resize="scale",
110
+ auth_key=sandbox.stream.get_auth_key(),
 
 
 
 
 
 
111
  )
112
+ or "",
113
+ websocket=websocket,
114
+ )
115
+ novnc_active = True
116
 
117
+ step_filename = f"{message_id}-1"
118
+ screenshot_bytes = agent.desktop.screenshot()
119
+ image = Image.open(BytesIO(screenshot_bytes))
120
+ screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
121
+ image.save(screenshot_path)
122
 
123
+ self.last_screenshot[message_id] = image
124
+
125
+ await asyncio.to_thread(
126
+ agent.run,
127
+ user_content,
128
+ )
129
+
130
+ self.active_tasks[message_id].traceMetadata.completed = True
131
 
132
+ except WebSocketException:
133
+ websocket_exception = True
134
+ pass
135
 
136
+ except (Exception, KeyboardInterrupt):
137
+ import traceback
 
138
 
139
+ logger.error(
140
+ f"Error processing task: {traceback.format_exc()}", exc_info=True
141
+ )
142
+ await self.websocket_manager.send_agent_error(
143
+ error="Error processing task", websocket=websocket
144
+ )
145
+
146
+ finally:
147
  # Send completion event
148
+ if not websocket_exception:
149
+ await self.websocket_manager.send_agent_complete(
150
+ metadata=self.active_tasks[message_id].traceMetadata,
151
+ websocket=websocket,
152
+ )
153
+
154
+ if novnc_active:
155
+ await self.websocket_manager.send_vnc_url_unset(websocket=websocket)
156
+
157
+ novnc_active = False
158
+
159
+ # Clean up
160
+ if message_id in self.active_tasks:
161
+ self.active_tasks[message_id].store_model()
162
+ del self.active_tasks[message_id]
163
+
164
+ # Clean up websocket reference
165
+ if message_id in self.task_websockets:
166
+ del self.task_websockets[message_id]
167
+
168
+ if message_id in self.last_screenshot:
169
+ del self.last_screenshot[message_id]
170
+
171
+ # Release sandbox back to the pool
172
+ if sandbox:
173
+ await self.sandbox_service.release_sandbox(sandbox)
174
+
175
+ async def _agent_processing(
176
+ self,
177
+ message_id: str,
178
+ ):
179
+ """Process the user task with the appropriate agent"""
180
+
181
+ # Set up log file for this task
182
+ active_task = self.active_tasks[message_id]
183
+
184
+ # Ensure the directory exists
185
+ os.makedirs(active_task.trace_path, exist_ok=True)
186
+
187
+ # Capture the event loop reference in the async context
188
+ # This will be used in the callback to safely schedule coroutines from the worker thread
189
+ loop = asyncio.get_running_loop()
190
+
191
+ def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
192
+ assert memory_step.step_number is not None
193
+
194
+ time.sleep(3)
195
+
196
+ if message_id in self.last_screenshot:
197
+ memory_step.observations_images = [
198
+ self.last_screenshot[message_id].copy()
199
+ ]
200
+ else:
201
+ image = self.last_screenshot[message_id]
202
+ # agent.last_marked_screenshot = AgentImage(screenshot_path)
203
+
204
+ for previous_memory_step in (
205
+ agent.memory.steps
206
+ ): # Remove previous screenshots from logs for lean processing
207
+ if (
208
+ isinstance(previous_memory_step, ActionStep)
209
+ and previous_memory_step.step_number is not None
210
+ and previous_memory_step.step_number
211
+ <= memory_step.step_number - 1
212
+ ):
213
+ previous_memory_step.observations_images = None
214
+ elif isinstance(previous_memory_step, TaskStep):
215
+ previous_memory_step.task_images = None
216
+
217
+ memory_step.observations_images = [image.copy()]
218
+
219
+ model_output = (
220
+ memory_step.model_output_message.content
221
+ if memory_step.model_output_message
222
+ else None
223
+ )
224
+ if model_output is None and isinstance(
225
+ memory_step.error, AgentMaxStepsError
226
+ ):
227
+ model_output = memory_step.action_output
228
+
229
+ thought = (
230
+ model_output.split("```")[0].replace("\nAction:\n", "")
231
+ if model_output
232
+ and (
233
+ memory_step.error is None
234
+ or isinstance(memory_step.error, AgentMaxStepsError)
235
+ )
236
+ else None
237
+ )
238
+ action_sequence = (
239
+ model_output.split("```")[1]
240
+ if model_output and memory_step.error is None
241
+ else None
242
  )
243
+ if memory_step.observations_images:
244
+ image = memory_step.observations_images[0]
245
+ buffered = BytesIO()
246
+ image.save(buffered, format="PNG")
247
+ image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
248
+ del buffered
249
+ del image
250
+ else:
251
+ image_base64 = None
252
 
253
+ step = AgentStep(
254
+ traceId=message_id,
255
+ stepId=str(memory_step.step_number),
256
+ image=image_base64,
257
+ thought=thought,
258
+ actions=AgentAction.from_function_calls(
259
+ parse_function_call(action_sequence)
260
+ )
261
+ if action_sequence
262
+ else None,
263
+ error=memory_step.error.message if memory_step.error else None,
264
+ duration=memory_step.timing.duration,
265
+ inputTokensUsed=memory_step.token_usage.input_tokens,
266
+ outputTokensUsed=memory_step.token_usage.output_tokens,
267
+ step_evaluation="neutral",
268
+ )
269
+ self.active_tasks[
270
+ message_id
271
+ ].traceMetadata.inputTokensUsed += memory_step.token_usage.input_tokens
272
+ self.active_tasks[
273
+ message_id
274
+ ].traceMetadata.outputTokensUsed += memory_step.token_usage.output_tokens
275
+ self.active_tasks[message_id].traceMetadata.numberOfSteps += 1
276
+ self.active_tasks[
277
+ message_id
278
+ ].traceMetadata.duration += memory_step.timing.duration
279
 
280
+ # Add step to active task
281
+ self.active_tasks[message_id].update_step(step)
 
 
282
 
283
+ websocket = self.task_websockets.get(message_id)
284
+ future = asyncio.run_coroutine_threadsafe(
285
+ self.websocket_manager.send_agent_progress(
286
+ step=step,
287
+ metadata=self.active_tasks[message_id].traceMetadata,
288
+ websocket=websocket,
289
+ ),
290
+ loop,
291
  )
292
+ future.result()
293
 
294
+ step_filename = f"{message_id}-{memory_step.step_number}"
295
+ screenshot_bytes = agent.desktop.screenshot()
296
+ image = Image.open(BytesIO(screenshot_bytes))
297
+ screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
298
+ image.save(screenshot_path)
299
+ del self.last_screenshot[message_id]
300
+ self.last_screenshot[message_id] = image
301
+
302
+ await self._agent_runner(message_id, step_callback)
303
+
304
+ def update_trace_step(
305
+ self,
306
+ trace_id: str,
307
+ step_id: str,
308
+ step_evaluation: Literal["like", "dislike", "neutral"],
309
+ ):
310
+ """
311
+ Update a specific step in a trace (e.g., update step evaluation)
312
+
313
+ Args:
314
+ trace_id: The trace ID
315
+ step_id: The step ID (1-indexed)
316
+ step_evaluation: The evaluation value to set
317
+
318
+ Returns:
319
+ The updated AgentStep
320
+
321
+ Raises:
322
+ ValueError: If step_id is invalid or step not found
323
+ FileNotFoundError: If trace not found
324
+ """
325
+ # Try to find in active tasks first
326
+ active_task = self.active_tasks.get(trace_id)
327
+
328
+ if active_task:
329
+ # Task is still active
330
+ try:
331
+ step_index = int(step_id) - 1
332
+ if 0 <= step_index < len(active_task.steps):
333
+ active_task.steps[step_index].step_evaluation = step_evaluation
334
+ active_task.update_step(active_task.steps[step_index])
335
+ else:
336
+ raise ValueError(f"Step {step_id} not found in trace")
337
+ except (ValueError, TypeError) as e:
338
+ raise ValueError(f"Invalid step_id format: {e}")
339
+ else:
340
+ # Task is not active, try to load from file
341
+ data_dir = "data"
342
+ trace_dirs = [
343
+ d for d in os.listdir(data_dir) if d.startswith(f"trace-{trace_id}")
344
+ ]
345
+
346
+ if not trace_dirs:
347
+ raise FileNotFoundError("Trace not found")
348
+
349
+ trace_path = os.path.join(data_dir, trace_dirs[0])
350
+ tasks_file = os.path.join(trace_path, "tasks.json")
351
+
352
+ if not os.path.exists(tasks_file):
353
+ raise FileNotFoundError("Trace data not found")
354
+
355
+ try:
356
+ # Load the trace data
357
+ with open(tasks_file, "r") as f:
358
+ task_data = json.load(f)
359
+
360
+ # Find and update the step
361
+ step_index = int(step_id) - 1
362
+ if 0 <= step_index < len(task_data["steps"]):
363
+ task_data["steps"][step_index]["step_evaluation"] = step_evaluation
364
 
365
+ # Save the updated data
366
+ with open(tasks_file, "w") as f:
367
+ json.dump(task_data, f, indent=2)
368
 
369
+ # Convert to AgentStep for response
370
+ updated_step = AgentStep(**task_data["steps"][step_index])
371
+ return updated_step
372
+ else:
373
+ raise ValueError(f"Step {step_id} not found in trace")
374
+ except (ValueError, KeyError, TypeError) as e:
375
+ raise ValueError(f"Error processing step update: {e}")
cua2-core/src/cua2_core/services/agent_utils/desktop_agent.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import unicodedata
4
+
5
+ from cua2_core.services.agent_utils.prompt import E2B_SYSTEM_PROMPT_TEMPLATE
6
+
7
+ # E2B imports
8
+ from e2b_desktop import Sandbox
9
+
10
+ # SmolaAgents imports
11
+ from smolagents import CodeAgent, Model, tool
12
+ from smolagents.monitoring import LogLevel
13
+
14
+
15
+ class E2BVisionAgent(CodeAgent):
16
+ """Agent for e2b desktop automation with Qwen2.5VL vision capabilities"""
17
+
18
+ def __init__(
19
+ self,
20
+ model: Model,
21
+ data_dir: str,
22
+ desktop: Sandbox,
23
+ max_steps: int = 200,
24
+ verbosity_level: LogLevel = 2,
25
+ planning_interval: int | None = None,
26
+ use_v1_prompt: bool = False,
27
+ **kwargs,
28
+ ):
29
+ self.desktop = desktop
30
+ self.data_dir = data_dir
31
+ self.planning_interval = planning_interval
32
+ # Initialize Desktop
33
+ self.width, self.height = self.desktop.get_screen_size()
34
+ print(f"Screen size: {self.width}x{self.height}")
35
+
36
+ # Set up temp directory
37
+ os.makedirs(self.data_dir, exist_ok=True)
38
+ print(f"Screenshots and steps will be saved to: {self.data_dir}")
39
+
40
+ self.use_v1_prompt = use_v1_prompt
41
+ # Initialize base agent
42
+ super().__init__(
43
+ tools=[],
44
+ model=model,
45
+ max_steps=max_steps,
46
+ verbosity_level=verbosity_level,
47
+ planning_interval=self.planning_interval,
48
+ stream_outputs=True,
49
+ **kwargs,
50
+ )
51
+ self.prompt_templates["system_prompt"] = E2B_SYSTEM_PROMPT_TEMPLATE.replace(
52
+ "<<resolution_x>>", str(self.width)
53
+ ).replace("<<resolution_y>>", str(self.height))
54
+
55
+ # Add screen info to state
56
+ self.state["screen_width"] = self.width
57
+ self.state["screen_height"] = self.height
58
+
59
+ # Add default tools
60
+ self.logger.log("Setting up agent tools...")
61
+ self._setup_desktop_tools()
62
+
63
+ def _setup_desktop_tools(self):
64
+ """Register all desktop tools"""
65
+
66
+ @tool
67
+ def click(x: int, y: int) -> str:
68
+ """
69
+ Performs a left-click at the specified coordinates
70
+ Args:
71
+ x: The x coordinate (horizontal position)
72
+ y: The y coordinate (vertical position)
73
+ """
74
+ self.desktop.move_mouse(x, y)
75
+ self.desktop.left_click()
76
+ self.click_coordinates = [x, y]
77
+ self.logger.log(f"Clicked at coordinates ({x}, {y})")
78
+ return f"Clicked at coordinates ({x}, {y})"
79
+
80
+ @tool
81
+ def right_click(x: int, y: int) -> str:
82
+ """
83
+ Performs a right-click at the specified coordinates
84
+ Args:
85
+ x: The x coordinate (horizontal position)
86
+ y: The y coordinate (vertical position)
87
+ """
88
+ self.desktop.move_mouse(x, y)
89
+ self.desktop.right_click()
90
+ self.click_coordinates = [x, y]
91
+ self.logger.log(f"Right-clicked at coordinates ({x}, {y})")
92
+ return f"Right-clicked at coordinates ({x}, {y})"
93
+
94
+ @tool
95
+ def double_click(x: int, y: int) -> str:
96
+ """
97
+ Performs a double-click at the specified coordinates
98
+ Args:
99
+ x: The x coordinate (horizontal position)
100
+ y: The y coordinate (vertical position)
101
+ """
102
+ self.desktop.move_mouse(x, y)
103
+ self.desktop.double_click()
104
+ self.click_coordinates = [x, y]
105
+ self.logger.log(f"Double-clicked at coordinates ({x}, {y})")
106
+ return f"Double-clicked at coordinates ({x}, {y})"
107
+
108
+ @tool
109
+ def move_mouse(x: int, y: int) -> str:
110
+ """
111
+ Moves the mouse cursor to the specified coordinates
112
+ Args:
113
+ x: The x coordinate (horizontal position)
114
+ y: The y coordinate (vertical position)
115
+ """
116
+ self.desktop.move_mouse(x, y)
117
+ self.logger.log(f"Moved mouse to coordinates ({x}, {y})")
118
+ return f"Moved mouse to coordinates ({x}, {y})"
119
+
120
+ def normalize_text(text):
121
+ return "".join(
122
+ c
123
+ for c in unicodedata.normalize("NFD", text)
124
+ if not unicodedata.combining(c)
125
+ )
126
+
127
+ @tool
128
+ def write(text: str) -> str:
129
+ """
130
+ Types the specified text at the current cursor position.
131
+ Args:
132
+ text: The text to type
133
+ """
134
+ clean_text = normalize_text(text)
135
+ self.desktop.write(clean_text, delay_in_ms=75)
136
+ self.logger.log(f"Typed text: '{clean_text}'")
137
+ return f"Typed text: '{clean_text}'"
138
+
139
+ @tool
140
+ def press(key: str) -> str:
141
+ """
142
+ Presses a keyboard key
143
+ Args:
144
+ key: The key to press (e.g. "enter", "space", "backspace", etc.).
145
+ """
146
+ self.desktop.press(key)
147
+ self.logger.log(f"Pressed key: {key}")
148
+ return f"Pressed key: {key}"
149
+
150
+ @tool
151
+ def go_back() -> str:
152
+ """
153
+ Goes back to the previous page in the browser. If using this tool doesn't work, just click the button directly.
154
+ Args:
155
+ """
156
+ self.desktop.press(["alt", "left"])
157
+ self.logger.log("Went back one page")
158
+ return "Went back one page"
159
+
160
+ @tool
161
+ def drag(x1: int, y1: int, x2: int, y2: int) -> str:
162
+ """
163
+ Clicks [x1, y1], drags mouse to [x2, y2], then release click.
164
+ Args:
165
+ x1: origin x coordinate
166
+ y1: origin y coordinate
167
+ x2: end x coordinate
168
+ y2: end y coordinate
169
+ """
170
+ self.desktop.drag([x1, y1], [x2, y2])
171
+ message = f"Dragged and dropped from [{x1}, {y1}] to [{x2}, {y2}]"
172
+ self.logger.log(message)
173
+ return message
174
+
175
+ @tool
176
+ def scroll(x: int, y: int, direction: str = "down", amount: int = 2) -> str:
177
+ """
178
+ Moves the mouse to selected coordinates, then uses the scroll button: this could scroll the page or zoom, depending on the app. DO NOT use scroll to move through linux desktop menus.
179
+ Args:
180
+ x: The x coordinate (horizontal position) of the element to scroll/zoom
181
+ y: The y coordinate (vertical position) of the element to scroll/zoom
182
+ direction: The direction to scroll ("up" or "down"), defaults to "down". For zoom, "up" zooms in, "down" zooms out.
183
+ amount: The amount to scroll. A good amount is 1 or 2.
184
+ """
185
+ self.desktop.move_mouse(x, y)
186
+ self.desktop.scroll(direction=direction, amount=amount)
187
+ message = f"Scrolled {direction} by {amount}"
188
+ self.logger.log(message)
189
+ return message
190
+
191
+ @tool
192
+ def wait(seconds: float) -> str:
193
+ """
194
+ Waits for the specified number of seconds. Very useful in case the prior order is still executing (for example starting very heavy applications like browsers or office apps)
195
+ Args:
196
+ seconds: Number of seconds to wait, generally 3 is enough.
197
+ """
198
+ time.sleep(seconds)
199
+ self.logger.log(f"Waited for {seconds} seconds")
200
+ return f"Waited for {seconds} seconds"
201
+
202
+ @tool
203
+ def open(url: str) -> str:
204
+ """
205
+ Directly opens a browser with the specified url: use this at start of web searches rather than trying to click the browser.
206
+ Args:
207
+ url: The URL to open
208
+ """
209
+ # Make sure URL has http/https prefix
210
+ if not url.startswith(("http://", "https://")):
211
+ url = "https://" + url
212
+
213
+ self.desktop.open(url)
214
+ # Give it time to load
215
+ time.sleep(2)
216
+ self.logger.log(f"Opening URL: {url}")
217
+ return f"Opened URL: {url}"
218
+
219
+ # Register the tools
220
+ self.tools["click"] = click
221
+ self.tools["right_click"] = right_click
222
+ self.tools["double_click"] = double_click
223
+ self.tools["move_mouse"] = move_mouse
224
+ self.tools["write"] = write
225
+ self.tools["press"] = press
226
+ self.tools["scroll"] = scroll
227
+ self.tools["wait"] = wait
228
+ self.tools["open"] = open
229
+ self.tools["go_back"] = go_back
230
+ self.tools["drag"] = drag
231
+ self.tools["scroll"] = scroll
cua2-core/src/cua2_core/services/agent_utils/function_parser.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function parser for extracting function names, parameter names, and values from string function calls.
4
+ Supports both mobile and pyautogui function patterns.
5
+ """
6
+
7
+ import re
8
+ from collections import OrderedDict
9
+ from typing import Any, Dict, List, Tuple
10
+
11
+ from pydantic import BaseModel
12
+
13
+
14
+ class FunctionCall(BaseModel):
15
+ """Represents a parsed function call with its parameters."""
16
+
17
+ function_name: str
18
+ parameters: Dict[str, Any]
19
+ original_string: str
20
+ description: str = ""
21
+
22
+ def to_string(self) -> str:
23
+ """
24
+ Reconstruct the function call string from the parsed data.
25
+
26
+ Returns:
27
+ String representation of the function call
28
+
29
+ Examples:
30
+ >>> call = FunctionCall("mobile.wait", {"seconds": 3}, "mobile.wait(seconds=3)")
31
+ >>> call.to_string()
32
+ "mobile.wait(seconds=3)"
33
+
34
+ >>> call = FunctionCall("function", {"arg_0": 1, "arg_1": 2, "x": 0.5}, "function(1, 2, x=0.5)")
35
+ >>> call.to_string()
36
+ "function(1, 2, x=0.5)"
37
+ """
38
+ if not self.parameters:
39
+ return f"{self.function_name}()"
40
+
41
+ # Separate positional and named arguments
42
+ positional_args = []
43
+ named_args = []
44
+
45
+ for name, value in self.parameters.items():
46
+ if name.startswith("arg_"):
47
+ # Positional argument
48
+ positional_args.append((int(name.split("_")[1]), value))
49
+ else:
50
+ # kwargs
51
+ named_args.append((name, value))
52
+
53
+ # Sort positional arguments by index
54
+ positional_args.sort(key=lambda x: x[0])
55
+
56
+ # Build parameter string
57
+ param_parts = []
58
+
59
+ # Add positional arguments
60
+ for _, value in positional_args:
61
+ param_parts.append(self._value_to_string(value))
62
+
63
+ # Add named arguments
64
+ for name, value in named_args:
65
+ param_parts.append(f"{name}={self._value_to_string(value)}")
66
+
67
+ return f"{self.function_name}({', '.join(param_parts)})"
68
+
69
+ def _value_to_string(self, value: Any) -> str:
70
+ """
71
+ Convert a value to its string representation for function calls.
72
+
73
+ Args:
74
+ value: The value to convert
75
+
76
+ Returns:
77
+ String representation of the value
78
+ """
79
+ if isinstance(value, str):
80
+ # Quote strings
81
+ return f"'{value}'"
82
+ elif isinstance(value, (list, tuple)):
83
+ # Convert lists/tuples to string representation
84
+ items = [self._value_to_string(item) for item in value]
85
+ return f"[{', '.join(items)}]"
86
+ elif isinstance(value, dict):
87
+ # Convert dictionaries to string representation
88
+ items = [f"'{k}': {self._value_to_string(v)}" for k, v in value.items()]
89
+ return f"{{{', '.join(items)}}}"
90
+ elif isinstance(value, bool):
91
+ # Convert booleans to lowercase
92
+ return str(value).lower()
93
+ elif value is None:
94
+ return "None"
95
+ else:
96
+ # Numbers and other types
97
+ return str(value)
98
+
99
+
100
+ def parse_function_call(
101
+ function_string: str, pattern_to_match: list[str] = []
102
+ ) -> List[FunctionCall]:
103
+ """
104
+ Parse a function call string and extract all function calls found.
105
+
106
+ Args:
107
+ function_string: String representation of function calls
108
+
109
+ Returns:
110
+ List of FunctionCall objects with parsed information
111
+
112
+ Examples:
113
+ >>> parse_function_call("mobile.wait(seconds=3)")
114
+ [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
115
+
116
+ >>> parse_function_call("mobile. wait(seconds=3)")
117
+ [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...)]
118
+
119
+ >>> parse_function_call("mobile.wait(seconds=3) mobile.home()")
120
+ [FunctionCall(function_name='wait', parameters={'seconds': 3}, ...), FunctionCall(function_name='home', parameters={}, ...)]
121
+ """
122
+ # Remove any leading/trailing whitespace
123
+ function_string = function_string.strip()
124
+
125
+ # Pattern to match function calls with parameters
126
+ # Matches: function_name(param1=value1, param2=value2, ...)
127
+ # Can have any characters before the function call, extracts just the function name
128
+ pattern = r".*?([a-zA-Z_][a-zA-Z0-9_.]*)\(([^)]*)\)"
129
+
130
+ matches = re.findall(pattern, function_string)
131
+ if not matches:
132
+ # No valid function calls found in: {function_string}
133
+ return []
134
+
135
+ results = []
136
+ for match in matches:
137
+ function_name = match[0]
138
+ params_string = match[1]
139
+
140
+ if pattern_to_match and all(
141
+ pattern not in function_name for pattern in pattern_to_match
142
+ ):
143
+ continue
144
+
145
+ # Parse parameters
146
+ parameters = parse_parameters(params_string)
147
+
148
+ # Create the original string for this specific function call
149
+ original_string = f"{function_name}({params_string})"
150
+
151
+ results.append(
152
+ FunctionCall(
153
+ function_name=function_name,
154
+ parameters=parameters,
155
+ original_string=original_string,
156
+ )
157
+ )
158
+
159
+ return results
160
+
161
+
162
+ def parse_parameters(params_string: str) -> Dict[str, Any]:
163
+ """
164
+ Parse parameter string and extract parameter names and values.
165
+
166
+ Args:
167
+ params_string: String containing parameters (e.g., "x=0.5, y=0.6, text='hello'")
168
+
169
+ Returns:
170
+ Dictionary mapping parameter names to their values
171
+
172
+ Examples:
173
+ >>> parse_parameters("x=0.5, y=0.6")
174
+ {'x': 0.5, 'y': 0.6}
175
+
176
+ >>> parse_parameters("app_name='drupe'")
177
+ {'app_name': 'drupe'}
178
+
179
+ >>> parse_parameters("'text'")
180
+ {'arg_0': 'text'}
181
+
182
+ >>> parse_parameters("1, 3, 4")
183
+ {'arg_0': 1, 'arg_1': 3, 'arg_2': 4}
184
+
185
+ >>> parse_parameters("arg1, arg2, x=0.5")
186
+ {'arg_0': 'arg1', 'arg_1': 'arg2', 'x': 0.5}
187
+ """
188
+ if not params_string.strip():
189
+ return {}
190
+
191
+ parameters = OrderedDict()
192
+
193
+ # Split by commas, but be careful with commas inside quotes or brackets
194
+ param_parts = split_parameters(params_string)
195
+
196
+ positional_index = 0
197
+
198
+ for part in param_parts:
199
+ part = part.strip()
200
+ if not part:
201
+ continue
202
+
203
+ # Parse individual parameter
204
+ name, value = parse_single_parameter(part)
205
+
206
+ # For positional arguments, use index-based naming
207
+ if name.startswith("arg_"):
208
+ name = f"arg_{positional_index}"
209
+ positional_index += 1
210
+
211
+ parameters[name] = value
212
+
213
+ return parameters
214
+
215
+
216
+ def split_parameters(params_string: str) -> List[str]:
217
+ """
218
+ Split parameter string by commas, respecting quotes and brackets.
219
+
220
+ Args:
221
+ params_string: String containing parameters
222
+
223
+ Returns:
224
+ List of individual parameter strings
225
+ """
226
+ parts = []
227
+ current_part = ""
228
+ paren_count = 0
229
+ bracket_count = 0
230
+ brace_count = 0
231
+ in_quotes = False
232
+ quote_char = None
233
+
234
+ for char in params_string:
235
+ if char in ['"', "'"] and (not in_quotes or char == quote_char):
236
+ if not in_quotes:
237
+ in_quotes = True
238
+ quote_char = char
239
+ else:
240
+ in_quotes = False
241
+ quote_char = None
242
+ elif not in_quotes:
243
+ if char == "(":
244
+ paren_count += 1
245
+ elif char == ")":
246
+ paren_count -= 1
247
+ elif char == "[":
248
+ bracket_count += 1
249
+ elif char == "]":
250
+ bracket_count -= 1
251
+ elif char == "{":
252
+ brace_count += 1
253
+ elif char == "}":
254
+ brace_count -= 1
255
+ elif (
256
+ char == ","
257
+ and paren_count == 0
258
+ and bracket_count == 0
259
+ and brace_count == 0
260
+ ):
261
+ parts.append(current_part.strip())
262
+ current_part = ""
263
+ continue
264
+
265
+ current_part += char
266
+
267
+ if current_part.strip():
268
+ parts.append(current_part.strip())
269
+
270
+ return parts
271
+
272
+
273
+ def parse_single_parameter(param_string: str) -> Tuple[str, Any]:
274
+ """
275
+ Parse a single parameter string into name and value.
276
+
277
+ Args:
278
+ param_string: String like "x=0.5" or "app_name='drupe'" or just "value"
279
+
280
+ Returns:
281
+ Tuple of (parameter_name, parameter_value)
282
+
283
+ Examples:
284
+ >>> parse_single_parameter("x=0.5")
285
+ ('x', 0.5)
286
+
287
+ >>> parse_single_parameter("app_name='drupe'")
288
+ ('app_name', 'drupe')
289
+
290
+ >>> parse_single_parameter("'text'")
291
+ ('arg_0', 'text')
292
+
293
+ >>> parse_single_parameter("123")
294
+ ('arg_0', 123)
295
+
296
+ >>> parse_single_parameter("3")
297
+ ('arg_0', 3)
298
+ """
299
+ # Pattern to match parameter name and value
300
+ pattern = r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.+)$"
301
+
302
+ match = re.match(pattern, param_string)
303
+ if match:
304
+ # Named parameter
305
+ param_name = match.group(1)
306
+ param_value_str = match.group(2).strip()
307
+ param_value = parse_value(param_value_str)
308
+ return param_name, param_value
309
+ else:
310
+ # Positional parameter - treat as unnamed argument
311
+ param_value = parse_value(param_string)
312
+ return "arg_0", param_value
313
+
314
+
315
+ def parse_value(value_string: str) -> Any:
316
+ """
317
+ Parse a value string into appropriate Python type.
318
+
319
+ Args:
320
+ value_string: String representation of a value
321
+
322
+ Returns:
323
+ Parsed value (int, float, str, list, etc.)
324
+
325
+ Examples:
326
+ >>> parse_value("3")
327
+ 3
328
+
329
+ >>> parse_value("3.14")
330
+ 3.14
331
+
332
+ >>> parse_value("'hello'")
333
+ 'hello'
334
+
335
+ >>> parse_value("[0.581, 0.898]")
336
+ [0.581, 0.898]
337
+ """
338
+ value_string = value_string.strip()
339
+
340
+ # String values (quoted)
341
+ if (value_string.startswith("'") and value_string.endswith("'")) or (
342
+ value_string.startswith('"') and value_string.endswith('"')
343
+ ):
344
+ return value_string[1:-1]
345
+
346
+ # List values
347
+ if value_string.startswith("[") and value_string.endswith("]"):
348
+ return parse_list(value_string)
349
+
350
+ # Dictionary values
351
+ if value_string.startswith("{") and value_string.endswith("}"):
352
+ return parse_dict(value_string)
353
+
354
+ # Boolean values
355
+ if value_string.lower() in ["true", "false"]:
356
+ return value_string.lower() == "true"
357
+
358
+ # None value
359
+ if value_string.lower() == "none":
360
+ return None
361
+
362
+ # Numeric values
363
+ try:
364
+ # Try integer first
365
+ if "." not in value_string:
366
+ return int(value_string)
367
+ else:
368
+ return float(value_string)
369
+ except ValueError:
370
+ # If it's not a number, return as string (remove quotes if present)
371
+ if value_string.startswith("'") and value_string.endswith("'"):
372
+ return value_string[1:-1]
373
+ elif value_string.startswith('"') and value_string.endswith('"'):
374
+ return value_string[1:-1]
375
+ else:
376
+ return value_string
377
+
378
+
379
+ def parse_list(list_string: str) -> List[Any]:
380
+ """
381
+ Parse a list string into a Python list.
382
+
383
+ Args:
384
+ list_string: String like "[0.581, 0.898]"
385
+
386
+ Returns:
387
+ List of parsed values
388
+
389
+ Examples:
390
+ >>> parse_list("[0.581, 0.898]")
391
+ [0.581, 0.898]
392
+ """
393
+ # Remove outer brackets
394
+ content = list_string[1:-1].strip()
395
+ if not content:
396
+ return []
397
+
398
+ # Split by commas, respecting nested structures
399
+ parts = split_parameters(content)
400
+
401
+ return [parse_value(part.strip()) for part in parts]
402
+
403
+
404
+ def parse_dict(dict_string: str) -> Dict[str, Any]:
405
+ """
406
+ Parse a dictionary string into a Python dict.
407
+
408
+ Args:
409
+ dict_string: String like "{'key': 'value'}"
410
+
411
+ Returns:
412
+ Dictionary of parsed key-value pairs
413
+ """
414
+ # Remove outer braces
415
+ content = dict_string[1:-1].strip()
416
+ if not content:
417
+ return {}
418
+
419
+ # Split by commas, respecting nested structures
420
+ parts = split_parameters(content)
421
+
422
+ result = {}
423
+ for part in parts:
424
+ part = part.strip()
425
+ if ":" in part:
426
+ key_str, value_str = part.split(":", 1)
427
+ key = parse_value(key_str.strip())
428
+ value = parse_value(value_str.strip())
429
+ result[key] = value
430
+
431
+ return result
432
+
433
+
434
+ def parse_multiple_functions(function_strings: List[str]) -> List[FunctionCall]:
435
+ """
436
+ Parse multiple function call strings.
437
+
438
+ Args:
439
+ function_strings: List of function call strings
440
+
441
+ Returns:
442
+ List of FunctionCall objects
443
+ """
444
+ results = []
445
+ for func_str in function_strings:
446
+ try:
447
+ result_list = parse_function_call(func_str)
448
+ results.extend(result_list)
449
+ except Exception as e:
450
+ print(f"Warning: Could not parse function call '{func_str}': {e}")
451
+ continue
452
+
453
+ return results
454
+
455
+
456
+ def extract_function_calls_from_text(text: str) -> List[FunctionCall]:
457
+ """
458
+ Extract and parse function calls from a text block.
459
+
460
+ Args:
461
+ text: Text containing function calls
462
+
463
+ Returns:
464
+ List of FunctionCall objects
465
+ """
466
+ # Pattern to find function calls in text
467
+ # Matches: function_name(param1=value1, param2=value2)
468
+ pattern = r"[a-zA-Z_][a-zA-Z0-9_.]*\([^)]*\)"
469
+
470
+ matches = re.findall(pattern, text)
471
+ return parse_multiple_functions(matches)
472
+
473
+
474
+ # Example usage and testing
475
+ if __name__ == "__main__":
476
+ test_cases = [
477
+ "mobile.home()",
478
+ "mobile.open_app(app_name='drupe')",
479
+ "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
480
+ "mobile.back()",
481
+ "mobile.long_press(x=0.799, y=0.911)",
482
+ "mobile.terminate(status='success')",
483
+ "answer('text')",
484
+ "pyautogui.hscroll(page=-0.1)",
485
+ "pyautogui.scroll(page=-0.1)",
486
+ "pyautogui.scroll(0.13)",
487
+ "pyautogui.click(x=0.8102, y=0.9463)",
488
+ "pyautogui.hotkey(keys=['ctrl', 'c'])",
489
+ "pyautogui.press(keys='enter')",
490
+ "pyautogui.press(keys=['enter'])",
491
+ "pyautogui.moveTo(x=0.04, y=0.405)",
492
+ "pyautogui.write(message='bread buns')",
493
+ "pyautogui.dragTo(x=0.8102, y=0.9463)",
494
+ "mobile.wait(seconds=3)\nmobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
495
+ # Additional test cases for multiple positional arguments
496
+ "function(arg1, arg2, arg3)",
497
+ "function('hello', 123, x=0.5)",
498
+ "function(arg1, arg2, named_param='value')",
499
+ "function(1, 2, 3, 4, 5)",
500
+ "function('a', 'b', 'c', x=1, y=2)",
501
+ ]
502
+
503
+ print("Testing function parser:")
504
+ print("=" * 50)
505
+
506
+ for test_case in test_cases:
507
+ try:
508
+ results = parse_function_call(test_case)
509
+ print(f"✓ {test_case}")
510
+ for result in results:
511
+ print(f" Function: {result.function_name}")
512
+ print(f" Parameters: {result.parameters}")
513
+ print()
514
+ except Exception as e:
515
+ print(f"✗ {test_case}")
516
+ print(f" Error: {e}")
517
+ print()
518
+
519
+ # Test extracting from text
520
+ print("Testing text extraction:")
521
+ print("=" * 50)
522
+
523
+ sample_text = """
524
+ mobile.wait(seconds=3)
525
+ mobile.open_app(app_name='drupe')
526
+ pyautogui.click(x=0.8102, y=0.9463)
527
+ pyautogui.write(message='bread buns')
528
+ """
529
+
530
+ extracted = extract_function_calls_from_text(sample_text)
531
+ for func_call in extracted:
532
+ print(f"Found: {func_call.function_name} with params: {func_call.parameters}")
533
+
534
+ # Test reconstruction
535
+ print("\nTesting function call reconstruction:")
536
+ print("=" * 50)
537
+
538
+ reconstruction_tests = [
539
+ "mobile.wait(seconds=3)",
540
+ "mobile.home()",
541
+ "mobile.open_app(app_name='drupe')",
542
+ "mobile.swipe(from_coord=[0.581, 0.898], to_coord=[0.601, 0.518])",
543
+ "answer('text')",
544
+ "pyautogui.scroll(0.13)",
545
+ "pyautogui.click(x=0.8102, y=0.9463)",
546
+ "pyautogui.hotkey(keys=['ctrl', 'c'])",
547
+ "function(1, 2, 3)",
548
+ "function('hello', 123, x=0.5, y=0.8)",
549
+ "function([1, 3], 'arg2', named_param='value')",
550
+ ]
551
+
552
+ for test_case in reconstruction_tests:
553
+ parsed_list = parse_function_call(test_case)
554
+ for parsed in parsed_list:
555
+ reconstructed = parsed.to_string()
556
+ print(f"Original: {test_case}")
557
+ print(f"Reconstructed: {reconstructed}")
558
+ print(f"Match: {test_case == reconstructed}")
559
+ assert test_case == reconstructed
560
+ print()
cua2-core/src/cua2_core/services/agent_utils/get_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import InferenceClientModel, Model
2
+
3
+ # Available model IDs
4
+ AVAILABLE_MODELS = [
5
+ "Qwen/Qwen3-VL-2B-Instruct",
6
+ "Qwen/Qwen3-VL-2B-Thinking",
7
+ "Qwen/Qwen3-VL-4B-Instruct",
8
+ "Qwen/Qwen3-VL-4B-Thinking",
9
+ "Qwen/Qwen3-VL-8B-Instruct",
10
+ "Qwen/Qwen3-VL-8B-Thinking",
11
+ "Qwen/Qwen3-VL-30B-A3B-Instruct",
12
+ "Qwen/Qwen3-VL-30B-A3B-Thinking",
13
+ ]
14
+
15
+
16
+ def get_model(model_id: str) -> Model:
17
+ """Get the model"""
18
+ return InferenceClientModel(model_id=model_id)
cua2-core/src/cua2_core/services/agent_utils/prompt.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ E2B_SYSTEM_PROMPT_TEMPLATE = """You are a computer-use automation assistant controlling a full desktop remotely.
4
+ The current date is <<current_date>>.
5
+
6
+ <mission>
7
+ Your objective is to complete a given task step-by-step by interacting with the desktop.
8
+ At every step, you:
9
+ 1. Observe the latest screenshot (always analyze it carefully).
10
+ 2. Reflect briefly on what you see and what to do next.
11
+ 3. Produce **one precise action**, formatted exactly as Python code in a fenced block.
12
+
13
+ You will receive a new screenshot after each action.
14
+ Never skip the structure below.
15
+ </mission>
16
+
17
+ ---
18
+
19
+ <action_process>
20
+ For every step, strictly follow this format:
21
+
22
+ Short term goal: what you’re trying to accomplish in this step.
23
+ What I see: describe key elements visible on the desktop.
24
+ Reflection: reasoning that justifies your next move (mention errors or corrections if needed).
25
+ **Action:**
26
+ ```python
27
+ click(x, y)
28
+ ```<end_code>
29
+ </action_process>
30
+
31
+ ---
32
+
33
+ <environment>
34
+ The desktop resolution is <<resolution_x>>x<<resolution_y>> pixels.
35
+ You can only interact through the following tools:
36
+
37
+ {%- for tool in tools.values() %}
38
+ - **{{ tool.name }}**: {{ tool.description }}
39
+ - Inputs: {{ tool.inputs }}
40
+ - Returns: {{ tool.output_type }}
41
+ {%- endfor %}
42
+
43
+ If a task requires a specific application or website, **use**:
44
+ ```python
45
+ open("app_or_url")
46
+ ```
47
+ to launch it before interacting.
48
+ Never manually click the browser icon — use `open_url()` directly for web pages.
49
+ </environment>
50
+
51
+ ---
52
+
53
+ <click_guidelines>
54
+ - Always click using **real, visible coordinates** based on the current screenshot.
55
+ - Click precisely **in the center** of the intended target (button, text, icon).
56
+ - Avoid random or approximate coordinates.
57
+ - If nothing changes after a click, check if you misclicked (green crosshair = last click position).
58
+ - If a menu item shows a ▶ (triangle), it means it expands—click directly on the text, not the icon.
59
+ - Use `scroll()` only within scrollable views (webpages, app lists, etc.).
60
+ </click_guidelines>
61
+
62
+ ---
63
+
64
+ <workflow_guidelines>
65
+ - **ALWAYS START** by analyzing if the task requires opening an application or URL. If so, your **first action** must be:
66
+ - For websites: `open_url("https://google.com")`
67
+ - For applications: `open("app_name")`
68
+ - Never manually navigate to apps via clicking icons—use the open tools directly.
69
+ - Complete one atomic action per step: e.g., **click**, **type**, or **wait**.
70
+ - Never combine multiple tool calls in one step.
71
+ - Validate that your previous action succeeded before continuing.
72
+ - If the interface hasn't changed, adjust your strategy instead of repeating endlessly.
73
+ - Use `wait(seconds)` for short delays if the interface is loading.
74
+ - Always conclude with:
75
+ ```python
76
+ final_answer("Answer the user's question or resume the task")
77
+ ```
78
+ once the task is fully completed and verified. Answer the user's question or resume the task.
79
+ </workflow_guidelines>
80
+
81
+ ---
82
+
83
+ <example>
84
+ Task: *Open a text editor and write “Hello World”*
85
+
86
+ Step 1
87
+ Short term goal: Launch the text editor.
88
+ What I see: “Text Editor” visible under Accessories.
89
+ Reflection: Clicking directly on “Text Editor”.
90
+ Action:
91
+ ```python
92
+ open("text_editor")
93
+ ```<end_code>
94
+
95
+ Step 2
96
+ Short term goal: click on the text editor page.
97
+ What I see: Text editor page.
98
+ Reflection: Click on the text editor page to write "Hello World".
99
+ Action:
100
+ ```python
101
+ click(52, 10)
102
+ ```<end_code>
103
+
104
+ Step 3
105
+ Short term goal: Type text.
106
+ What I see: Empty notepad open.
107
+ Reflection: Ready to type.
108
+ Action:
109
+ ```python
110
+ write("Hello World")
111
+ ```<end_code>
112
+
113
+ Step 3
114
+ Short term goal: Verify text and conclude.
115
+ What I see: “Hello World” visible in notepad.
116
+ Reflection: Task successful.
117
+ Action:
118
+ ```python
119
+ final_answer("The task is complete and the text 'Hello World' is visible in the notepad.")
120
+ ```<end_code>
121
+ </example>
122
+
123
+ ---
124
+
125
+ <core_principles>
126
+ - Think visually and spatially.
127
+ - Always ground your reasoning in what’s visible in the screenshot.
128
+ - Never assume what’s on the next screen.
129
+ - Always check the result of your last action.
130
+ - Be deliberate, consistent, and patient.
131
+ - **ALWAYS START** by analyzing if the task requires opening an application or URL. If so, your **first action** must be:
132
+ - For websites: `open_url("https://google.com")`
133
+ - For applications: `open("app_name")`
134
+ - **NEVER** manually navigate to apps via clicking icons—use the open tools directly.
135
+ </core_principles>
136
+ """.replace("<<current_date>>", datetime.now().strftime("%A, %d-%B-%Y"))
cua2-core/src/cua2_core/services/sandbox_service.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import time
5
+ from datetime import datetime
6
+ from typing import Any
7
+
8
+ from e2b_desktop import Sandbox
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
13
+ SANDBOX_TIMEOUT = 300
14
+ WIDTH = 1280
15
+ HEIGHT = 960
16
+
17
+
18
+ class SandboxService:
19
+ def __init__(self, max_sandboxes: int = 50):
20
+ if not os.getenv("E2B_API_KEY"):
21
+ raise ValueError("E2B_API_KEY is not set")
22
+ self.max_sandboxes = max_sandboxes
23
+ self.sandboxes: dict[str, Sandbox] = {}
24
+ self.sandbox_metadata: dict[str, dict[str, Any]] = {}
25
+ self.sandbox_lock = asyncio.Lock()
26
+
27
+ async def acquire_sandbox(self, session_hash: str) -> Sandbox | None:
28
+ async with self.sandbox_lock:
29
+ current_time = datetime.now()
30
+
31
+ if (
32
+ session_hash in self.sandboxes
33
+ and session_hash in self.sandbox_metadata
34
+ and current_time - self.sandbox_metadata[session_hash]["created_at"]
35
+ < SANDBOX_TIMEOUT
36
+ ):
37
+ print(f"Reusing Sandbox for session {session_hash}")
38
+ self.sandbox_metadata[session_hash]["last_accessed"] = current_time
39
+ return self.sandboxes[session_hash]
40
+
41
+ if session_hash in self.sandboxes:
42
+ try:
43
+ print(f"Closing expired sandbox for session {session_hash}")
44
+ await asyncio.to_thread(self.sandboxes[session_hash].kill)
45
+ except Exception as e:
46
+ print(f"Error closing expired sandbox: {str(e)}")
47
+ elif len(self.sandboxes) >= self.max_sandboxes:
48
+ return None
49
+
50
+ print(f"Creating new sandbox for session {session_hash}")
51
+
52
+ def create_and_setup_sandbox():
53
+ desktop = Sandbox.create(
54
+ api_key=os.getenv("E2B_API_KEY"),
55
+ resolution=(WIDTH, HEIGHT),
56
+ dpi=96,
57
+ timeout=SANDBOX_TIMEOUT,
58
+ template="k0wmnzir0zuzye6dndlw",
59
+ )
60
+ desktop.stream.start(require_auth=True)
61
+ setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
62
+ desktop.commands.run(setup_cmd)
63
+ time.sleep(3)
64
+ return desktop
65
+
66
+ desktop = await asyncio.to_thread(create_and_setup_sandbox)
67
+
68
+ print(f"Sandbox ID for session {session_hash} is {desktop.sandbox_id}.")
69
+
70
+ self.sandboxes[session_hash] = desktop
71
+ self.sandbox_metadata[session_hash] = {
72
+ "created_at": current_time,
73
+ "last_accessed": current_time,
74
+ }
75
+ return desktop
76
+
77
+ async def release_sandbox(self, session_hash: str):
78
+ async with self.sandbox_lock:
79
+ if session_hash in self.sandboxes:
80
+ print(f"Releasing sandbox for session {session_hash}")
81
+ await asyncio.to_thread(self.sandboxes[session_hash].kill)
82
+ del self.sandboxes[session_hash]
83
+ del self.sandbox_metadata[session_hash]
84
+
85
+ async def cleanup_sandboxes(self):
86
+ async with self.sandbox_lock:
87
+ for session_hash in list(self.sandboxes.keys()):
88
+ await asyncio.to_thread(self.sandboxes[session_hash].kill)
89
+ del self.sandboxes[session_hash]
90
+ del self.sandbox_metadata[session_hash]
cua2-core/src/cua2_core/websocket/websocket_manager.py CHANGED
@@ -1,11 +1,29 @@
1
  import asyncio
2
  import json
3
- from typing import Dict, Optional, Set
4
-
5
- from cua2_core.models.models import AgentTraceMetadata, WebSocketEvent
 
 
 
 
 
 
 
 
 
 
 
 
6
  from fastapi import WebSocket
7
 
8
 
 
 
 
 
 
 
9
  class WebSocketManager:
10
  """Manages WebSocket connections and broadcasting"""
11
 
@@ -29,90 +47,72 @@ class WebSocketManager:
29
  f"WebSocket disconnected. Total connections: {len(self.active_connections)}"
30
  )
31
 
32
- async def send_personal_message(
33
- self, message: WebSocketEvent, websocket: WebSocket
34
- ):
35
  """Send a message to a specific WebSocket connection"""
36
  try:
37
- await websocket.send_text(json.dumps(message.model_dump(mode="json")))
 
 
 
 
38
  except Exception as e:
39
  print(f"Error sending personal message: {e}")
40
  # Only disconnect if the connection is still in our set
41
  if websocket in self.active_connections:
42
  self.disconnect(websocket)
 
43
 
44
- async def broadcast(self, message: WebSocketEvent):
45
- """Broadcast a message to all connected WebSockets"""
46
- if not self.active_connections:
47
- return
48
-
49
- # Create a list of connections to remove if they fail
50
- disconnected = []
51
-
52
- for connection in self.active_connections.copy():
53
- try:
54
- await connection.send_text(json.dumps(message.model_dump(mode="json")))
55
- except Exception as e:
56
- print(f"Error broadcasting to connection: {e}")
57
- disconnected.append(connection)
58
-
59
- # Remove failed connections
60
- for connection in disconnected:
61
- if connection in self.active_connections:
62
- self.disconnect(connection)
63
-
64
- async def send_agent_start(self, content: str, message_id: str):
65
  """Send agent start event"""
66
- event = WebSocketEvent(
67
- type="agent_start", content=content, messageId=message_id
 
 
 
 
 
 
 
 
68
  )
69
- await self.broadcast(event)
70
 
71
- async def send_agent_progress(self, content: str, message_id: str):
 
 
 
 
 
72
  """Send agent progress event"""
73
- event = WebSocketEvent(
74
- type="agent_progress", content=content, messageId=message_id
 
75
  )
76
- await self.broadcast(event)
77
 
78
  async def send_agent_complete(
79
- self,
80
- content: str,
81
- message_id: str,
82
- metadata: Optional[AgentTraceMetadata] = None,
83
  ):
84
  """Send agent complete event"""
85
- event = WebSocketEvent(
86
- type="agent_complete",
87
- content=content,
88
- messageId=message_id,
89
- metadata=metadata,
90
- )
91
- await self.broadcast(event)
92
 
93
- async def send_agent_error(self, content: str, message_id: Optional[str] = None):
94
  """Send agent error event"""
95
- event = WebSocketEvent(
96
- type="agent_error", content=content, messageId=message_id
97
- )
98
- await self.broadcast(event)
99
 
100
- async def send_vnc_url_set(self, vnc_url: str, content: Optional[str] = None):
101
  """Send VNC URL set event"""
102
- event = WebSocketEvent(
103
- type="vnc_url_set",
104
- content=content or f"VNC stream available at: {vnc_url}",
105
  vncUrl=vnc_url,
106
  )
107
- await self.broadcast(event)
108
 
109
- async def send_vnc_url_unset(self, content: Optional[str] = None):
110
  """Send VNC URL unset event (reset to default display)"""
111
- event = WebSocketEvent(
112
- type="vnc_url_unset",
113
- content=content or "VNC stream disconnected, showing default display",
114
- )
115
- await self.broadcast(event)
116
 
117
  def get_connection_count(self) -> int:
118
  """Get the number of active connections"""
 
1
  import asyncio
2
  import json
3
+ from typing import Dict, Set
4
+
5
+ from cua2_core.models.models import (
6
+ ActiveTask,
7
+ AgentCompleteEvent,
8
+ AgentErrorEvent,
9
+ AgentProgressEvent,
10
+ AgentStartEvent,
11
+ AgentStep,
12
+ AgentTrace,
13
+ AgentTraceMetadata,
14
+ VncUrlSetEvent,
15
+ VncUrlUnsetEvent,
16
+ WebSocketEvent,
17
+ )
18
  from fastapi import WebSocket
19
 
20
 
21
+ class WebSocketException(Exception):
22
+ """Exception for WebSocket errors"""
23
+
24
+ pass
25
+
26
+
27
  class WebSocketManager:
28
  """Manages WebSocket connections and broadcasting"""
29
 
 
47
  f"WebSocket disconnected. Total connections: {len(self.active_connections)}"
48
  )
49
 
50
+ async def send_message(self, message: WebSocketEvent, websocket: WebSocket):
 
 
51
  """Send a message to a specific WebSocket connection"""
52
  try:
53
+ await websocket.send_text(
54
+ json.dumps(
55
+ message.model_dump(mode="json", context={"actions_as_json": False})
56
+ )
57
+ )
58
  except Exception as e:
59
  print(f"Error sending personal message: {e}")
60
  # Only disconnect if the connection is still in our set
61
  if websocket in self.active_connections:
62
  self.disconnect(websocket)
63
+ raise WebSocketException()
64
 
65
+ async def send_agent_start(self, active_task: ActiveTask, websocket: WebSocket):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  """Send agent start event"""
67
+ event = AgentStartEvent(
68
+ agentTrace=AgentTrace(
69
+ id=active_task.message_id,
70
+ timestamp=active_task.timestamp,
71
+ instruction=active_task.instruction,
72
+ modelId=active_task.model_id,
73
+ steps=active_task.steps,
74
+ traceMetadata=active_task.traceMetadata,
75
+ isRunning=True,
76
+ ),
77
  )
78
+ await self.send_message(event, websocket)
79
 
80
+ async def send_agent_progress(
81
+ self,
82
+ step: AgentStep,
83
+ metadata: AgentTraceMetadata,
84
+ websocket: WebSocket,
85
+ ):
86
  """Send agent progress event"""
87
+ event = AgentProgressEvent(
88
+ agentStep=step,
89
+ traceMetadata=metadata,
90
  )
91
+ await self.send_message(event, websocket)
92
 
93
  async def send_agent_complete(
94
+ self, metadata: AgentTraceMetadata, websocket: WebSocket
 
 
 
95
  ):
96
  """Send agent complete event"""
97
+ event = AgentCompleteEvent(traceMetadata=metadata)
98
+ await self.send_message(event, websocket)
 
 
 
 
 
99
 
100
+ async def send_agent_error(self, error: str, websocket: WebSocket):
101
  """Send agent error event"""
102
+ event = AgentErrorEvent(error=error)
103
+ await self.send_message(event, websocket)
 
 
104
 
105
+ async def send_vnc_url_set(self, vnc_url: str, websocket: WebSocket):
106
  """Send VNC URL set event"""
107
+ event = VncUrlSetEvent(
 
 
108
  vncUrl=vnc_url,
109
  )
110
+ await self.send_message(event, websocket)
111
 
112
+ async def send_vnc_url_unset(self, websocket: WebSocket):
113
  """Send VNC URL unset event (reset to default display)"""
114
+ event = VncUrlUnsetEvent()
115
+ await self.send_message(event, websocket)
 
 
 
116
 
117
  def get_connection_count(self) -> int:
118
  """Get the number of active connections"""
cua2-core/tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Tests for cua2-core"""
cua2-core/tests/test_routes.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import Mock
2
+
3
+ import pytest
4
+ from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
5
+ from cua2_core.routes.routes import router
6
+ from cua2_core.services.agent_service import AgentService
7
+ from cua2_core.services.agent_utils.get_model import AVAILABLE_MODELS
8
+ from fastapi import FastAPI
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.testclient import TestClient
11
+
12
+
13
+ @pytest.fixture
14
+ def mock_agent_service():
15
+ """Fixture to create a mocked AgentService"""
16
+ service = Mock(spec=AgentService)
17
+ service.active_tasks = {}
18
+ service.update_trace_step = Mock()
19
+ return service
20
+
21
+
22
+ @pytest.fixture
23
+ def mock_websocket_manager():
24
+ """Fixture to create a mocked WebSocketManager"""
25
+ manager = Mock()
26
+ manager.get_connection_count = Mock(return_value=0)
27
+ return manager
28
+
29
+
30
+ @pytest.fixture
31
+ def app(mock_agent_service, mock_websocket_manager):
32
+ """Fixture to create FastAPI app with mocked services"""
33
+ # Create a test FastAPI app
34
+ test_app = FastAPI(title="Test App")
35
+
36
+ # Add CORS middleware
37
+ test_app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+ # Include the router
46
+ test_app.include_router(router)
47
+
48
+ # Mock the services in app state
49
+ test_app.state.agent_service = mock_agent_service
50
+ test_app.state.websocket_manager = mock_websocket_manager
51
+
52
+ return test_app
53
+
54
+
55
+ @pytest.fixture
56
+ def client(app):
57
+ """Fixture to create test client"""
58
+ return TestClient(app)
59
+
60
+
61
+ class TestGetAvailableModels:
62
+ """Test suite for GET /models endpoint"""
63
+
64
+ def test_get_available_models_success(self, client):
65
+ """Test successful retrieval of available models"""
66
+ response = client.get("/models")
67
+
68
+ assert response.status_code == 200
69
+ data = response.json()
70
+
71
+ assert "models" in data
72
+ assert isinstance(data["models"], list)
73
+ assert len(data["models"]) > 0
74
+
75
+ # Verify models match AVAILABLE_MODELS
76
+ assert data["models"] == AVAILABLE_MODELS
77
+
78
+ def test_get_available_models_structure(self, client):
79
+ """Test that response matches AvailableModelsResponse schema"""
80
+ response = client.get("/models")
81
+
82
+ assert response.status_code == 200
83
+ data = response.json()
84
+
85
+ # Validate against Pydantic model
86
+ models_response = AvailableModelsResponse(**data)
87
+ assert models_response.models == AVAILABLE_MODELS
88
+
89
+ def test_get_available_models_content(self, client):
90
+ """Test that specific expected models are included"""
91
+ response = client.get("/models")
92
+
93
+ assert response.status_code == 200
94
+ data = response.json()
95
+
96
+ # Check for some specific models
97
+ expected_models = [
98
+ "Qwen/Qwen3-VL-2B-Instruct",
99
+ "Qwen/Qwen3-VL-30B-A3B-Instruct",
100
+ ]
101
+
102
+ for model in expected_models:
103
+ assert model in data["models"]
104
+
105
+
106
+ class TestUpdateTraceStep:
107
+ """Test suite for PATCH /traces/{trace_id}/steps/{step_id} endpoint"""
108
+
109
+ def test_update_trace_step_success(self, client, mock_agent_service):
110
+ """Test successful step update"""
111
+ trace_id = "test-trace-123"
112
+ step_id = "1"
113
+ request_data = {"step_evaluation": "like"}
114
+
115
+ # Mock the service method to succeed
116
+ mock_agent_service.update_trace_step.return_value = None
117
+
118
+ response = client.patch(
119
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
120
+ )
121
+
122
+ assert response.status_code == 200
123
+ data = response.json()
124
+
125
+ assert data["success"] is True
126
+ assert data["message"] == "Step updated successfully"
127
+
128
+ # Verify the service was called correctly
129
+ mock_agent_service.update_trace_step.assert_called_once_with(
130
+ trace_id=trace_id, step_id=step_id, step_evaluation="like"
131
+ )
132
+
133
+ def test_update_trace_step_with_dislike(self, client, mock_agent_service):
134
+ """Test step update with 'dislike' evaluation"""
135
+ trace_id = "test-trace-456"
136
+ step_id = "2"
137
+ request_data = {"step_evaluation": "dislike"}
138
+
139
+ mock_agent_service.update_trace_step.return_value = None
140
+
141
+ response = client.patch(
142
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
143
+ )
144
+
145
+ assert response.status_code == 200
146
+
147
+ mock_agent_service.update_trace_step.assert_called_once_with(
148
+ trace_id=trace_id, step_id=step_id, step_evaluation="dislike"
149
+ )
150
+
151
+ def test_update_trace_step_with_neutral(self, client, mock_agent_service):
152
+ """Test step update with 'neutral' evaluation"""
153
+ trace_id = "test-trace-789"
154
+ step_id = "3"
155
+ request_data = {"step_evaluation": "neutral"}
156
+
157
+ mock_agent_service.update_trace_step.return_value = None
158
+
159
+ response = client.patch(
160
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
161
+ )
162
+
163
+ assert response.status_code == 200
164
+
165
+ mock_agent_service.update_trace_step.assert_called_once_with(
166
+ trace_id=trace_id, step_id=step_id, step_evaluation="neutral"
167
+ )
168
+
169
+ def test_update_trace_step_invalid_evaluation(self, client, mock_agent_service):
170
+ """Test step update with invalid evaluation value"""
171
+ trace_id = "test-trace-123"
172
+ step_id = "1"
173
+ request_data = {"step_evaluation": "invalid"}
174
+
175
+ response = client.patch(
176
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
177
+ )
178
+
179
+ # Should fail validation
180
+ assert response.status_code == 422
181
+
182
+ def test_update_trace_step_value_error(self, client, mock_agent_service):
183
+ """Test step update when service raises ValueError"""
184
+ trace_id = "test-trace-123"
185
+ step_id = "invalid"
186
+ request_data = {"step_evaluation": "like"}
187
+
188
+ # Mock the service to raise ValueError
189
+ mock_agent_service.update_trace_step.side_effect = ValueError(
190
+ "Invalid step_id format"
191
+ )
192
+
193
+ response = client.patch(
194
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
195
+ )
196
+
197
+ assert response.status_code == 400
198
+ assert "Invalid step_id format" in response.json()["detail"]
199
+
200
+ def test_update_trace_step_not_found(self, client, mock_agent_service):
201
+ """Test step update when trace is not found"""
202
+ trace_id = "nonexistent-trace"
203
+ step_id = "1"
204
+ request_data = {"step_evaluation": "like"}
205
+
206
+ # Mock the service to raise FileNotFoundError
207
+ mock_agent_service.update_trace_step.side_effect = FileNotFoundError(
208
+ "Trace not found"
209
+ )
210
+
211
+ response = client.patch(
212
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
213
+ )
214
+
215
+ assert response.status_code == 404
216
+ assert "Trace not found" in response.json()["detail"]
217
+
218
+ def test_update_trace_step_step_not_found(self, client, mock_agent_service):
219
+ """Test step update when step doesn't exist in trace"""
220
+ trace_id = "test-trace-123"
221
+ step_id = "999"
222
+ request_data = {"step_evaluation": "like"}
223
+
224
+ # Mock the service to raise ValueError for step not found
225
+ mock_agent_service.update_trace_step.side_effect = ValueError(
226
+ "Step 999 not found in trace"
227
+ )
228
+
229
+ response = client.patch(
230
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
231
+ )
232
+
233
+ assert response.status_code == 400
234
+ assert "Step 999 not found in trace" in response.json()["detail"]
235
+
236
+ def test_update_trace_step_missing_request_body(self, client, mock_agent_service):
237
+ """Test step update with missing request body"""
238
+ trace_id = "test-trace-123"
239
+ step_id = "1"
240
+
241
+ response = client.patch(f"/traces/{trace_id}/steps/{step_id}", json={})
242
+
243
+ # Should fail validation
244
+ assert response.status_code == 422
245
+
246
+ def test_update_trace_step_with_special_characters(
247
+ self, client, mock_agent_service
248
+ ):
249
+ """Test step update with trace_id containing special characters"""
250
+ trace_id = "trace-01K960P4FA2BVC058EZDXQEB5E-Qwen-Qwen3-VL-30B-A3B-Instruct"
251
+ step_id = "1"
252
+ request_data = {"step_evaluation": "like"}
253
+
254
+ mock_agent_service.update_trace_step.return_value = None
255
+
256
+ response = client.patch(
257
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
258
+ )
259
+
260
+ assert response.status_code == 200
261
+
262
+ mock_agent_service.update_trace_step.assert_called_once_with(
263
+ trace_id=trace_id, step_id=step_id, step_evaluation="like"
264
+ )
265
+
266
+ def test_update_trace_step_response_structure(self, client, mock_agent_service):
267
+ """Test that response matches UpdateStepResponse schema"""
268
+ trace_id = "test-trace-123"
269
+ step_id = "1"
270
+ request_data = {"step_evaluation": "like"}
271
+
272
+ mock_agent_service.update_trace_step.return_value = None
273
+
274
+ response = client.patch(
275
+ f"/traces/{trace_id}/steps/{step_id}", json=request_data
276
+ )
277
+
278
+ assert response.status_code == 200
279
+ data = response.json()
280
+
281
+ # Validate against Pydantic model
282
+ update_response = UpdateStepResponse(**data)
283
+ assert update_response.success is True
284
+ assert update_response.message == "Step updated successfully"
285
+
286
+
287
+ class TestRoutesIntegration:
288
+ """Integration tests for multiple routes"""
289
+
290
+ def test_models_endpoint_available(self, client):
291
+ """Test that models endpoint is available"""
292
+ response = client.get("/models")
293
+ assert response.status_code == 200
294
+
295
+ def test_update_step_endpoint_available(self, client, mock_agent_service):
296
+ """Test that update step endpoint is available"""
297
+ mock_agent_service.update_trace_step.return_value = None
298
+
299
+ response = client.patch(
300
+ "/traces/test/steps/1", json={"step_evaluation": "like"}
301
+ )
302
+ assert response.status_code == 200
303
+
304
+ def test_invalid_route(self, client):
305
+ """Test accessing an invalid route"""
306
+ response = client.get("/invalid-route")
307
+ assert response.status_code == 404
308
+
309
+
310
+ if __name__ == "__main__":
311
+ pytest.main([__file__, "-v"])
cua2-front/src/components/mock/TaskButton.tsx CHANGED
@@ -12,8 +12,8 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
12
  onClick={() => {
13
  if (!isAgentProcessing && isConnected) {
14
  onSendTask(
15
- "Complete the online form by clicking through the required fields",
16
- "anthropic/claude-sonnet-4-5-20250929"
17
  );
18
  }
19
  }}
@@ -56,7 +56,7 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
56
  )}
57
  </div>
58
  <p style={{ fontSize: '15px', fontWeight: 500, color: '#1f2937' }}>
59
- Complete the online form by clicking through the required fields
60
  </p>
61
  </div>
62
  <div style={{
@@ -67,7 +67,7 @@ export const TaskButton: React.FC<TaskButtonProps> = ({ isAgentProcessing, isCon
67
  }}>
68
  <span style={{ fontSize: '11px', fontWeight: 600, color: 'rgba(0, 0, 0, 0.6)', textTransform: 'uppercase', letterSpacing: '1px' }}>Model</span>
69
  <p style={{ fontSize: '12px', fontWeight: 600, color: '#1f2937', marginTop: '2px', whiteSpace: 'nowrap' }}>
70
- claude-sonnet-4-5-20250929
71
  </p>
72
  </div>
73
  </div>
 
12
  onClick={() => {
13
  if (!isAgentProcessing && isConnected) {
14
  onSendTask(
15
+ "Find the price of a NVIDIA RTX 4090 GPU",
16
+ "Qwen/Qwen3-VL-30B-A3B-Instruct"
17
  );
18
  }
19
  }}
 
56
  )}
57
  </div>
58
  <p style={{ fontSize: '15px', fontWeight: 500, color: '#1f2937' }}>
59
+ Find the price of a NVIDIA RTX 4090 GPU
60
  </p>
61
  </div>
62
  <div style={{
 
67
  }}>
68
  <span style={{ fontSize: '11px', fontWeight: 600, color: 'rgba(0, 0, 0, 0.6)', textTransform: 'uppercase', letterSpacing: '1px' }}>Model</span>
69
  <p style={{ fontSize: '12px', fontWeight: 600, color: '#1f2937', marginTop: '2px', whiteSpace: 'nowrap' }}>
70
+ Qwen/Qwen3-VL-30B-A3B-Instruct
71
  </p>
72
  </div>
73
  </div>
cua2-front/src/pages/Index.tsx CHANGED
@@ -1,16 +1,14 @@
1
- import React from 'react';
2
  import { useWebSocket } from '@/hooks/useWebSocket';
3
- import { WebSocketEvent } from '@/types/agent';
4
  import { useState } from 'react';
5
- import { AgentTrace, AgentStep } from '@/types/agent';
6
  import { ulid } from 'ulid';
7
- import { Header, VNCStream, Metadata, StackSteps } from '@/components/mock';
8
 
9
  const Index = () => {
10
  const [trace, setTrace] = useState<AgentTrace>();
11
  const [isAgentProcessing, setIsAgentProcessing] = useState(false);
12
  const [vncUrl, setVncUrl] = useState<string>('');
13
- const [selectedModelId, setSelectedModelId] = useState<string>("claude-sonnet-4-5-20250929");
14
 
15
  // #################### WebSocket Connection ########################
16
 
@@ -51,12 +49,12 @@ const Index = () => {
51
  setIsAgentProcessing(false);
52
  setTrace(trace => {
53
  return trace.id === event.traceMetadata.traceId
54
- ? {
55
- ...trace,
56
- isRunning: false,
57
- metadata: event.traceMetadata,
58
- }
59
- : trace;
60
  });
61
  console.log('Agent complete received:', event.traceMetadata);
62
  break;
 
1
+ import { Header, Metadata, StackSteps, VNCStream } from '@/components/mock';
2
  import { useWebSocket } from '@/hooks/useWebSocket';
3
+ import { AgentStep, AgentTrace, WebSocketEvent } from '@/types/agent';
4
  import { useState } from 'react';
 
5
  import { ulid } from 'ulid';
 
6
 
7
  const Index = () => {
8
  const [trace, setTrace] = useState<AgentTrace>();
9
  const [isAgentProcessing, setIsAgentProcessing] = useState(false);
10
  const [vncUrl, setVncUrl] = useState<string>('');
11
+ const [selectedModelId, setSelectedModelId] = useState<string>("Qwen/Qwen3-VL-30B-A3B-Instruct");
12
 
13
  // #################### WebSocket Connection ########################
14
 
 
49
  setIsAgentProcessing(false);
50
  setTrace(trace => {
51
  return trace.id === event.traceMetadata.traceId
52
+ ? {
53
+ ...trace,
54
+ isRunning: false,
55
+ metadata: event.traceMetadata,
56
+ }
57
+ : trace;
58
  });
59
  console.log('Agent complete received:', event.traceMetadata);
60
  break;
cua2-front/src/types/agent.ts CHANGED
@@ -82,3 +82,18 @@ export interface UserTaskMessage {
82
  type: 'user_task';
83
  trace: AgentTrace;
84
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  type: 'user_task';
83
  trace: AgentTrace;
84
  }
85
+
86
+ // #################### API Routes Types ########################
87
+
88
+ export interface AvailableModelsResponse {
89
+ models: string[];
90
+ }
91
+
92
+ export interface UpdateStepRequest {
93
+ step_evaluation: 'like' | 'dislike' | 'neutral';
94
+ }
95
+
96
+ export interface UpdateStepResponse {
97
+ success: boolean;
98
+ message: string;
99
+ }