Duibonduil's picture
Upload 8 files
a27d8ba verified
# coding: utf-8
import json
import traceback
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Dict, List
from openai import RateLimitError
from pydantic import BaseModel, ConfigDict, Field
from aworld.core.common import ActionResult
class PolicyMetadata(BaseModel):
"""Metadata for a single step including timing information"""
start_time: float
end_time: float
number: int
input_tokens: int
@property
def duration_seconds(self) -> float:
"""Calculate step duration in seconds"""
return self.end_time - self.start_time
class AgentBrain(BaseModel):
"""Current state of the agent"""
evaluation_previous_goal: str = None
memory: str = None
thought: str = None
next_goal: str = None
class AgentHistory(BaseModel):
"""History item for agent actions"""
model_output: Optional[BaseModel] = None
result: List[ActionResult]
metadata: Optional[PolicyMetadata] = None
content: Optional[str] = None
base64_img: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""Custom serialization handling"""
return {
'model_output': self.model_output.model_dump() if self.model_output else None,
'result': [r.model_dump(exclude_none=True) for r in self.result],
'metadata': self.metadata.model_dump() if self.metadata else None,
'content': self.xml_content,
'base64_img': self.base64_img
}
class AgentHistoryList(BaseModel):
"""List of agent history items"""
history: List[AgentHistory]
def total_duration_seconds(self) -> float:
"""Get total duration of all steps in seconds"""
total = 0.0
for h in self.history:
if h.metadata:
total += h.metadata.duration_seconds
return total
def save_to_file(self, filepath: str | Path) -> None:
"""Save history to JSON file with proper serialization"""
try:
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
data = self.model_dump()
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
except Exception as e:
raise e
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""Custom serialization that properly uses AgentHistory's model_dump"""
return {
'history': [h.model_dump(**kwargs) for h in self.history],
}
@classmethod
def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList':
"""Load history from JSON file"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.model_validate(data)
class AgentError:
"""Container for agent error handling"""
VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.'
RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.'
NO_VALID_ACTION = 'No valid action found'
@staticmethod
def format_error(error: Exception, include_trace: bool = False) -> str:
"""Format error message based on error type and optionally include trace"""
if isinstance(error, RateLimitError):
return AgentError.RATE_LIMIT_ERROR
if include_trace:
return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}'
return f'{str(error)}'
class AgentState(BaseModel):
"""Holds all state information for an Agent"""
agent_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
n_steps: int = 1
consecutive_failures: int = 0
last_result: Optional[List['ActionResult']] = None
history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[]))
last_plan: Optional[str] = None
paused: bool = False
stopped: bool = False
@dataclass
class AgentStepInfo:
number: int
max_steps: int
def is_last_step(self) -> bool:
"""Check if this is the last step"""
return self.number >= self.max_steps - 1