| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional |
|
|
| from tenacity import retry |
| from tenacity.stop import stop_after_attempt |
| from tenacity.wait import wait_exponential |
|
|
| from camel.agents import BaseAgent |
| from camel.configs import ChatGPTConfig |
| from camel.messages import ChatMessage, MessageType, SystemMessage |
| from camel.model_backend import ModelBackend, ModelFactory |
| from camel.typing import ModelType, RoleType |
| from camel.utils import ( |
| get_model_token_limit, |
| num_tokens_from_messages, |
| openai_api_key_required, |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class ChatAgentResponse: |
| r"""Response of a ChatAgent. |
| |
| Attributes: |
| msgs (List[ChatMessage]): A list of zero, one or several messages. |
| If the list is empty, there is some error in message generation. |
| If the list has one message, this is normal mode. |
| If the list has several messages, this is the critic mode. |
| terminated (bool): A boolean indicating whether the agent decided |
| to terminate the chat session. |
| info (Dict[str, Any]): Extra information about the chat message. |
| """ |
| msgs: List[ChatMessage] |
| terminated: bool |
| info: Dict[str, Any] |
|
|
| @property |
| def msg(self): |
| if self.terminated: |
| raise RuntimeError("error in ChatAgentResponse, info:{}".format(str(self.info))) |
| if len(self.msgs) > 1: |
| raise RuntimeError("Property msg is only available for a single message in msgs") |
| elif len(self.msgs) == 0: |
| if len(self.info) > 0: |
| raise RuntimeError("Empty msgs in ChatAgentResponse, info:{}".format(str(self.info))) |
| else: |
| |
| return None |
| return self.msgs[0] |
|
|
|
|
| class ChatAgent(BaseAgent): |
| r"""Class for managing conversations of CAMEL Chat Agents. |
| |
| Args: |
| system_message (SystemMessage): The system message for the chat agent. |
| model (ModelType, optional): The LLM model to use for generating |
| responses. (default :obj:`ModelType.GPT_3_5_TURBO`) |
| model_config (Any, optional): Configuration options for the LLM model. |
| (default: :obj:`None`) |
| message_window_size (int, optional): The maximum number of previous |
| messages to include in the context window. If `None`, no windowing |
| is performed. (default: :obj:`None`) |
| """ |
|
|
| def __init__( |
| self, |
| system_message: SystemMessage, |
| model: Optional[ModelType] = None, |
| model_config: Optional[Any] = None, |
| message_window_size: Optional[int] = None, |
| ) -> None: |
|
|
| self.system_message: SystemMessage = system_message |
| self.role_name: str = system_message.role_name |
| self.role_type: RoleType = system_message.role_type |
| self.model: ModelType = (model if model is not None else ModelType.GPT_3_5_TURBO) |
| self.model_config: ChatGPTConfig = model_config or ChatGPTConfig() |
| self.model_token_limit: int = get_model_token_limit(self.model) |
| self.message_window_size: Optional[int] = message_window_size |
| self.model_backend: ModelBackend = ModelFactory.create(self.model, self.model_config.__dict__) |
| self.terminated: bool = False |
| self.info: bool = False |
| self.init_messages() |
|
|
| def reset(self) -> List[MessageType]: |
| r"""Resets the :obj:`ChatAgent` to its initial state and returns the |
| stored messages. |
| |
| Returns: |
| List[MessageType]: The stored messages. |
| """ |
| self.terminated = False |
| self.init_messages() |
| return self.stored_messages |
|
|
| def get_info( |
| self, |
| id: Optional[str], |
| usage: Optional[Dict[str, int]], |
| termination_reasons: List[str], |
| num_tokens: int, |
| ) -> Dict[str, Any]: |
| r"""Returns a dictionary containing information about the chat session. |
| |
| Args: |
| id (str, optional): The ID of the chat session. |
| usage (Dict[str, int], optional): Information about the usage of |
| the LLM model. |
| termination_reasons (List[str]): The reasons for the termination of |
| the chat session. |
| num_tokens (int): The number of tokens used in the chat session. |
| |
| Returns: |
| Dict[str, Any]: The chat session information. |
| """ |
| return { |
| "id": id, |
| "usage": usage, |
| "termination_reasons": termination_reasons, |
| "num_tokens": num_tokens, |
| } |
|
|
| def init_messages(self) -> None: |
| r"""Initializes the stored messages list with the initial system |
| message. |
| """ |
| self.stored_messages: List[MessageType] = [self.system_message] |
|
|
| def update_messages(self, message: ChatMessage) -> List[MessageType]: |
| r"""Updates the stored messages list with a new message. |
| |
| Args: |
| message (ChatMessage): The new message to add to the stored |
| messages. |
| |
| Returns: |
| List[ChatMessage]: The updated stored messages. |
| """ |
| self.stored_messages.append(message) |
| return self.stored_messages |
|
|
| @retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5)) |
| @openai_api_key_required |
| def step( |
| self, |
| input_message: ChatMessage, |
| ) -> ChatAgentResponse: |
| r"""Performs a single step in the chat session by generating a response |
| to the input message. |
| |
| Args: |
| input_message (ChatMessage): The input message to the agent. |
| |
| Returns: |
| ChatAgentResponse: A struct |
| containing the output messages, a boolean indicating whether |
| the chat session has terminated, and information about the chat |
| session. |
| """ |
| messages = self.update_messages(input_message) |
| if self.message_window_size is not None and len( |
| messages) > self.message_window_size: |
| messages = [self.system_message |
| ] + messages[-self.message_window_size:] |
| openai_messages = [message.to_openai_message() for message in messages] |
| num_tokens = num_tokens_from_messages(openai_messages, self.model) |
|
|
| |
| |
| |
| |
|
|
| output_messages: Optional[List[ChatMessage]] |
| info: Dict[str, Any] |
|
|
| if num_tokens < self.model_token_limit: |
| response = self.model_backend.run(messages=openai_messages) |
| if not isinstance(response, dict): |
| raise RuntimeError("OpenAI returned unexpected struct") |
| output_messages = [ |
| ChatMessage(role_name=self.role_name, role_type=self.role_type, |
| meta_dict=dict(), **dict(choice["message"])) |
| for choice in response["choices"] |
| ] |
| info = self.get_info( |
| response["id"], |
| response["usage"], |
| [str(choice["finish_reason"]) for choice in response["choices"]], |
| num_tokens, |
| ) |
|
|
| |
| |
| if output_messages[0].content.split("\n")[-1].startswith("<INFO>"): |
| self.info = True |
| else: |
| self.terminated = True |
| output_messages = [] |
|
|
| info = self.get_info( |
| None, |
| None, |
| ["max_tokens_exceeded_by_camel"], |
| num_tokens, |
| ) |
|
|
| return ChatAgentResponse(output_messages, self.terminated, info) |
|
|
| def __repr__(self) -> str: |
| r"""Returns a string representation of the :obj:`ChatAgent`. |
| |
| Returns: |
| str: The string representation of the :obj:`ChatAgent`. |
| """ |
| return f"ChatAgent({self.role_name}, {self.role_type}, {self.model})" |
|
|