Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| import json | |
| import traceback | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Optional, Dict, List | |
| from langchain_core.load import dumpd, load | |
| from langchain_core.messages import BaseMessage, AIMessage, ToolMessage, SystemMessage, HumanMessage | |
| from openai import RateLimitError | |
| from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator | |
| from aworld.core.agent.base import AgentResult | |
| from aworld.core.common import ActionResult, Observation | |
| class MessageMetadata(BaseModel): | |
| """Metadata for a message""" | |
| tokens: int = 0 | |
| class ManagedMessage(BaseModel): | |
| """A message with its metadata""" | |
| message: BaseMessage | |
| metadata: MessageMetadata = Field(default_factory=MessageMetadata) | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| # https://github.com/pydantic/pydantic/discussions/7558 | |
| def to_json(self, original_dump): | |
| """ | |
| Returns the JSON representation of the model. | |
| It uses langchain's `dumps` function to serialize the `message` | |
| property before encoding the overall dict with json.dumps. | |
| """ | |
| data = original_dump(self) | |
| # NOTE: We override the message field to use langchain JSON serialization. | |
| data['message'] = dumpd(self.message) | |
| return data | |
| def validate( | |
| cls, | |
| value: Any, | |
| *, | |
| strict: bool | None = None, | |
| from_attributes: bool | None = None, | |
| context: Any | None = None, | |
| ) -> Any: | |
| """ | |
| Custom validator that uses langchain's `loads` function | |
| to parse the message if it is provided as a JSON string. | |
| """ | |
| if isinstance(value, dict) and 'message' in value: | |
| # NOTE: We use langchain's load to convert the JSON string back into a BaseMessage object. | |
| value['message'] = load(value['message']) | |
| return value | |
| class MessageHistory(BaseModel): | |
| """History of messages with metadata""" | |
| messages: list[ManagedMessage] = Field(default_factory=list) | |
| current_tokens: int = 0 | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| def add_message(self, message: BaseMessage, metadata: MessageMetadata, position: int | None = None) -> None: | |
| """Add message with metadata to history""" | |
| if position is None: | |
| self.messages.append(ManagedMessage(message=message, metadata=metadata)) | |
| else: | |
| self.messages.insert(position, ManagedMessage(message=message, metadata=metadata)) | |
| self.current_tokens += metadata.tokens | |
| def add_model_output(self, output) -> None: | |
| """Add model output as AI message""" | |
| tool_calls = [ | |
| { | |
| 'name': 'AgentOutput', | |
| 'args': output.model_dump(mode='json', exclude_unset=True), | |
| 'id': '1', | |
| 'type': 'tool_call', | |
| } | |
| ] | |
| msg = AIMessage( | |
| content='', | |
| tool_calls=tool_calls, | |
| ) | |
| self.add_message(msg, MessageMetadata(tokens=100)) # Estimate tokens for tool calls | |
| # Empty tool response | |
| tool_message = ToolMessage(content='', tool_call_id='1') | |
| self.add_message(tool_message, MessageMetadata(tokens=10)) # Estimate tokens for empty response | |
| def get_messages(self) -> list[BaseMessage]: | |
| """Get all messages""" | |
| return [m.message for m in self.messages] | |
| def get_total_tokens(self) -> int: | |
| """Get total tokens in history""" | |
| return self.current_tokens | |
| def remove_oldest_message(self) -> None: | |
| """Remove oldest non-system message""" | |
| for i, msg in enumerate(self.messages): | |
| if not isinstance(msg.message, SystemMessage): | |
| self.current_tokens -= msg.metadata.tokens | |
| self.messages.pop(i) | |
| break | |
| def remove_last_state_message(self) -> None: | |
| """Remove last state message from history""" | |
| if len(self.messages) > 2 and isinstance(self.messages[-1].message, HumanMessage): | |
| self.current_tokens -= self.messages[-1].metadata.tokens | |
| self.messages.pop() | |
| class MessageManagerState(BaseModel): | |
| """Holds the state for MessageManager""" | |
| history: MessageHistory = Field(default_factory=MessageHistory) | |
| tool_id: int = 1 | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| class AgentSettings(BaseModel): | |
| """Options for the agent""" | |
| max_failures: int = 3 | |
| retry_delay: int = 10 | |
| save_history: bool = True | |
| history_path: Optional[str] = None | |
| max_actions_per_step: int = 10 | |
| validate_output: bool = False | |
| message_context: Optional[str] = None | |
| class PolicyMetadata(BaseModel): | |
| """Metadata for a single step including timing information""" | |
| start_time: float | |
| end_time: float | |
| number: int | |
| input_tokens: int | |
| def duration_seconds(self) -> float: | |
| """Calculate step duration in seconds""" | |
| return self.end_time - self.start_time | |
| class AgentBrain(BaseModel): | |
| """Current state of the agent""" | |
| evaluation_previous_goal: str | |
| memory: str | |
| next_goal: str | |
| class AgentHistory(BaseModel): | |
| """History item for agent actions""" | |
| model_output: Optional[BaseModel] = None | |
| result: List[ActionResult] | |
| metadata: Optional[PolicyMetadata] = None | |
| content: Optional[str] = None | |
| base64_img: Optional[str] = None | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| def model_dump(self, **kwargs) -> Dict[str, Any]: | |
| """Custom serialization handling""" | |
| return { | |
| 'model_output': self.model_output.model_dump() if self.model_output else None, | |
| 'result': [r.model_dump(exclude_none=True) for r in self.result], | |
| 'metadata': self.metadata.model_dump() if self.metadata else None, | |
| 'content': self.xml_content, | |
| 'base64_img': self.base64_img | |
| } | |
| class AgentHistoryList(BaseModel): | |
| """List of agent history items""" | |
| history: List[AgentHistory] | |
| def total_duration_seconds(self) -> float: | |
| """Get total duration of all steps in seconds""" | |
| total = 0.0 | |
| for h in self.history: | |
| if h.metadata: | |
| total += h.metadata.duration_seconds | |
| return total | |
| def save_to_file(self, filepath: str | Path) -> None: | |
| """Save history to JSON file with proper serialization""" | |
| try: | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| data = self.model_dump() | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=2) | |
| except Exception as e: | |
| raise e | |
| def model_dump(self, **kwargs) -> Dict[str, Any]: | |
| """Custom serialization that properly uses AgentHistory's model_dump""" | |
| return { | |
| 'history': [h.model_dump(**kwargs) for h in self.history], | |
| } | |
| def load_from_file(cls, filepath: str | Path) -> 'AgentHistoryList': | |
| """Load history from JSON file""" | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| return cls.model_validate(data) | |
| class AgentError: | |
| """Container for agent error handling""" | |
| VALIDATION_ERROR = 'Invalid model output format. Please follow the correct schema.' | |
| RATE_LIMIT_ERROR = 'Rate limit reached. Waiting before retry.' | |
| NO_VALID_ACTION = 'No valid action found' | |
| def format_error(error: Exception, include_trace: bool = False) -> str: | |
| """Format error message based on error type and optionally include trace""" | |
| if isinstance(error, RateLimitError): | |
| return AgentError.RATE_LIMIT_ERROR | |
| if include_trace: | |
| return f'{str(error)}\nStacktrace:\n{traceback.format_exc()}' | |
| return f'{str(error)}' | |
| class AgentState(BaseModel): | |
| """Holds all state information for an Agent""" | |
| agent_id: str = Field(default_factory=lambda: str(uuid.uuid4())) | |
| n_steps: int = 1 | |
| consecutive_failures: int = 0 | |
| last_result: Optional[List['ActionResult']] = None | |
| history: AgentHistoryList = Field(default_factory=lambda: AgentHistoryList(history=[])) | |
| last_plan: Optional[str] = None | |
| paused: bool = False | |
| stopped: bool = False | |
| message_manager_state: MessageManagerState = Field(default_factory=MessageManagerState) | |
| class AgentStepInfo: | |
| number: int | |
| max_steps: int | |
| def is_last_step(self) -> bool: | |
| """Check if this is the last step""" | |
| return self.number >= self.max_steps - 1 | |
| class Trajectory: | |
| """Stores the agent's history, including all observations, info, and AgentResults.""" | |
| history: List[tuple[Observation, Dict[str, Any], AgentResult]] = field(default_factory=list) | |
| def add_step(self, observation: Observation, info: Dict[str, Any], agent_result: AgentResult): | |
| """Add a step to the history""" | |
| self.history.append((observation, info, agent_result)) | |
| def get_history(self) -> List[tuple[Observation, Dict[str, Any], AgentResult]]: | |
| """Retrieve the complete history""" | |
| return self.history | |