A-Mahla commited on
Commit
8f4ea43
·
unverified ·
1 Parent(s): dcaec95

Amir/handle sandbox (#18)

Browse files

* ADD sandbox gestion

* ADD better sandbox gestion

* ADD better sandbox gestion

* ADD better sandbox gestion

* ADD better sandbox gestion

* ADD better sandbox gestion

README.md CHANGED
@@ -93,8 +93,34 @@ cp env.example .env
93
  Edit `.env` with your configuration:
94
  - API keys for OpenAI/LiteLLM
95
  - Database connections (if applicable)
 
96
  - Other service credentials
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  ### 4. Start Development Servers
99
 
100
  #### Option 1: Using Makefile (Recommended)
 
93
  Edit `.env` with your configuration:
94
  - API keys for OpenAI/LiteLLM
95
  - Database connections (if applicable)
96
+ - HuggingFace credentials for data archival (optional)
97
  - Other service credentials
98
 
99
+ #### Data Archival Configuration (Optional)
100
+
101
+ CUA2 includes an automatic data archival feature that backs up old trace data to HuggingFace datasets:
102
+
103
+ ```bash
104
+ # HuggingFace token for uploading archived data
105
+ HF_TOKEN=your_huggingface_token_here
106
+
107
+ # HuggingFace dataset repository ID (e.g., "username/dataset-name")
108
+ HF_DATASET_REPO=your_username/your_dataset_repo
109
+
110
+ # Check interval (default: 30 minutes)
111
+ ARCHIVE_INTERVAL_MINUTES=30
112
+
113
+ # Age threshold - folders older than this will be archived (default: 30 minutes)
114
+ FOLDER_AGE_THRESHOLD_MINUTES=30
115
+ ```
116
+
117
+ **How it works:**
118
+ 1. Every 30 minutes (configurable), the system checks the `data/` folder for trace folders
119
+ 2. Folders older than 30 minutes (configurable) are compressed into `.tar.gz` archives
120
+ 3. Archives are uploaded to your HuggingFace dataset repository
121
+ 4. After verifying successful upload, local folders are deleted to free up space
122
+ 5. This keeps your disk usage minimal while preserving all agent traces in the cloud
123
+
124
  ### 4. Start Development Servers
125
 
126
  #### Option 1: Using Makefile (Recommended)
cua2-core/pyproject.toml CHANGED
@@ -36,6 +36,7 @@ dependencies = [
36
  "smolagents[openai,litellm]==1.22.0",
37
  "openai==2.6.1",
38
  "e2b-desktop==2.1.0",
 
39
  ]
40
 
41
  [project.optional-dependencies]
 
36
  "smolagents[openai,litellm]==1.22.0",
37
  "openai==2.6.1",
38
  "e2b-desktop==2.1.0",
39
+ "huggingface_hub==1.1.2",
40
  ]
41
 
42
  [project.optional-dependencies]
cua2-core/pytest.ini CHANGED
@@ -11,3 +11,4 @@ addopts =
11
  markers =
12
  unit: Unit tests
13
  integration: Integration tests
 
 
11
  markers =
12
  unit: Unit tests
13
  integration: Integration tests
14
+ slow: Slow tests that take more time to execute
cua2-core/src/cua2_core/models/models.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import threading
4
  from datetime import datetime
5
  from typing import Annotated, Literal, Optional
 
6
 
7
  from cua2_core.services.agent_utils.function_parser import FunctionCall
8
  from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
@@ -106,6 +107,15 @@ class AgentStep(BaseModel):
106
  thought: Optional[str] = None
107
  actions: list[AgentAction] = []
108
 
 
 
 
 
 
 
 
 
 
109
  @field_serializer("actions")
110
  def serialize_actions(self, actions: list[AgentAction], _info):
111
  """Convert actions to list of strings when dumping (controlled by context)"""
@@ -206,6 +216,7 @@ class HeartbeatEvent(BaseModel):
206
  """Heartbeat event"""
207
 
208
  type: Literal["heartbeat"] = "heartbeat"
 
209
 
210
 
