|
|
from enum import Enum |
|
|
from typing import Any, List, Literal, Optional, Union |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
class Role(str, Enum): |
|
|
"""Message role options""" |
|
|
|
|
|
SYSTEM = "system" |
|
|
USER = "user" |
|
|
ASSISTANT = "assistant" |
|
|
TOOL = "tool" |
|
|
|
|
|
|
|
|
ROLE_VALUES = tuple(role.value for role in Role) |
|
|
ROLE_TYPE = Literal[ROLE_VALUES] |
|
|
|
|
|
|
|
|
class ToolChoice(str, Enum): |
|
|
"""Tool choice options""" |
|
|
|
|
|
NONE = "none" |
|
|
AUTO = "auto" |
|
|
REQUIRED = "required" |
|
|
|
|
|
|
|
|
TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) |
|
|
TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] |
|
|
|
|
|
|
|
|
class AgentState(str, Enum): |
|
|
"""Agent execution states""" |
|
|
|
|
|
IDLE = "IDLE" |
|
|
RUNNING = "RUNNING" |
|
|
FINISHED = "FINISHED" |
|
|
ERROR = "ERROR" |
|
|
|
|
|
|
|
|
class Function(BaseModel): |
|
|
name: str |
|
|
arguments: str |
|
|
|
|
|
|
|
|
class ToolCall(BaseModel): |
|
|
"""Represents a tool/function call in a message""" |
|
|
|
|
|
id: str |
|
|
type: str = "function" |
|
|
function: Function |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
"""Represents a chat message in the conversation""" |
|
|
|
|
|
role: ROLE_TYPE = Field(...) |
|
|
content: Optional[str] = Field(default=None) |
|
|
tool_calls: Optional[List[ToolCall]] = Field(default=None) |
|
|
name: Optional[str] = Field(default=None) |
|
|
tool_call_id: Optional[str] = Field(default=None) |
|
|
base64_image: Optional[str] = Field(default=None) |
|
|
|
|
|
def __add__(self, other) -> List["Message"]: |
|
|
"""支持 Message + list 或 Message + Message 的操作""" |
|
|
if isinstance(other, list): |
|
|
return [self] + other |
|
|
elif isinstance(other, Message): |
|
|
return [self, other] |
|
|
else: |
|
|
raise TypeError( |
|
|
f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" |
|
|
) |
|
|
|
|
|
def __radd__(self, other) -> List["Message"]: |
|
|
"""支持 list + Message 的操作""" |
|
|
if isinstance(other, list): |
|
|
return other + [self] |
|
|
else: |
|
|
raise TypeError( |
|
|
f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" |
|
|
) |
|
|
|
|
|
def to_dict(self) -> dict: |
|
|
"""Convert message to dictionary format""" |
|
|
message = {"role": self.role} |
|
|
if self.content is not None: |
|
|
message["content"] = self.content |
|
|
if self.tool_calls is not None: |
|
|
message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] |
|
|
if self.name is not None: |
|
|
message["name"] = self.name |
|
|
if self.tool_call_id is not None: |
|
|
message["tool_call_id"] = self.tool_call_id |
|
|
if self.base64_image is not None: |
|
|
message["base64_image"] = self.base64_image |
|
|
return message |
|
|
|
|
|
@classmethod |
|
|
def user_message( |
|
|
cls, content: str, base64_image: Optional[str] = None |
|
|
) -> "Message": |
|
|
"""Create a user message""" |
|
|
return cls(role=Role.USER, content=content, base64_image=base64_image) |
|
|
|
|
|
@classmethod |
|
|
def system_message(cls, content: str) -> "Message": |
|
|
"""Create a system message""" |
|
|
return cls(role=Role.SYSTEM, content=content) |
|
|
|
|
|
@classmethod |
|
|
def assistant_message( |
|
|
cls, content: Optional[str] = None, base64_image: Optional[str] = None |
|
|
) -> "Message": |
|
|
"""Create an assistant message""" |
|
|
return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) |
|
|
|
|
|
@classmethod |
|
|
def tool_message( |
|
|
cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None |
|
|
) -> "Message": |
|
|
"""Create a tool message""" |
|
|
return cls( |
|
|
role=Role.TOOL, |
|
|
content=content, |
|
|
name=name, |
|
|
tool_call_id=tool_call_id, |
|
|
base64_image=base64_image, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_tool_calls( |
|
|
cls, |
|
|
tool_calls: List[Any], |
|
|
content: Union[str, List[str]] = "", |
|
|
base64_image: Optional[str] = None, |
|
|
**kwargs, |
|
|
): |
|
|
"""Create ToolCallsMessage from raw tool calls. |
|
|
|
|
|
Args: |
|
|
tool_calls: Raw tool calls from LLM |
|
|
content: Optional message content |
|
|
base64_image: Optional base64 encoded image |
|
|
""" |
|
|
formatted_calls = [ |
|
|
{"id": call.id, "function": call.function.model_dump(), "type": "function"} |
|
|
for call in tool_calls |
|
|
] |
|
|
return cls( |
|
|
role=Role.ASSISTANT, |
|
|
content=content, |
|
|
tool_calls=formatted_calls, |
|
|
base64_image=base64_image, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class Memory(BaseModel): |
|
|
messages: List[Message] = Field(default_factory=list) |
|
|
max_messages: int = Field(default=100) |
|
|
|
|
|
def add_message(self, message: Message) -> None: |
|
|
"""Add a message to memory""" |
|
|
self.messages.append(message) |
|
|
|
|
|
if len(self.messages) > self.max_messages: |
|
|
self.messages = self.messages[-self.max_messages :] |
|
|
|
|
|
def add_messages(self, messages: List[Message]) -> None: |
|
|
"""Add multiple messages to memory""" |
|
|
self.messages.extend(messages) |
|
|
|
|
|
if len(self.messages) > self.max_messages: |
|
|
self.messages = self.messages[-self.max_messages :] |
|
|
|
|
|
def clear(self) -> None: |
|
|
"""Clear all messages""" |
|
|
self.messages.clear() |
|
|
|
|
|
def get_recent_messages(self, n: int) -> List[Message]: |
|
|
"""Get n most recent messages""" |
|
|
return self.messages[-n:] |
|
|
|
|
|
def to_dict_list(self) -> List[dict]: |
|
|
"""Convert messages to list of dicts""" |
|
|
return [msg.to_dict() for msg in self.messages] |
|
|
|