Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
import json | |
import traceback | |
import uuid | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Optional, Dict, List | |
from openai import RateLimitError | |
from pydantic import BaseModel, ConfigDict, Field | |
from aworld.core.common import ActionResult | |
class PolicyMetadata(BaseModel): | |
"""Metadata for a single step including timing information""" | |
start_time: float | |
end_time: float | |
number: int | |
input_tokens: int | |
def duration_seconds(self) -> float: | |
"""Calculate step duration in seconds""" | |
return self.end_time - self.start_time | |
class AgentBrain(BaseModel): | |
"""Current state of the agent""" | |
evaluation_previous_goal: str = None | |
memory: str = None | |
thought: str = None | |
next_goal: str = None | |
class AgentHistory(BaseModel): | |
"""History item for agent actions""" | |
model_output: Optional[BaseModel] = None | |
result: List[ActionResult] | |
metadata: Optional[PolicyMetadata] = None | |
content: Optional[str] = None | |
base64_img: Optional[str] = None | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
def model_dump(self, **kwargs) -> Dict[str, Any]: | |
"""Custom serialization handling""" | |
return { | |
'model_output': self.model_output.model_dump() if self.model_output else None, | |
'result': [r.model_dump(exclude_none=True) for r in self.result], | |
'metadata': self.metadata.model_dump() if self.metadata else None, | |
'content': self.xml_content, | |
'base64_img': self.base64_img | |
} | |
class AgentHistoryList(BaseModel): | |
"""List of agent history items""" | |
history: List[AgentHistory] | |
def total_duration_seconds(self) -> float: | |
"""Get total duration of all steps in seconds""" | |
total = 0.0 | |
for h in self.history: | |
if h.metadata: | |
total += h.metadata.duration_seconds | |
return total | |
def save_to_file(self, filepath: str | Path) -> None: | |
"""Save history to JSON file with proper serialization""" | |
try: | |
Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
data = self.model_dump() | |
with open(filepath, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=2) | |
except Exception as e: | |
raise e | |
def model_dump(self, **kwargs) -> Dict[str, Any]: | |
"""Custom serialization that properly uses AgentHistory's model_dump""" | |
return { | |
'history': [h.model_dump(**kwargs) for h in self.history], | |
} | |
def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList': | |
"""Load history from JSON file""" | |
with open(filepath, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
return cls.model_validate(data) | |
class AgentError: | |
"""Container for agent error handling""" | |
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.' | |
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.' | |
NO_VALID_ACTION = 'No valid action found' | |
def format_error(error: Exception, include_trace: bool = False) -> str: | |
"""Format error message based on error type and optionally include trace""" | |
if isinstance(error, RateLimitError): | |
return AgentError.RATE_LIMIT_ERROR | |
if include_trace: | |
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}' | |
return f'{str(error)}' | |
class AgentState(BaseModel): | |
"""Holds all state information for an Agent""" | |
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
n_steps: int = 1 | |
consecutive_failures: int = 0 | |
last_result: Optional[List['ActionResult']] = None | |
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[])) | |
last_plan: Optional[str] = None | |
paused: bool = False | |
stopped: bool = False | |
class AgentStepInfo: | |
number: int | |
max_steps: int | |
def is_last_step(self) -> bool: | |
"""Check if this is the last step""" | |
return self.number >= self.max_steps - 1 | |