211
  WebSocketEvent: TypeAlias = Annotated[
@@ -266,7 +277,7 @@ class ActiveTask(BaseModel):
266
  self.model_dump(
267
  mode="json",
268
  exclude={"_file_locks"},
269
- context={"actions_as_json": True},
270
  ),
271
  f,
272
  indent=2,
@@ -286,7 +297,7 @@ class ActiveTask(BaseModel):
286
  self.model_dump(
287
  mode="json",
288
  exclude={"_file_locks"},
289
- context={"actions_as_json": True},
290
  ),
291
  f,
292
  indent=2,
 
3
  import threading
4
  from datetime import datetime
5
  from typing import Annotated, Literal, Optional
6
+ from uuid import uuid4
7
 
8
  from cua2_core.services.agent_utils.function_parser import FunctionCall
9
  from pydantic import BaseModel, Field, PrivateAttr, field_serializer, model_validator
 
107
  thought: Optional[str] = None
108
  actions: list[AgentAction] = []
109
 
110
+ @field_serializer("image")
111
+ def serialize_image(self, image: str, _info):
112
+ """Convert image to path when dumping to JSON"""
113
+
114
+ if _info.context and _info.context.get("image_as_path", True):
115
+ return f"{self.traceId}-{self.stepId}.png"
116
+
117
+ return image
118
+
119
  @field_serializer("actions")
120
  def serialize_actions(self, actions: list[AgentAction], _info):
121
  """Convert actions to list of strings when dumping (controlled by context)"""
 
216
  """Heartbeat event"""
217
 
218
  type: Literal["heartbeat"] = "heartbeat"
219
+ uuid: str = Field(default_factory=lambda: str(uuid4()))
220
 
221
 
222
  WebSocketEvent: TypeAlias = Annotated[
 
277
  self.model_dump(
278
  mode="json",
279
  exclude={"_file_locks"},
280
+ context={"actions_as_json": True, "image_as_path": True},
281
  ),
282
  f,
283
  indent=2,
 
297
  self.model_dump(
298
  mode="json",
299
  exclude={"_file_locks"},
300
+ context={"actions_as_json": True, "image_as_path": True},
301
  ),
302
  f,
303
  indent=2,
cua2-core/src/cua2_core/routes/websocket.py CHANGED
@@ -21,8 +21,9 @@ async def websocket_endpoint(websocket: WebSocket):
21
  await websocket_manager.connect(websocket)
22
 
23
  try:
24
- # Send welcome heartbeat
25
- welcome_message = HeartbeatEvent(type="heartbeat")
 
26
  await websocket_manager.send_message(welcome_message, websocket)
27
 
28
  # Keep the connection alive and wait for messages
@@ -100,5 +101,11 @@ async def websocket_endpoint(websocket: WebSocket):
100
  except Exception as e:
101
  print(f"WebSocket connection error: {e}")
102
  finally:
 
 
 
 
 
 
103
  # Ensure cleanup happens
104
  websocket_manager.disconnect(websocket)
 
21
  await websocket_manager.connect(websocket)
22
 
23
  try:
24
+ welcome_message = HeartbeatEvent(
25
+ uuid=await agent_service.create_id_and_sandbox(websocket)
26
+ )
27
  await websocket_manager.send_message(welcome_message, websocket)
28
 
29
  # Keep the connection alive and wait for messages
 
101
  except Exception as e:
102
  print(f"WebSocket connection error: {e}")
103
  finally:
104
+ # Cleanup tasks and sandboxes associated with this websocket
105
+ try:
106
+ await agent_service.cleanup_tasks_for_websocket(websocket)
107
+ except Exception as e:
108
+ print(f"Error cleaning up tasks for websocket: {e}")
109
+
110
  # Ensure cleanup happens
111
  websocket_manager.disconnect(websocket)
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import time
7
  from io import BytesIO
8
  from typing import Callable, Literal
 
9
 
10
  from cua2_core.models.models import (
11
  ActiveTask,
@@ -17,12 +18,14 @@ from cua2_core.models.models import (
17
  from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
18
  from cua2_core.services.agent_utils.function_parser import parse_function_call
19
  from cua2_core.services.agent_utils.get_model import get_model
 
20
  from cua2_core.services.sandbox_service import SandboxService
21
  from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
22
  from e2b_desktop import Sandbox, TimeoutException
23
  from fastapi import WebSocket
24
  from PIL import Image
25
  from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
 
26
 
27
  logger = logging.getLogger(__name__)
28
 
@@ -50,6 +53,35 @@ class AgentService:
50
  self._lock = asyncio.Lock()
51
  self.max_sandboxes = int(600 / num_workers)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  async def process_user_task(
54
  self, trace: AgentTrace, websocket: WebSocket
55
  ) -> str | None:
@@ -60,6 +92,9 @@ class AgentService:
60
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
61
 
62
  async with self._lock:
 
 
 
63
  active_task = ActiveTask(
64
  message_id=trace_id,
65
  instruction=trace.instruction,
@@ -79,9 +114,11 @@ class AgentService:
79
 
80
  # Store the task and websocket for this task
81
  self.active_tasks[trace_id] = active_task
82
- self.task_websockets[trace_id] = websocket
83
  self.last_screenshot[trace_id] = None
84
 
 
 
 
85
  asyncio.create_task(self._agent_processing(trace_id))
86
 
87
  return trace_id
@@ -111,8 +148,15 @@ class AgentService:
111
 
112
  model = get_model(self.active_tasks[message_id].model_id)
113
 
114
- # Acquire a sandbox from the pool
115
- sandbox = await self.sandbox_service.acquire_sandbox(message_id)
 
 
 
 
 
 
 
116
  if sandbox is None:
117
  raise Exception("No sandbox available: pool limit reached")
118
 
@@ -180,7 +224,12 @@ class AgentService:
180
 
181
  finally:
182
  # Send completion event
183
- if not websocket_exception:
 
 
 
 
 
184
  await self.websocket_manager.send_agent_complete(
185
  metadata=self.active_tasks[message_id].traceMetadata,
186
  websocket=websocket,
@@ -210,9 +259,12 @@ class AgentService:
210
  if message_id in self.last_screenshot:
211
  del self.last_screenshot[message_id]
212
 
 
 
 
213
  # Release sandbox back to the pool
214
  if sandbox:
215
- await self.sandbox_service.release_sandbox(sandbox)
216
 
217
  async def _agent_processing(
218
  self,
@@ -236,24 +288,8 @@ class AgentService:
236
  if memory_step.step_number > agent.max_steps:
237
  raise AgentStopException("Max steps reached")
238
 
239
- time.sleep(3)
240
-
241
- image = self.last_screenshot[message_id]
242
- assert image is not None
243
-
244
- for previous_memory_step in (
245
- agent.memory.steps
246
- ): # Remove previous screenshots from logs for lean processing
247
- if (
248
- isinstance(previous_memory_step, ActionStep)
249
- and previous_memory_step.step_number is not None
250
- and previous_memory_step.step_number <= memory_step.step_number - 1
251
- ):
252
- previous_memory_step.observations_images = None
253
- elif isinstance(previous_memory_step, TaskStep):
254
- previous_memory_step.task_images = None
255
-
256
- memory_step.observations_images = [image.copy()]
257
 
258
  model_output = (
259
  memory_step.model_output_message.content
@@ -280,6 +316,35 @@ class AgentService:
280
  """The task failed due to an error""" # TODO: To Handle in front
281
  )
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  if memory_step.observations_images:
284
  image = memory_step.observations_images[0]
285
  buffered = BytesIO()
@@ -295,11 +360,7 @@ class AgentService:
295
  stepId=str(memory_step.step_number),
296
  image=image_base64,
297
  thought=thought,
298
- actions=AgentAction.from_function_calls(
299
- parse_function_call(action_sequence)
300
- )
301
- if action_sequence
302
- else None,
303
  error=memory_step.error.message if memory_step.error else None,
304
  duration=memory_step.timing.duration,
305
  inputTokensUsed=memory_step.token_usage.input_tokens,
@@ -317,15 +378,16 @@ class AgentService:
317
  self.active_tasks[message_id].update_step(step)
318
 
319
  websocket = self.task_websockets.get(message_id)
320
- future = asyncio.run_coroutine_threadsafe(
321
- self.websocket_manager.send_agent_progress(
322
- step=step,
323
- metadata=self.active_tasks[message_id].traceMetadata,
324
- websocket=websocket,
325
- ),
326
- loop,
327
- )
328
- future.result()
 
329
 
330
  if self.active_tasks[message_id].traceMetadata.completed:
331
  raise AgentStopException("Task not completed")
@@ -419,3 +481,40 @@ class AgentService:
419
  self.active_tasks[trace_id].update_trace_metadata(
420
  completed=True,
421
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import time
7
  from io import BytesIO
8
  from typing import Callable, Literal
9
+ from uuid import uuid4
10
 
11
  from cua2_core.models.models import (
12
  ActiveTask,
 
18
  from cua2_core.services.agent_utils.desktop_agent import E2BVisionAgent
19
  from cua2_core.services.agent_utils.function_parser import parse_function_call
20
  from cua2_core.services.agent_utils.get_model import get_model
21
+ from cua2_core.services.archival_service import ArchivalService
22
  from cua2_core.services.sandbox_service import SandboxService
23
  from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketManager
24
  from e2b_desktop import Sandbox, TimeoutException
25
  from fastapi import WebSocket
26
  from PIL import Image
27
  from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
28
+ from starlette.websockets import WebSocketState
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
53
  self._lock = asyncio.Lock()
54
  self.max_sandboxes = int(600 / num_workers)
55
 
56
+ # Initialize archival service in dedicated process
57
+ self.archival_service = ArchivalService(
58
+ hf_token=os.getenv("HF_TOKEN"),
59
+ hf_dataset_repo="smolagents/cua_traces",
60
+ data_dir="data",
61
+ archive_interval_minutes=30,
62
+ folder_age_threshold_minutes=30,
63
+ )
64
+ # Start the archival service process
65
+ self.archival_service.start()
66
+
67
+ def _update_archival_active_tasks(self):
68
+ """
69
+ Update the archival service with current active task IDs.
70
+ Should be called whenever tasks are added or removed.
71
+ """
72
+ if self.archival_service.is_alive():
73
+ self.archival_service.update_active_tasks(set(self.active_tasks.keys()))
74
+
75
+ async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
76
+ """Create a new ID and sandbox"""
77
+ async with self._lock:
78
+ uuid = str(uuid4())
79
+ while uuid in self.active_tasks:
80
+ uuid = str(uuid4())
81
+ self.task_websockets[uuid] = websocket
82
+ await self.sandbox_service.acquire_sandbox(uuid)
83
+ return uuid
84
+
85
  async def process_user_task(
86
  self, trace: AgentTrace, websocket: WebSocket
87
  ) -> str | None:
 
92
  trace.traceMetadata = AgentTraceMetadata(traceId=trace_id)
93
 
94
  async with self._lock:
95
+ if self.task_websockets[trace_id] != websocket:
96
+ raise WebSocketException("WebSocket mismatch")
97
+
98
  active_task = ActiveTask(
99
  message_id=trace_id,
100
  instruction=trace.instruction,
 
114
 
115
  # Store the task and websocket for this task
116
  self.active_tasks[trace_id] = active_task
 
117
  self.last_screenshot[trace_id] = None
118
 
119
+ # Update archival service with new active task
120
+ self._update_archival_active_tasks()
121
+
122
  asyncio.create_task(self._agent_processing(trace_id))
123
 
124
  return trace_id
 
148
 
149
  model = get_model(self.active_tasks[message_id].model_id)
150
 
151
+ max_attempts = 10
152
+ for _ in range(max_attempts):
153
+ response = await self.sandbox_service.acquire_sandbox(message_id)
154
+ if response.sandbox is not None and response.state == "ready":
155
+ sandbox = response.sandbox
156
+ break
157
+ elif response.state == "max_sandboxes_reached":
158
+ raise Exception("No sandbox available: pool limit reached")
159
+ await asyncio.sleep(2)
160
  if sandbox is None:
161
  raise Exception("No sandbox available: pool limit reached")
162
 
 
224
 
225
  finally:
226
  # Send completion event
227
+ # Check if websocket is still connected before sending
228
+ if (
229
+ not websocket_exception
230
+ and websocket
231
+ and websocket.client_state == WebSocketState.CONNECTED
232
+ ):
233
  await self.websocket_manager.send_agent_complete(
234
  metadata=self.active_tasks[message_id].traceMetadata,
235
  websocket=websocket,
 
259
  if message_id in self.last_screenshot:
260
  del self.last_screenshot[message_id]
261
 
262
+ # Update archival service after task removal
263
+ self._update_archival_active_tasks()
264
+
265
  # Release sandbox back to the pool
266
  if sandbox:
267
+ await self.sandbox_service.release_sandbox(message_id)
268
 
269
  async def _agent_processing(
270
  self,
 
288
  if memory_step.step_number > agent.max_steps:
289
  raise AgentStopException("Max steps reached")
290
 
291
+ if self.active_tasks[message_id].traceMetadata.completed:
292
+ raise AgentStopException("Task not completed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  model_output = (
295
  memory_step.model_output_message.content
 
316
  """The task failed due to an error""" # TODO: To Handle in front
317
  )
318
 
319
+ agent_actions = (
320
+ AgentAction.from_function_calls(parse_function_call(action_sequence))
321
+ if action_sequence
322
+ else None
323
+ )
324
+
325
+ if not (
326
+ agent_actions is not None
327
+ and any(action.function_name == "wait" for action in agent_actions)
328
+ ):
329
+ time.sleep(3)
330
+
331
+ image = self.last_screenshot[message_id]
332
+ assert image is not None
333
+
334
+ for previous_memory_step in (
335
+ agent.memory.steps
336
+ ): # Remove previous screenshots from logs for lean processing
337
+ if (
338
+ isinstance(previous_memory_step, ActionStep)
339
+ and previous_memory_step.step_number is not None
340
+ and previous_memory_step.step_number <= memory_step.step_number - 1
341
+ ):
342
+ previous_memory_step.observations_images = None
343
+ elif isinstance(previous_memory_step, TaskStep):
344
+ previous_memory_step.task_images = None
345
+
346
+ memory_step.observations_images = [image.copy()]
347
+
348
  if memory_step.observations_images:
349
  image = memory_step.observations_images[0]
350
  buffered = BytesIO()
 
360
  stepId=str(memory_step.step_number),
361
  image=image_base64,
362
  thought=thought,
363
+ actions=agent_actions,
 
 
 
 
364
  error=memory_step.error.message if memory_step.error else None,
365
  duration=memory_step.timing.duration,
366
  inputTokensUsed=memory_step.token_usage.input_tokens,
 
378
  self.active_tasks[message_id].update_step(step)
379
 
380
  websocket = self.task_websockets.get(message_id)
381
+ if websocket and websocket.client_state == WebSocketState.CONNECTED:
382
+ future = asyncio.run_coroutine_threadsafe(
383
+ self.websocket_manager.send_agent_progress(
384
+ step=step,
385
+ metadata=self.active_tasks[message_id].traceMetadata,
386
+ websocket=websocket,
387
+ ),
388
+ loop,
389
+ )
390
+ future.result()
391
 
392
  if self.active_tasks[message_id].traceMetadata.completed:
393
  raise AgentStopException("Task not completed")
 
481
  self.active_tasks[trace_id].update_trace_metadata(
482
  completed=True,
483
  )
484
+
485
+ async def cleanup_tasks_for_websocket(self, websocket: WebSocket):
486
+ """
487
+ Clean up all tasks associated with a disconnected websocket.
488
+ This will stop the tasks and release their sandboxes.
489
+ """
490
+ tasks_to_cleanup = []
491
+
492
+ # Find all message_ids associated with this websocket
493
+ async with self._lock:
494
+ for message_id, ws in list(self.task_websockets.items()):
495
+ if ws == websocket:
496
+ tasks_to_cleanup.append(message_id)
497
+ logger.info(
498
+ f"Marking task {message_id} for cleanup due to websocket disconnect"
499
+ )
500
+
501
+ # Cleanup each task
502
+ for message_id in tasks_to_cleanup:
503
+ try:
504
+ # Mark task as completed to stop the agent
505
+ if message_id in self.active_tasks:
506
+ self.active_tasks[message_id].update_trace_metadata(
507
+ completed=True,
508
+ )
509
+ logger.info(
510
+ f"Stopped task {message_id} due to websocket disconnect"
511
+ )
512
+
513
+ # Release the sandbox immediately
514
+ await self.sandbox_service.release_sandbox(message_id)
515
+ logger.info(
516
+ f"Released sandbox for task {message_id} due to websocket disconnect"
517
+ )
518
+
519
+ except Exception as e:
520
+ logger.error(f"Error cleaning up task {message_id}: {e}", exc_info=True)
cua2-core/src/cua2_core/services/archival_service.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Service for automatic data archival to HuggingFace datasets.
3
+
4
+ This service runs in a dedicated process that periodically:
5
+ 1. Scans for old trace data folders
6
+ 2. Compresses them into tar.gz archives
7
+ 3. Uploads to HuggingFace dataset repository
8
+ 4. Verifies successful upload
9
+ 5. Deletes local files only after verification
10
+ """
11
+
12
+ import logging
13
+ import multiprocessing
14
+ import multiprocessing.synchronize
15
+ import os
16
+ import shutil
17
+ import signal
18
+ import tarfile
19
+ import time
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ from huggingface_hub import HfApi, hf_hub_download
24
+ from huggingface_hub.utils import HfHubHTTPError
25
+
26
+ # Configure logging for the process
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class ArchivalService:
35
+ """Service for handling automatic data archival to HuggingFace in a dedicated process"""
36
+
37
+ def __init__(
38
+ self,
39
+ hf_token: str | None = os.getenv("HF_TOKEN"),
40
+ hf_dataset_repo: str | None = "smolagents/cua_traces",
41
+ data_dir: str = "data",
42
+ archive_interval_minutes: int = 30,
43
+ folder_age_threshold_minutes: int = 30,
44
+ ):
45
+ """
46
+ Initialize the archival service.
47
+
48
+ Args:
49
+ hf_token: HuggingFace API token
50
+ hf_dataset_repo: HuggingFace dataset repository ID (e.g., "username/dataset-name")
51
+ data_dir: Directory containing trace data folders
52
+ archive_interval_minutes: How often to check for old folders
53
+ folder_age_threshold_minutes: Minimum age before archival
54
+ """
55
+ self.hf_token = hf_token
56
+ self.hf_dataset_repo = hf_dataset_repo
57
+ self.data_dir = data_dir
58
+ self.archive_interval_minutes = archive_interval_minutes
59
+ self.folder_age_threshold_minutes = folder_age_threshold_minutes
60
+
61
+ # Multiprocessing components
62
+ self._process: multiprocessing.Process | None = None
63
+ self._stop_event: multiprocessing.synchronize.Event = multiprocessing.Event()
64
+ self._manager = multiprocessing.Manager()
65
+ self._active_tasks: Any = self._manager.dict() # DictProxy type
66
+
67
+ def start(self):
68
+ """Start the archival service in a dedicated process."""
69
+ if self._process and self._process.is_alive():
70
+ logger.warning("Archival service is already running")
71
+ return
72
+
73
+ if not self.hf_token or not self.hf_dataset_repo:
74
+ logger.warning(
75
+ "HuggingFace credentials or dataset repo not configured. Data archival disabled."
76
+ )
77
+ return
78
+
79
+ self._stop_event.clear()
80
+ self._process = multiprocessing.Process(
81
+ target=_archival_worker_process,
82
+ args=(
83
+ self.hf_token,
84
+ self.hf_dataset_repo,
85
+ self.data_dir,
86
+ self.archive_interval_minutes,
87
+ self.folder_age_threshold_minutes,
88
+ self._stop_event,
89
+ self._active_tasks,
90
+ ),
91
+ daemon=True,
92
+ name="ArchivalWorker",
93
+ )
94
+ self._process.start()
95
+ logger.info(
96
+ f"Started archival service in process {self._process.pid}. "
97
+ f"Checking every {self.archive_interval_minutes} minutes."
98
+ )
99
+
100
+ def stop(self, timeout: float = 10.0):
101
+ """
102
+ Stop the archival service process.
103
+
104
+ Args:
105
+ timeout: Maximum time to wait for process to terminate (seconds)
106
+ """
107
+ if not self._process or not self._process.is_alive():
108
+ return
109
+
110
+ logger.info(f"Stopping archival service (PID: {self._process.pid})...")
111
+ self._stop_event.set()
112
+
113
+ self._process.join(timeout=timeout)
114
+
115
+ if self._process.is_alive():
116
+ logger.warning("Archival process did not stop gracefully, terminating...")
117
+ self._process.terminate()
118
+ self._process.join(timeout=2.0)
119
+
120
+ if self._process.is_alive():
121
+ logger.error("Force killing archival process...")
122
+ self._process.kill()
123
+ self._process.join()
124
+
125
+ logger.info("Archival service stopped")
126
+ self._process = None
127
+
128
+ def update_active_tasks(self, active_task_ids: set[str]):
129
+ """
130
+ Update the set of active task IDs.
131
+ The archival process will skip folders for these tasks.
132
+
133
+ Args:
134
+ active_task_ids: Set of currently active trace IDs
135
+ """
136
+ # Clear and update the shared dict
137
+ self._active_tasks.clear()
138
+ for task_id in active_task_ids:
139
+ self._active_tasks[task_id] = True
140
+
141
+ def is_alive(self) -> bool:
142
+ """Check if the archival process is running."""
143
+ return self._process is not None and self._process.is_alive()
144
+
145
+
146
+ def _archival_worker_process(
147
+ hf_token: str,
148
+ hf_dataset_repo: str,
149
+ data_dir: str,
150
+ archive_interval_minutes: int,
151
+ folder_age_threshold_minutes: int,
152
+ stop_event: multiprocessing.synchronize.Event,
153
+ active_tasks: Any,
154
+ ):
155
+ """
156
+ Worker process that performs the archival operations.
157
+ Runs in a separate process from the main application.
158
+
159
+ Args:
160
+ hf_token: HuggingFace API token
161
+ hf_dataset_repo: HuggingFace dataset repository
162
+ data_dir: Data directory path
163
+ archive_interval_minutes: Check interval
164
+ folder_age_threshold_minutes: Folder age threshold
165
+ stop_event: Event to signal process shutdown
166
+ active_tasks: Shared dict of active task IDs
167
+ """
168
+
169
+ def signal_handler(signum, frame):
170
+ """Handle termination signals gracefully."""
171
+ logger.info(f"Received signal {signum}, stopping archival worker...")
172
+ stop_event.set()
173
+
174
+ # Register signal handlers
175
+ signal.signal(signal.SIGTERM, signal_handler)
176
+ signal.signal(signal.SIGINT, signal_handler)
177
+
178
+ # Initialize HuggingFace API in this process
179
+ hf_api = HfApi(token=hf_token)
180
+
181
+ logger.info(
182
+ f"Archival worker started (PID: {os.getpid()}). "
183
+ f"Checking every {archive_interval_minutes} minutes."
184
+ )
185
+
186
+ # Main worker loop
187
+ while not stop_event.is_set():
188
+ try:
189
+ # Sleep in small intervals to be responsive to stop_event
190
+ for _ in range(archive_interval_minutes * 60):
191
+ if stop_event.is_set():
192
+ break
193
+ time.sleep(1)
194
+
195
+ if stop_event.is_set():
196
+ break
197
+
198
+ logger.info("Starting data archival check...")
199
+ _process_old_folders(
200
+ data_dir=data_dir,
201
+ folder_age_threshold_minutes=folder_age_threshold_minutes,
202
+ active_tasks=active_tasks,
203
+ hf_api=hf_api,
204
+ hf_dataset_repo=hf_dataset_repo,
205
+ hf_token=hf_token,
206
+ )
207
+
208
+ except Exception as e:
209
+ logger.error(f"Error in archival worker: {e}", exc_info=True)
210
+ # Continue running despite errors
211
+
212
+ logger.info("Archival worker shutting down gracefully")
213
+
214
+
215
+ def _process_old_folders(
216
+ data_dir: str,
217
+ folder_age_threshold_minutes: int,
218
+ active_tasks: Any,
219
+ hf_api: HfApi,
220
+ hf_dataset_repo: str,
221
+ hf_token: str,
222
+ ):
223
+ """
224
+ Process and archive folders older than the threshold.
225
+ Runs in the archival worker process.
226
+ """
227
+ if not os.path.exists(data_dir):
228
+ logger.warning(f"Data directory {data_dir} does not exist")
229
+ return
230
+
231
+ data_path = Path(data_dir)
232
+ current_time = time.time()
233
+ threshold_seconds = folder_age_threshold_minutes * 60
234
+
235
+ # Get all trace folders
236
+ try:
237
+ trace_folders = [
238
+ f for f in data_path.iterdir() if f.is_dir() and f.name.startswith("trace-")
239
+ ]
240
+ except Exception as e:
241
+ logger.error(f"Error listing data directory: {e}", exc_info=True)
242
+ return
243
+
244
+ for folder in trace_folders:
245
+ try:
246
+ # Check if folder is old enough and not currently active
247
+ folder_mtime = folder.stat().st_mtime
248
+ folder_age_seconds = current_time - folder_mtime
249
+
250
+ # Extract trace_id from folder name (format: trace-{uuid}-{model_name})
251
+ folder_name = folder.name
252
+ parts = folder_name.split("-", 2) # Split into ['trace', uuid, model_name]
253
+ if len(parts) < 2:
254
+ logger.warning(f"Unexpected folder name format: {folder_name}")
255
+ continue
256
+
257
+ trace_id = parts[1]
258
+
259
+ # Skip if folder is still being used by an active task
260
+ if trace_id in active_tasks:
261
+ logger.debug(f"Skipping active task folder: {folder_name}")
262
+ continue
263
+
264
+ # Check if folder is old enough
265
+ if folder_age_seconds < threshold_seconds:
266
+ logger.debug(
267
+ f"Folder {folder_name} is not old enough ({folder_age_seconds / 60:.1f} minutes)"
268
+ )
269
+ continue
270
+
271
+ logger.info(
272
+ f"Processing old folder: {folder_name} (age: {folder_age_seconds / 60:.1f} minutes)"
273
+ )
274
+
275
+ # Compress the folder
276
+ archive_path = _compress_folder(folder)
277
+
278
+ if not archive_path:
279
+ logger.error(f"Failed to compress folder: {folder_name}")
280
+ continue
281
+
282
+ # Upload to HuggingFace
283
+ uploaded = _upload_to_huggingface(hf_api, hf_dataset_repo, archive_path)
284
+
285
+ if not uploaded:
286
+ logger.error(f"Failed to upload archive: {archive_path.name}")
287
+ # Clean up the local archive file
288
+ archive_path.unlink(missing_ok=True)
289
+ continue
290
+
291
+ # Verify the file exists in the repo
292
+ verified = _verify_file_in_repo(
293
+ hf_dataset_repo, hf_token, archive_path.name
294
+ )
295
+
296
+ if verified:
297
+ logger.info(
298
+ f"Successfully verified {archive_path.name} in HuggingFace repo"
299
+ )
300
+
301
+ # Delete the local folder
302
+ shutil.rmtree(folder)
303
+ logger.info(f"Deleted local folder: {folder_name}")
304
+
305
+ # Delete the local archive
306
+ archive_path.unlink(missing_ok=True)
307
+ logger.info(f"Deleted local archive: {archive_path.name}")
308
+ else:
309
+ logger.error(
310
+ f"Could not verify {archive_path.name} in repo. Keeping local files."
311
+ )
312
+ # Keep both the folder and archive for safety
313
+
314
+ except Exception as e:
315
+ logger.error(f"Error processing folder {folder.name}: {e}", exc_info=True)
316
+
317
+
318
+ def _compress_folder(folder_path: Path) -> Path | None:
319
+ """
320
+ Compress a folder into a tar.gz archive.
321
+
322
+ Args:
323
+ folder_path: Path to the folder to compress
324
+
325
+ Returns:
326
+ Path to the created archive file, or None if failed
327
+ """
328
+ try:
329
+ archive_name = f"{folder_path.name}.tar.gz"
330
+ archive_path = folder_path.parent / archive_name
331
+
332
+ logger.info(f"Compressing {folder_path.name} to {archive_name}")
333
+
334
+ with tarfile.open(archive_path, "w:gz") as tar:
335
+ tar.add(folder_path, arcname=folder_path.name)
336
+
337
+ archive_size = archive_path.stat().st_size
338
+ logger.info(
339
+ f"Created archive {archive_name} ({archive_size / 1024 / 1024:.2f} MB)"
340
+ )
341
+
342
+ return archive_path
343
+
344
+ except Exception as e:
345
+ logger.error(f"Error compressing folder {folder_path}: {e}", exc_info=True)
346
+ return None
347
+
348
+
349
+ def _upload_to_huggingface(
350
+ hf_api: HfApi, hf_dataset_repo: str, archive_path: Path
351
+ ) -> bool:
352
+ """
353
+ Upload an archive file to HuggingFace dataset repository.
354
+
355
+ Args:
356
+ hf_api: HuggingFace API client
357
+ hf_dataset_repo: HuggingFace dataset repository ID
358
+ archive_path: Path to the archive file
359
+
360
+ Returns:
361
+ True if upload succeeded, False otherwise
362
+ """
363
+ try:
364
+ logger.info(
365
+ f"Uploading {archive_path.name} to HuggingFace repo {hf_dataset_repo}"
366
+ )
367
+
368
+ hf_api.upload_file(
369
+ path_or_fileobj=str(archive_path),
370
+ path_in_repo=archive_path.name,
371
+ repo_id=hf_dataset_repo,
372
+ repo_type="dataset",
373
+ )
374
+
375
+ logger.info(f"Successfully uploaded {archive_path.name} to HuggingFace")
376
+ return True
377
+
378
+ except Exception as e:
379
+ logger.error(
380
+ f"Error uploading {archive_path.name} to HuggingFace: {e}", exc_info=True
381
+ )
382
+ return False
383
+
384
+
385
+ def _verify_file_in_repo(hf_dataset_repo: str, hf_token: str, filename: str) -> bool:
386
+ """
387
+ Verify that a file exists in the HuggingFace repository.
388
+
389
+ Args:
390
+ hf_dataset_repo: HuggingFace dataset repository ID
391
+ hf_token: HuggingFace API token
392
+ filename: Name of the file to verify
393
+
394
+ Returns:
395
+ True if file exists in repo, False otherwise
396
+ """
397
+ try:
398
+ logger.info(f"Verifying {filename} exists in HuggingFace repo")
399
+
400
+ # Try to get file info - this will raise an error if file doesn't exist
401
+ hf_hub_download(
402
+ repo_id=hf_dataset_repo,
403
+ filename=filename,
404
+ repo_type="dataset",
405
+ token=hf_token,
406
+ local_dir_use_symlinks=False,
407
+ # Just check if file exists without actually downloading
408
+ cache_dir=None,
409
+ local_files_only=False,
410
+ )
411
+
412
+ logger.info(f"Verified {filename} exists in repo")
413
+ return True
414
+
415
+ except HfHubHTTPError as e:
416
+ if e.response.status_code == 404:
417
+ logger.error(f"File {filename} not found in repo (404)")
418
+ else:
419
+ logger.error(f"HTTP error verifying file: {e}")
420
+ return False
421
+ except Exception as e:
422
+ logger.error(f"Error verifying {filename} in repo: {e}", exc_info=True)
423
+ return False
cua2-core/src/cua2_core/services/sandbox_service.py CHANGED
@@ -1,20 +1,26 @@
1
  import asyncio
2
- import logging
3
  import os
4
  import time
5
  from datetime import datetime
6
- from typing import Any
7
 
8
  from e2b_desktop import Sandbox
9
-
10
- logger = logging.getLogger(__name__)
11
 
12
  SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
13
- SANDBOX_TIMEOUT = 300
 
14
  WIDTH = 1280
15
  HEIGHT = 960
16
 
17
 
 
 
 
 
 
 
 
18
  class SandboxService:
19
  def __init__(self, max_sandboxes: int = 50):
20
  if not os.getenv("E2B_API_KEY"):
@@ -24,70 +30,154 @@ class SandboxService:
24
  self.sandbox_metadata: dict[str, dict[str, Any]] = {}
25
  self.sandbox_lock = asyncio.Lock()
26
 
27
- async def acquire_sandbox(self, session_hash: str) -> Sandbox | None:
28
- async with self.sandbox_lock:
29
- current_time = datetime.now()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if (
32
  session_hash in self.sandboxes
33
  and session_hash in self.sandbox_metadata
34
- and current_time - self.sandbox_metadata[session_hash]["created_at"]
35
- < SANDBOX_TIMEOUT
 
 
 
36
  ):
37
  print(f"Reusing Sandbox for session {session_hash}")
38
  self.sandbox_metadata[session_hash]["last_accessed"] = current_time
39
- return self.sandboxes[session_hash]
40
-
41
- if session_hash in self.sandboxes:
42
- try:
43
- print(f"Closing expired sandbox for session {session_hash}")
44
- await asyncio.to_thread(self.sandboxes[session_hash].kill)
45
- except Exception as e:
46
- print(f"Error closing expired sandbox: {str(e)}")
47
- elif len(self.sandboxes) >= self.max_sandboxes:
48
- return None
49
-
50
- print(f"Creating new sandbox for session {session_hash}")
51
-
52
- def create_and_setup_sandbox():
53
- desktop = Sandbox.create(
54
- api_key=os.getenv("E2B_API_KEY"),
55
- resolution=(WIDTH, HEIGHT),
56
- dpi=96,
57
- timeout=SANDBOX_TIMEOUT,
58
- template="k0wmnzir0zuzye6dndlw",
59
  )
60
- desktop.stream.start(require_auth=True)
61
- setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
62
- desktop.commands.run(setup_cmd)
63
- time.sleep(3)
64
- return desktop
65
 
66
- desktop = await asyncio.to_thread(create_and_setup_sandbox)
 
 
 
 
 
 
67
 
68
- print(f"Sandbox ID for session {session_hash} is {desktop.sandbox_id}.")
 
 
 
 
 
 
69
 
70
- self.sandboxes[session_hash] = desktop
 
 
 
 
 
71
  self.sandbox_metadata[session_hash] = {
 
72
  "created_at": current_time,
73
  "last_accessed": current_time,
74
  }
75
- return desktop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  async def release_sandbox(self, session_hash: str):
 
 
 
78
  async with self.sandbox_lock:
79
  if session_hash in self.sandboxes:
80
  print(f"Releasing sandbox for session {session_hash}")
81
- await asyncio.to_thread(self.sandboxes[session_hash].kill)
82
  del self.sandboxes[session_hash]
83
- del self.sandbox_metadata[session_hash]
 
 
 
 
 
 
 
 
84
 
85
  async def cleanup_sandboxes(self):
 
 
 
86
  async with self.sandbox_lock:
87
  for session_hash in list(self.sandboxes.keys()):
88
- await asyncio.to_thread(self.sandboxes[session_hash].kill)
89
  del self.sandboxes[session_hash]
90
- del self.sandbox_metadata[session_hash]
 
 
 
 
 
 
 
 
91
 
92
 
93
  if __name__ == "__main__":
 
1
  import asyncio
 
2
  import os
3
  import time
4
  from datetime import datetime
5
+ from typing import Any, Literal
6
 
7
  from e2b_desktop import Sandbox
8
+ from pydantic import BaseModel
 
9
 
10
  SANDBOX_METADATA: dict[str, dict[str, Any]] = {}
11
+ SANDBOX_TIMEOUT = 500
12
+ SANDBOX_CREATION_TIMEOUT = 180
13
  WIDTH = 1280
14
  HEIGHT = 960
15
 
16
 
17
+ class SandboxResponse(BaseModel):
18
+ model_config = {"arbitrary_types_allowed": True}
19
+
20
+ sandbox: Sandbox | None
21
+ state: Literal["creating", "ready", "max_sandboxes_reached"]
22
+
23
+
24
  class SandboxService:
25
  def __init__(self, max_sandboxes: int = 50):
26
  if not os.getenv("E2B_API_KEY"):
 
30
  self.sandbox_metadata: dict[str, dict[str, Any]] = {}
31
  self.sandbox_lock = asyncio.Lock()
32
 
33
+ async def _create_sandbox_background(
34
+ self, session_hash: str, expired_sandbox: Sandbox | None
35
+ ):
36
+ """Background task to create and setup a sandbox."""
37
+ # Kill expired sandbox first
38
+ if expired_sandbox:
39
+ try:
40
+ print(f"Closing expired sandbox for session {session_hash}")
41
+ await asyncio.to_thread(expired_sandbox.kill)
42
+ except Exception as e:
43
+ print(f"Error closing expired sandbox: {str(e)}")
44
+
45
+ def create_and_setup_sandbox():
46
+ desktop = Sandbox.create(
47
+ api_key=os.getenv("E2B_API_KEY"),
48
+ resolution=(WIDTH, HEIGHT),
49
+ dpi=96,
50
+ timeout=SANDBOX_TIMEOUT,
51
+ template="k0wmnzir0zuzye6dndlw",
52
+ )
53
+ desktop.stream.start(require_auth=True)
54
+ setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
55
+ desktop.commands.run(setup_cmd)
56
+ time.sleep(3)
57
+ return desktop
58
+
59
+ try:
60
+ desktop = await asyncio.to_thread(create_and_setup_sandbox)
61
+ print(f"Sandbox ID for session {session_hash} is {desktop.sandbox_id}.")
62
 
63
+ # Log sandbox creation
64
+
65
+ # Update sandbox state under lock
66
+ async with self.sandbox_lock:
67
+ self.sandboxes[session_hash] = desktop
68
+ self.sandbox_metadata[session_hash]["state"] = "ready"
69
+
70
+ except Exception as e:
71
+ print(f"Error creating sandbox for session {session_hash}: {str(e)}")
72
+ # Clean up metadata on failure
73
+ async with self.sandbox_lock:
74
+ if session_hash in self.sandbox_metadata:
75
+ del self.sandbox_metadata[session_hash]
76
+
77
+ async def acquire_sandbox(self, session_hash: str) -> SandboxResponse:
78
+ current_time = datetime.now()
79
+ should_create = False
80
+ expired_sandbox = None
81
+
82
+ # Quick check under lock - only check state and mark creation
83
+ async with self.sandbox_lock:
84
+ # Check if sandbox exists and is ready
85
  if (
86
  session_hash in self.sandboxes
87
  and session_hash in self.sandbox_metadata
88
+ and self.sandbox_metadata[session_hash].get("state") == "ready"
89
+ and (
90
+ current_time - self.sandbox_metadata[session_hash]["created_at"]
91
+ ).total_seconds()
92
+ < SANDBOX_CREATION_TIMEOUT
93
  ):
94
  print(f"Reusing Sandbox for session {session_hash}")
95
  self.sandbox_metadata[session_hash]["last_accessed"] = current_time
96
+ return SandboxResponse(
97
+ sandbox=self.sandboxes[session_hash], state="ready"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
 
 
99
 
100
+ # Check if sandbox is already being created
101
+ if (
102
+ session_hash in self.sandbox_metadata
103
+ and self.sandbox_metadata[session_hash].get("state") == "creating"
104
+ ):
105
+ print(f"Sandbox for session {session_hash} is already being created")
106
+ return SandboxResponse(sandbox=None, state="creating")
107
 
108
+ # Mark expired sandbox for cleanup (remove from dict within lock)
109
+ if session_hash in self.sandboxes:
110
+ print(f"Marking expired sandbox for session {session_hash} for cleanup")
111
+ expired_sandbox = self.sandboxes[session_hash]
112
+ del self.sandboxes[session_hash]
113
+ if session_hash in self.sandbox_metadata:
114
+ del self.sandbox_metadata[session_hash]
115
 
116
+ # Check if we have capacity
117
+ if len(self.sandboxes) >= self.max_sandboxes:
118
+ return SandboxResponse(sandbox=None, state="max_sandboxes_reached")
119
+
120
+ # Mark that we're creating this sandbox
121
+ print(f"Creating new sandbox for session {session_hash}")
122
  self.sandbox_metadata[session_hash] = {
123
+ "state": "creating",
124
  "created_at": current_time,
125
  "last_accessed": current_time,
126
  }
127
+ should_create = True
128
+
129
+ # Start sandbox creation in background without waiting
130
+ if should_create:
131
+ asyncio.create_task(
132
+ self._create_sandbox_background(session_hash, expired_sandbox)
133
+ )
134
+
135
+ async with self.sandbox_lock:
136
+ if self.sandbox_metadata[session_hash]["state"] == "creating":
137
+ return SandboxResponse(sandbox=None, state="creating")
138
+ if self.sandbox_metadata[session_hash]["state"] == "ready":
139
+ return SandboxResponse(
140
+ sandbox=self.sandboxes[session_hash], state="ready"
141
+ )
142
+
143
+ return SandboxResponse(sandbox=None, state="creating")
144
 
145
  async def release_sandbox(self, session_hash: str):
146
+ sandbox_to_kill = None
147
+
148
+ # Remove from dictionaries under lock
149
  async with self.sandbox_lock:
150
  if session_hash in self.sandboxes:
151
  print(f"Releasing sandbox for session {session_hash}")
152
+ sandbox_to_kill = self.sandboxes[session_hash]
153
  del self.sandboxes[session_hash]
154
+ if session_hash in self.sandbox_metadata:
155
+ del self.sandbox_metadata[session_hash]
156
+
157
+ # Kill sandbox outside of lock
158
+ if sandbox_to_kill:
159
+ try:
160
+ await asyncio.to_thread(sandbox_to_kill.kill)
161
+ except Exception as e:
162
+ print(f"Error killing sandbox for session {session_hash}: {str(e)}")
163
 
164
  async def cleanup_sandboxes(self):
165
+ sandboxes_to_kill = []
166
+
167
+ # Collect sandboxes under lock
168
  async with self.sandbox_lock:
169
  for session_hash in list(self.sandboxes.keys()):
170
+ sandboxes_to_kill.append((session_hash, self.sandboxes[session_hash]))
171
  del self.sandboxes[session_hash]
172
+ if session_hash in self.sandbox_metadata:
173
+ del self.sandbox_metadata[session_hash]
174
+
175
+ # Kill all sandboxes outside of lock
176
+ for session_hash, sandbox in sandboxes_to_kill:
177
+ try:
178
+ await asyncio.to_thread(sandbox.kill)
179
+ except Exception as e:
180
+ print(f"Error killing sandbox for session {session_hash}: {str(e)}")
181
 
182
 
183
  if __name__ == "__main__":
cua2-core/tests/test_archival_service.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the ArchivalService multiprocessing implementation.
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+ import tarfile
8
+ import tempfile
9
+ import time
10
+ from pathlib import Path
11
+ from unittest.mock import MagicMock, Mock, patch
12
+
13
+ import pytest
14
+ from cua2_core.services.archival_service import (
15
+ ArchivalService,
16
+ _compress_folder,
17
+ _process_old_folders,
18
+ _upload_to_huggingface,
19
+ _verify_file_in_repo,
20
+ )
21
+ from huggingface_hub.utils import HfHubHTTPError
22
+
23
+
24
+ @pytest.fixture
25
+ def temp_data_dir():
26
+ """Create a temporary data directory for testing."""
27
+ temp_dir = tempfile.mkdtemp()
28
+ yield temp_dir
29
+ # Cleanup
30
+ if os.path.exists(temp_dir):
31
+ shutil.rmtree(temp_dir)
32
+
33
+
34
+ @pytest.fixture
35
+ def mock_hf_api():
36
+ """Create a mock HuggingFace API client."""
37
+ mock_api = MagicMock()
38
+ mock_api.upload_file.return_value = None
39
+ return mock_api
40
+
41
+
42
+ @pytest.fixture
43
+ def archival_service(temp_data_dir):
44
+ """Create an ArchivalService instance for testing."""
45
+ service = ArchivalService(
46
+ hf_token="test_token",
47
+ hf_dataset_repo="test_user/test_repo",
48
+ data_dir=temp_data_dir,
49
+ archive_interval_minutes=1, # Short interval for testing
50
+ folder_age_threshold_minutes=1,
51
+ )
52
+ yield service
53
+ # Cleanup - stop the service if running
54
+ if service.is_alive():
55
+ service.stop(timeout=5.0)
56
+
57
+
58
+ class TestArchivalServiceInitialization:
59
+ """Test ArchivalService initialization."""
60
+
61
+ def test_init_with_defaults(self):
62
+ """Test initialization with default values."""
63
+ with patch.dict(os.environ, {"HF_TOKEN": "env_token"}, clear=False):
64
+ # Need to pass the env token explicitly since os.getenv is called at function definition time
65
+ service = ArchivalService(hf_token=os.getenv("HF_TOKEN"))
66
+ assert service.hf_token == "env_token"
67
+ assert service.hf_dataset_repo == "smolagents/cua_traces"
68
+ assert service.data_dir == "data"
69
+ assert service.archive_interval_minutes == 30
70
+ assert service.folder_age_threshold_minutes == 30
71
+ assert service._process is None
72
+ assert not service.is_alive()
73
+
74
+ def test_init_with_custom_values(self):
75
+ """Test initialization with custom values."""
76
+ service = ArchivalService(
77
+ hf_token="custom_token",
78
+ hf_dataset_repo="custom/repo",
79
+ data_dir="/custom/path",
80
+ archive_interval_minutes=60,
81
+ folder_age_threshold_minutes=120,
82
+ )
83
+ assert service.hf_token == "custom_token"
84
+ assert service.hf_dataset_repo == "custom/repo"
85
+ assert service.data_dir == "/custom/path"
86
+ assert service.archive_interval_minutes == 60
87
+ assert service.folder_age_threshold_minutes == 120
88
+
89
+ def test_init_multiprocessing_components(self):
90
+ """Test that multiprocessing components are initialized."""
91
+ service = ArchivalService(hf_token="test", hf_dataset_repo="test/test")
92
+ assert service._stop_event is not None
93
+ assert service._manager is not None
94
+ assert service._active_tasks is not None
95
+
96
+
97
+ class TestArchivalServiceLifecycle:
98
+ """Test ArchivalService lifecycle management."""
99
+
100
+ def test_start_service(self, archival_service):
101
+ """Test starting the archival service."""
102
+ archival_service.start()
103
+
104
+ # Give the process a moment to start
105
+ time.sleep(0.5)
106
+
107
+ assert archival_service.is_alive()
108
+ assert archival_service._process is not None
109
+ assert archival_service._process.pid is not None
110
+
111
+ def test_start_without_credentials(self, temp_data_dir):
112
+ """Test starting service without credentials logs warning."""
113
+ service = ArchivalService(
114
+ hf_token=None,
115
+ hf_dataset_repo=None,
116
+ data_dir=temp_data_dir,
117
+ )
118
+
119
+ service.start()
120
+
121
+ # Service should not start
122
+ assert not service.is_alive()
123
+ assert service._process is None
124
+
125
+ def test_start_already_running(self, archival_service):
126
+ """Test starting service when already running."""
127
+ archival_service.start()
128
+ time.sleep(0.5)
129
+
130
+ pid1 = archival_service._process.pid
131
+
132
+ # Try to start again
133
+ archival_service.start()
134
+
135
+ # Should be same process
136
+ assert archival_service._process.pid == pid1
137
+
138
+ def test_stop_service(self, archival_service):
139
+ """Test stopping the archival service."""
140
+ archival_service.start()
141
+ time.sleep(0.5)
142
+ assert archival_service.is_alive()
143
+
144
+ archival_service.stop(timeout=5.0)
145
+
146
+ assert not archival_service.is_alive()
147
+ assert archival_service._process is None
148
+
149
+ def test_stop_not_running(self, archival_service):
150
+ """Test stopping service when not running."""
151
+ # Should not raise any errors
152
+ archival_service.stop(timeout=1.0)
153
+ assert not archival_service.is_alive()
154
+
155
+ def test_stop_with_timeout(self, archival_service):
156
+ """Test stop with timeout and force termination."""
157
+ archival_service.start()
158
+ time.sleep(0.5)
159
+
160
+ # Stop with very short timeout to test force kill path
161
+ archival_service.stop(timeout=0.001)
162
+
163
+ # Process should be stopped one way or another
164
+ time.sleep(0.5)
165
+ assert not archival_service.is_alive()
166
+
167
+ def test_is_alive_returns_false_when_not_started(self, archival_service):
168
+ """Test is_alive returns False when service not started."""
169
+ assert not archival_service.is_alive()
170
+
171
+
172
+ class TestActiveTasksManagement:
173
+ """Test active tasks management."""
174
+
175
+ def test_update_active_tasks(self, archival_service):
176
+ """Test updating active tasks."""
177
+ task_ids = {"task-1", "task-2", "task-3"}
178
+
179
+ archival_service.update_active_tasks(task_ids)
180
+
181
+ # Verify tasks are in shared dict
182
+ for task_id in task_ids:
183
+ assert task_id in archival_service._active_tasks
184
+
185
+ def test_update_active_tasks_clears_old(self, archival_service):
186
+ """Test that updating active tasks clears old ones."""
187
+ archival_service.update_active_tasks({"task-1", "task-2"})
188
+ assert "task-1" in archival_service._active_tasks
189
+
190
+ archival_service.update_active_tasks({"task-3"})
191
+ assert "task-1" not in archival_service._active_tasks
192
+ assert "task-3" in archival_service._active_tasks
193
+
194
+ def test_update_active_tasks_empty_set(self, archival_service):
195
+ """Test updating with empty set."""
196
+ archival_service.update_active_tasks({"task-1"})
197
+ archival_service.update_active_tasks(set())
198
+
199
+ assert len(archival_service._active_tasks) == 0
200
+
201
+
202
+ class TestCompressFolder:
203
+ """Test folder compression functionality."""
204
+
205
+ def test_compress_folder_success(self, temp_data_dir):
206
+ """Test successful folder compression."""
207
+ # Create a test folder with some files
208
+ test_folder = Path(temp_data_dir) / "trace-test-123-model"
209
+ test_folder.mkdir()
210
+ (test_folder / "file1.txt").write_text("test content 1")
211
+ (test_folder / "file2.txt").write_text("test content 2")
212
+
213
+ archive_path = _compress_folder(test_folder)
214
+
215
+ assert archive_path is not None
216
+ assert archive_path.exists()
217
+ assert archive_path.name == "trace-test-123-model.tar.gz"
218
+ assert archive_path.suffix == ".gz"
219
+
220
+ # Verify archive contents
221
+ with tarfile.open(archive_path, "r:gz") as tar:
222
+ members = tar.getnames()
223
+ assert "trace-test-123-model/file1.txt" in members
224
+ assert "trace-test-123-model/file2.txt" in members
225
+
226
+ # Cleanup
227
+ archive_path.unlink()
228
+
229
+ def test_compress_folder_empty_folder(self, temp_data_dir):
230
+ """Test compressing an empty folder."""
231
+ test_folder = Path(temp_data_dir) / "trace-empty-456-model"
232
+ test_folder.mkdir()
233
+
234
+ archive_path = _compress_folder(test_folder)
235
+
236
+ assert archive_path is not None
237
+ assert archive_path.exists()
238
+
239
+ # Cleanup
240
+ archive_path.unlink()
241
+
242
+ def test_compress_folder_nonexistent(self, temp_data_dir):
243
+ """Test compressing a nonexistent folder."""
244
+ test_folder = Path(temp_data_dir) / "nonexistent"
245
+
246
+ archive_path = _compress_folder(test_folder)
247
+
248
+ assert archive_path is None
249
+
250
+ def test_compress_folder_with_subdirectories(self, temp_data_dir):
251
+ """Test compressing folder with subdirectories."""
252
+ test_folder = Path(temp_data_dir) / "trace-nested-789-model"
253
+ test_folder.mkdir()
254
+ subdir = test_folder / "subdir"
255
+ subdir.mkdir()
256
+ (subdir / "nested.txt").write_text("nested content")
257
+
258
+ archive_path = _compress_folder(test_folder)
259
+
260
+ assert archive_path is not None
261
+
262
+ # Verify nested structure preserved
263
+ with tarfile.open(archive_path, "r:gz") as tar:
264
+ members = tar.getnames()
265
+ assert "trace-nested-789-model/subdir/nested.txt" in members
266
+
267
+ # Cleanup
268
+ archive_path.unlink()
269
+
270
+
271
+ class TestUploadToHuggingFace:
272
+ """Test HuggingFace upload functionality."""
273
+
274
+ def test_upload_success(self, mock_hf_api, temp_data_dir):
275
+ """Test successful upload to HuggingFace."""
276
+ # Create a test archive
277
+ archive_path = Path(temp_data_dir) / "test-archive.tar.gz"
278
+ archive_path.write_text("test archive content")
279
+
280
+ result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
281
+
282
+ assert result is True
283
+ mock_hf_api.upload_file.assert_called_once_with(
284
+ path_or_fileobj=str(archive_path),
285
+ path_in_repo="test-archive.tar.gz",
286
+ repo_id="test/repo",
287
+ repo_type="dataset",
288
+ )
289
+
290
+ def test_upload_failure(self, mock_hf_api, temp_data_dir):
291
+ """Test upload failure."""
292
+ mock_hf_api.upload_file.side_effect = Exception("Upload failed")
293
+
294
+ archive_path = Path(temp_data_dir) / "test-archive.tar.gz"
295
+ archive_path.write_text("test archive content")
296
+
297
+ result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
298
+
299
+ assert result is False
300
+
301
+ def test_upload_nonexistent_file(self, mock_hf_api, temp_data_dir):
302
+ """Test uploading a nonexistent file."""
303
+ archive_path = Path(temp_data_dir) / "nonexistent.tar.gz"
304
+
305
+ # Make the mock raise an exception when trying to upload nonexistent file
306
+ mock_hf_api.upload_file.side_effect = FileNotFoundError("File not found")
307
+
308
+ result = _upload_to_huggingface(mock_hf_api, "test/repo", archive_path)
309
+
310
+ assert result is False
311
+
312
+
313
+ class TestVerifyFileInRepo:
314
+ """Test file verification functionality."""
315
+
316
+ @patch("cua2_core.services.archival_service.hf_hub_download")
317
+ def test_verify_success(self, mock_download):
318
+ """Test successful file verification."""
319
+ mock_download.return_value = "/path/to/file"
320
+
321
+ result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
322
+
323
+ assert result is True
324
+ mock_download.assert_called_once()
325
+
326
+ @patch("cua2_core.services.archival_service.hf_hub_download")
327
+ def test_verify_file_not_found(self, mock_download):
328
+ """Test verification when file not found (404)."""
329
+ mock_response = Mock()
330
+ mock_response.status_code = 404
331
+ error = HfHubHTTPError("Not found", response=mock_response)
332
+ mock_download.side_effect = error
333
+
334
+ result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
335
+
336
+ assert result is False
337
+
338
+ @patch("cua2_core.services.archival_service.hf_hub_download")
339
+ def test_verify_other_http_error(self, mock_download):
340
+ """Test verification with other HTTP errors."""
341
+ mock_response = Mock()
342
+ mock_response.status_code = 500
343
+ error = HfHubHTTPError("Server error", response=mock_response)
344
+ mock_download.side_effect = error
345
+
346
+ result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
347
+
348
+ assert result is False
349
+
350
+ @patch("cua2_core.services.archival_service.hf_hub_download")
351
+ def test_verify_generic_exception(self, mock_download):
352
+ """Test verification with generic exception."""
353
+ mock_download.side_effect = Exception("Generic error")
354
+
355
+ result = _verify_file_in_repo("test/repo", "test_token", "test.tar.gz")
356
+
357
+ assert result is False
358
+
359
+
360
+ class TestProcessOldFolders:
361
+ """Test old folder processing logic."""
362
+
363
+ def test_process_old_folders_basic(self, temp_data_dir, mock_hf_api):
364
+ """Test processing old folders."""
365
+ # Create an old folder (modify mtime to make it old)
366
+ old_folder = Path(temp_data_dir) / "trace-old123-model"
367
+ old_folder.mkdir()
368
+ (old_folder / "data.json").write_text('{"test": "data"}')
369
+
370
+ # Make it old by modifying mtime
371
+ old_time = time.time() - 3600 # 1 hour ago
372
+ os.utime(old_folder, (old_time, old_time))
373
+
374
+ active_tasks = {}
375
+
376
+ with patch(
377
+ "cua2_core.services.archival_service._verify_file_in_repo",
378
+ return_value=True,
379
+ ):
380
+ _process_old_folders(
381
+ data_dir=temp_data_dir,
382
+ folder_age_threshold_minutes=1,
383
+ active_tasks=active_tasks,
384
+ hf_api=mock_hf_api,
385
+ hf_dataset_repo="test/repo",
386
+ hf_token="test_token",
387
+ )
388
+
389
+ # Folder should be deleted after successful archival
390
+ assert not old_folder.exists()
391
+
392
+ # Upload should have been called
393
+ assert mock_hf_api.upload_file.called
394
+
395
+ def test_process_folders_skips_active_tasks(self, temp_data_dir, mock_hf_api):
396
+ """Test that active tasks are skipped."""
397
+ # Create a folder for an active task
398
+ active_folder = Path(temp_data_dir) / "trace-active456-model"
399
+ active_folder.mkdir()
400
+ (active_folder / "data.json").write_text('{"test": "data"}')
401
+
402
+ # Make it old
403
+ old_time = time.time() - 3600
404
+ os.utime(active_folder, (old_time, old_time))
405
+
406
+ # Mark as active
407
+ active_tasks = {"active456": True}
408
+
409
+ _process_old_folders(
410
+ data_dir=temp_data_dir,
411
+ folder_age_threshold_minutes=1,
412
+ active_tasks=active_tasks,
413
+ hf_api=mock_hf_api,
414
+ hf_dataset_repo="test/repo",
415
+ hf_token="test_token",
416
+ )
417
+
418
+ # Folder should still exist (not archived)
419
+ assert active_folder.exists()
420
+
421
+ # Upload should not have been called
422
+ assert not mock_hf_api.upload_file.called
423
+
424
+ def test_process_folders_skips_recent(self, temp_data_dir, mock_hf_api):
425
+ """Test that recent folders are skipped."""
426
+ # Create a recent folder
427
+ recent_folder = Path(temp_data_dir) / "trace-recent789-model"
428
+ recent_folder.mkdir()
429
+ (recent_folder / "data.json").write_text('{"test": "data"}')
430
+
431
+ # Folder is fresh (current time)
432
+
433
+ active_tasks = {}
434
+
435
+ _process_old_folders(
436
+ data_dir=temp_data_dir,
437
+ folder_age_threshold_minutes=60, # 60 minutes threshold
438
+ active_tasks=active_tasks,
439
+ hf_api=mock_hf_api,
440
+ hf_dataset_repo="test/repo",
441
+ hf_token="test_token",
442
+ )
443
+
444
+ # Folder should still exist (too recent)
445
+ assert recent_folder.exists()
446
+
447
+ # Upload should not have been called
448
+ assert not mock_hf_api.upload_file.called
449
+
450
+ def test_process_folders_keeps_on_verification_failure(
451
+ self, temp_data_dir, mock_hf_api
452
+ ):
453
+ """Test that folders are kept if verification fails."""
454
+ old_folder = Path(temp_data_dir) / "trace-verify-fail-model"
455
+ old_folder.mkdir()
456
+ (old_folder / "data.json").write_text('{"test": "data"}')
457
+
458
+ old_time = time.time() - 3600
459
+ os.utime(old_folder, (old_time, old_time))
460
+
461
+ active_tasks = {}
462
+
463
+ # Mock verification to fail
464
+ with patch(
465
+ "cua2_core.services.archival_service._verify_file_in_repo",
466
+ return_value=False,
467
+ ):
468
+ _process_old_folders(
469
+ data_dir=temp_data_dir,
470
+ folder_age_threshold_minutes=1,
471
+ active_tasks=active_tasks,
472
+ hf_api=mock_hf_api,
473
+ hf_dataset_repo="test/repo",
474
+ hf_token="test_token",
475
+ )
476
+
477
+ # Folder should still exist (verification failed)
478
+ assert old_folder.exists()
479
+
480
+ def test_process_folders_handles_nonexistent_dir(self, mock_hf_api):
481
+ """Test handling of nonexistent data directory."""
482
+ # Should not raise exception
483
+ _process_old_folders(
484
+ data_dir="/nonexistent/path",
485
+ folder_age_threshold_minutes=1,
486
+ active_tasks={},
487
+ hf_api=mock_hf_api,
488
+ hf_dataset_repo="test/repo",
489
+ hf_token="test_token",
490
+ )
491
+
492
+ # No uploads should occur
493
+ assert not mock_hf_api.upload_file.called
494
+
495
+ def test_process_folders_handles_bad_folder_names(self, temp_data_dir, mock_hf_api):
496
+ """Test handling of folders with unexpected name format."""
497
+ # Create folder with bad name format
498
+ bad_folder = Path(temp_data_dir) / "trace-invalid" # Missing parts
499
+ bad_folder.mkdir()
500
+
501
+ old_time = time.time() - 3600
502
+ os.utime(bad_folder, (old_time, old_time))
503
+
504
+ # Should not raise exception
505
+ _process_old_folders(
506
+ data_dir=temp_data_dir,
507
+ folder_age_threshold_minutes=1,
508
+ active_tasks={},
509
+ hf_api=mock_hf_api,
510
+ hf_dataset_repo="test/repo",
511
+ hf_token="test_token",
512
+ )
513
+
514
+ # Folder should still exist (invalid name)
515
+ assert bad_folder.exists()
516
+
517
+
518
+ class TestIntegration:
519
+ """Integration tests for the complete archival workflow."""
520
+
521
+ @pytest.mark.slow
522
+ def test_full_archival_workflow(self, temp_data_dir):
523
+ """Test the complete archival workflow end-to-end."""
524
+ # Create service
525
+ service = ArchivalService(
526
+ hf_token="test_token",
527
+ hf_dataset_repo="test/repo",
528
+ data_dir=temp_data_dir,
529
+ archive_interval_minutes=1,
530
+ folder_age_threshold_minutes=1,
531
+ )
532
+
533
+ # Create test folder
534
+ test_folder = Path(temp_data_dir) / "trace-integration-test-model"
535
+ test_folder.mkdir()
536
+ (test_folder / "test.json").write_text('{"test": "data"}')
537
+
538
+ # Make it old
539
+ old_time = time.time() - 3600
540
+ os.utime(test_folder, (old_time, old_time))
541
+
542
+ # Start service (mocked to prevent actual HF upload)
543
+ with (
544
+ patch(
545
+ "cua2_core.services.archival_service._upload_to_huggingface",
546
+ return_value=True,
547
+ ),
548
+ patch(
549
+ "cua2_core.services.archival_service._verify_file_in_repo",
550
+ return_value=True,
551
+ ),
552
+ ):
553
+ service.start()
554
+ time.sleep(2) # Wait for at least one cycle
555
+
556
+ service.stop(timeout=5.0)
557
+
558
+ assert not service.is_alive()
559
+
560
+ def test_service_survives_worker_errors(self, temp_data_dir):
561
+ """Test that service continues running despite worker errors."""
562
+ service = ArchivalService(
563
+ hf_token="test_token",
564
+ hf_dataset_repo="test/repo",
565
+ data_dir=temp_data_dir,
566
+ archive_interval_minutes=1,
567
+ folder_age_threshold_minutes=1,
568
+ )
569
+
570
+ # Mock to raise exception
571
+ with patch(
572
+ "cua2_core.services.archival_service._process_old_folders",
573
+ side_effect=Exception("Test error"),
574
+ ):
575
+ service.start()
576
+ time.sleep(2)
577
+
578
+ # Service should still be alive
579
+ assert service.is_alive()
580
+
581
+ service.stop(timeout=5.0)
582
+
583
+
584
+ if __name__ == "__main__":
585
+ pytest.main([__file__, "-v", "--tb=short"])
cua2-front/src/App.tsx CHANGED
@@ -15,9 +15,10 @@ const App = () => {
15
  // Initialize WebSocket connection at app level so it persists across route changes
16
  const { stopCurrentTask } = useAgentWebSocket({ url: getWebSocketUrl() });
17
 
18
- // Store stopCurrentTask in window for global access
19
  (window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask = stopCurrentTask;
20
 
 
21
  return (
22
  <ThemeProvider theme={theme}>
23
  <CssBaseline />
 
15
  // Initialize WebSocket connection at app level so it persists across route changes
16
  const { stopCurrentTask } = useAgentWebSocket({ url: getWebSocketUrl() });
17
 
18
+ // Store functions in window for global access
19
  (window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask = stopCurrentTask;
20
 
21
+
22
  return (
23
  <ThemeProvider theme={theme}>
24
  <CssBaseline />
cua2-front/src/components/sandbox/completionview/CompletionView.tsx CHANGED
@@ -1,18 +1,18 @@
1
- import React from 'react';
2
- import { Box, Typography, Button, Divider, Alert, Paper } from '@mui/material';
3
- import CheckIcon from '@mui/icons-material/Check';
4
- import CloseIcon from '@mui/icons-material/Close';
5
- import StopCircleIcon from '@mui/icons-material/StopCircle';
6
- import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
7
  import AddIcon from '@mui/icons-material/Add';
8
- import SmartToyIcon from '@mui/icons-material/SmartToy';
9
  import AssignmentIcon from '@mui/icons-material/Assignment';
10
  import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline';
11
- import AccessTimeIcon from '@mui/icons-material/AccessTime';
 
 
 
12
  import InputIcon from '@mui/icons-material/Input';
13
  import OutputIcon from '@mui/icons-material/Output';
14
- import FormatListNumberedIcon from '@mui/icons-material/FormatListNumbered';
15
- import { FinalStep, AgentTrace, AgentStep } from '@/types/agent';
 
 
16
  import { DownloadGifButton } from './DownloadGifButton';
17
  import { DownloadJsonButton } from './DownloadJsonButton';
18
 
@@ -65,14 +65,14 @@ export const CompletionView: React.FC<CompletionViewProps> = ({
65
  case 'sandbox_timeout':
66
  return {
67
  icon: <AccessTimeIcon sx={{ fontSize: 28 }} />,
68
- title: 'Sandbox Timeout',
69
  color: 'error.main',
70
  };
71
  case 'failure':
72
  default:
73
  return {
74
  icon: <CloseIcon sx={{ fontSize: 28 }} />,
75
- title: 'Task Failed',
76
  color: 'error.main',
77
  };
78
  }
 
1
+ import { AgentStep, AgentTrace, FinalStep } from '@/types/agent';
2
+ import AccessTimeIcon from '@mui/icons-material/AccessTime';
 
 
 
 
3
  import AddIcon from '@mui/icons-material/Add';
 
4
  import AssignmentIcon from '@mui/icons-material/Assignment';
5
  import ChatBubbleOutlineIcon from '@mui/icons-material/ChatBubbleOutline';
6
+ import CheckIcon from '@mui/icons-material/Check';
7
+ import CloseIcon from '@mui/icons-material/Close';
8
+ import FormatListNumberedIcon from '@mui/icons-material/FormatListNumbered';
9
+ import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty';
10
  import InputIcon from '@mui/icons-material/Input';
11
  import OutputIcon from '@mui/icons-material/Output';
12
+ import SmartToyIcon from '@mui/icons-material/SmartToy';
13
+ import StopCircleIcon from '@mui/icons-material/StopCircle';
14
+ import { Alert, Box, Button, Divider, Paper, Typography } from '@mui/material';
15
+ import React from 'react';
16
  import { DownloadGifButton } from './DownloadGifButton';
17
  import { DownloadJsonButton } from './DownloadJsonButton';
18
 
 
65
  case 'sandbox_timeout':
66
  return {
67
  icon: <AccessTimeIcon sx={{ fontSize: 28 }} />,
68
+ title: 'Max Sandbox Time Reached',
69
  color: 'error.main',
70
  };
71
  case 'failure':
72
  default:
73
  return {
74
  icon: <CloseIcon sx={{ fontSize: 28 }} />,
75
+ title: 'Task Failed (Agent Internal Error)',
76
  color: 'error.main',
77
  };
78
  }
cua2-front/src/hooks/useAgentWebSocket.ts CHANGED
@@ -1,8 +1,7 @@
 
 
1
  import { useCallback, useEffect } from 'react';
2
  import { useWebSocket } from './useWebSocket';
3
- import { useAgentStore } from '@/stores/agentStore';
4
- import { WebSocketEvent, AgentTrace, AgentStep } from '@/types/agent';
5
- import { ulid } from 'ulid';
6
 
7
  interface UseAgentWebSocketOptions {
8
  url: string;
@@ -11,6 +10,8 @@ interface UseAgentWebSocketOptions {
11
  export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
12
  const {
13
  setTrace,
 
 
14
  updateTraceWithStep,
15
  completeTrace,
16
  setIsAgentProcessing,
@@ -52,6 +53,7 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
52
  numberOfSteps: 0,
53
  maxSteps: 200,
54
  completed: false,
 
55
  },
56
  };
57
 
@@ -94,10 +96,13 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
94
 
95
  case 'heartbeat':
96
  console.log('Heartbeat received:', event);
 
 
97
  break;
 
98
  }
99
  },
100
- [setTrace, updateTraceWithStep, completeTrace, setIsAgentProcessing, setIsConnectingToE2B, setVncUrl, setError, resetAgent]
101
  );
102
 
103
  // Handle WebSocket errors
@@ -113,10 +118,16 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
113
  onError: handleWebSocketError,
114
  });
115
 
116
- // Sync connection state to store
117
  useEffect(() => {
118
  setIsConnected(isConnected);
119
- }, [isConnected, setIsConnected]);
 
 
 
 
 
 
120
 
121
  // Create a global sendNewTask function that can be called from anywhere
122
  useEffect(() => {
@@ -125,7 +136,13 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
125
  // Reset agent state before starting a new task
126
  resetAgent();
127
 
128
- const traceId = ulid();
 
 
 
 
 
 
129
  const trace: AgentTrace = {
130
  id: traceId,
131
  instruction,
@@ -140,7 +157,8 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
140
  numberOfSteps: 0,
141
  maxSteps: 200, // Default max steps, will be updated by backend
142
  completed: false,
143
- },
 
144
  };
145
 
146
  setTrace(trace);
@@ -155,7 +173,7 @@ export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
155
 
156
  console.log('Task sent:', trace);
157
  };
158
- }, [setTrace, setIsAgentProcessing, setIsConnectingToE2B, sendMessage, resetAgent]);
159
 
160
  // Function to stop the current task
161
  const stopCurrentTask = useCallback(() => {
 
1
+ import { useAgentStore } from '@/stores/agentStore';
2
+ import { AgentTrace, AgentTraceMetadata, WebSocketEvent } from '@/types/agent';
3
  import { useCallback, useEffect } from 'react';
4
  import { useWebSocket } from './useWebSocket';
 
 
 
5
 
6
  interface UseAgentWebSocketOptions {
7
  url: string;
 
10
  export const useAgentWebSocket = ({ url }: UseAgentWebSocketOptions) => {
11
  const {
12
  setTrace,
13
+ traceId,
14
+ setTraceId,
15
  updateTraceWithStep,
16
  completeTrace,
17
  setIsAgentProcessing,
 
53
  numberOfSteps: 0,
54
  maxSteps: 200,
55
  completed: false,
56
+ final_state: null,
57
  },
58
  };
59
 
 
96
 
97
  case 'heartbeat':
98
  console.log('Heartbeat received:', event);
99
+ setTraceId(event.uuid);
100
+ console.log('TraceId set from backend:', event.uuid);
101
  break;
102
+
103
  }
104
  },
105
+ [setTrace, updateTraceWithStep, completeTrace, setIsAgentProcessing, setIsConnectingToE2B, setVncUrl, setError, resetAgent, setTraceId, traceId]
106
  );
107
 
108
  // Handle WebSocket errors
 
118
  onError: handleWebSocketError,
119
  });
120
 
121
+ // Sync connection state to store and clear traceId on disconnect
122
  useEffect(() => {
123
  setIsConnected(isConnected);
124
+
125
+ // Clear traceId when websocket disconnects
126
+ if (!isConnected) {
127
+ setTraceId(null);
128
+ console.log('WebSocket disconnected - traceId cleared');
129
+ }
130
+ }, [isConnected, setIsConnected, setTraceId]);
131
 
132
  // Create a global sendNewTask function that can be called from anywhere
133
  useEffect(() => {
 
136
  // Reset agent state before starting a new task
137
  resetAgent();
138
 
139
+ // Ensure traceId is set before creating trace
140
+ if (!traceId) {
141
+ console.error('Internal error: Cannot send task. TraceId not set. Refreshing page...');
142
+ window.location.reload();
143
+ return;
144
+ }
145
+
146
  const trace: AgentTrace = {
147
  id: traceId,
148
  instruction,
 
157
  numberOfSteps: 0,
158
  maxSteps: 200, // Default max steps, will be updated by backend
159
  completed: false,
160
+ final_state: null,
161
+ } as AgentTraceMetadata,
162
  };
163
 
164
  setTrace(trace);
 
173
 
174
  console.log('Task sent:', trace);
175
  };
176
+ }, [setTrace, setIsAgentProcessing, setIsConnectingToE2B, sendMessage, resetAgent, traceId]);
177
 
178
  // Function to stop the current task
179
  const stopCurrentTask = useCallback(() => {
cua2-front/src/pages/Index.tsx DELETED
@@ -1,147 +0,0 @@
1
- import { Header, Metadata, StackSteps, VNCStream } from '@/components/mock';
2
- import { getWebSocketUrl } from '@/config/api';
3
- import { useWebSocket } from '@/hooks/useWebSocket';
4
- import { AgentStep, AgentTrace, WebSocketEvent } from '@/types/agent';
5
- import { useState } from 'react';
6
- import { ulid } from 'ulid';
7
-
8
- const Index = () => {
9
- const [trace, setTrace] = useState<AgentTrace>();
10
- const [isAgentProcessing, setIsAgentProcessing] = useState(false);
11
- const [vncUrl, setVncUrl] = useState<string>('');
12
- const [selectedModelId, setSelectedModelId] = useState<string>("Qwen/Qwen3-VL-30B-A3B-Instruct");
13
-
14
- // #################### WebSocket Connection ########################
15
-
16
- // WebSocket connection - Automatically configured based on environment
17
- const WS_URL = getWebSocketUrl();
18
-
19
- const handleWebSocketMessage = (event: WebSocketEvent) => {
20
- console.log('WebSocket event received:', event);
21
-
22
- switch (event.type) {
23
- case 'agent_start':
24
- setIsAgentProcessing(true);
25
- setTrace(event.agentTrace);
26
- console.log('Agent start received:', event.agentTrace);
27
- break;
28
-
29
- case 'agent_progress':
30
- // Add new step from a agent trace run with image, generated text, actions, tokens and timestamp
31
- setTrace(prev => {
32
- const existingSteps = prev?.steps || [] as AgentStep[];
33
- const stepExists = existingSteps.some(step => step.stepId === event.agentStep.stepId);
34
-
35
- if (!stepExists) {
36
- return {
37
- ...prev,
38
- steps: [...existingSteps, event.agentStep],
39
- traceMetadata: event.traceMetadata,
40
- isRunning: true
41
- };
42
- }
43
- return prev;
44
- });
45
- console.log('Agent progress received:', event.agentStep);
46
- break;
47
-
48
- case 'agent_complete':
49
- setIsAgentProcessing(false);
50
- setTrace(trace => {
51
- return trace.id === event.traceMetadata.traceId
52
- ? {
53
- ...trace,
54
- isRunning: false,
55
- metadata: event.traceMetadata,
56
- }
57
- : trace;
58
- });
59
- console.log('Agent complete received:', event.traceMetadata);
60
- break;
61
-
62
- case 'agent_error':
63
- setIsAgentProcessing(false);
64
- // TODO: Handle agent error
65
- console.log('Agent error received:', event.error);
66
- break;
67
-
68
- case 'vnc_url_set':
69
- setVncUrl(event.vncUrl);
70
- // TODO: Handle VNC URL set
71
- console.log('VNC URL set received:', event.vncUrl);
72
- break;
73
-
74
- case 'vnc_url_unset':
75
- setVncUrl('');
76
- // TODO: Handle VNC URL unset
77
- console.log('VNC URL unset received:');
78
- break;
79
-
80
- case 'heartbeat':
81
- console.log('Heartbeat received:', event);
82
- break;
83
- }
84
- };
85
-
86
- const handleWebSocketError = () => {
87
- // WebSocket Frontend Error handling
88
-
89
- };
90
-
91
- const { isConnected, connectionState, sendMessage, manualReconnect } = useWebSocket({
92
- url: WS_URL,
93
- onMessage: handleWebSocketMessage,
94
- onError: handleWebSocketError,
95
- });
96
-
97
- // #################### Frontend Functionality ########################
98
-
99
- const handleModelId = (modelId: string) => {
100
- setSelectedModelId(modelId);
101
- };
102
-
103
- const handleSendNewTask = (content: string, modelId: string) => {
104
- const trace: AgentTrace = {
105
- id: ulid(),
106
- instruction: content,
107
- modelId: selectedModelId,
108
- timestamp: new Date(),
109
- isRunning: true,
110
- };
111
-
112
- setTrace(trace);
113
-
114
- // Send message to Python backend via WebSocket
115
- sendMessage({
116
- type: 'user_task',
117
- trace: trace,
118
- });
119
- };
120
-
121
- // #################### Mock Frontend Rendering ########################
122
-
123
- return (
124
- <div style={{ height: '100%', width: '100%', display: 'flex', flexDirection: 'column', backgroundColor: '#f3f4f6' }}>
125
- <Header
126
- isConnected={isConnected}
127
- isAgentProcessing={isAgentProcessing}
128
- onSendTask={handleSendNewTask}
129
- />
130
-
131
- <div style={{ flex: 1, display: 'flex', justifyContent: 'center', alignItems: 'center', overflow: 'hidden', minHeight: 0, padding: '32px' }}>
132
- <div style={{ width: '100%', height: '100%', maxWidth: '1400px', maxHeight: '900px', display: 'flex', flexDirection: 'row', overflow: 'hidden' }}>
133
- {/* Left Side: VNC Stream + Metadata */}
134
- <div style={{ flex: 1, display: 'flex', flexDirection: 'column', padding: '20px 12px', gap: '20px', minWidth: 0 }}>
135
- <VNCStream vncUrl={vncUrl} />
136
- <Metadata trace={trace} />
137
- </div>
138
-
139
- {/* Right Side: Stack Steps */}
140
- <StackSteps trace={trace} />
141
- </div>
142
- </div>
143
- </div>
144
- );
145
- };
146
-
147
- export default Index;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cua2-front/src/pages/Task.tsx CHANGED
@@ -1,8 +1,8 @@
1
- import React, { useEffect } from 'react';
2
- import { useNavigate } from 'react-router-dom';
3
- import { useAgentStore, selectTrace, selectIsAgentProcessing, selectVncUrl, selectMetadata, selectSelectedStep } from '@/stores/agentStore';
4
  import { Header, SandboxViewer, StepsList, Timeline } from '@/components';
 
5
  import { Box } from '@mui/material';
 
 
6
 
7
  const Task = () => {
8
  const navigate = useNavigate();
@@ -25,8 +25,19 @@ const Task = () => {
25
 
26
  // Handler for going back to home
27
  const handleBackToHome = () => {
 
 
 
 
 
 
 
 
 
28
  useAgentStore.getState().resetAgent();
29
- navigate('/');
 
 
30
  };
31
 
32
  // Determine if we should show success/fail status (same logic as SandboxViewer)
@@ -66,15 +77,15 @@ const Task = () => {
66
  overflowX: 'hidden',
67
  }}
68
  >
69
- <Box
70
- sx={{
71
- width: '100%',
72
- display: 'flex',
73
- flexDirection: { xs: 'column', md: 'row' },
74
- p: { xs: 2, md: 4 },
75
- pb: { xs: 2, md: 3 },
76
- }}
77
- >
78
  {/* Left Side: OS Stream + Metadata */}
79
  <Box
80
  sx={{
 
 
 
 
1
  import { Header, SandboxViewer, StepsList, Timeline } from '@/components';
2
+ import { selectIsAgentProcessing, selectMetadata, selectSelectedStep, selectTrace, selectVncUrl, useAgentStore } from '@/stores/agentStore';
3
  import { Box } from '@mui/material';
4
+ import { useEffect } from 'react';
5
+ import { useNavigate } from 'react-router-dom';
6
 
7
  const Task = () => {
8
  const navigate = useNavigate();
 
25
 
26
  // Handler for going back to home
27
  const handleBackToHome = () => {
28
+ const currentTrace = useAgentStore.getState().trace;
29
+
30
+ // Stop the current task if it's running
31
+ const stopTask = (window as Window & { __stopCurrentTask?: () => void }).__stopCurrentTask;
32
+ if (stopTask) {
33
+ stopTask();
34
+ }
35
+
36
+ // Reset frontend state
37
  useAgentStore.getState().resetAgent();
38
+
39
+ // Reload the page to reconnect websocket
40
+ window.location.href = '/';
41
  };
42
 
43
  // Determine if we should show success/fail status (same logic as SandboxViewer)
 
77
  overflowX: 'hidden',
78
  }}
79
  >
80
+ <Box
81
+ sx={{
82
+ width: '100%',
83
+ display: 'flex',
84
+ flexDirection: { xs: 'column', md: 'row' },
85
+ p: { xs: 2, md: 4 },
86
+ pb: { xs: 2, md: 3 },
87
+ }}
88
+ >
89
  {/* Left Side: OS Stream + Metadata */}
90
  <Box
91
  sx={{
cua2-front/src/stores/agentStore.ts CHANGED
@@ -5,6 +5,7 @@ import { devtools } from 'zustand/middleware';
5
  interface AgentState {
6
  // State
7
  trace?: AgentTrace;
 
8
  isAgentProcessing: boolean;
9
  isConnectingToE2B: boolean; // New state for E2B connection
10
  vncUrl: string;
@@ -19,6 +20,7 @@ interface AgentState {
19
 
20
  // Actions
21
  setTrace: (trace: AgentTrace | undefined) => void;
 
22
  updateTraceWithStep: (step: AgentStep, metadata: AgentTraceMetadata) => void;
23
  completeTrace: (metadata: AgentTraceMetadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') => void;
24
  setIsAgentProcessing: (processing: boolean) => void;
@@ -36,6 +38,7 @@ interface AgentState {
36
 
37
  const initialState = {
38
  trace: undefined,
 
39
  isAgentProcessing: false,
40
  isConnectingToE2B: false,
41
  vncUrl: '',
@@ -58,6 +61,10 @@ export const useAgentStore = create<AgentState>()(
58
  setTrace: (trace) =>
59
  set({ trace }, false, 'setTrace'),
60
 
 
 
 
 
61
  // Update trace with a new step
62
  updateTraceWithStep: (step, metadata) =>
63
  set(
@@ -90,62 +97,62 @@ export const useAgentStore = create<AgentState>()(
90
  'updateTraceWithStep'
91
  ),
92
 
93
- // Complete the trace
94
- completeTrace: (metadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') =>
95
- set(
96
- (state) => {
97
- if (!state.trace) return state;
98
-
99
- // Preserve existing maxSteps if new metadata has 0
100
- const updatedMetadata = {
101
- ...metadata,
102
- maxSteps: metadata.maxSteps > 0
103
- ? metadata.maxSteps
104
- : (state.trace.traceMetadata?.maxSteps || 200),
105
- completed: true,
106
- };
107
-
108
- // Determine the final step type based on final_state from backend
109
- let stepType: 'success' | 'failure' | 'stopped' | 'max_steps_reached' | 'sandbox_timeout';
110
- let stepMessage: string | undefined;
111
-
112
- if (finalState === 'stopped') {
113
- stepType = 'stopped';
114
- stepMessage = 'Task stopped by user';
115
- } else if (finalState === 'max_steps_reached') {
116
- stepType = 'max_steps_reached';
117
- stepMessage = 'Maximum steps reached';
118
- } else if (finalState === 'sandbox_timeout') {
119
- stepType = 'sandbox_timeout';
120
- stepMessage = 'Sandbox timeout';
121
- } else if (finalState === 'error' || state.error) {
122
- stepType = 'failure';
123
- stepMessage = state.error || 'Task failed';
124
- } else {
125
- stepType = 'success';
126
- stepMessage = undefined;
127
- }
128
-
129
- const finalStep: FinalStep = {
130
- type: stepType,
131
- message: stepMessage,
132
- metadata: updatedMetadata,
133
- };
134
-
135
- return {
136
- trace: {
137
- ...state.trace,
138
- isRunning: false,
139
- traceMetadata: updatedMetadata,
 
 
 
 
 
140
  },
141
- finalStep,
142
- // Keep error in state for display
143
- selectedStepIndex: null, // Reset to live mode on completion
144
- };
145
- },
146
- false,
147
- 'completeTrace'
148
- ),
149
 
150
  // Set processing state
151
  setIsAgentProcessing: (isAgentProcessing) =>
@@ -227,10 +234,11 @@ export const useAgentStore = create<AgentState>()(
227
  toggleDarkMode: () =>
228
  set((state) => ({ isDarkMode: !state.isDarkMode }), false, 'toggleDarkMode'),
229
 
230
- // Reset agent state
231
  resetAgent: () =>
232
  set((state) => ({
233
  ...initialState,
 
234
  isDarkMode: state.isDarkMode, // Keep dark mode preference
235
  isConnected: state.isConnected, // Keep connection status
236
  selectedModelId: state.selectedModelId, // Keep selected model
@@ -244,6 +252,7 @@ export const useAgentStore = create<AgentState>()(
244
 
245
  // Selectors for better performance
246
  export const selectTrace = (state: AgentState) => state.trace;
 
247
  export const selectIsAgentProcessing = (state: AgentState) => state.isAgentProcessing;
248
  export const selectIsConnectingToE2B = (state: AgentState) => state.isConnectingToE2B;
249
  export const selectVncUrl = (state: AgentState) => state.vncUrl;
 
5
  interface AgentState {
6
  // State
7
  trace?: AgentTrace;
8
+ traceId: string | null; // Set by backend heartbeat, persists during connection
9
  isAgentProcessing: boolean;
10
  isConnectingToE2B: boolean; // New state for E2B connection
11
  vncUrl: string;
 
20
 
21
  // Actions
22
  setTrace: (trace: AgentTrace | undefined) => void;
23
+ setTraceId: (traceId: string | null) => void;
24
  updateTraceWithStep: (step: AgentStep, metadata: AgentTraceMetadata) => void;
25
  completeTrace: (metadata: AgentTraceMetadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') => void;
26
  setIsAgentProcessing: (processing: boolean) => void;
 
38
 
39
  const initialState = {
40
  trace: undefined,
41
+ traceId: null, // Will be set by backend heartbeat
42
  isAgentProcessing: false,
43
  isConnectingToE2B: false,
44
  vncUrl: '',
 
61
  setTrace: (trace) =>
62
  set({ trace }, false, 'setTrace'),
63
 
64
+ // Set trace ID (set by backend heartbeat, only cleared on disconnect)
65
+ setTraceId: (traceId) =>
66
+ set({ traceId }, false, 'setTraceId'),
67
+
68
  // Update trace with a new step
69
  updateTraceWithStep: (step, metadata) =>
70
  set(
 
97
  'updateTraceWithStep'
98
  ),
99
 
100
+ // Complete the trace
101
+ completeTrace: (metadata, finalState?: 'success' | 'stopped' | 'max_steps_reached' | 'error' | 'sandbox_timeout') =>
102
+ set(
103
+ (state) => {
104
+ if (!state.trace) return state;
105
+
106
+ // Preserve existing maxSteps if new metadata has 0
107
+ const updatedMetadata = {
108
+ ...metadata,
109
+ maxSteps: metadata.maxSteps > 0
110
+ ? metadata.maxSteps
111
+ : (state.trace.traceMetadata?.maxSteps || 200),
112
+ completed: true,
113
+ };
114
+
115
+ // Determine the final step type based on final_state from backend
116
+ let stepType: 'success' | 'failure' | 'stopped' | 'max_steps_reached' | 'sandbox_timeout';
117
+ let stepMessage: string | undefined;
118
+
119
+ if (finalState === 'stopped') {
120
+ stepType = 'stopped';
121
+ stepMessage = 'Task stopped by user';
122
+ } else if (finalState === 'max_steps_reached') {
123
+ stepType = 'max_steps_reached';
124
+ stepMessage = 'Maximum steps reached';
125
+ } else if (finalState === 'sandbox_timeout') {
126
+ stepType = 'sandbox_timeout';
127
+ stepMessage = 'Sandbox timeout';
128
+ } else if (finalState === 'error' || state.error) {
129
+ stepType = 'failure';
130
+ stepMessage = state.error || 'Task failed';
131
+ } else {
132
+ stepType = 'success';
133
+ stepMessage = undefined;
134
+ }
135
+
136
+ const finalStep: FinalStep = {
137
+ type: stepType,
138
+ message: stepMessage,
139
+ metadata: updatedMetadata,
140
+ };
141
+
142
+ return {
143
+ trace: {
144
+ ...state.trace,
145
+ isRunning: false,
146
+ traceMetadata: updatedMetadata,
147
+ },
148
+ finalStep,
149
+ // Keep error in state for display
150
+ selectedStepIndex: null, // Reset to live mode on completion
151
+ };
152
  },
153
+ false,
154
+ 'completeTrace'
155
+ ),
 
 
 
 
 
156
 
157
  // Set processing state
158
  setIsAgentProcessing: (isAgentProcessing) =>
 
234
  toggleDarkMode: () =>
235
  set((state) => ({ isDarkMode: !state.isDarkMode }), false, 'toggleDarkMode'),
236
 
237
+ // Reset agent state (but preserve traceId from backend during connection)
238
  resetAgent: () =>
239
  set((state) => ({
240
  ...initialState,
241
+ traceId: state.traceId, // IMPORTANT: Keep traceId from backend
242
  isDarkMode: state.isDarkMode, // Keep dark mode preference
243
  isConnected: state.isConnected, // Keep connection status
244
  selectedModelId: state.selectedModelId, // Keep selected model
 
252
 
253
  // Selectors for better performance
254
  export const selectTrace = (state: AgentState) => state.trace;
255
+ export const selectTraceId = (state: AgentState) => state.traceId;
256
  export const selectIsAgentProcessing = (state: AgentState) => state.isAgentProcessing;
257
  export const selectIsConnectingToE2B = (state: AgentState) => state.isConnectingToE2B;
258
  export const selectVncUrl = (state: AgentState) => state.vncUrl;
cua2-front/src/types/agent.ts CHANGED
@@ -80,6 +80,7 @@ interface VncUrlUnsetEvent {
80
 
81
  interface HeartbeatEvent {
82
  type: 'heartbeat';
 
83
  }
84
 
85
  export type WebSocketEvent =
 
80
 
81
  interface HeartbeatEvent {
82
  type: 'heartbeat';
83
+ uuid: string;
84
  }
85
 
86
  export type WebSocketEvent =