| | import copy |
| | import warnings |
| | from collections import OrderedDict, UserDict, UserList, abc |
| | from functools import wraps |
| | from itertools import chain, repeat |
| | from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union |
| |
|
| | from lagent.agents.aggregator import DefaultAggregator |
| | from lagent.hooks import Hook, RemovableHandle |
| | from lagent.llms import BaseLLM |
| | from lagent.memory import Memory, MemoryManager |
| | from lagent.prompts.parsers import StrParser |
| | from lagent.prompts.prompt_template import PromptTemplate |
| | from lagent.schema import AgentMessage |
| | from lagent.utils import create_object |
| |
|
| |
|
| | class Agent: |
| | """Agent is the basic unit of the system. It is responsible for |
| | communicating with the LLM, managing the memory, and handling the |
| | message aggregation and parsing. It can also be extended with hooks |
| | |
| | Args: |
| | llm (Union[BaseLLM, Dict]): The language model used by the agent. |
| | template (Union[PromptTemplate, str]): The template used to format the |
| | messages. |
| | memory (Dict): The memory used by the agent. |
| | output_format (Dict): The output format used by the agent. |
| | aggregator (Dict): The aggregator used by the agent. |
| | name (Optional[str]): The name of the agent. |
| | description (Optional[str]): The description of the agent. |
| | hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent. |
| | |
| | Returns: |
| | AgentMessage: The response message. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | llm: Union[BaseLLM, Dict] = None, |
| | template: Union[PromptTemplate, str, dict, List[dict]] = None, |
| | memory: Dict = dict(type=Memory), |
| | output_format: Optional[Dict] = None, |
| | aggregator: Dict = dict(type=DefaultAggregator), |
| | name: Optional[str] = None, |
| | description: Optional[str] = None, |
| | hooks: Optional[Union[List[Dict], Dict]] = None, |
| | ): |
| | self.name = name or self.__class__.__name__ |
| | self.llm: BaseLLM = create_object(llm) |
| | self.memory: MemoryManager = MemoryManager(memory) if memory else None |
| | self.output_format: StrParser = create_object(output_format) |
| | self.template = template |
| | self.description = description |
| | self.aggregator: DefaultAggregator = create_object(aggregator) |
| | self._hooks: Dict[int, Hook] = OrderedDict() |
| | if hooks: |
| | for hook in hooks: |
| | hook = create_object(hook) |
| | self.register_hook(hook) |
| |
|
| | def update_memory(self, message, session_id=0): |
| | if self.memory: |
| | self.memory.add(message, session_id=session_id) |
| |
|
| | def __call__( |
| | self, |
| | *message: Union[str, AgentMessage, List[AgentMessage]], |
| | session_id=0, |
| | **kwargs, |
| | ) -> AgentMessage: |
| | |
| | message = [ |
| | AgentMessage(sender='user', content=m) |
| | if isinstance(m, str) else copy.deepcopy(m) for m in message |
| | ] |
| | for hook in self._hooks.values(): |
| | result = hook.before_agent(self, message, session_id) |
| | if result: |
| | message = result |
| | self.update_memory(message, session_id=session_id) |
| | response_message = self.forward( |
| | *message, session_id=session_id, **kwargs) |
| | if not isinstance(response_message, AgentMessage): |
| | response_message = AgentMessage( |
| | sender=self.name, |
| | content=response_message, |
| | ) |
| | self.update_memory(response_message, session_id=session_id) |
| | response_message = copy.deepcopy(response_message) |
| | for hook in self._hooks.values(): |
| | result = hook.after_agent(self, response_message, session_id) |
| | if result: |
| | response_message = result |
| | return response_message |
| |
|
| | def forward(self, |
| | *message: AgentMessage, |
| | session_id=0, |
| | **kwargs) -> Union[AgentMessage, str]: |
| | formatted_messages = self.aggregator.aggregate( |
| | self.memory.get(session_id), |
| | self.name, |
| | self.output_format, |
| | self.template, |
| | ) |
| | llm_response = self.llm.chat(formatted_messages, **kwargs) |
| | if self.output_format: |
| | formatted_messages = self.output_format.parse_response( |
| | llm_response) |
| | return AgentMessage( |
| | sender=self.name, |
| | content=llm_response, |
| | formatted=formatted_messages, |
| | ) |
| | return llm_response |
| |
|
| | def __setattr__(self, __name: str, __value: Any) -> None: |
| | if isinstance(__value, Agent): |
| | _agents = getattr(self, '_agents', OrderedDict()) |
| | _agents[__name] = __value |
| | super().__setattr__('_agents', _agents) |
| | super().__setattr__(__name, __value) |
| |
|
| | def state_dict(self, session_id=0): |
| | state_dict, stack = {}, [('', self)] |
| | while stack: |
| | prefix, node = stack.pop() |
| | key = prefix + 'memory' |
| | if node.memory is not None: |
| | if session_id not in node.memory.memory_map: |
| | warnings.warn(f'No session id {session_id} in {key}') |
| | memory = node.memory.get(session_id) |
| | state_dict[key] = memory and memory.save() or [] |
| | if hasattr(node, '_agents'): |
| | for name, value in reversed(node._agents.items()): |
| | stack.append((prefix + name + '.', value)) |
| | return state_dict |
| |
|
| | def load_state_dict(self, state_dict: Dict, session_id=0): |
| | _state_dict = self.state_dict() |
| | missing_keys = set(_state_dict) - set(state_dict) |
| | if missing_keys: |
| | raise KeyError(f'Missing keys: {missing_keys}') |
| | extra_keys = set(state_dict) - set(_state_dict) |
| | if extra_keys: |
| | warnings.warn(f'Mismatch keys which are not used: {extra_keys}') |
| | for key in _state_dict: |
| | obj = self |
| | for attr in key.split('.')[:-1]: |
| | if isinstance(obj, AgentList): |
| | assert attr.isdigit() |
| | obj = obj[int(attr)] |
| | elif isinstance(obj, AgentDict): |
| | obj = obj[attr] |
| | else: |
| | obj = getattr(obj, attr) |
| | if obj.memory is not None: |
| | if session_id not in obj.memory.memory_map: |
| | obj.memory.create_instance(session_id) |
| | obj.memory.memory_map[session_id].load(state_dict[key] or []) |
| |
|
| | def register_hook(self, hook: Callable): |
| | handle = RemovableHandle(self._hooks) |
| | self._hooks[handle.id] = hook |
| | return handle |
| |
|
| | def reset(self, |
| | session_id=0, |
| | keypath: Optional[str] = None, |
| | recursive: bool = False): |
| | assert not (keypath and |
| | recursive), 'keypath and recursive can\'t be used together' |
| | if keypath: |
| | keys, agent = keypath.split('.'), self |
| | for key in keys: |
| | agents = getattr(agent, '_agents', {}) |
| | if key not in agents: |
| | raise KeyError(f'No sub-agent named {key} in {agent}') |
| | agent = agents[key] |
| | agent.reset(session_id, recursive=False) |
| | else: |
| | if self.memory: |
| | self.memory.reset(session_id=session_id) |
| | if recursive: |
| | for agent in getattr(self, '_agents', {}).values(): |
| | agent.reset(session_id, recursive=True) |
| |
|
| | def __repr__(self): |
| |
|
| | def _rcsv_repr(agent, n_indent=1): |
| | res = agent.__class__.__name__ + (f"(name='{agent.name}')" |
| | if agent.name else '') |
| | modules = [ |
| | f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" |
| | for name, agent in getattr(agent, '_agents', {}).items() |
| | ] |
| | if modules: |
| | res += '(\n' + '\n'.join( |
| | modules) + f'\n{(n_indent - 1) * " "})' |
| | elif not res.endswith(')'): |
| | res += '()' |
| | return res |
| |
|
| | return _rcsv_repr(self) |
| |
|
| |
|
| | class AsyncAgent(Agent): |
| |
|
| | async def __call__(self, |
| | *message: AgentMessage | List[AgentMessage], |
| | session_id=0, |
| | **kwargs) -> AgentMessage: |
| | message = [ |
| | AgentMessage(sender='user', content=m) |
| | if isinstance(m, str) else copy.deepcopy(m) for m in message |
| | ] |
| | for hook in self._hooks.values(): |
| | result = hook.before_agent(self, message, session_id) |
| | if result: |
| | message = result |
| | self.update_memory(message, session_id=session_id) |
| | response_message = await self.forward( |
| | *message, session_id=session_id, **kwargs) |
| | if not isinstance(response_message, AgentMessage): |
| | response_message = AgentMessage( |
| | sender=self.name, |
| | content=response_message, |
| | ) |
| | self.update_memory(response_message, session_id=session_id) |
| | response_message = copy.deepcopy(response_message) |
| | for hook in self._hooks.values(): |
| | result = hook.after_agent(self, response_message, session_id) |
| | if result: |
| | response_message = result |
| | return response_message |
| |
|
| | async def forward(self, |
| | *message: AgentMessage, |
| | session_id=0, |
| | **kwargs) -> Union[AgentMessage, str]: |
| | formatted_messages = self.aggregator.aggregate( |
| | self.memory.get(session_id), |
| | self.name, |
| | self.output_format, |
| | self.template, |
| | ) |
| | llm_response = await self.llm.chat(formatted_messages, session_id, |
| | **kwargs) |
| | if self.output_format: |
| | formatted_messages = self.output_format.parse_response( |
| | llm_response) |
| | return AgentMessage( |
| | sender=self.name, |
| | content=llm_response, |
| | formatted=formatted_messages, |
| | ) |
| | return llm_response |
| |
|
| |
|
| | class Sequential(Agent): |
| | """Sequential is an agent container that forwards messages to each agent |
| | in the order they are added.""" |
| |
|
| | def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): |
| | super().__init__(**kwargs) |
| | self._agents = OrderedDict() |
| | if not agents: |
| | raise ValueError('At least one agent should be provided') |
| | if isinstance(agents[0], |
| | Iterable) and not isinstance(agents[0], Agent): |
| | if not agents[0]: |
| | raise ValueError('At least one agent should be provided') |
| | agents = agents[0] |
| | for key, agent in enumerate(agents): |
| | if isinstance(agents, Mapping): |
| | key, agent = agent, agents[agent] |
| | elif isinstance(agent, tuple): |
| | key, agent = agent |
| | self.add_agent(key, agent) |
| |
|
| | def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): |
| | assert isinstance( |
| | agent, (Agent, AsyncAgent |
| | )), f'{type(agent)} is not an Agent or AsyncAgent subclass' |
| | self._agents[str(name)] = agent |
| |
|
| | def forward(self, |
| | *message: AgentMessage, |
| | session_id=0, |
| | exit_at: Optional[int] = None, |
| | **kwargs) -> AgentMessage: |
| | assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' |
| | if exit_at is None: |
| | exit_at = len(self) - 1 |
| | iterator = chain.from_iterable(repeat(self._agents.values())) |
| | for _ in range(exit_at + 1): |
| | agent = next(iterator) |
| | if isinstance(message, AgentMessage): |
| | message = (message, ) |
| | message = agent(*message, session_id=session_id, **kwargs) |
| | return message |
| |
|
| | def __getitem__(self, key): |
| | if isinstance(key, int) and key < 0: |
| | assert key >= -len(self), 'index out of range' |
| | key = len(self) + key |
| | return self._agents[str(key)] |
| |
|
| | def __len__(self): |
| | return len(self._agents) |
| |
|
| |
|
| | class AsyncSequential(Sequential, AsyncAgent): |
| |
|
| | async def forward(self, |
| | *message: AgentMessage, |
| | session_id=0, |
| | exit_at: Optional[int] = None, |
| | **kwargs) -> AgentMessage: |
| | assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' |
| | if exit_at is None: |
| | exit_at = len(self) - 1 |
| | iterator = chain.from_iterable(repeat(self._agents.values())) |
| | for _ in range(exit_at + 1): |
| | agent = next(iterator) |
| | if isinstance(message, AgentMessage): |
| | message = (message, ) |
| | message = await agent(*message, session_id=session_id, **kwargs) |
| | return message |
| |
|
| |
|
| | class AgentContainerMixin: |
| |
|
| | def __init_subclass__(cls): |
| | super().__init_subclass__() |
| |
|
| | def wrap_api(func): |
| |
|
| | @wraps(func) |
| | def wrapped_func(self, *args, **kwargs): |
| | data = self.data.copy() if hasattr(self, 'data') else None |
| |
|
| | def _backup(d): |
| | if d is None: |
| | self.data.clear() |
| | else: |
| | self.data = d |
| |
|
| | ret = func(self, *args, **kwargs) |
| | agents = OrderedDict() |
| | for k, item in (self.data.items() if isinstance( |
| | self.data, abc.Mapping) else enumerate(self.data)): |
| | if isinstance(self.data, |
| | abc.Mapping) and not isinstance(k, str): |
| | _backup(data) |
| | raise KeyError( |
| | f'agent name should be a string, got {type(k)}') |
| | if isinstance(k, str) and '.' in k: |
| | _backup(data) |
| | raise KeyError( |
| | f'agent name can\'t contain ".", got {k}') |
| | if not isinstance(item, (Agent, AsyncAgent)): |
| | _backup(data) |
| | raise TypeError( |
| | f'{type(item)} is not an Agent or AsyncAgent subclass' |
| | ) |
| | agents[str(k)] = item |
| | self._agents = agents |
| | return ret |
| |
|
| | return wrapped_func |
| |
|
| | for method in [ |
| | 'append', 'sort', 'reverse', 'pop', 'clear', 'update', |
| | 'insert', 'extend', 'remove', '__init__', '__setitem__', |
| | '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', |
| | '__imul__', '__rmul__' |
| | ]: |
| | if hasattr(cls, method): |
| | setattr(cls, method, wrap_api(getattr(cls, method))) |
| |
|
| |
|
| | class AgentList(Agent, UserList, AgentContainerMixin): |
| |
|
| | def __init__(self, |
| | agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): |
| | Agent.__init__(self, memory=None) |
| | UserList.__init__(self, agents) |
| | self.name = None |
| |
|
| |
|
| | class AgentDict(Agent, UserDict, AgentContainerMixin): |
| |
|
| | def __init__(self, |
| | agents: Optional[Mapping[str, Union[Agent, |
| | AsyncAgent]]] = None): |
| | Agent.__init__(self, memory=None) |
| | UserDict.__init__(self, agents) |
| | self.name = None |
| |
|