Spaces:
Running
Running
import copy | |
import warnings | |
from collections import OrderedDict, UserDict, UserList, abc | |
from functools import wraps | |
from itertools import chain, repeat | |
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union | |
from lagent.agents.aggregator import DefaultAggregator | |
from lagent.hooks import Hook, RemovableHandle | |
from lagent.llms import BaseLLM | |
from lagent.memory import Memory, MemoryManager | |
from lagent.prompts.parsers import StrParser | |
from lagent.prompts.prompt_template import PromptTemplate | |
from lagent.schema import AgentMessage | |
from lagent.utils import create_object | |
class Agent: | |
"""Agent is the basic unit of the system. It is responsible for | |
communicating with the LLM, managing the memory, and handling the | |
message aggregation and parsing. It can also be extended with hooks | |
Args: | |
llm (Union[BaseLLM, Dict]): The language model used by the agent. | |
template (Union[PromptTemplate, str]): The template used to format the | |
messages. | |
memory (Dict): The memory used by the agent. | |
output_format (Dict): The output format used by the agent. | |
aggregator (Dict): The aggregator used by the agent. | |
name (Optional[str]): The name of the agent. | |
description (Optional[str]): The description of the agent. | |
hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent. | |
Returns: | |
AgentMessage: The response message. | |
""" | |
def __init__( | |
self, | |
llm: Union[BaseLLM, Dict] = None, | |
template: Union[PromptTemplate, str, dict, List[dict]] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Optional[Dict] = None, | |
aggregator: Dict = dict(type=DefaultAggregator), | |
name: Optional[str] = None, | |
description: Optional[str] = None, | |
hooks: Optional[Union[List[Dict], Dict]] = None, | |
): | |
self.name = name or self.__class__.__name__ | |
self.llm: BaseLLM = create_object(llm) | |
self.memory: MemoryManager = MemoryManager(memory) if memory else None | |
self.output_format: StrParser = create_object(output_format) | |
self.template = template | |
self.description = description | |
self.aggregator: DefaultAggregator = create_object(aggregator) | |
self._hooks: Dict[int, Hook] = OrderedDict() | |
if hooks: | |
for hook in hooks: | |
hook = create_object(hook) | |
self.register_hook(hook) | |
def update_memory(self, message, session_id=0): | |
if self.memory: | |
self.memory.add(message, session_id=session_id) | |
def __call__( | |
self, | |
*message: Union[str, AgentMessage, List[AgentMessage]], | |
session_id=0, | |
**kwargs, | |
) -> AgentMessage: | |
# message.receiver = self.name | |
message = [ | |
AgentMessage(sender='user', content=m) | |
if isinstance(m, str) else copy.deepcopy(m) for m in message | |
] | |
for hook in self._hooks.values(): | |
result = hook.before_agent(self, message, session_id) | |
if result: | |
message = result | |
self.update_memory(message, session_id=session_id) | |
response_message = self.forward( | |
*message, session_id=session_id, **kwargs) | |
if not isinstance(response_message, AgentMessage): | |
response_message = AgentMessage( | |
sender=self.name, | |
content=response_message, | |
) | |
self.update_memory(response_message, session_id=session_id) | |
response_message = copy.deepcopy(response_message) | |
for hook in self._hooks.values(): | |
result = hook.after_agent(self, response_message, session_id) | |
if result: | |
response_message = result | |
return response_message | |
def forward(self, | |
*message: AgentMessage, | |
session_id=0, | |
**kwargs) -> Union[AgentMessage, str]: | |
formatted_messages = self.aggregator.aggregate( | |
self.memory.get(session_id), | |
self.name, | |
self.output_format, | |
self.template, | |
) | |
llm_response = self.llm.chat(formatted_messages, **kwargs) | |
if self.output_format: | |
formatted_messages = self.output_format.parse_response( | |
llm_response) | |
return AgentMessage( | |
sender=self.name, | |
content=llm_response, | |
formatted=formatted_messages, | |
) | |
return llm_response | |
def __setattr__(self, __name: str, __value: Any) -> None: | |
if isinstance(__value, Agent): | |
_agents = getattr(self, '_agents', OrderedDict()) | |
_agents[__name] = __value | |
super().__setattr__('_agents', _agents) | |
super().__setattr__(__name, __value) | |
def state_dict(self, session_id=0): | |
state_dict, stack = {}, [('', self)] | |
while stack: | |
prefix, node = stack.pop() | |
key = prefix + 'memory' | |
if node.memory is not None: | |
if session_id not in node.memory.memory_map: | |
warnings.warn(f'No session id {session_id} in {key}') | |
memory = node.memory.get(session_id) | |
state_dict[key] = memory and memory.save() or [] | |
if hasattr(node, '_agents'): | |
for name, value in reversed(node._agents.items()): | |
stack.append((prefix + name + '.', value)) | |
return state_dict | |
def load_state_dict(self, state_dict: Dict, session_id=0): | |
_state_dict = self.state_dict() | |
missing_keys = set(_state_dict) - set(state_dict) | |
if missing_keys: | |
raise KeyError(f'Missing keys: {missing_keys}') | |
extra_keys = set(state_dict) - set(_state_dict) | |
if extra_keys: | |
warnings.warn(f'Mismatch keys which are not used: {extra_keys}') | |
for key in _state_dict: | |
obj = self | |
for attr in key.split('.')[:-1]: | |
if isinstance(obj, AgentList): | |
assert attr.isdigit() | |
obj = obj[int(attr)] | |
elif isinstance(obj, AgentDict): | |
obj = obj[attr] | |
else: | |
obj = getattr(obj, attr) | |
if obj.memory is not None: | |
if session_id not in obj.memory.memory_map: | |
obj.memory.create_instance(session_id) | |
obj.memory.memory_map[session_id].load(state_dict[key] or []) | |
def register_hook(self, hook: Callable): | |
handle = RemovableHandle(self._hooks) | |
self._hooks[handle.id] = hook | |
return handle | |
def reset(self, | |
session_id=0, | |
keypath: Optional[str] = None, | |
recursive: bool = False): | |
assert not (keypath and | |
recursive), 'keypath and recursive can\'t be used together' | |
if keypath: | |
keys, agent = keypath.split('.'), self | |
for key in keys: | |
agents = getattr(agent, '_agents', {}) | |
if key not in agents: | |
raise KeyError(f'No sub-agent named {key} in {agent}') | |
agent = agents[key] | |
agent.reset(session_id, recursive=False) | |
else: | |
if self.memory: | |
self.memory.reset(session_id=session_id) | |
if recursive: | |
for agent in getattr(self, '_agents', {}).values(): | |
agent.reset(session_id, recursive=True) | |
def __repr__(self): | |
def _rcsv_repr(agent, n_indent=1): | |
res = agent.__class__.__name__ + (f"(name='{agent.name}')" | |
if agent.name else '') | |
modules = [ | |
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}" | |
for name, agent in getattr(agent, '_agents', {}).items() | |
] | |
if modules: | |
res += '(\n' + '\n'.join( | |
modules) + f'\n{(n_indent - 1) * " "})' | |
elif not res.endswith(')'): | |
res += '()' | |
return res | |
return _rcsv_repr(self) | |
class AsyncAgent(Agent): | |
async def __call__(self, | |
*message: AgentMessage | List[AgentMessage], | |
session_id=0, | |
**kwargs) -> AgentMessage: | |
message = [ | |
AgentMessage(sender='user', content=m) | |
if isinstance(m, str) else copy.deepcopy(m) for m in message | |
] | |
for hook in self._hooks.values(): | |
result = hook.before_agent(self, message, session_id) | |
if result: | |
message = result | |
self.update_memory(message, session_id=session_id) | |
response_message = await self.forward( | |
*message, session_id=session_id, **kwargs) | |
if not isinstance(response_message, AgentMessage): | |
response_message = AgentMessage( | |
sender=self.name, | |
content=response_message, | |
) | |
self.update_memory(response_message, session_id=session_id) | |
response_message = copy.deepcopy(response_message) | |
for hook in self._hooks.values(): | |
result = hook.after_agent(self, response_message, session_id) | |
if result: | |
response_message = result | |
return response_message | |
async def forward(self, | |
*message: AgentMessage, | |
session_id=0, | |
**kwargs) -> Union[AgentMessage, str]: | |
formatted_messages = self.aggregator.aggregate( | |
self.memory.get(session_id), | |
self.name, | |
self.output_format, | |
self.template, | |
) | |
llm_response = await self.llm.chat(formatted_messages, session_id, | |
**kwargs) | |
if self.output_format: | |
formatted_messages = self.output_format.parse_response( | |
llm_response) | |
return AgentMessage( | |
sender=self.name, | |
content=llm_response, | |
formatted=formatted_messages, | |
) | |
return llm_response | |
class Sequential(Agent): | |
"""Sequential is an agent container that forwards messages to each agent | |
in the order they are added.""" | |
def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs): | |
super().__init__(**kwargs) | |
self._agents = OrderedDict() | |
if not agents: | |
raise ValueError('At least one agent should be provided') | |
if isinstance(agents[0], | |
Iterable) and not isinstance(agents[0], Agent): | |
if not agents[0]: | |
raise ValueError('At least one agent should be provided') | |
agents = agents[0] | |
for key, agent in enumerate(agents): | |
if isinstance(agents, Mapping): | |
key, agent = agent, agents[agent] | |
elif isinstance(agent, tuple): | |
key, agent = agent | |
self.add_agent(key, agent) | |
def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]): | |
assert isinstance( | |
agent, (Agent, AsyncAgent | |
)), f'{type(agent)} is not an Agent or AsyncAgent subclass' | |
self._agents[str(name)] = agent | |
def forward(self, | |
*message: AgentMessage, | |
session_id=0, | |
exit_at: Optional[int] = None, | |
**kwargs) -> AgentMessage: | |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' | |
if exit_at is None: | |
exit_at = len(self) - 1 | |
iterator = chain.from_iterable(repeat(self._agents.values())) | |
for _ in range(exit_at + 1): | |
agent = next(iterator) | |
if isinstance(message, AgentMessage): | |
message = (message, ) | |
message = agent(*message, session_id=session_id, **kwargs) | |
return message | |
def __getitem__(self, key): | |
if isinstance(key, int) and key < 0: | |
assert key >= -len(self), 'index out of range' | |
key = len(self) + key | |
return self._agents[str(key)] | |
def __len__(self): | |
return len(self._agents) | |
class AsyncSequential(Sequential, AsyncAgent): | |
async def forward(self, | |
*message: AgentMessage, | |
session_id=0, | |
exit_at: Optional[int] = None, | |
**kwargs) -> AgentMessage: | |
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0' | |
if exit_at is None: | |
exit_at = len(self) - 1 | |
iterator = chain.from_iterable(repeat(self._agents.values())) | |
for _ in range(exit_at + 1): | |
agent = next(iterator) | |
if isinstance(message, AgentMessage): | |
message = (message, ) | |
message = await agent(*message, session_id=session_id, **kwargs) | |
return message | |
class AgentContainerMixin: | |
def __init_subclass__(cls): | |
super().__init_subclass__() | |
def wrap_api(func): | |
def wrapped_func(self, *args, **kwargs): | |
data = self.data.copy() if hasattr(self, 'data') else None | |
def _backup(d): | |
if d is None: | |
self.data.clear() | |
else: | |
self.data = d | |
ret = func(self, *args, **kwargs) | |
agents = OrderedDict() | |
for k, item in (self.data.items() if isinstance( | |
self.data, abc.Mapping) else enumerate(self.data)): | |
if isinstance(self.data, | |
abc.Mapping) and not isinstance(k, str): | |
_backup(data) | |
raise KeyError( | |
f'agent name should be a string, got {type(k)}') | |
if isinstance(k, str) and '.' in k: | |
_backup(data) | |
raise KeyError( | |
f'agent name can\'t contain ".", got {k}') | |
if not isinstance(item, (Agent, AsyncAgent)): | |
_backup(data) | |
raise TypeError( | |
f'{type(item)} is not an Agent or AsyncAgent subclass' | |
) | |
agents[str(k)] = item | |
self._agents = agents | |
return ret | |
return wrapped_func | |
for method in [ | |
'append', 'sort', 'reverse', 'pop', 'clear', 'update', | |
'insert', 'extend', 'remove', '__init__', '__setitem__', | |
'__delitem__', '__add__', '__iadd__', '__radd__', '__mul__', | |
'__imul__', '__rmul__' | |
]: | |
if hasattr(cls, method): | |
setattr(cls, method, wrap_api(getattr(cls, method))) | |
class AgentList(Agent, UserList, AgentContainerMixin): | |
def __init__(self, | |
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None): | |
Agent.__init__(self, memory=None) | |
UserList.__init__(self, agents) | |
self.name = None | |
class AgentDict(Agent, UserDict, AgentContainerMixin): | |
def __init__(self, | |
agents: Optional[Mapping[str, Union[Agent, | |
AsyncAgent]]] = None): | |
Agent.__init__(self, memory=None) | |
UserDict.__init__(self, agents) | |
self.name = None | |