Spaces:
Running
Running
A-Mahla
commited on
Amir/handle sandbox (#18)
Browse files* ADD sandbox gestion
* ADD better sandbox gestion
* ADD better sandbox gestion
* ADD better sandbox gestion
* ADD better sandbox gestion
* ADD better sandbox gestion
- README.md +26 -0
- cua2-core/pyproject.toml +1 -0
- cua2-core/pytest.ini +1 -0
- cua2-core/src/cua2_core/models/models.py +13 -2
- cua2-core/src/cua2_core/routes/websocket.py +9 -2
- cua2-core/src/cua2_core/services/agent_service.py +136 -37
- cua2-core/src/cua2_core/services/archival_service.py +423 -0
- cua2-core/src/cua2_core/services/sandbox_service.py +133 -43
- cua2-core/tests/test_archival_service.py +585 -0
- cua2-front/src/App.tsx +2 -1
- cua2-front/src/components/sandbox/completionview/CompletionView.tsx +12 -12
- cua2-front/src/hooks/useAgentWebSocket.ts +27 -9
- cua2-front/src/pages/Index.tsx +0 -147
- cua2-front/src/pages/Task.tsx +24 -13
- cua2-front/src/stores/agentStore.ts +65 -56
- cua2-front/src/types/agent.ts +1 -0
README.md
CHANGED
|
@@ -93,8 +93,34 @@ cp env.example .env
|
|
| 93 |
Edit `.env` with your configuration:
|
| 94 |
- API keys for OpenAI/LiteLLM
|
| 95 |
- Database connections (if applicable)
|
|
|
|
| 96 |
- Other service credentials
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
### 4. Start Development Servers
|
| 99 |
|
| 100 |
#### Option 1: Using Makefile (Recommended)
|
|
|
|
| 93 |
Edit `.env` with your configuration:
|
| 94 |
- API keys for OpenAI/LiteLLM
|
| 95 |
- Database connections (if applicable)
|
| 96 |
+
- HuggingFace credentials for data archival (optional)
|
| 97 |
- Other service credentials
|
| 98 |
|
| 99 |
+
#### Data Archival Configuration (Optional)
|
| 100 |
+
|
| 101 |
+
CUA2 includes an automatic data archival feature that backs up old trace data to HuggingFace datasets:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
# HuggingFace token for uploading archived data
|
| 105 |
+
HF_TOKEN=your_huggingface_token_here
|
| 106 |
+
|
| 107 |
+
# HuggingFace dataset repository ID (e.g., "username/dataset-name")
|
| 108 |
+
HF_DATASET_REPO=your_username/your_dataset_repo
|
| 109 |
+
|
| 110 |
+
# Check interval (default: 30 minutes)
|
| 111 |
+
ARCHIVE_INTERVAL_MINUTES=30
|
| 112 |
+
|
| 113 |
+
# Age threshold - folders older than this will be archived (default: 30 minutes)
|
| 114 |
+
FOLDER_AGE_THRESHOLD_MINUTES=30
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
**How it works:**
|
| 118 |
+
1. Every 30 minutes (configurable), the system checks the `data/` folder for trace folders
|
| 119 |
+
2. Folders older than 30 minutes (configurable) are compressed into `.tar.gz` archives
|
| 120 |
+
3. Archives are uploaded to your HuggingFace dataset repository
|
| 121 |
+
4. After verifying successful upload, local folders are deleted to free up space
|
| 122 |
+
5. This keeps your disk usage minimal while preserving all agent traces in the cloud
|
| 123 |
+
|
| 124 |
### 4. Start Development Servers
|
| 125 |
|
| 126 |
#### Option 1: Using Makefile (Recommended)
|
cua2-core/pyproject.toml
CHANGED
|
@@ -36,6 +36,7 @@ dependencies = [
|
|
| 36 |
"smolagents[openai,litellm]==1.22.0",
|
| 37 |
"openai==2.6.1",
|
| 38 |
"e2b-desktop==2.1.0",
|
|
|
|
| 39 |
]
|
| 40 |
|
| 41 |
[project.optional-dependencies]
|
|
|
|
| 36 |
"smolagents[openai,litellm]==1.22.0",
|
| 37 |
"openai==2.6.1",
|
| 38 |
"e2b-desktop==2.1.0",
|
| 39 |
+
"huggingface_hub==1.1.2",
|
| 40 |
]
|
| 41 |
|
| 42 |
[project.optional-dependencies]
|
cua2-core/pytest.ini
CHANGED
|
@@ -11,3 +11,4 @@ addopts =
|
|
| 11 |
markers =
|
| 12 |
unit: Unit tests
|
| 13 |
integration: Integration tests
|
|
|
|
|
|
| 11 |
markers =
|
| 12 |
unit: Unit tests
|
| 13 |
integration: Integration tests
|
| 14 |
+
slow: Slow tests that take more time to execute
|
cua2-core/src/cua2_core/models/models.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import threading
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import Annotated, Literal, Optional
|
|
|
|
| 6 |
|
| 7 |
from cua2_core.services.agent_utils.function_parser import FunctionCall
|
| 8 |
from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
|
|
@@ -106,6 +107,15 @@ class AgentStep(BaseModel):
|
|
| 106 |
thought: Optional[str] = None
|
| 107 |
actions: list[AgentAction] = []
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
@field_serializer("actions")
|
| 110 |
def serialize_actions(self, actions: list[AgentAction], _info):
|
| 111 |
"""Convert actions to list of strings when dumping (controlled by context)"""
|
|
@@ -206,6 +216,7 @@ class HeartbeatEvent(BaseModel):
|
|
| 206 |
"""Heartbeat event"""
|
| 207 |
|
| 208 |
type: Literal["heartbeat"] = "heartbeat"
|
|
|
|
| 209 |
|
| 210 |
|
| 211 |
WebSocketEvent: TypeAlias = Annotated[
|
|
@@ -266,7 +277,7 @@ class ActiveTask(BaseModel):
|
|
| 266 |
self.model_dump(
|
| 267 |
mode="json",
|
| 268 |
exclude={"_file_locks"},
|
| 269 |
-
context={"actions_as_json": True},
|
| 270 |
),
|
| 271 |
f,
|
| 272 |
indent=2,
|
|
@@ -286,7 +297,7 @@ class ActiveTask(BaseModel):
|
|
| 286 |
self.model_dump(
|
| 287 |
mode="json",
|
| 288 |
exclude={"_file_locks"},
|
| 289 |
-
context={"actions_as_json": True},
|
| 290 |
),
|
| 291 |
f,
|
| 292 |
indent=2,
|
|
|
|
| 3 |
import threading
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import Annotated, Literal, Optional
|
| 6 |
+
from uuid import uuid4
|
| 7 |
|
| 8 |
from cua2_core.services.agent_utils.function_parser import FunctionCall
|
| 9 |
from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
|
|
|
|
| 107 |
thought: Optional[str] = None
|
| 108 |
actions: list[AgentAction] = []
|
| 109 |
|
| 110 |
+
@field_serializer("image")
|
| 111 |
+
def serialize_image(self, image: str, _info):
|
| 112 |
+
"""Convert image to path when dumping to JSON"""
|
| 113 |
+
|
| 114 |
+
if _info.context and _info.context.get("image_as_path", True):
|
| 115 |
+
return f"{self.traceId}-{self.stepId}.png"
|
| 116 |
+
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
@field_serializer("actions")
|
| 120 |
def serialize_actions(self, actions: list[AgentAction], _info):
|
| 121 |
"""Convert actions to list of strings when dumping (controlled by context)"""
|
|
|
|
| 216 |
"""Heartbeat event"""
|
| 217 |
|
| 218 |
type: Literal["heartbeat"] = "heartbeat"
|
| 219 |
+
uuid: str = Field(default_factory=lambda: str(uuid4()))
|
| 220 |
|
| 221 |
|
| 222 |
WebSocketEvent: TypeAlias = Annotated[
|
|
|
|
| 277 |
self.model_dump(
|
| 278 |
mode="json",
|
| 279 |
exclude={"_file_locks"},
|
| 280 |
+
context={"actions_as_json": True, "image_as_path": True},
|
| 281 |
),
|
| 282 |
f,
|
| 283 |
indent=2,
|
|
|
|
| 297 |
self.model_dump(
|
| 298 |
mode="json",
|
| 299 |
exclude={"_file_locks"},
|
| 300 |
+
context={"actions_as_json": True, "image_as_path": True},
|
| 301 |
),
|
| 302 |
f,
|
| 303 |
indent=2,
|
cua2-core/src/cua2_core/routes/websocket.py
CHANGED
|
@@ -21,8 +21,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 21 |
await websocket_manager.connect(websocket)
|
| 22 |
|
| 23 |
try:
|
| 24 |
-
|
| 25 |
-
|
|
|
|
| 26 |
await websocket_manager.send_message(welcome_message, websocket)
|
| 27 |
|
| 28 |
# Keep the connection alive and wait for messages
|
|
@@ -100,5 +101,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 100 |
except Exception as e:
|
| 101 |
print(f"WebSocket connection error: {e}")
|
| 102 |
finally:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
# Ensure cleanup happens
|
| 104 |
websocket_manager.disconnect(websocket)
|
|
|
|
| 21 |
await websocket_manager.connect(websocket)
|
| 22 |
|
| 23 |
try:
|
| 24 |
+
welcome_message = HeartbeatEvent(
|
| 25 |
+
uuid=await agent_service.create_id_and_sandbox(websocket)
|
| 26 |
+
)
|
| 27 |
await websocket_manager.send_message(welcome_message, websocket)
|
| 28 |
|
| 29 |
# Keep the connection alive and wait for messages
|
|
|
|
| 101 |
except Exception as e:
|
| 102 |
print(f"WebSocket connection error: {e}")
|
| 103 |
finally:
|
| 104 |
+
# Cleanup tasks and sandboxes associated with this websocket
|
| 105 |
+
try:
|
| 106 |
+
await agent_service.cleanup_tasks_for_websocket(websocket)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Error cleaning up tasks for websocket: {e}")
|
| 109 |
+
|
| 110 |
# Ensure cleanup happens
|
| 111 |
websocket_manager.disconnect(websocket)
|
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
import time
|
| 7 |
from io import BytesIO
|
| 8 |
from typing import Callable, Literal
|
|
|
|
| 9 |
|
| 10 |
from cua2_core.models.models import (
|
| 11 |
ActiveTask,
|
|
@@ -17,12 +18,14 @@ from cua2_core.models.models import (
|
|
| 17 |
from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
|
| 18 |
from cua2_core.services.agent_utils.function_parser import parse_function_call
|
| 19 |
from cua2_core.services.agent_utils.get_model import get_model
|
|
|
|
| 20 |
from cua2_core.services.sandbox_service import SandboxService
|
| 21 |
from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
|
| 22 |
from e2b_desktop import Sandbox, TimeoutException
|
| 23 |
from fastapi import WebSocket
|
| 24 |
from PIL import Image
|
| 25 |
from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
|
|
|
| 26 |
|
| 27 |
logger = logging.getLogger(__name__)
|
| 28 |
|
|
@@ -50,6 +53,35 @@ class AgentService:
|
|
| 50 |
self._lock = asyncio.Lock()
|
| 51 |
self.max_sandboxes = int(600 / num_workers)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
async def process_user_task(
|
| 54 |
self, trace: AgentTrace, websocket: WebSocket
|
| 55 |
) -> str | None:
|
|
@@ -60,6 +92,9 @@ class AgentService:
|
|
| 60 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 61 |
|
| 62 |
async with self._lock:
|
|
|
|
|
|
|
|
|
|
| 63 |
active_task = ActiveTask(
|
| 64 |
message_id=trace_id,
|
| 65 |
instruction=trace.instruction,
|
|
@@ -79,9 +114,11 @@ class AgentService:
|
|
| 79 |
|
| 80 |
# Store the task and websocket for this task
|
| 81 |
self.active_tasks[trace_id] = active_task
|
| 82 |
-
self.task_websockets[trace_id] = websocket
|
| 83 |
self.last_screenshot[trace_id] = None
|
| 84 |
|
|
|
|
|
|
|
|
|
|
| 85 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 86 |
|
| 87 |
return trace_id
|
|
@@ -111,8 +148,15 @@ class AgentService:
|
|
| 111 |
|
| 112 |
model = get_model(self.active_tasks[message_id].model_id)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
if sandbox is None:
|
| 117 |
raise Exception("No sandbox available: pool limit reached")
|
| 118 |
|
|
@@ -180,7 +224,12 @@ class AgentService:
|
|
| 180 |
|
| 181 |
finally:
|
| 182 |
# Send completion event
|
| 183 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
await self.websocket_manager.send_agent_complete(
|
| 185 |
metadata=self.active_tasks[message_id].traceMetadata,
|
| 186 |
websocket=websocket,
|
|
@@ -210,9 +259,12 @@ class AgentService:
|
|
| 210 |
if message_id in self.last_screenshot:
|
| 211 |
del self.last_screenshot[message_id]
|
| 212 |
|
|
|
|
|
|
|
|
|
|
| 213 |
# Release sandbox back to the pool
|
| 214 |
if sandbox:
|
| 215 |
-
await self.sandbox_service.release_sandbox(
|
| 216 |
|
| 217 |
async def _agent_processing(
|
| 218 |
self,
|
|
@@ -236,24 +288,8 @@ class AgentService:
|
|
| 236 |
if memory_step.step_number > agent.max_steps:
|
| 237 |
raise AgentStopException("Max steps reached")
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
image = self.last_screenshot[message_id]
|
| 242 |
-
assert image is not None
|
| 243 |
-
|
| 244 |
-
for previous_memory_step in (
|
| 245 |
-
agent.memory.steps
|
| 246 |
-
): # Remove previous screenshots from logs for lean processing
|
| 247 |
-
if (
|
| 248 |
-
isinstance(previous_memory_step, ActionStep)
|
| 249 |
-
and previous_memory_step.step_number is not None
|
| 250 |
-
and previous_memory_step.step_number <= memory_step.step_number - 1
|
| 251 |
-
):
|
| 252 |
-
previous_memory_step.observations_images = None
|
| 253 |
-
elif isinstance(previous_memory_step, TaskStep):
|
| 254 |
-
previous_memory_step.task_images = None
|
| 255 |
-
|
| 256 |
-
memory_step.observations_images = [image.copy()]
|
| 257 |
|
| 258 |
model_output = (
|
| 259 |
memory_step.model_output_message.content
|
|
@@ -280,6 +316,35 @@ class AgentService:
|
|
| 280 |
"""The task failed due to an error""" # TODO: To Handle in front
|
| 281 |
)
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
if memory_step.observations_images:
|
| 284 |
image = memory_step.observations_images[0]
|
| 285 |
buffered = BytesIO()
|
|
@@ -295,11 +360,7 @@ class AgentService:
|
|
| 295 |
stepId=str(memory_step.step_number),
|
| 296 |
image=image_base64,
|
| 297 |
thought=thought,
|
| 298 |
-
actions=
|
| 299 |
-
parse_function_call(action_sequence)
|
| 300 |
-
)
|
| 301 |
-
if action_sequence
|
| 302 |
-
else None,
|
| 303 |
error=memory_step.error.message if memory_step.error else None,
|
| 304 |
duration=memory_step.timing.duration,
|
| 305 |
inputTokensUsed=memory_step.token_usage.input_tokens,
|
|
@@ -317,15 +378,16 @@ class AgentService:
|
|
| 317 |
self.active_tasks[message_id].update_step(step)
|
| 318 |
|
| 319 |
websocket = self.task_websockets.get(message_id)
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
|
|
|
| 329 |
|
| 330 |
if self.active_tasks[message_id].traceMetadata.completed:
|
| 331 |
raise AgentStopException("Task not completed")
|
|
@@ -419,3 +481,40 @@ class AgentService:
|
|
| 419 |
self.active_tasks[trace_id].update_trace_metadata(
|
| 420 |
completed=True,
|
| 421 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import time
|
| 7 |
from io import BytesIO
|
| 8 |
from typing import Callable, Literal
|
| 9 |
+
from uuid import uuid4
|
| 10 |
|
| 11 |
from cua2_core.models.models import (
|
| 12 |
ActiveTask,
|
|
|
|
| 18 |
from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
|
| 19 |
from cua2_core.services.agent_utils.function_parser import parse_function_call
|
| 20 |
from cua2_core.services.agent_utils.get_model import get_model
|
| 21 |
+
from cua2_core.services.archival_service import ArchivalService
|
| 22 |
from cua2_core.services.sandbox_service import SandboxService
|
| 23 |
from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
|
| 24 |
from e2b_desktop import Sandbox, TimeoutException
|
| 25 |
from fastapi import WebSocket
|
| 26 |
from PIL import Image
|
| 27 |
from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
|
| 28 |
+
from starlette.websockets import WebSocketState
|
| 29 |
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
|
|
|
| 53 |
self._lock = asyncio.Lock()
|
| 54 |
self.max_sandboxes = int(600 / num_workers)
|
| 55 |
|
| 56 |
+
# Initialize archival service in dedicated process
|
| 57 |
+
self.archival_service = ArchivalService(
|
| 58 |
+
hf_token=os.getenv("HF_TOKEN"),
|
| 59 |
+
hf_dataset_repo="smolagents/cua_traces",
|
| 60 |
+
data_dir="data",
|
| 61 |
+
archive_interval_minutes=30,
|
| 62 |
+
folder_age_threshold_minutes=30,
|
| 63 |
+
)
|
| 64 |
+
# Start the archival service process
|
| 65 |
+
self.archival_service.start()
|
| 66 |
+
|
| 67 |
+
def _update_archival_active_tasks(self):
|
| 68 |
+
"""
|
| 69 |
+
Update the archival service with current active task IDs.
|
| 70 |
+
Should be called whenever tasks are added or removed.
|
| 71 |
+
"""
|
| 72 |
+
if self.archival_service.is_alive():
|
| 73 |
+
self.archival_service.update_active_tasks(set(self.active_tasks.keys()))
|
| 74 |
+
|
| 75 |
+
async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
|
| 76 |
+
"""Create a new ID and sandbox"""
|
| 77 |
+
async with self._lock:
|
| 78 |
+
uuid = str(uuid4())
|
| 79 |
+
while uuid in self.active_tasks:
|
| 80 |
+
uuid = str(uuid4())
|
| 81 |
+
self.task_websockets[uuid] = websocket
|
| 82 |
+
await self.sandbox_service.acquire_sandbox(uuid)
|
| 83 |
+
return uuid
|
| 84 |
+
|
| 85 |
async def process_user_task(
|
| 86 |
self, trace: AgentTrace, websocket: WebSocket
|
| 87 |
) -> str | None:
|
|
|
|
| 92 |
trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
|
| 93 |
|
| 94 |
async with self._lock:
|
| 95 |
+
if self.task_websockets[trace_id] != websocket:
|
| 96 |
+
raise WebSocketException("WebSocket mismatch")
|
| 97 |
+
|
| 98 |
active_task = ActiveTask(
|
| 99 |
message_id=trace_id,
|
| 100 |
instruction=trace.instruction,
|
|
|
|
| 114 |
|
| 115 |
# Store the task and websocket for this task
|
| 116 |
self.active_tasks[trace_id] = active_task
|
|
|
|
| 117 |
self.last_screenshot[trace_id] = None
|
| 118 |
|
| 119 |
+
# Update archival service with new active task
|
| 120 |
+
self._update_archival_active_tasks()
|
| 121 |
+
|
| 122 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 123 |
|
| 124 |
return trace_id
|
|
|
|
| 148 |
|
| 149 |
model = get_model(self.active_tasks[message_id].model_id)
|
| 150 |
|
| 151 |
+
max_attempts = 10
|
| 152 |
+
for _ in range(max_attempts):
|
| 153 |
+
response = await self.sandbox_service.acquire_sandbox(message_id)
|
| 154 |
+
if response.sandbox is not None and response.state == "ready":
|
| 155 |
+
sandbox = response.sandbox
|
| 156 |
+
break
|
| 157 |
+
elif response.state == "max_sandboxes_reached":
|
| 158 |
+
raise Exception("No sandbox available: pool limit reached")
|
| 159 |
+
await asyncio.sleep(2)
|
| 160 |
if sandbox is None:
|
| 161 |
raise Exception("No sandbox available: pool limit reached")
|
| 162 |
|
|
|
|
| 224 |
|
| 225 |
finally:
|
| 226 |
# Send completion event
|
| 227 |
+
# Check if websocket is still connected before sending
|
| 228 |
+
if (
|
| 229 |
+
not websocket_exception
|
| 230 |
+
and websocket
|
| 231 |
+
and websocket.client_state == WebSocketState.CONNECTED
|
| 232 |
+
):
|
| 233 |
await self.websocket_manager.send_agent_complete(
|
| 234 |
metadata=self.active_tasks[message_id].traceMetadata,
|
| 235 |
websocket=websocket,
|
|
|
|
| 259 |
if message_id in self.last_screenshot:
|
| 260 |
del self.last_screenshot[message_id]
|
| 261 |
|
| 262 |
+
# Update archival service after task removal
|
| 263 |
+
self._update_archival_active_tasks()
|
| 264 |
+
|
| 265 |
# Release sandbox back to the pool
|
| 266 |
if sandbox:
|
| 267 |
+
await self.sandbox_service.release_sandbox(message_id)
|
| 268 |
|
| 269 |
async def _agent_processing(
|
| 270 |
self,
|
|
|
|
| 288 |
if memory_step.step_number > agent.max_steps:
|
| 289 |
raise AgentStopException("Max steps reached")
|
| 290 |
|
| 291 |
+
if self.active_tasks[message_id].traceMetadata.completed:
|
| 292 |
+
raise AgentStopException("Task not completed")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
model_output = (
|
| 295 |
memory_step.model_output_message.content
|
|
|
|
| 316 |
"""The task failed due to an error""" # TODO: To Handle in front
|
| 317 |
)
|
| 318 |
|
| 319 |
+
agent_actions = (
|
| 320 |
+
AgentAction.from_function_calls(parse_function_call(action_sequence))
|
| 321 |
+
if action_sequence
|
| 322 |
+
else None
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if not (
|
| 326 |
+
agent_actions is not None
|
| 327 |
+
and any(action.function_name == "wait" for action in agent_actions)
|
| 328 |
+
):
|
| 329 |
+
time.sleep(3)
|
| 330 |
+
|
| 331 |
+
image = self.last_screenshot[message_id]
|
| 332 |
+
assert image is not None
|
| 333 |
+
|
| 334 |
+
for previous_memory_step in (
|
| 335 |
+
agent.memory.steps
|
| 336 |
+
): # Remove previous screenshots from logs for lean processing
|
| 337 |
+
if (
|
| 338 |
+
isinstance(previous_memory_step, ActionStep)
|
| 339 |
+
and previous_memory_step.step_number is not None
|
| 340 |
+
and previous_memory_step.step_number <= memory_step.step_number - 1
|
| 341 |
+
):
|
| 342 |
+
previous_memory_step.observations_images = None
|
| 343 |
+
elif isinstance(previous_memory_step, TaskStep):
|
| 344 |
+
previous_memory_step.task_images = None
|
| 345 |
+
|
| 346 |
+
memory_step.observations_images = [image.copy()]
|
| 347 |
+
|
| 348 |
if memory_step.observations_images:
|
| 349 |
image = memory_step.observations_images[0]
|
| 350 |
buffered = BytesIO()
|
|
|
|
| 360 |
stepId=str(memory_step.step_number),
|
| 361 |
image=image_base64,
|
| 362 |
thought=thought,
|
| 363 |
+
actions=agent_actions,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
error=memory_step.error.message if memory_step.error else None,
|
| 365 |
duration=memory_step.timing.duration,
|
| 366 |
inputTokensUsed=memory_step.token_usage.input_tokens,
|
|
|
|
| 378 |
self.active_tasks[message_id].update_step(step)
|
| 379 |
|
| 380 |
websocket = self.task_websockets.get(message_id)
|
| 381 |
+
if websocket and websocket.client_state == WebSocketState.CONNECTED:
|
| 382 |
+
future = asyncio.run_coroutine_threadsafe(
|
| 383 |
+
self.websocket_manager.send_agent_progress(
|
| 384 |
+
step=step,
|
| 385 |
+
metadata=self.active_tasks[message_id].traceMetadata,
|
| 386 |
+
websocket=websocket,
|
| 387 |
+
),
|
| 388 |
+
loop,
|
| 389 |
+
)
|
| 390 |
+
future.result()
|
| 391 |
|
| 392 |
if self.active_tasks[message_id].traceMetadata.completed:
|
| 393 |
raise AgentStopException("Task not completed")
|
|
|
|
| 481 |
self.active_tasks[trace_id].update_trace_metadata(
|
| 482 |
completed=True,
|
| 483 |
)
|
| 484 |
+
|
| 485 |
+
async def cleanup_tasks_for_websocket(self, websocket: WebSocket):
|
| 486 |
+
"""
|
| 487 |
+
Clean up all tasks associated with a disconnected websocket.
|
| 488 |
+
This will stop the tasks and release their sandboxes.
|
| 489 |
+
"""
|
| 490 |
+
tasks_to_cleanup = []
|
| 491 |
+
|
| 492 |
+
# Find all message_ids associated with this websocket
|
| 493 |
+
async with self._lock:
|
| 494 |
+
for message_id, ws in list(self.task_websockets.items()):
|
| 495 |
+
if ws == websocket:
|
| 496 |
+
tasks_to_cleanup.append(message_id)
|
| 497 |
+
logger.info(
|
| 498 |
+
f"Marking task {message_id} for cleanup due to websocket disconnect"
|
| 499 |
+
)
|
| 500 |
+
|
| 501 |
+
# Cleanup each task
|
| 502 |
+
for message_id in tasks_to_cleanup:
|
| 503 |
+
try:
|
| 504 |
+
# Mark task as completed to stop the agent
|
| 505 |
+
if message_id in self.active_tasks:
|
| 506 |
+
self.active_tasks[message_id].update_trace_metadata(
|
| 507 |
+
completed=True,
|
| 508 |
+
)
|
| 509 |
+
logger.info(
|
| 510 |
+
f"Stopped task {message_id} due to websocket disconnect"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Release the sandbox immediately
|
| 514 |
+
await self.sandbox_service.release_sandbox(message_id)
|
| 515 |
+
logger.info(
|
| 516 |
+
f"Released sandbox for task {message_id} due to websocket disconnect"
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
except Exception as e:
|
| 520 |
+
logger.error(f"Error cleaning up task {message_id}: {e}", exc_info=True)
|
cua2-core/src/cua2_core/services/archival_service.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Service for automatic data archival to HuggingFace datasets.
|
| 3 |
+
|
| 4 |
+
This service runs in a dedicated process that periodically:
|
| 5 |
+
1. Scans for old trace data folders
|
| 6 |
+
2. Compresses them into tar.gz archives
|
| 7 |
+
3. Uploads to HuggingFace dataset repository
|
| 8 |
+
4. Verifies successful upload
|
| 9 |
+
5. Deletes local files only after verification
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import multiprocessing
|
| 14 |
+
import multiprocessing.synchronize
|
| 15 |
+
import os
|
| 16 |
+
import shutil
|
| 17 |
+
import signal
|
| 18 |
+
import tarfile
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
from huggingface_hub import HfApi, hf_hub_download
|
| 24 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 25 |
+
|
| 26 |
+
# Configure logging for the process
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 30 |
+
)
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ArchivalService:
|
| 35 |
+
"""Service for handling automatic data archival to HuggingFace in a dedicated process"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
hf_token: str | None = os.getenv("HF_TOKEN"),
|
| 40 |
+
hf_dataset_repo: str | None = "smolagents/cua_traces",
|
| 41 |
+
data_dir: str = "data",
|
| 42 |
+
archive_interval_minutes: int = 30,
|
| 43 |
+
folder_age_threshold_minutes: int = 30,
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Initialize the archival service.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
hf_token: HuggingFace API token
|
| 50 |
+
hf_dataset_repo: HuggingFace dataset repository ID (e.g., "username/dataset-name")
|
| 51 |
+
data_dir: Directory containing trace data folders
|
| 52 |
+
archive_interval_minutes: How often to check for old folders
|
| 53 |
+
folder_age_threshold_minutes: Minimum age before archival
|
| 54 |
+
"""
|
| 55 |
+
self.hf_token = hf_token
|
| 56 |
+
self.hf_dataset_repo = hf_dataset_repo
|
| 57 |
+
self.data_dir = data_dir
|
| 58 |
+
self.archive_interval_minutes = archive_interval_minutes
|
| 59 |
+
self.folder_age_threshold_minutes = folder_age_threshold_minutes
|
| 60 |
+
|
| 61 |
+
# Multiprocessing components
|
| 62 |
+
self._process: multiprocessing.Process | None = None
|
| 63 |
+
self._stop_event: multiprocessing.synchronize.Event = multiprocessing.Event()
|
| 64 |
+
self._manager = multiprocessing.Manager()
|
| 65 |
+
self._active_tasks: Any = self._manager.dict() # DictProxy type
|
| 66 |
+
|
| 67 |
+
def start(self):
|
| 68 |
+
"""Start the archival service in a dedicated process."""
|
| 69 |
+
if self._process and self._process.is_alive():
|
| 70 |
+
logger.warning("Archival service is already running")
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
if not self.hf_token or not self.hf_dataset_repo:
|
| 74 |
+
logger.warning(
|
| 75 |
+
"HuggingFace credentials or dataset repo not configured. Data archival disabled."
|
| 76 |
+
)
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
self._stop_event.clear()
|
| 80 |
+
self._process = multiprocessing.Process(
|
| 81 |
+
target=_archival_worker_process,
|
| 82 |
+
args=(
|
| 83 |
+
self.hf_token,
|
| 84 |
+
self.hf_dataset_repo,
|
| 85 |
+
self.data_dir,
|
| 86 |
+
self.archive_interval_minutes,
|
| 87 |
+
self.folder_age_threshold_minutes,
|
| 88 |
+
self._stop_event,
|
| 89 |
+
self._active_tasks,
|
| 90 |
+
),
|
| 91 |
+
daemon=True,
|
| 92 |
+
name="ArchivalWorker",
|
| 93 |
+
)
|
| 94 |
+
self._process.start()
|
| 95 |
+
logger.info(
|
| 96 |
+
f"Started archival service in process {self._process.pid}. "
|
| 97 |
+
f"Checking every {self.archive_interval_minutes} minutes."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def stop(self, timeout: float = 10.0):
|
| 101 |
+
"""
|
| 102 |
+
Stop the archival service process.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
timeout: Maximum time to wait for process to terminate (seconds)
|
| 106 |
+
"""
|
| 107 |
+
if not self._process or not self._process.is_alive():
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
logger.info(f"Stopping archival service (PID: {self._process.pid})...")
|
| 111 |
+
self._stop_event.set()
|
| 112 |
+
|
| 113 |
+
self._process.join(timeout=timeout)
|
| 114 |
+
|
| 115 |
+
if self._process.is_alive():
|
| 116 |
+
logger.warning("Archival process did not stop gracefully, terminating...")
|
| 117 |
+
self._process.terminate()
|
| 118 |
+
self._process.join(timeout=2.0)
|
| 119 |
+
|
| 120 |
+
if self._process.is_alive():
|
| 121 |
+
logger.error("Force killing archival process...")
|
| 122 |
+
self._process.kill()
|
| 123 |
+
self._process.join()
|
| 124 |
+
|
| 125 |
+
logger.info("Archival service stopped")
|
| 126 |
+
self._process = None
|
| 127 |
+
|
| 128 |
+
def update_active_tasks(self, active_task_ids: set[str]):
|
| 129 |
+
"""
|
| 130 |
+
Update the set of active task IDs.
|
| 131 |
+
The archival process will skip folders for these tasks.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
active_task_ids: Set of currently active trace IDs
|
| 135 |
+
"""
|
| 136 |
+
# Clear and update the shared dict
|
| 137 |
+
self._active_tasks.clear()
|
| 138 |
+
for task_id in active_task_ids:
|
| 139 |
+
self._active_tasks[task_id] = True
|
| 140 |
+
|
| 141 |
+
def is_alive(self) -> bool:
|
| 142 |
+
"""Check if the archival process is running."""
|
| 143 |
+
return self._process is not None and self._process.is_alive()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _archival_worker_process(
|
| 147 |
+
hf_token: str,
|
| 148 |
+
hf_dataset_repo: str,
|
| 149 |
+
data_dir: str,
|
| 150 |
+
archive_interval_minutes: int,
|
| 151 |
+
folder_age_threshold_minutes: int,
|
| 152 |
+
stop_event: multiprocessing.synchronize.Event,
|
| 153 |
+
active_tasks: Any,
|
| 154 |
+
):
|
| 155 |
+
"""
|
| 156 |
+
Worker process that performs the archival operations.
|
| 157 |
+
Runs in a separate process from the main application.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
hf_token: HuggingFace API token
|
| 161 |
+
hf_dataset_repo: HuggingFace dataset repository
|
| 162 |
+
data_dir: Data directory path
|
| 163 |
+
archive_interval_minutes: Check interval
|
| 164 |
+
folder_age_threshold_minutes: Folder age threshold
|
| 165 |
+
stop_event: Event to signal process shutdown
|
| 166 |
+
active_tasks: Shared dict of active task IDs
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def signal_handler(signum, frame):
|
| 170 |
+
"""Handle termination signals gracefully."""
|
| 171 |
+
logger.info(f"Received signal {signum}, stopping archival worker...")
|
| 172 |
+
stop_event.set()
|
| 173 |
+
|
| 174 |
+
# Register signal handlers
|
| 175 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
| 176 |
+
signal.signal(signal.SIGINT, signal_handler)
|
| 177 |
+
|
| 178 |
+
# Initialize HuggingFace API in this process
|
| 179 |
+
hf_api = HfApi(token=hf_token)
|
| 180 |
+
|
| 181 |
+
logger.info(
|
| 182 |
+
f"Archival worker started (PID: {os.getpid()}). "
|
| 183 |
+
f"Checking every {archive_interval_minutes} minutes."
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Main worker loop
|
| 187 |
+
while not stop_event.is_set():
|
| 188 |
+
try:
|
| 189 |
+
# Sleep in small intervals to be responsive to stop_event
|
| 190 |
+
for _ in range(archive_interval_minutes * 60):
|
| 191 |
+
if stop_event.is_set():
|
| 192 |
+
break
|
| 193 |
+
time.sleep(1)
|
| 194 |
+
|
| 195 |
+
if stop_event.is_set():
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
logger.info("Starting data archival check...")
|
| 199 |
+
_process_old_folders(
|
| 200 |
+
data_dir=data_dir,
|
| 201 |
+
folder_age_threshold_minutes=folder_age_threshold_minutes,
|
| 202 |
+
active_tasks=active_tasks,
|
| 203 |
+
hf_api=hf_api,
|
| 204 |
+
hf_dataset_repo=hf_dataset_repo,
|
| 205 |
+
hf_token=hf_token,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Error in archival worker: {e}", exc_info=True)
|
| 210 |
+
# Continue running despite errors
|
| 211 |
+
|
| 212 |
+
logger.info("Archival worker shutting down gracefully")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _process_old_folders(
|
| 216 |
+
data_dir: str,
|
| 217 |
+
folder_age_threshold_minutes: int,
|
| 218 |
+
active_tasks: Any,
|
| 219 |
+
hf_api: HfApi,
|
| 220 |
+
hf_dataset_repo: str,
|
| 221 |
+
hf_token: str,
|
| 222 |
+
):
|
| 223 |
+
"""
|
| 224 |
+
Process and archive folders older than the threshold.
|
| 225 |
+
Runs in the archival worker process.
|
| 226 |
+
"""
|
| 227 |
+
if not os.path.exists(data_dir):
|
| 228 |
+
logger.warning(f"Data directory {data_dir} does not exist")
|
| 229 |
+
return
|
| 230 |
+
|
| 231 |
+
data_path = Path(data_dir)
|
| 232 |
+
current_time = time.time()
|
| 233 |
+
threshold_seconds = folder_age_threshold_minutes * 60
|
| 234 |
+
|
| 235 |
+
# Get all trace folders
|
| 236 |
+
try:
|
| 237 |
+
trace_folders = [
|
| 238 |
+
f for f in data_path.iterdir() if f.is_dir() and f.name.startswith("trace-")
|
| 239 |
+
]
|
| 240 |
+
except Exception as e:
|
| 241 |
+
logger.error(f"Error listing data directory: {e}", exc_info=True)
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
for folder in trace_folders:
|
| 245 |
+
try:
|
| 246 |
+
# Check if folder is old enough and not currently active
|
| 247 |
+
folder_mtime = folder.stat().st_mtime
|
| 248 |
+
folder_age_seconds = current_time - folder_mtime
|
| 249 |
+
|
| 250 |
+
# Extract trace_id from folder name (format: trace-{uuid}-{model_name})
|
| 251 |
+
folder_name = folder.name
|
| 252 |
+
parts = folder_name.split("-", 2) # Split into ['trace', uuid, model_name]
|
| 253 |
+
if len(parts) < 2:
|
| 254 |
+
logger.warning(f"Unexpected folder name format: {folder_name}")
|
| 255 |
+
continue
|
| 256 |
+
|
| 257 |
+
trace_id = parts[1]
|
| 258 |
+
|
| 259 |
+
# Skip if folder is still being used by an active task
|
| 260 |
+
if trace_id in active_tasks:
|
| 261 |
+
logger.debug(f"Skipping active task folder: {folder_name}")
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
# Check if folder is old enough
|
| 265 |
+
if folder_age_seconds < threshold_seconds:
|
| 266 |
+
logger.debug(
|
| 267 |
+
f"Folder {folder_name} is not old enough ({folder_age_seconds / 60:.1f} minutes)"
|
| 268 |
+
)
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
logger.info(
|
| 272 |
+
f"Processing old folder: {folder_name} (age: {folder_age_seconds / 60:.1f} minutes)"
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Compress the folder
|
| 276 |
+
archive_path = _compress_folder(folder)
|
| 277 |
+
|
| 278 |
+
if not archive_path:
|
| 279 |
+
logger.error(f"Failed to compress folder: {folder_name}")
|
| 280 |
+
continue
|
| 281 |
+
|
| 282 |
+
# Upload to HuggingFace
|
| 283 |
+
uploaded = _upload_to_huggingface(hf_api, hf_dataset_repo, archive_path)
|
| 284 |
+
|
| 285 |
+
if not uploaded:
|
| 286 |
+
logger.error(f"Failed to upload archive: {archive_path.name}")
|
| 287 |
+
# Clean up the local archive file
|
| 288 |
+
archive_path.unlink(missing_ok=True)
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
# Verify the file exists in the repo
|
| 292 |
+
verified = _verify_file_in_repo(
|
| 293 |
+
hf_dataset_repo, hf_token, archive_path.name
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if verified:
|
| 297 |
+
logger.info(
|
| 298 |
+
f"Successfully verified {archive_path.name} in HuggingFace repo"
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Delete the local folder
|
| 302 |
+
shutil.rmtree(folder)
|
| 303 |
+
logger.info(f"Deleted local folder: {folder_name}")
|
| 304 |
+
|
| 305 |
+
# Delete the local archive
|
| 306 |
+
archive_path.unlink(missing_ok=True)
|
| 307 |
+
logger.info(f"Deleted local archive: {archive_path.name}")
|
| 308 |
+
else:
|
| 309 |
+
logger.error(
|
| 310 |
+
f"Could not verify {archive_path.name} in repo. Keeping local files."
|
| 311 |
+
)
|
| 312 |
+
# Keep both the folder and archive for safety
|
| 313 |
+
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.error(f"Error processing folder {folder.name}: {e}", exc_info=True)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def _compress_folder(folder_path: Path) -> Path | None:
|
| 319 |
+
"""
|
| 320 |
+
Compress a folder into a tar.gz archive.
|
| 321 |
+
|
| 322 |
+
Args:
|
| 323 |
+
folder_path: Path to the folder to compress
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Path to the created archive file, or None if failed
|
| 327 |
+
"""
|
| 328 |
+
try:
|
| 329 |
+
archive_name = f"{folder_path.name}.tar.gz"
|
| 330 |
+
archive_path = folder_path.parent / archive_name
|
| 331 |
+
|
| 332 |
+
logger.info(f"Compressing {folder_path.name} to {archive_name}")
|
| 333 |
+
|
| 334 |
+
with tarfile.open(archive_path, "w:gz") as tar:
|
| 335 |
+
tar.add(folder_path, arcname=folder_path.name)
|
| 336 |
+
|
| 337 |
+
archive_size = archive_path.stat().st_size
|
| 338 |
+
logger.info(
|
| 339 |
+
f"Created archive {archive_name} ({archive_size / 1024 / 1024:.2f} MB)"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
return archive_path
|
| 343 |
+
|
| 344 |
+
except Exception as e:
|
| 345 |
+
logger.error(f"Error compressing folder {folder_path}: {e}", exc_info=True)
|
| 346 |
+
return None
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def _upload_to_huggingface(
|
| 350 |
+
hf_api: HfApi, hf_dataset_repo: str, archive_path: Path
|
| 351 |
+
) -> bool:
|
| 352 |
+
"""
|
| 353 |
+
Upload an archive file to HuggingFace dataset repository.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
hf_api: HuggingFace API client
|
| 357 |
+
hf_dataset_repo: HuggingFace dataset repository ID
|
| 358 |
+
archive_path: Path to the archive file
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
True if upload succeeded, False otherwise
|
| 362 |
+
"""
|
| 363 |
+
try:
|
| 364 |
+
logger.info(
|
| 365 |
+
f"Uploading {archive_path.name} to HuggingFace repo {hf_dataset_repo}"
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
hf_api.upload_file(
|
| 369 |
+
path_or_fileobj=str(archive_path),
|
| 370 |
+
path_in_repo=archive_path.name,
|
| 371 |
+
repo_id=hf_dataset_repo,
|
| 372 |
+
repo_type="dataset",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
logger.info(f"Successfully uploaded {archive_path.name} to HuggingFace")
|
| 376 |
+
return True
|
| 377 |
+
|
| 378 |
+
except Exception as e:
|
| 379 |
+
logger.error(
|
| 380 |
+
f"Error uploading {archive_path.name} to HuggingFace: {e}", exc_info=True
|
| 381 |
+
)
|
| 382 |
+
return False
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _verify_file_in_repo(hf_dataset_repo: str, hf_token: str, filename: str) -> bool:
|
| 386 |
+
"""
|
| 387 |
+
Verify that a file exists in the HuggingFace repository.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
hf_dataset_repo: HuggingFace dataset repository ID
|
| 391 |
+
hf_token: HuggingFace API token
|
| 392 |
+
filename: Name of the file to verify
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
True if file exists in repo, False otherwise
|
| 396 |
+
"""
|
| 397 |
+
try:
|
| 398 |
+
logger.info(f"Verifying {filename} exists in HuggingFace repo")
|
| 399 |
+
|
| 400 |
+
# Try to get file info - this will raise an error if file doesn't exist
|
| 401 |
+
hf_hub_download(
|
| 402 |
+
repo_id=hf_dataset_repo,
|
| 403 |
+
filename=filename,
|
| 404 |
+
repo_type="dataset",
|
| 405 |
+
token=hf_token,
|
| 406 |
+
local_dir_use_symlinks=False,
|
| 407 |
+
# Just check if file exists without actually downloading
|
| 408 |
+
cache_dir=None,
|
| 409 |
+
local_files_only=False,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
logger.info(f"Verified {filename} exists in repo")
|
| 413 |
+
return True
|
| 414 |
+
|
| 415 |
+
except HfHubHTTPError as e:
|
| 416 |
+
if e.response.status_code == 404:
|
| 417 |
+
logger.error(f"File {filename} not found in repo (404)")
|
| 418 |
+
else:
|
| 419 |
+
logger.error(f"HTTP error verifying file: {e}")
|
| 420 |
+
return False
|
| 421 |
+
except Exception as e:
|
| 422 |
+
logger.error(f"Error verifying {filename} in repo: {e}", exc_info=True)
|
| 423 |
+
return False
|
cua2-core/src/cua2_core/services/sandbox_service.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
import logging
|
| 3 |
import os
|
| 4 |
import time
|
| 5 |
from datetime import datetime
|
| 6 |
-
from typing import Any
|
| 7 |
|
| 8 |
from e2b_desktop import Sandbox
|
| 9 |
-
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
|
| 13 |
-
SANDBOX_TIMEOUT =
|
|
|
|
| 14 |
WIDTH = 1280
|
| 15 |
HEIGHT = 960
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class SandboxService:
|
| 19 |
def __init__(self, max_sandboxes: int = 50):
|
| 20 |
if not os.getenv("E2B_API_KEY"):
|
|
@@ -24,70 +30,154 @@ class SandboxService:
|
|
| 24 |
self.sandbox_metadata: dict[str, dict[str, Any]] = {}
|
| 25 |
self.sandbox_lock = asyncio.Lock()
|
| 26 |
|
| 27 |
-
async def
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
if (
|
| 32 |
session_hash in self.sandboxes
|
| 33 |
and session_hash in self.sandbox_metadata
|
| 34 |
-
and
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
):
|
| 37 |
print(f"Reusing Sandbox for session {session_hash}")
|
| 38 |
self.sandbox_metadata[session_hash]["last_accessed"] = current_time
|
| 39 |
-
return
|
| 40 |
-
|
| 41 |
-
if session_hash in self.sandboxes:
|
| 42 |
-
try:
|
| 43 |
-
print(f"Closing expired sandbox for session {session_hash}")
|
| 44 |
-
await asyncio.to_thread(self.sandboxes[session_hash].kill)
|
| 45 |
-
except Exception as e:
|
| 46 |
-
print(f"Error closing expired sandbox: {str(e)}")
|
| 47 |
-
elif len(self.sandboxes) >= self.max_sandboxes:
|
| 48 |
-
return None
|
| 49 |
-
|
| 50 |
-
print(f"Creating new sandbox for session {session_hash}")
|
| 51 |
-
|
| 52 |
-
def create_and_setup_sandbox():
|
| 53 |
-
desktop = Sandbox.create(
|
| 54 |
-
api_key=os.getenv("E2B_API_KEY"),
|
| 55 |
-
resolution=(WIDTH, HEIGHT),
|
| 56 |
-
dpi=96,
|
| 57 |
-
timeout=SANDBOX_TIMEOUT,
|
| 58 |
-
template="k0wmnzir0zuzye6dndlw",
|
| 59 |
)
|
| 60 |
-
desktop.stream.start(require_auth=True)
|
| 61 |
-
setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
|
| 62 |
-
desktop.commands.run(setup_cmd)
|
| 63 |
-
time.sleep(3)
|
| 64 |
-
return desktop
|
| 65 |
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
self.sandbox_metadata[session_hash] = {
|
|
|
|
| 72 |
"created_at": current_time,
|
| 73 |
"last_accessed": current_time,
|
| 74 |
}
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
async def release_sandbox(self, session_hash: str):
|
|
|
|
|
|
|
|
|
|
| 78 |
async with self.sandbox_lock:
|
| 79 |
if session_hash in self.sandboxes:
|
| 80 |
print(f"Releasing sandbox for session {session_hash}")
|
| 81 |
-
|
| 82 |
del self.sandboxes[session_hash]
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
async def cleanup_sandboxes(self):
|
|
|
|
|
|
|
|
|
|
| 86 |
async with self.sandbox_lock:
|
| 87 |
for session_hash in list(self.sandboxes.keys()):
|
| 88 |
-
|
| 89 |
del self.sandboxes[session_hash]
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
from datetime import datetime
|
| 5 |
+
from typing import Any, Literal
|
| 6 |
|
| 7 |
from e2b_desktop import Sandbox
|
| 8 |
+
from pydantic import BaseModel
|
|
|
|
| 9 |
|
| 10 |
SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
|
| 11 |
+
SANDBOX_TIMEOUT = 500
|
| 12 |
+
SANDBOX_CREATION_TIMEOUT = 180
|
| 13 |
WIDTH = 1280
|
| 14 |
HEIGHT = 960
|
| 15 |
|
| 16 |
|
| 17 |
+
class SandboxResponse(BaseModel):
|
| 18 |
+
model_config = {"arbitrary_types_allowed": True}
|
| 19 |
+
|
| 20 |
+
sandbox: Sandbox | None
|
| 21 |
+
state: Literal["creating", "ready", "max_sandboxes_reached"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
class SandboxService:
|
| 25 |
def __init__(self, max_sandboxes: int = 50):
|
| 26 |
if not os.getenv("E2B_API_KEY"):
|
|
|
|
| 30 |
self.sandbox_metadata: dict[str, dict[str, Any]] = {}
|
| 31 |
self.sandbox_lock = asyncio.Lock()
|
| 32 |
|
| 33 |
+
async def _create_sandbox_background(
|
| 34 |
+
self, session_hash: str, expired_sandbox: Sandbox | None
|
| 35 |
+
):
|
| 36 |
+
"""Background task to create and setup a sandbox."""
|
| 37 |
+
# Kill expired sandbox first
|
| 38 |
+
if expired_sandbox:
|
| 39 |
+
try:
|
| 40 |
+
print(f"Closing expired sandbox for session {session_hash}")
|
| 41 |
+
await asyncio.to_thread(expired_sandbox.kill)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error closing expired sandbox: {str(e)}")
|
| 44 |
+
|
| 45 |
+
def create_and_setup_sandbox():
|
| 46 |
+
desktop = Sandbox.create(
|
| 47 |
+
api_key=os.getenv("E2B_API_KEY"),
|
| 48 |
+
resolution=(WIDTH, HEIGHT),
|
| 49 |
+
dpi=96,
|
| 50 |
+
timeout=SANDBOX_TIMEOUT,
|
| 51 |
+
template="k0wmnzir0zuzye6dndlw",
|
| 52 |
+
)
|
| 53 |
+
desktop.stream.start(require_auth=True)
|
| 54 |
+
setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
|
| 55 |
+
desktop.commands.run(setup_cmd)
|
| 56 |
+
time.sleep(3)
|
| 57 |
+
return desktop
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
desktop = await asyncio.to_thread(create_and_setup_sandbox)
|
| 61 |
+
print(f"Sandbox ID for session {session_hash} is {desktop.sandbox_id}.")
|
| 62 |
|
| 63 |
+
# Log sandbox creation
|
| 64 |
+
|
| 65 |
+
# Update sandbox state under lock
|
| 66 |
+
async with self.sandbox_lock:
|
| 67 |
+
self.sandboxes[session_hash] = desktop
|
| 68 |
+
self.sandbox_metadata[session_hash]["state"] = "ready"
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error creating sandbox for session {session_hash}: {str(e)}")
|
| 72 |
+
# Clean up metadata on failure
|
| 73 |
+
async with self.sandbox_lock:
|
| 74 |
+
if session_hash in self.sandbox_metadata:
|
| 75 |
+
del self.sandbox_metadata[session_hash]
|
| 76 |
+
|
| 77 |
+
async def acquire_sandbox(self, session_hash: str) -> SandboxResponse:
|
| 78 |
+
current_time = datetime.now()
|
| 79 |
+
should_create = False
|
| 80 |
+
expired_sandbox = None
|
| 81 |
+
|
| 82 |
+
# Quick check under lock - only check state and mark creation
|
| 83 |
+
async with self.sandbox_lock:
|
| 84 |
+
# Check if sandbox exists and is ready
|
| 85 |
if (
|
| 86 |
session_hash in self.sandboxes
|
| 87 |
and session_hash in self.sandbox_metadata
|
| 88 |
+
and self.sandbox_metadata[session_hash].get("state") == "ready"
|
| 89 |
+
and (
|
| 90 |
+
current_time - self.sandbox_metadata[session_hash]["created_at"]
|
| 91 |
+
).total_seconds()
|
| 92 |
+
< SANDBOX_CREATION_TIMEOUT
|
| 93 |
):
|
| 94 |
print(f"Reusing Sandbox for session {session_hash}")
|
| 95 |
self.sandbox_metadata[session_hash]["last_accessed"] = current_time
|
| 96 |
+
return SandboxResponse(
|
| 97 |
+
sandbox=self.sandboxes[session_hash], state="ready"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
# Check if sandbox is already being created
|
| 101 |
+
if (
|
| 102 |
+
session_hash in self.sandbox_metadata
|
| 103 |
+
and self.sandbox_metadata[session_hash].get("state") == "creating"
|
| 104 |
+
):
|
| 105 |
+
print(f"Sandbox for session {session_hash} is already being created")
|
| 106 |
+
return SandboxResponse(sandbox=None, state="creating")
|
| 107 |
|
| 108 |
+
# Mark expired sandbox for cleanup (remove from dict within lock)
|
| 109 |
+
if session_hash in self.sandboxes:
|
| 110 |
+
print(f"Marking expired sandbox for session {session_hash} for cleanup")
|
| 111 |
+
expired_sandbox = self.sandboxes[session_hash]
|
| 112 |
+
del self.sandboxes[session_hash]
|
| 113 |
+
if session_hash in self.sandbox_metadata:
|
| 114 |
+
del self.sandbox_metadata[session_hash]
|
| 115 |
|
| 116 |
+
# Check if we have capacity
|
| 117 |
+
if len(self.sandboxes) >= self.max_sandboxes:
|
| 118 |
+
return SandboxResponse(sandbox=None, state="max_sandboxes_reached")
|
| 119 |
+
|
| 120 |
+
# Mark that we're creating this sandbox
|
| 121 |
+
print(f"Creating new sandbox for session {session_hash}")
|
| 122 |
self.sandbox_metadata[session_hash] = {
|
| 123 |
+
"state": "creating",
|
| 124 |
"created_at": current_time,
|
| 125 |
"last_accessed": current_time,
|
| 126 |
}
|
| 127 |
+
should_create = True
|
| 128 |
+
|
| 129 |
+
# Start sandbox creation in background without waiting
|
| 130 |
+
if should_create:
|
| 131 |
+
asyncio.create_task(
|
| 132 |
+
self._create_sandbox_background(session_hash, expired_sandbox)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
async with self.sandbox_lock:
|
| 136 |
+
if self.sandbox_metadata[session_hash]["state"] == "creating":
|
| 137 |
+
return SandboxResponse(sandbox=None, state="creating")
|
| 138 |
+
if self.sandbox_metadata[session_hash]["state"] == "ready":
|
| 139 |
+
return SandboxResponse(
|
| 140 |
+
sandbox=self.sandboxes[session_hash], state="ready"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
return SandboxResponse(sandbox=None, state="creating")
|
| 144 |
|
| 145 |
async def release_sandbox(self, session_hash: str):
|
| 146 |
+
sandbox_to_kill = None
|
| 147 |
+
|
| 148 |
+
# Remove from dictionaries under lock
|
| 149 |
async with self.sandbox_lock:
|
| 150 |
if session_hash in self.sandboxes:
|
| 151 |
print(f"Releasing sandbox for session {session_hash}")
|
| 152 |
+
sandbox_to_kill = self.sandboxes[session_hash]
|
| 153 |
del self.sandboxes[session_hash]
|
| 154 |
+
if session_hash in self.sandbox_metadata:
|
| 155 |
+
del self.sandbox_metadata[session_hash]
|
| 156 |
+
|
| 157 |
+
# Kill sandbox outside of lock
|
| 158 |
+
if sandbox_to_kill:
|
| 159 |
+
try:
|
| 160 |
+
await asyncio.to_thread(sandbox_to_kill.kill)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"Error killing sandbox for session {session_hash}: {str(e)}")
|
| 163 |
|
| 164 |
async def cleanup_sandboxes(self):
|
| 165 |
+
sandboxes_to_kill = []
|
| 166 |
+
|
| 167 |
+
# Collect sandboxes under lock
|
| 168 |
async with self.sandbox_lock:
|
| 169 |
for session_hash in list(self.sandboxes.keys()):
|
| 170 |
+
sandboxes_to_kill.append((session_hash, self.sandboxes[session_hash]))
|
| 171 |
del self.sandboxes[session_hash]
|
| 172 |
+
if session_hash in self.sandbox_metadata:
|
| 173 |
+
del self.sandbox_metadata[session_hash]
|
| 174 |
+
|
| 175 |
+
# Kill all sandboxes outside of lock
|
| 176 |
+
for session_hash, sandbox in sandboxes_to_kill:
|
| 177 |
+
try:
|
| 178 |
+
await asyncio.to_thread(sandbox.kill)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"Error killing sandbox for session {session_hash}: {str(e)}")
|
| 181 |
|
| 182 |
|
| 183 |
if __name__ == "__main__":
|
cua2-core/tests/test_archival_service.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the ArchivalService multiprocessing implementation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import tarfile
|
| 8 |
+
import tempfile
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from unittest.mock import MagicMock, Mock, patch
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
from cua2_core.services.archival_service import (
|
| 15 |
+
ArchivalService,
|
| 16 |
+
_compress_folder,
|
| 17 |
+
_process_old_folders,
|
| 18 |
+
_upload_to_huggingface,
|
| 19 |
+
_verify_file_in_repo,
|
| 20 |
+
)
|
| 21 |
+
from huggingface_hub.utils import HfHubHTTPError
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def temp_data_dir():
|
| 26 |
+
"""Create a temporary data directory for testing."""
|
| 27 |
+
temp_dir = tempfile.mkdtemp()
|
| 28 |
+
yield temp_dir
|
| 29 |
+
# Cleanup
|
| 30 |
+
if os.path.exists(temp_dir):
|
| 31 |
+
shutil.rmtree(temp_dir)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@pytest.fixture
|
| 35 |
+
def mock_hf_api():
|
| 36 |
+
"""Create a mock HuggingFace API client."""
|
| 37 |
+
mock_api = MagicMock()
|
| 38 |
+
mock_api.upload_file.return_value = None
|
| 39 |
+
return mock_api
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@pytest.fixture
|
| 43 |
+
def archival_service(temp_data_dir):
|
| 44 |
+
"""Create an ArchivalService instance for testing."""
|
| 45 |
+
service = ArchivalService(
|
| 46 |
+
hf_token="test_token",
|
| 47 |
+
hf_dataset_repo="test_user/test_repo",
|
| 48 |
+
data_dir=temp_data_dir,
|
| 49 |
+
archive_interval_minutes=1, # Short interval for testing
|
| 50 |
+
folder_age_threshold_minutes=1,
|
| 51 |
+
)
|
| 52 |
+
yield service
|
| 53 |
+
# Cleanup - stop the service if running
|
| 54 |
+
if service.is_alive():
|
| 55 |
+
service.stop(timeout=5.0)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TestArchivalServiceInitialization:
|
| 59 |
+
"""Test ArchivalService initialization."""
|
| 60 |
+
|
| 61 |
+
def test_init_with_defaults(self):
|
| 62 |
+
"""Test initialization with default values."""
|
| 63 |
+
with patch.dict(os.environ, {"HF_TOKEN": "env_token"}, clear=False):
|
| 64 |
+
# Need to pass the env token explicitly since os.getenv is called at function definition time
|
| 65 |
+
service = ArchivalService(hf_token=os.getenv("HF_TOKEN"))
|
| 66 |
+
assert service.hf_token == "env_token"
|
| 67 |
+
assert service.hf_dataset_repo == "smolagents/cua_traces"
|
| 68 |
+
assert service.data_dir == "data"
|
| 69 |
+
assert service.archive_interval_minutes == 30
|
| 70 |
+
assert service.folder_age_threshold_minutes == 30
|
| 71 |
+
assert service._process is None
|
| 72 |
+
assert not service.is_alive()
|
| 73 |
+
|
| 74 |
+
def test_init_with_custom_values(self):
|
| 75 |
+
"""Test initialization with custom values."""
|
| 76 |
+
service = ArchivalService(
|
| 77 |
+
hf_token="custom_token",
|
| 78 |
+
hf_dataset_repo="custom/repo",
|
| 79 |
+
data_dir="/custom/path",
|
| 80 |
+
archive_interval_minutes=60,
|
| 81 |
+
folder_age_threshold_minutes=120,
|
| 82 |
+
)
|
| 83 |
+
assert service.hf_token == "custom_token"
|
| 84 |
+
assert service.hf_dataset_repo == "custom/repo"
|
| 85 |
+
assert service.data_dir == "/custom/path"
|
| 86 |
+
assert service.archive_interval_minutes == 60
|
| 87 |
+
assert service.folder_age_threshold_minutes == 120
|
| 88 |
+
|
| 89 |
+
def test_init_multiprocessing_components(self):
|
| 90 |
+
"""Test that multiprocessing components are initialized."""
|
| 91 |
+
service = ArchivalService(hf_token="test", hf_dataset_repo="test/test")
|
| 92 |
+
assert service._stop_event is not None
|
| 93 |
+
assert service._manager is not None
|
| 94 |
+
assert service._active_tasks is not None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TestArchivalServiceLifecycle:
|
| 98 |
+
"""Test ArchivalService lifecycle management."""
|
| 99 |
+
|
| 100 |
+
def test_start_service(self, archival_service):
|
| 101 |
+
"""Test starting the archival service."""
|
| 102 |
+
archival_service.start()
|
| 103 |
+
|
| 104 |
+
# Give the process a moment to start
|
| 105 |
+
time.sleep(0.5)
|
| 106 |
+
|
| 107 |
+
assert archival_service.is_alive()
|
| 108 |
+
assert archival_service._process is not None
|
| 109 |
+
assert archival_service._process.pid is not None
|
| 110 |
+
|
| 111 |
+
def test_start_without_credentials(self, temp_data_dir):
|
| 112 |
+
"""Test starting service without credentials logs warning."""
|
| 113 |
+
service = ArchivalService(
|
| 114 |
+
hf_token=None,
|
| 115 |
+
hf_dataset_repo=None,
|
| 116 |
+
data_dir=temp_data_dir,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
service.start()
|
| 120 |
+
|
| 121 |
+
# Service should not start
|
| 122 |
+
assert not service.is_alive()
|
| 123 |
+
assert service._process is None
|
| 124 |
+
|
| 125 |
+
def test_start_already_running(self, archival_service):
|
| 126 |
+
"""Test starting service when already running."""
|
| 127 |
+
archival_service.start()
|
| 128 |
+
time.sleep(0.5)
|
| 129 |
+
|
| 130 |
+
pid1 = archival_service._process.pid
|
| 131 |
+
|
| 132 |
+
# Try to start again
|
| 133 |
+
archival_service.start()
|
| 134 |
+
|
| 135 |
+
# Should be same process
|
| 136 |
+
assert archival_service._process.pid == pid1
|
| 137 |
+
|
| 138 |
+
def test_stop_service(self, archival_service):
|
| 139 |
+
"""Test stopping the archival service."""
|
| 140 |
+
archival_service.start()
|
| 141 |
+
time.sleep(0.5)
|
| 142 |
+
assert archival_service.is_alive()
|
| 143 |
+
|
| 144 |
+
archival_service.stop(timeout=5.0)
|
| 145 |
+
|
| 146 |
+
assert not archival_service.is_alive()
|
| 147 |
+
assert archival_service._process is None
|
| 148 |
+
|
| 149 |
+
def test_stop_not_running(self, archival_service):
|
| 150 |
+
"""Test stopping service when not running."""
|
| 151 |
+
# Should not raise any errors
|
| 152 |
+
archival_service.stop(timeout=1.0)
|
| 153 |
+
assert not archival_service.is_alive()
|
| 154 |
+
|
| 155 |
+
def test_stop_with_timeout(self, archival_service):
|
| 156 |
+
"""Test stop with timeout and force termination."""
|
| 157 |
+
archival_service.start()
|
| 158 |
+
time.sleep(0.5)
|
| 159 |
+
|
| 160 |
+
# Stop with very short timeout to test force kill path
|
| 161 |
+
archival_service.stop(timeout=0.001)
|
| 162 |
+
|
| 163 |
+
# Process should be stopped one way or another
|
| 164 |
+
time.sleep(0.5)
|
| 165 |
+
assert not archival_service.is_alive()
|
| 166 |
+
|
| 167 |
+
def test_is_alive_returns_false_when_not_started(self, archival_service):
|
| 168 |
+
"""Test is_alive returns False when service not started."""
|
| 169 |
+
assert not archival_service.is_alive()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class TestActiveTasksManagement:
|
| 173 |
+
"""Test active tasks management."""
|
| 174 |
+
|
| 175 |
+
def test_update_active_tasks(self, archival_service):
|
| 176 |
+
"""Test updating active tasks."""
|
| 177 |
+
task_ids = {"task-1", "task-2", "task-3"}
|
| 178 |
+
|
| 179 |
+
archival_service.update_active_tasks(task_ids)
|
| 180 |
+
|
| 181 |
+
# Verify tasks are in shared dict
|
| 182 |
+
for task_id in task_ids:
|
| 183 |
+
assert task_id in archival_service._active_tasks
|
| 184 |
+
|
| 185 |
+
def test_update_active_tasks_clears_old(self, archival_service):
|
| 186 |
+
"""Test that updating active tasks clears old ones."""
|
| 187 |
+
archival_service.update_active_tasks({"task-1", "task-2"})
|
| 188 |
+
assert "task-1" in archival_service._active_tasks
|
| 189 |
+
|
| 190 |
+
archival_service.update_active_tasks({"task-3"})
|
| 191 |
+
assert "task-1" not in archival_service._active_tasks
|
| 192 |
+
assert "task-3" in archival_service._active_tasks
|
| 193 |
+
|
| 194 |
+
def test_update_active_tasks_empty_set(self, archival_service):
|
| 195 |
+
"""Test updating with empty set."""
|
| 196 |
+
archival_service.update_active_tasks({"task-1"})
|
| 197 |
+
archival_service.update_active_tasks(set())
|
| 198 |
+
|
| 199 |
+
assert len(archival_service._active_tasks) == 0
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TestCompressFolder:
|
| 203 |
+
"""Test folder compression functionality."""
|
| 204 |
+
|
| 205 |
+
def test_compress_folder_success(self, temp_data_dir):
|
| 206 |
+
"""Test successful folder compression."""
|
| 207 |
+
# Create a test folder with some files
|
| 208 |
+
test_folder = Path(temp_data_dir) / "trace-test-123-model"
|
| 209 |
+
test_folder.mkdir()
|
| 210 |
+
(test_folder / "file1.txt").write_text("test content 1")
|
| 211 |
+
(test_folder / "file2.txt").write_text("test content 2")
|
| 212 |
+
|
| 213 |
+
archive_path = _compress_folder(test_folder)
|
| 214 |
+
|
| 215 |
+
assert archive_path is not None
|
| 216 |
+
assert archive_path.exists()
|
| 217 |
+
assert archive_path.name == "trace-test-123-model.tar.gz"
|
| 218 |
+
assert archive_path.suffix == ".gz"
|
| 219 |
+
|
| 220 |
+
# Verify archive contents
|
| 221 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 222 |
+
members = tar.getnames()
|
| 223 |
+
assert "trace-test-123-model/file1.txt" in members
|
| 224 |
+
assert "trace-test-123-model/file2.txt" in members
|
| 225 |
+
|
| 226 |
+
# Cleanup
|
| 227 |
+
archive_path.unlink()
|
| 228 |
+
|
| 229 |
+
def test_compress_folder_empty_folder(self, temp_data_dir):
|
| 230 |
+
"""Test compressing an empty folder."""
|
| 231 |
+
test_folder = Path(temp_data_dir) / "trace-empty-456-model"
|
| 232 |
+
test_folder.mkdir()
|
| 233 |
+
|
| 234 |
+
archive_path = _compress_folder(test_folder)
|
| 235 |
+
|
| 236 |
+
assert archive_path is not None
|
| 237 |
+
assert archive_path.exists()
|
| 238 |
+
|
| 239 |
+
# Cleanup
|
| 240 |
+
archive_path.unlink()
|
| 241 |
+
|
| 242 |
+
def test_compress_folder_nonexistent(self, temp_data_dir):
|
| 243 |
+
"""Test compressing a nonexistent folder."""
|
| 244 |
+
test_folder = Path(temp_data_dir) / "nonexistent"
|
| 245 |
+
|
| 246 |
+
archive_path = _compress_folder(test_folder)
|
| 247 |
+
|
| 248 |
+
assert archive_path is None
|
| 249 |
+
|
| 250 |
+
def test_compress_folder_with_subdirectories(self, temp_data_dir):
|
| 251 |
+
"""Test compressing folder with subdirectories."""
|
| 252 |
+
test_folder = Path(temp_data_dir) / "trace-nested-789-model"
|
| 253 |
+
test_folder.mkdir()
|
| 254 |
+
subdir = test_folder / "subdir"
|
| 255 |
+
subdir.mkdir()
|
| 256 |
+
(subdir / "nested.txt").write_text("nested content")
|
| 257 |
+
|
| 258 |
+
archive_path = _compress_folder(test_folder)
|
| 259 |
+
|
| 260 |
+
assert archive_path is not None
|
| 261 |
+
|
| 262 |
+
# Verify nested structure preserved
|
| 263 |
+
with tarfile.open(archive_path, "r:gz") as tar:
|
| 264 |
+
members = tar.getnames()
|
| 265 |
+
assert "trace-nested-789-model/subdir/nested.txt" in members
|
| 266 |
+
|
| 267 |
+
# Cleanup
|
| 268 |
+
archive_path.unlink()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class TestUploadToHuggingFace:
|
| 272 |
+
"""Test HuggingFace upload functionality."""
|
| 273 |
+
|
| 274 |
+
def test_upload_success(self, mock_hf_api, temp_data_dir):
|
| 275 |
+
"""Test successful upload to HuggingFace."""
|
| 276 |
+
# Create a test archive
|
| 277 |
+
archive_path = Path(temp_data_dir) / "test-archive.tar.gz"
|
| 278 |
+
archive_path.write_text("test archive content")
|
| 279 |
+
|
| 280 |
+
result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
|
| 281 |
+
|
| 282 |
+
assert result is True
|
| 283 |
+
mock_hf_api.upload_file.assert_called_once_with(
|
| 284 |
+
path_or_fileobj=str(archive_path),
|
| 285 |
+
path_in_repo="test-archive.tar.gz",
|
| 286 |
+
repo_id="test/repo",
|
| 287 |
+
repo_type="dataset",
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
def test_upload_failure(self, mock_hf_api, temp_data_dir):
|
| 291 |
+
"""Test upload failure."""
|
| 292 |
+
mock_hf_api.upload_file.side_effect = Exception("Upload failed")
|
| 293 |
+
|
| 294 |
+
archive_path = Path(temp_data_dir) / "test-archive.tar.gz"
|
| 295 |
+
archive_path.write_text("test archive content")
|
| 296 |
+
|
| 297 |
+
result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
|
| 298 |
+
|
| 299 |
+
assert result is False
|
| 300 |
+
|
| 301 |
+
def test_upload_nonexistent_file(self, mock_hf_api, temp_data_dir):
|
| 302 |
+
"""Test uploading a nonexistent file."""
|
| 303 |
+
archive_path = Path(temp_data_dir) / "nonexistent.tar.gz"
|
| 304 |
+
|
| 305 |
+
# Make the mock raise an exception when trying to upload nonexistent file
|
| 306 |
+
mock_hf_api.upload_file.side_effect = FileNotFoundError("File not found")
|
| 307 |
+
|
| 308 |
+
result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
|
| 309 |
+
|
| 310 |
+
assert result is False
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class TestVerifyFileInRepo:
|
| 314 |
+
"""Test file verification functionality."""
|
| 315 |
+
|
| 316 |
+
@patch("cua2_core.services.archival_service.hf_hub_download")
|
| 317 |
+
def test_verify_success(self, mock_download):
|
| 318 |
+
"""Test successful file verification."""
|
| 319 |
+
mock_download.return_value = "/path/to/file"
|
| 320 |
+
|
| 321 |
+
result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
|
| 322 |
+
|
| 323 |
+
assert result is True
|
| 324 |
+
mock_download.assert_called_once()
|
| 325 |
+
|
| 326 |
+
@patch("cua2_core.services.archival_service.hf_hub_download")
|
| 327 |
+
def test_verify_file_not_found(self, mock_download):
|
| 328 |
+
"""Test verification when file not found (404)."""
|
| 329 |
+
mock_response = Mock()
|
| 330 |
+
mock_response.status_code = 404
|
| 331 |
+
error = HfHubHTTPError("Not found", response=mock_response)
|
| 332 |
+
mock_download.side_effect = error
|
| 333 |
+
|
| 334 |
+
result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
|
| 335 |
+
|
| 336 |
+
assert result is False
|
| 337 |
+
|
| 338 |
+
@patch("cua2_core.services.archival_service.hf_hub_download")
|
| 339 |
+
def test_verify_other_http_error(self, mock_download):
|
| 340 |
+
"""Test verification with other HTTP errors."""
|
| 341 |
+
mock_response = Mock()
|
| 342 |
+
mock_response.status_code = 500
|
| 343 |
+
error = HfHubHTTPError("Server error", response=mock_response)
|
| 344 |
+
mock_download.side_effect = error
|
| 345 |
+
|
| 346 |
+
result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
|
| 347 |
+
|
| 348 |
+
assert result is False
|
| 349 |
+
|
| 350 |
+
@patch("cua2_core.services.archival_service.hf_hub_download")
|
| 351 |
+
def test_verify_generic_exception(self, mock_download):
|
| 352 |
+
"""Test verification with generic exception."""
|
| 353 |
+
mock_download.side_effect = Exception("Generic error")
|
| 354 |
+
|
| 355 |
+
result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
|
| 356 |
+
|
| 357 |
+
assert result is False
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class TestProcessOldFolders:
|
| 361 |
+
"""Test old folder processing logic."""
|
| 362 |
+
|
| 363 |
+
def test_process_old_folders_basic(self, temp_data_dir, mock_hf_api):
|
| 364 |
+
"""Test processing old folders."""
|
| 365 |
+
# Create an old folder (modify mtime to make it old)
|
| 366 |
+
old_folder = Path(temp_data_dir) / "trace-old123-model"
|
| 367 |
+
old_folder.mkdir()
|
| 368 |
+
(old_folder / "data.json").write_text('{"test": "data"}')
|
| 369 |
+
|
| 370 |
+
# Make it old by modifying mtime
|
| 371 |
+
old_time = time.time() - 3600 # 1 hour ago
|
| 372 |
+
os.utime(old_folder, (old_time, old_time))
|
| 373 |
+
|
| 374 |
+
active_tasks = {}
|
| 375 |
+
|
| 376 |
+
with patch(
|
| 377 |
+
"cua2_core.services.archival_service._verify_file_in_repo",
|
| 378 |
+
return_value=True,
|
| 379 |
+
):
|
| 380 |
+
_process_old_folders(
|
| 381 |
+
data_dir=temp_data_dir,
|
| 382 |
+
folder_age_threshold_minutes=1,
|
| 383 |
+
active_tasks=active_tasks,
|
| 384 |
+
hf_api=mock_hf_api,
|
| 385 |
+
hf_dataset_repo="test/repo",
|
| 386 |
+
hf_token="test_token",
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# Folder should be deleted after successful archival
|
| 390 |
+
assert not old_folder.exists()
|
| 391 |
+
|
| 392 |
+
# Upload should have been called
|
| 393 |
+
assert mock_hf_api.upload_file.called
|
| 394 |
+
|
| 395 |
+
def test_process_folders_skips_active_tasks(self, temp_data_dir, mock_hf_api):
|
| 396 |
+
"""Test that active tasks are skipped."""
|
| 397 |
+
# Create a folder for an active task
|
| 398 |
+
active_folder = Path(temp_data_dir) / "trace-active456-model"
|
| 399 |
+
active_folder.mkdir()
|
| 400 |
+
(active_folder / "data.json").write_text('{"test": "data"}')
|
| 401 |
+
|
| 402 |
+
# Make it old
|
| 403 |
+
old_time = time.time() - 3600
|
| 404 |
+
os.utime(active_folder, (old_time, old_time))
|
| 405 |
+
|
| 406 |
+
# Mark as active
|
| 407 |
+
active_tasks = {"active456": True}
|
| 408 |
+
|
| 409 |
+
_process_old_folders(
|
| 410 |
+
data_dir=temp_data_dir,
|
| 411 |
+
folder_age_threshold_minutes=1,
|
| 412 |
+
active_tasks=active_tasks,
|
| 413 |
+
hf_api=mock_hf_api,
|
| 414 |
+
hf_dataset_repo="test/repo",
|
| 415 |
+
hf_token="test_token",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Folder should still exist (not archived)
|
| 419 |
+
assert active_folder.exists()
|
| 420 |
+
|
| 421 |
+
# Upload should not have been called
|
| 422 |
+
assert not mock_hf_api.upload_file.called
|
| 423 |
+
|
| 424 |
+
def test_process_folders_skips_recent(self, temp_data_dir, mock_hf_api):
|
| 425 |
+
"""Test that recent folders are skipped."""
|
| 426 |
+
# Create a recent folder
|
| 427 |
+
recent_folder = Path(temp_data_dir) / "trace-recent789-model"
|
| 428 |
+
recent_folder.mkdir()
|
| 429 |
+
(recent_folder / "data.json").write_text('{"test": "data"}')
|
| 430 |
+
|
| 431 |
+
# Folder is fresh (current time)
|
| 432 |
+
|
| 433 |
+
active_tasks = {}
|
| 434 |
+
|
| 435 |
+
_process_old_folders(
|
| 436 |
+
data_dir=temp_data_dir,
|
| 437 |
+
folder_age_threshold_minutes=60, # 60 minutes threshold
|
| 438 |
+
active_tasks=active_tasks,
|
| 439 |
+
hf_api=mock_hf_api,
|
| 440 |
+
hf_dataset_repo="test/repo",
|
| 441 |
+
hf_token="test_token",
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# Folder should still exist (too recent)
|
| 445 |
+
assert recent_folder.exists()
|
| 446 |
+
|
| 447 |
+
# Upload should not have been called
|
| 448 |
+
assert not mock_hf_api.upload_file.called
|
| 449 |
+
|
| 450 |
+
def test_process_folders_keeps_on_verification_failure(
|
| 451 |
+
self, temp_data_dir, mock_hf_api
|
| 452 |
+
):
|
| 453 |
+
"""Test that folders are kept if verification fails."""
|
| 454 |
+
old_folder = Path(temp_data_dir) / "trace-verify-fail-model"
|
| 455 |
+
old_folder.mkdir()
|
| 456 |
+
(old_folder / "data.json").write_text('{"test": "data"}')
|
| 457 |
+
|
| 458 |
+
old_time = time.time() - 3600
|
| 459 |
+
os.utime(old_folder, (old_time, old_time))
|
| 460 |
+
|
| 461 |
+
active_tasks = {}
|
| 462 |
+
|
| 463 |
+
# Mock verification to fail
|
| 464 |
+
with patch(
|
| 465 |
+
"cua2_core.services.archival_service._verify_file_in_repo",
|
| 466 |
+
return_value=False,
|
| 467 |
+
):
|
| 468 |
+
_process_old_folders(
|
| 469 |
+
data_dir=temp_data_dir,
|
| 470 |
+
folder_age_threshold_minutes=1,
|
| 471 |
+
active_tasks=active_tasks,
|
| 472 |
+
hf_api=mock_hf_api,
|
| 473 |
+
hf_dataset_repo="test/repo",
|
| 474 |
+
hf_token="test_token",
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
# Folder should still exist (verification failed)
|
| 478 |
+
assert old_folder.exists()
|
| 479 |
+
|
| 480 |
+
def test_process_folders_handles_nonexistent_dir(self, mock_hf_api):
|
| 481 |
+
"""Test handling of nonexistent data directory."""
|
| 482 |
+
# Should not raise exception
|
| 483 |
+
_process_old_folders(
|
| 484 |
+
data_dir="/nonexistent/path",
|
| 485 |
+
folder_age_threshold_minutes=1,
|
| 486 |
+
active_tasks={},
|
| 487 |
+
hf_api=mock_hf_api,
|
| 488 |
+
hf_dataset_repo="test/repo",
|
| 489 |
+
hf_token="test_token",
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# No uploads should occur
|
| 493 |
+
assert not mock_hf_api.upload_file.called
|
| 494 |
+
|
| 495 |
+
def test_process_folders_handles_bad_folder_names(self, temp_data_dir, mock_hf_api):
|
| 496 |
+
"""Test handling of folders with unexpected name format."""
|
| 497 |
+
# Create folder with bad name format
|
| 498 |
+
bad_folder = Path(temp_data_dir) / "trace-invalid" # Missing parts
|
| 499 |
+
bad_folder.mkdir()
|
| 500 |
+
|
| 501 |
+
old_time = time.time() - 3600
|
| 502 |
+
os.utime(bad_folder, (old_time, old_time))
|
| 503 |
+
|
| 504 |
+
# Should not raise exception
|
| 505 |
+
_process_old_folders(
|
| 506 |
+
data_dir=temp_data_dir,
|
| 507 |
+
folder_age_threshold_minutes=1,
|
| 508 |
+
active_tasks={},
|
| 509 |
+
hf_api=mock_hf_api,
|
| 510 |
+
hf_dataset_repo="test/repo",
|
| 511 |
+
hf_token="test_token",
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Folder should still exist (invalid name)
|
| 515 |
+
assert bad_folder.exists()
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class TestIntegration:
|
| 519 |
+
"""Integration tests for the complete archival workflow."""
|
| 520 |
+
|
| 521 |
+
@pytest.mark.slow
|
| 522 |
+
def test_full_archival_workflow(self, temp_data_dir):
|
| 523 |
+
"""Test the complete archival workflow end-to-end."""
|
| 524 |
+
# Create service
|
| 525 |
+
service = ArchivalService(
|
| 526 |
+
hf_token="test_token",
|
| 527 |
+
hf_dataset_repo="test/repo",
|
| 528 |
+
data_dir=temp_data_dir,
|
| 529 |
+
archive_interval_minutes=1,
|
| 530 |
+
folder_age_threshold_minutes=1,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# Create test folder
|
| 534 |
+
test_folder = Path(temp_data_dir) / "trace-integration-test-model"
|
| 535 |
+
test_folder.mkdir()
|
| 536 |
+
(test_folder / "test.json").write_text('{"test": "data"}')
|
| 537 |
+
|
| 538 |
+
# Make it old
|
| 539 |
+
old_time = time.time() - 3600
|
| 540 |
+
os.utime(test_folder, (old_time, old_time))
|
| 541 |
+
|
| 542 |
+
# Start service (mocked to prevent actual HF upload)
|
| 543 |
+
with (
|
| 544 |
+
patch(
|
| 545 |
+
"cua2_core.services.archival_service._upload_to_huggingface",
|
| 546 |
+
return_value=True,
|
| 547 |
+
),
|
| 548 |
+
patch(
|
| 549 |
+
"cua2_core.services.archival_service._verify_file_in_repo",
|
| 550 |
+
return_value=True,
|
| 551 |
+
),
|
| 552 |
+
):
|
| 553 |
+
service.start()
|
| 554 |
+
time.sleep(2) # Wait for at least one cycle
|
| 555 |
+
|
| 556 |
+
service.stop(timeout=5.0)
|
| 557 |
+
|
| 558 |
+
assert not service.is_alive()
|
| 559 |
+
|
| 560 |
+
def test_service_survives_worker_errors(self, temp_data_dir):
|
| 561 |
+
"""Test that service continues running despite worker errors."""
|
| 562 |
+
service = ArchivalService(
|
| 563 |
+
hf_token="test_token",
|
| 564 |
+
hf_dataset_repo="test/repo",
|
| 565 |
+
data_dir=temp_data_dir,
|
| 566 |
+
archive_interval_minutes=1,
|
| 567 |
+
folder_age_threshold_minutes=1,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Mock to raise exception
|
| 571 |
+
with patch(
|
| 572 |
+
"cua2_core.services.archival_service._process_old_folders",
|
| 573 |
+
side_effect=Exception("Test error"),
|
| 574 |
+
):
|
| 575 |
+
service.start()
|
| 576 |
+
time.sleep(2)
|
| 577 |
+
|
| 578 |
+
# Service should still be alive
|
| 579 |
+
assert service.is_alive()
|
| 580 |
+
|
| 581 |
+
service.stop(timeout=5.0)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
if __name__ == "__main__":
|
| 585 |
+
pytest.main([__file__, "-v", "--tb=short"])
|
cua2-front/src/App.tsx
CHANGED
|
@@ -15,9 +15,10 @@ const App = () => {
|
|
| 15 |
// Initialize WebSocket connection at app level so it persists across route changes
|
| 16 |
const { stopCurrentTask } = useAgentWebSocket({ url: getWebSocketUrl() });
|
| 17 |
|
| 18 |
-
// Store
|
| 19 |
(window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask = stopCurrentTask;
|
| 20 |
|
|
|
|
| 21 |
return (
|
| 22 |
<ThemeProvider theme={theme}>
|
| 23 |
<CssBaseline />
|
|
|
|
| 15 |
// Initialize WebSocket connection at app level so it persists across route changes
|
| 16 |
const { stopCurrentTask } = useAgentWebSocket({ url: getWebSocketUrl() });
|
| 17 |
|
| 18 |
+
// Store functions in window for global access
|
| 19 |
(window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask = stopCurrentTask;
|
| 20 |
|
| 21 |
+
|
| 22 |
return (
|
| 23 |
<ThemeProvider theme={theme}>
|
| 24 |
<CssBaseline />
|
cua2-front/src/components/sandbox/completionview/CompletionView.tsx
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
-
import
|
| 2 |
-
import
|
| 3 |
-
import CheckIcon from '@mui/icons-material/Check';
|
| 4 |
-
import CloseIcon from '@mui/icons-material/Close';
|
| 5 |
-
import StopCircleIcon from '@mui/icons-material/StopCircle';
|
| 6 |
-
import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
|
| 7 |
import AddIcon from '@mui/icons-material/Add';
|
| 8 |
-
import SmartToyIcon from '@mui/icons-material/SmartToy';
|
| 9 |
import AssignmentIcon from '@mui/icons-material/Assignment';
|
| 10 |
import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline';
|
| 11 |
-
import
|
|
|
|
|
|
|
|
|
|
| 12 |
import InputIcon from '@mui/icons-material/Input';
|
| 13 |
import OutputIcon from '@mui/icons-material/Output';
|
| 14 |
-
import
|
| 15 |
-
import
|
|
|
|
|
|
|
| 16 |
import { DownloadGifButton } from './DownloadGifButton';
|
| 17 |
import { DownloadJsonButton } from './DownloadJsonButton';
|
| 18 |
|
|
@@ -65,14 +65,14 @@ export const CompletionView: React.FC<CompletionViewProps> = ({
|
|
| 65 |
case 'sandbox_timeout':
|
| 66 |
return {
|
| 67 |
icon: <AccessTimeIcon sx={{ fontSize: 28 }} />,
|
| 68 |
-
title: 'Sandbox
|
| 69 |
color: 'error.main',
|
| 70 |
};
|
| 71 |
case 'failure':
|
| 72 |
default:
|
| 73 |
return {
|
| 74 |
icon: <CloseIcon sx={{ fontSize: 28 }} />,
|
| 75 |
-
title: 'Task Failed',
|
| 76 |
color: 'error.main',
|
| 77 |
};
|
| 78 |
}
|
|
|
|
| 1 |
+
import { AgentStep, AgentTrace, FinalStep } from '@/types/agent';
|
| 2 |
+
import AccessTimeIcon from '@mui/icons-material/AccessTime';
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import AddIcon from '@mui/icons-material/Add';
|
|
|
|
| 4 |
import AssignmentIcon from '@mui/icons-material/Assignment';
|
| 5 |
import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline';
|
| 6 |
+
import CheckIcon from '@mui/icons-material/Check';
|
| 7 |
+
import CloseIcon from '@mui/icons-material/Close';
|
| 8 |
+
import FormatListNumberedIcon from '@mui/icons-material/FormatListNumbered';
|
| 9 |
+
import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
|
| 10 |
import InputIcon from '@mui/icons-material/Input';
|
| 11 |
import OutputIcon from '@mui/icons-material/Output';
|
| 12 |
+
import SmartToyIcon from '@mui/icons-material/SmartToy';
|
| 13 |
+
import StopCircleIcon from '@mui/icons-material/StopCircle';
|
| 14 |
+
import { Alert, Box, Button, Divider, Paper, Typography } from '@mui/material';
|
| 15 |
+
import React from 'react';
|
| 16 |
import { DownloadGifButton } from './DownloadGifButton';
|
| 17 |
import { DownloadJsonButton } from './DownloadJsonButton';
|
| 18 |
|
|
|
|
| 65 |
case 'sandbox_timeout':
|
| 66 |
return {
|
| 67 |
icon: <AccessTimeIcon sx={{ fontSize: 28 }} />,
|
| 68 |
+
title: 'Max Sandbox Time Reached',
|
| 69 |
color: 'error.main',
|
| 70 |
};
|
| 71 |
case 'failure':
|
| 72 |
default:
|
| 73 |
return {
|
| 74 |
icon: <CloseIcon sx={{ fontSize: 28 }} />,
|
| 75 |
+
title: 'Task Failed (Agent Internal Error)',
|
| 76 |
color: 'error.main',
|
| 77 |
};
|
| 78 |
}
|
cua2-front/src/hooks/useAgentWebSocket.ts
CHANGED
|
@@ -1,8 +1,7 @@
|
|
|
|
|
|
|
|
| 1 |
import { useCallback, useEffect } from 'react';
|
| 2 |
import { useWebSocket } from './useWebSocket';
|
| 3 |
-
import { useAgentStore } from '@/stores/agentStore';
|
| 4 |
-
import { WebSocketEvent, AgentTrace, AgentStep } from '@/types/agent';
|
| 5 |
-
import { ulid } from 'ulid';
|
| 6 |
|
| 7 |
interface UseAgentWebSocketOptions {
|
| 8 |
url: string;
|
|
@@ -11,6 +10,8 @@ interface UseAgentWebSocketOptions {
|
|
| 11 |
export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
| 12 |
const {
|
| 13 |
setTrace,
|
|
|
|
|
|
|
| 14 |
updateTraceWithStep,
|
| 15 |
completeTrace,
|
| 16 |
setIsAgentProcessing,
|
|
@@ -52,6 +53,7 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 52 |
numberOfSteps: 0,
|
| 53 |
maxSteps: 200,
|
| 54 |
completed: false,
|
|
|
|
| 55 |
},
|
| 56 |
};
|
| 57 |
|
|
@@ -94,10 +96,13 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 94 |
|
| 95 |
case 'heartbeat':
|
| 96 |
console.log('Heartbeat received:', event);
|
|
|
|
|
|
|
| 97 |
break;
|
|
|
|
| 98 |
}
|
| 99 |
},
|
| 100 |
-
[setTrace, updateTraceWithStep, completeTrace, setIsAgentProcessing, setIsConnectingToE2B, setVncUrl, setError, resetAgent]
|
| 101 |
);
|
| 102 |
|
| 103 |
// Handle WebSocket errors
|
|
@@ -113,10 +118,16 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 113 |
onError: handleWebSocketError,
|
| 114 |
});
|
| 115 |
|
| 116 |
-
// Sync connection state to store
|
| 117 |
useEffect(() => {
|
| 118 |
setIsConnected(isConnected);
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
// Create a global sendNewTask function that can be called from anywhere
|
| 122 |
useEffect(() => {
|
|
@@ -125,7 +136,13 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 125 |
// Reset agent state before starting a new task
|
| 126 |
resetAgent();
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
const trace: AgentTrace = {
|
| 130 |
id: traceId,
|
| 131 |
instruction,
|
|
@@ -140,7 +157,8 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 140 |
numberOfSteps: 0,
|
| 141 |
maxSteps: 200, // Default max steps, will be updated by backend
|
| 142 |
completed: false,
|
| 143 |
-
|
|
|
|
| 144 |
};
|
| 145 |
|
| 146 |
setTrace(trace);
|
|
@@ -155,7 +173,7 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
|
| 155 |
|
| 156 |
console.log('Task sent:', trace);
|
| 157 |
};
|
| 158 |
-
}, [setTrace, setIsAgentProcessing, setIsConnectingToE2B, sendMessage, resetAgent]);
|
| 159 |
|
| 160 |
// Function to stop the current task
|
| 161 |
const stopCurrentTask = useCallback(() => {
|
|
|
|
| 1 |
+
import { useAgentStore } from '@/stores/agentStore';
|
| 2 |
+
import { AgentTrace, AgentTraceMetadata, WebSocketEvent } from '@/types/agent';
|
| 3 |
import { useCallback, useEffect } from 'react';
|
| 4 |
import { useWebSocket } from './useWebSocket';
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
interface UseAgentWebSocketOptions {
|
| 7 |
url: string;
|
|
|
|
| 10 |
export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
|
| 11 |
const {
|
| 12 |
setTrace,
|
| 13 |
+
traceId,
|
| 14 |
+
setTraceId,
|
| 15 |
updateTraceWithStep,
|
| 16 |
completeTrace,
|
| 17 |
setIsAgentProcessing,
|
|
|
|
| 53 |
numberOfSteps: 0,
|
| 54 |
maxSteps: 200,
|
| 55 |
completed: false,
|
| 56 |
+
final_state: null,
|
| 57 |
},
|
| 58 |
};
|
| 59 |
|
|
|
|
| 96 |
|
| 97 |
case 'heartbeat':
|
| 98 |
console.log('Heartbeat received:', event);
|
| 99 |
+
setTraceId(event.uuid);
|
| 100 |
+
console.log('TraceId set from backend:', event.uuid);
|
| 101 |
break;
|
| 102 |
+
|
| 103 |
}
|
| 104 |
},
|
| 105 |
+
[setTrace, updateTraceWithStep, completeTrace, setIsAgentProcessing, setIsConnectingToE2B, setVncUrl, setError, resetAgent, setTraceId, traceId]
|
| 106 |
);
|
| 107 |
|
| 108 |
// Handle WebSocket errors
|
|
|
|
| 118 |
onError: handleWebSocketError,
|
| 119 |
});
|
| 120 |
|
| 121 |
+
// Sync connection state to store and clear traceId on disconnect
|
| 122 |
useEffect(() => {
|
| 123 |
setIsConnected(isConnected);
|
| 124 |
+
|
| 125 |
+
// Clear traceId when websocket disconnects
|
| 126 |
+
if (!isConnected) {
|
| 127 |
+
setTraceId(null);
|
| 128 |
+
console.log('WebSocket disconnected - traceId cleared');
|
| 129 |
+
}
|
| 130 |
+
}, [isConnected, setIsConnected, setTraceId]);
|
| 131 |
|
| 132 |
// Create a global sendNewTask function that can be called from anywhere
|
| 133 |
useEffect(() => {
|
|
|
|
| 136 |
// Reset agent state before starting a new task
|
| 137 |
resetAgent();
|
| 138 |
|
| 139 |
+
// Ensure traceId is set before creating trace
|
| 140 |
+
if (!traceId) {
|
| 141 |
+
console.error('Internal error: Cannot send task. TraceId not set. Refreshing page...');
|
| 142 |
+
window.location.reload();
|
| 143 |
+
return;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
const trace: AgentTrace = {
|
| 147 |
id: traceId,
|
| 148 |
instruction,
|
|
|
|
| 157 |
numberOfSteps: 0,
|
| 158 |
maxSteps: 200, // Default max steps, will be updated by backend
|
| 159 |
completed: false,
|
| 160 |
+
final_state: null,
|
| 161 |
+
} as AgentTraceMetadata,
|
| 162 |
};
|
| 163 |
|
| 164 |
setTrace(trace);
|
|
|
|
| 173 |
|
| 174 |
console.log('Task sent:', trace);
|
| 175 |
};
|
| 176 |
+
}, [setTrace, setIsAgentProcessing, setIsConnectingToE2B, sendMessage, resetAgent, traceId]);
|
| 177 |
|
| 178 |
// Function to stop the current task
|
| 179 |
const stopCurrentTask = useCallback(() => {
|
cua2-front/src/pages/Index.tsx
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
import { Header, Metadata, StackSteps, VNCStream } from '@/components/mock';
|
| 2 |
-
import { getWebSocketUrl } from '@/config/api';
|
| 3 |
-
import { useWebSocket } from '@/hooks/useWebSocket';
|
| 4 |
-
import { AgentStep, AgentTrace, WebSocketEvent } from '@/types/agent';
|
| 5 |
-
import { useState } from 'react';
|
| 6 |
-
import { ulid } from 'ulid';
|
| 7 |
-
|
| 8 |
-
const Index = () => {
|
| 9 |
-
const [trace, setTrace] = useState<AgentTrace>();
|
| 10 |
-
const [isAgentProcessing, setIsAgentProcessing] = useState(false);
|
| 11 |
-
const [vncUrl, setVncUrl] = useState<string>('');
|
| 12 |
-
const [selectedModelId, setSelectedModelId] = useState<string>("Qwen/Qwen3-VL-30B-A3B-Instruct");
|
| 13 |
-
|
| 14 |
-
// #################### WebSocket Connection ########################
|
| 15 |
-
|
| 16 |
-
// WebSocket connection - Automatically configured based on environment
|
| 17 |
-
const WS_URL = getWebSocketUrl();
|
| 18 |
-
|
| 19 |
-
const handleWebSocketMessage = (event: WebSocketEvent) => {
|
| 20 |
-
console.log('WebSocket event received:', event);
|
| 21 |
-
|
| 22 |
-
switch (event.type) {
|
| 23 |
-
case 'agent_start':
|
| 24 |
-
setIsAgentProcessing(true);
|
| 25 |
-
setTrace(event.agentTrace);
|
| 26 |
-
console.log('Agent start received:', event.agentTrace);
|
| 27 |
-
break;
|
| 28 |
-
|
| 29 |
-
case 'agent_progress':
|
| 30 |
-
// Add new step from a agent trace run with image, generated text, actions, tokens and timestamp
|
| 31 |
-
setTrace(prev => {
|
| 32 |
-
const existingSteps = prev?.steps || [] as AgentStep[];
|
| 33 |
-
const stepExists = existingSteps.some(step => step.stepId === event.agentStep.stepId);
|
| 34 |
-
|
| 35 |
-
if (!stepExists) {
|
| 36 |
-
return {
|
| 37 |
-
...prev,
|
| 38 |
-
steps: [...existingSteps, event.agentStep],
|
| 39 |
-
traceMetadata: event.traceMetadata,
|
| 40 |
-
isRunning: true
|
| 41 |
-
};
|
| 42 |
-
}
|
| 43 |
-
return prev;
|
| 44 |
-
});
|
| 45 |
-
console.log('Agent progress received:', event.agentStep);
|
| 46 |
-
break;
|
| 47 |
-
|
| 48 |
-
case 'agent_complete':
|
| 49 |
-
setIsAgentProcessing(false);
|
| 50 |
-
setTrace(trace => {
|
| 51 |
-
return trace.id === event.traceMetadata.traceId
|
| 52 |
-
? {
|
| 53 |
-
...trace,
|
| 54 |
-
isRunning: false,
|
| 55 |
-
metadata: event.traceMetadata,
|
| 56 |
-
}
|
| 57 |
-
: trace;
|
| 58 |
-
});
|
| 59 |
-
console.log('Agent complete received:', event.traceMetadata);
|
| 60 |
-
break;
|
| 61 |
-
|
| 62 |
-
case 'agent_error':
|
| 63 |
-
setIsAgentProcessing(false);
|
| 64 |
-
// TODO: Handle agent error
|
| 65 |
-
console.log('Agent error received:', event.error);
|
| 66 |
-
break;
|
| 67 |
-
|
| 68 |
-
case 'vnc_url_set':
|
| 69 |
-
setVncUrl(event.vncUrl);
|
| 70 |
-
// TODO: Handle VNC URL set
|
| 71 |
-
console.log('VNC URL set received:', event.vncUrl);
|
| 72 |
-
break;
|
| 73 |
-
|
| 74 |
-
case 'vnc_url_unset':
|
| 75 |
-
setVncUrl('');
|
| 76 |
-
// TODO: Handle VNC URL unset
|
| 77 |
-
console.log('VNC URL unset received:');
|
| 78 |
-
break;
|
| 79 |
-
|
| 80 |
-
case 'heartbeat':
|
| 81 |
-
console.log('Heartbeat received:', event);
|
| 82 |
-
break;
|
| 83 |
-
}
|
| 84 |
-
};
|
| 85 |
-
|
| 86 |
-
const handleWebSocketError = () => {
|
| 87 |
-
// WebSocket Frontend Error handling
|
| 88 |
-
|
| 89 |
-
};
|
| 90 |
-
|
| 91 |
-
const { isConnected, connectionState, sendMessage, manualReconnect } = useWebSocket({
|
| 92 |
-
url: WS_URL,
|
| 93 |
-
onMessage: handleWebSocketMessage,
|
| 94 |
-
onError: handleWebSocketError,
|
| 95 |
-
});
|
| 96 |
-
|
| 97 |
-
// #################### Frontend Functionality ########################
|
| 98 |
-
|
| 99 |
-
const handleModelId = (modelId: string) => {
|
| 100 |
-
setSelectedModelId(modelId);
|
| 101 |
-
};
|
| 102 |
-
|
| 103 |
-
const handleSendNewTask = (content: string, modelId: string) => {
|
| 104 |
-
const trace: AgentTrace = {
|
| 105 |
-
id: ulid(),
|
| 106 |
-
instruction: content,
|
| 107 |
-
modelId: selectedModelId,
|
| 108 |
-
timestamp: new Date(),
|
| 109 |
-
isRunning: true,
|
| 110 |
-
};
|
| 111 |
-
|
| 112 |
-
setTrace(trace);
|
| 113 |
-
|
| 114 |
-
// Send message to Python backend via WebSocket
|
| 115 |
-
sendMessage({
|
| 116 |
-
type: 'user_task',
|
| 117 |
-
trace: trace,
|
| 118 |
-
});
|
| 119 |
-
};
|
| 120 |
-
|
| 121 |
-
// #################### Mock Frontend Rendering ########################
|
| 122 |
-
|
| 123 |
-
return (
|
| 124 |
-
<div style={{ height: '100%', width: '100%', display: 'flex', flexDirection: 'column', backgroundColor: '#f3f4f6' }}>
|
| 125 |
-
<Header
|
| 126 |
-
isConnected={isConnected}
|
| 127 |
-
isAgentProcessing={isAgentProcessing}
|
| 128 |
-
onSendTask={handleSendNewTask}
|
| 129 |
-
/>
|
| 130 |
-
|
| 131 |
-
<div style={{ flex: 1, display: 'flex', justifyContent: 'center', alignItems: 'center', overflow: 'hidden', minHeight: 0, padding: '32px' }}>
|
| 132 |
-
<div style={{ width: '100%', height: '100%', maxWidth: '1400px', maxHeight: '900px', display: 'flex', flexDirection: 'row', overflow: 'hidden' }}>
|
| 133 |
-
{/* Left Side: VNC Stream + Metadata */}
|
| 134 |
-
<div style={{ flex: 1, display: 'flex', flexDirection: 'column', padding: '20px 12px', gap: '20px', minWidth: 0 }}>
|
| 135 |
-
<VNCStream vncUrl={vncUrl} />
|
| 136 |
-
<Metadata trace={trace} />
|
| 137 |
-
</div>
|
| 138 |
-
|
| 139 |
-
{/* Right Side: Stack Steps */}
|
| 140 |
-
<StackSteps trace={trace} />
|
| 141 |
-
</div>
|
| 142 |
-
</div>
|
| 143 |
-
</div>
|
| 144 |
-
);
|
| 145 |
-
};
|
| 146 |
-
|
| 147 |
-
export default Index;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cua2-front/src/pages/Task.tsx
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import React, { useEffect } from 'react';
|
| 2 |
-
import { useNavigate } from 'react-router-dom';
|
| 3 |
-
import { useAgentStore, selectTrace, selectIsAgentProcessing, selectVncUrl, selectMetadata, selectSelectedStep } from '@/stores/agentStore';
|
| 4 |
import { Header, SandboxViewer, StepsList, Timeline } from '@/components';
|
|
|
|
| 5 |
import { Box } from '@mui/material';
|
|
|
|
|
|
|
| 6 |
|
| 7 |
const Task = () => {
|
| 8 |
const navigate = useNavigate();
|
|
@@ -25,8 +25,19 @@ const Task = () => {
|
|
| 25 |
|
| 26 |
// Handler for going back to home
|
| 27 |
const handleBackToHome = () => {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
useAgentStore.getState().resetAgent();
|
| 29 |
-
|
|
|
|
|
|
|
| 30 |
};
|
| 31 |
|
| 32 |
// Determine if we should show success/fail status (same logic as SandboxViewer)
|
|
@@ -66,15 +77,15 @@ const Task = () => {
|
|
| 66 |
overflowX: 'hidden',
|
| 67 |
}}
|
| 68 |
>
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
{/* Left Side: OS Stream + Metadata */}
|
| 79 |
<Box
|
| 80 |
sx={{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import { Header, SandboxViewer, StepsList, Timeline } from '@/components';
|
| 2 |
+
import { selectIsAgentProcessing, selectMetadata, selectSelectedStep, selectTrace, selectVncUrl, useAgentStore } from '@/stores/agentStore';
|
| 3 |
import { Box } from '@mui/material';
|
| 4 |
+
import { useEffect } from 'react';
|
| 5 |
+
import { useNavigate } from 'react-router-dom';
|
| 6 |
|
| 7 |
const Task = () => {
|
| 8 |
const navigate = useNavigate();
|
|
|
|
| 25 |
|
| 26 |
// Handler for going back to home
|
| 27 |
const handleBackToHome = () => {
|
| 28 |
+
const currentTrace = useAgentStore.getState().trace;
|
| 29 |
+
|
| 30 |
+
// Stop the current task if it's running
|
| 31 |
+
const stopTask = (window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask;
|
| 32 |
+
if (stopTask) {
|
| 33 |
+
stopTask();
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
// Reset frontend state
|
| 37 |
useAgentStore.getState().resetAgent();
|
| 38 |
+
|
| 39 |
+
// Reload the page to reconnect websocket
|
| 40 |
+
window.location.href = '/';
|
| 41 |
};
|
| 42 |
|
| 43 |
// Determine if we should show success/fail status (same logic as SandboxViewer)
|
|
|
|
| 77 |
overflowX: 'hidden',
|
| 78 |
}}
|
| 79 |
>
|
| 80 |
+
<Box
|
| 81 |
+
sx={{
|
| 82 |
+
width: '100%',
|
| 83 |
+
display: 'flex',
|
| 84 |
+
flexDirection: { xs: 'column', md: 'row' },
|
| 85 |
+
p: { xs: 2, md: 4 },
|
| 86 |
+
pb: { xs: 2, md: 3 },
|
| 87 |
+
}}
|
| 88 |
+
>
|
| 89 |
{/* Left Side: OS Stream + Metadata */}
|
| 90 |
<Box
|
| 91 |
sx={{
|
cua2-front/src/stores/agentStore.ts
CHANGED
|
@@ -5,6 +5,7 @@ import { devtools } from 'zustand/middleware';
|
|
| 5 |
interface AgentState {
|
| 6 |
// State
|
| 7 |
trace?: AgentTrace;
|
|
|
|
| 8 |
isAgentProcessing: boolean;
|
| 9 |
isConnectingToE2B: boolean; // New state for E2B connection
|
| 10 |
vncUrl: string;
|
|
@@ -19,6 +20,7 @@ interface AgentState {
|
|
| 19 |
|
| 20 |
// Actions
|
| 21 |
setTrace: (trace: AgentTrace | undefined) => void;
|
|
|
|
| 22 |
updateTraceWithStep: (step: AgentStep, metadata: AgentTraceMetadata) => void;
|
| 23 |
completeTrace: (metadata: AgentTraceMetadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') => void;
|
| 24 |
setIsAgentProcessing: (processing: boolean) => void;
|
|
@@ -36,6 +38,7 @@ interface AgentState {
|
|
| 36 |
|
| 37 |
const initialState = {
|
| 38 |
trace: undefined,
|
|
|
|
| 39 |
isAgentProcessing: false,
|
| 40 |
isConnectingToE2B: false,
|
| 41 |
vncUrl: '',
|
|
@@ -58,6 +61,10 @@ export const useAgentStore = create<AgentState>()(
|
|
| 58 |
setTrace: (trace) =>
|
| 59 |
set({ trace }, false, 'setTrace'),
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
// Update trace with a new step
|
| 62 |
updateTraceWithStep: (step, metadata) =>
|
| 63 |
set(
|
|
@@ -90,62 +97,62 @@ export const useAgentStore = create<AgentState>()(
|
|
| 90 |
'updateTraceWithStep'
|
| 91 |
),
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
},
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
};
|
| 145 |
-
},
|
| 146 |
-
false,
|
| 147 |
-
'completeTrace'
|
| 148 |
-
),
|
| 149 |
|
| 150 |
// Set processing state
|
| 151 |
setIsAgentProcessing: (isAgentProcessing) =>
|
|
@@ -227,10 +234,11 @@ export const useAgentStore = create<AgentState>()(
|
|
| 227 |
toggleDarkMode: () =>
|
| 228 |
set((state) => ({ isDarkMode: !state.isDarkMode }), false, 'toggleDarkMode'),
|
| 229 |
|
| 230 |
-
// Reset agent state
|
| 231 |
resetAgent: () =>
|
| 232 |
set((state) => ({
|
| 233 |
...initialState,
|
|
|
|
| 234 |
isDarkMode: state.isDarkMode, // Keep dark mode preference
|
| 235 |
isConnected: state.isConnected, // Keep connection status
|
| 236 |
selectedModelId: state.selectedModelId, // Keep selected model
|
|
@@ -244,6 +252,7 @@ export const useAgentStore = create<AgentState>()(
|
|
| 244 |
|
| 245 |
// Selectors for better performance
|
| 246 |
export const selectTrace = (state: AgentState) => state.trace;
|
|
|
|
| 247 |
export const selectIsAgentProcessing = (state: AgentState) => state.isAgentProcessing;
|
| 248 |
export const selectIsConnectingToE2B = (state: AgentState) => state.isConnectingToE2B;
|
| 249 |
export const selectVncUrl = (state: AgentState) => state.vncUrl;
|
|
|
|
| 5 |
interface AgentState {
|
| 6 |
// State
|
| 7 |
trace?: AgentTrace;
|
| 8 |
+
traceId: string | null; // Set by backend heartbeat, persists during connection
|
| 9 |
isAgentProcessing: boolean;
|
| 10 |
isConnectingToE2B: boolean; // New state for E2B connection
|
| 11 |
vncUrl: string;
|
|
|
|
| 20 |
|
| 21 |
// Actions
|
| 22 |
setTrace: (trace: AgentTrace | undefined) => void;
|
| 23 |
+
setTraceId: (traceId: string | null) => void;
|
| 24 |
updateTraceWithStep: (step: AgentStep, metadata: AgentTraceMetadata) => void;
|
| 25 |
completeTrace: (metadata: AgentTraceMetadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') => void;
|
| 26 |
setIsAgentProcessing: (processing: boolean) => void;
|
|
|
|
| 38 |
|
| 39 |
const initialState = {
|
| 40 |
trace: undefined,
|
| 41 |
+
traceId: null, // Will be set by backend heartbeat
|
| 42 |
isAgentProcessing: false,
|
| 43 |
isConnectingToE2B: false,
|
| 44 |
vncUrl: '',
|
|
|
|
| 61 |
setTrace: (trace) =>
|
| 62 |
set({ trace }, false, 'setTrace'),
|
| 63 |
|
| 64 |
+
// Set trace ID (set by backend heartbeat, only cleared on disconnect)
|
| 65 |
+
setTraceId: (traceId) =>
|
| 66 |
+
set({ traceId }, false, 'setTraceId'),
|
| 67 |
+
|
| 68 |
// Update trace with a new step
|
| 69 |
updateTraceWithStep: (step, metadata) =>
|
| 70 |
set(
|
|
|
|
| 97 |
'updateTraceWithStep'
|
| 98 |
),
|
| 99 |
|
| 100 |
+
// Complete the trace
|
| 101 |
+
completeTrace: (metadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') =>
|
| 102 |
+
set(
|
| 103 |
+
(state) => {
|
| 104 |
+
if (!state.trace) return state;
|
| 105 |
+
|
| 106 |
+
// Preserve existing maxSteps if new metadata has 0
|
| 107 |
+
const updatedMetadata = {
|
| 108 |
+
...metadata,
|
| 109 |
+
maxSteps: metadata.maxSteps > 0
|
| 110 |
+
? metadata.maxSteps
|
| 111 |
+
: (state.trace.traceMetadata?.maxSteps || 200),
|
| 112 |
+
completed: true,
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
// Determine the final step type based on final_state from backend
|
| 116 |
+
let stepType: 'success' | 'failure' | 'stopped' | 'max_steps_reached' | 'sandbox_timeout';
|
| 117 |
+
let stepMessage: string | undefined;
|
| 118 |
+
|
| 119 |
+
if (finalState === 'stopped') {
|
| 120 |
+
stepType = 'stopped';
|
| 121 |
+
stepMessage = 'Task stopped by user';
|
| 122 |
+
} else if (finalState === 'max_steps_reached') {
|
| 123 |
+
stepType = 'max_steps_reached';
|
| 124 |
+
stepMessage = 'Maximum steps reached';
|
| 125 |
+
} else if (finalState === 'sandbox_timeout') {
|
| 126 |
+
stepType = 'sandbox_timeout';
|
| 127 |
+
stepMessage = 'Sandbox timeout';
|
| 128 |
+
} else if (finalState === 'error' || state.error) {
|
| 129 |
+
stepType = 'failure';
|
| 130 |
+
stepMessage = state.error || 'Task failed';
|
| 131 |
+
} else {
|
| 132 |
+
stepType = 'success';
|
| 133 |
+
stepMessage = undefined;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
const finalStep: FinalStep = {
|
| 137 |
+
type: stepType,
|
| 138 |
+
message: stepMessage,
|
| 139 |
+
metadata: updatedMetadata,
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
trace: {
|
| 144 |
+
...state.trace,
|
| 145 |
+
isRunning: false,
|
| 146 |
+
traceMetadata: updatedMetadata,
|
| 147 |
+
},
|
| 148 |
+
finalStep,
|
| 149 |
+
// Keep error in state for display
|
| 150 |
+
selectedStepIndex: null, // Reset to live mode on completion
|
| 151 |
+
};
|
| 152 |
},
|
| 153 |
+
false,
|
| 154 |
+
'completeTrace'
|
| 155 |
+
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
// Set processing state
|
| 158 |
setIsAgentProcessing: (isAgentProcessing) =>
|
|
|
|
| 234 |
toggleDarkMode: () =>
|
| 235 |
set((state) => ({ isDarkMode: !state.isDarkMode }), false, 'toggleDarkMode'),
|
| 236 |
|
| 237 |
+
// Reset agent state (but preserve traceId from backend during connection)
|
| 238 |
resetAgent: () =>
|
| 239 |
set((state) => ({
|
| 240 |
...initialState,
|
| 241 |
+
traceId: state.traceId, // IMPORTANT: Keep traceId from backend
|
| 242 |
isDarkMode: state.isDarkMode, // Keep dark mode preference
|
| 243 |
isConnected: state.isConnected, // Keep connection status
|
| 244 |
selectedModelId: state.selectedModelId, // Keep selected model
|
|
|
|
| 252 |
|
| 253 |
// Selectors for better performance
|
| 254 |
export const selectTrace = (state: AgentState) => state.trace;
|
| 255 |
+
export const selectTraceId = (state: AgentState) => state.traceId;
|
| 256 |
export const selectIsAgentProcessing = (state: AgentState) => state.isAgentProcessing;
|
| 257 |
export const selectIsConnectingToE2B = (state: AgentState) => state.isConnectingToE2B;
|
| 258 |
export const selectVncUrl = (state: AgentState) => state.vncUrl;
|
cua2-front/src/types/agent.ts
CHANGED
|
@@ -80,6 +80,7 @@ interface VncUrlUnsetEvent {
|
|
| 80 |
|
| 81 |
interface HeartbeatEvent {
|
| 82 |
type: 'heartbeat';
|
|
|
|
| 83 |
}
|
| 84 |
|
| 85 |
export type WebSocketEvent =
|
|
|
|
| 80 |
|
| 81 |
interface HeartbeatEvent {
|
| 82 |
type: 'heartbeat';
|
| 83 |
+
uuid: string;
|
| 84 |
}
|
| 85 |
|
| 86 |
export type WebSocketEvent =
|