import json from typing import Callable, Dict, List, Union from pydantic import BaseModel, Field from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction from lagent.agents.agent import Agent, AsyncAgent from lagent.agents.aggregator import DefaultAggregator from lagent.hooks import ActionPreprocessor from lagent.llms import BaseLLM from lagent.memory import Memory from lagent.prompts.parsers.json_parser import JSONParser from lagent.prompts.prompt_template import PromptTemplate from lagent.schema import AgentMessage from lagent.utils import create_object select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: {action_info} {output_format} 开始!""" output_format_template = """如果使用工具请遵循以下格式回复: {function_format} 如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 {finish_format}""" class ReAct(Agent): def __init__(self, llm: Union[BaseLLM, Dict], actions: Union[BaseAction, List[BaseAction]], template: Union[PromptTemplate, str] = None, memory: Dict = dict(type=Memory), output_format: Dict = dict(type=JSONParser), aggregator: Dict = dict(type=DefaultAggregator), hooks: List = [dict(type=ActionPreprocessor)], finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content or 'conclusion' in m.formatted, max_turn: int = 5, **kwargs): self.max_turn = max_turn self.finish_condition = finish_condition actions = dict( type=ActionExecutor, actions=actions, hooks=hooks, ) self.actions: ActionExecutor = create_object(actions) select_agent = dict( type=Agent, llm=llm, template=template.format( action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) self.select_agent = create_object(select_agent) super().__init__(**kwargs) def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: for _ in range(self.max_turn): message = self.select_agent(message) if self.finish_condition(message): return message message = self.actions(message) return message class AsyncReAct(AsyncAgent): def __init__(self, llm: Union[BaseLLM, Dict], actions: Union[BaseAction, List[BaseAction]], template: Union[PromptTemplate, str] = None, memory: Dict = dict(type=Memory), output_format: Dict = dict(type=JSONParser), aggregator: Dict = dict(type=DefaultAggregator), hooks: List = [dict(type=ActionPreprocessor)], finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content or 'conclusion' in m.formatted, max_turn: int = 5, **kwargs): self.max_turn = max_turn self.finish_condition = finish_condition actions = dict( type=AsyncActionExecutor, actions=actions, hooks=hooks, ) self.actions: AsyncActionExecutor = create_object(actions) select_agent = dict( type=AsyncAgent, llm=llm, template=template.format( action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction()), output_format=output_format, memory=memory, aggregator=aggregator, hooks=hooks, ) self.select_agent = create_object(select_agent) super().__init__(**kwargs) async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage: for _ in range(self.max_turn): message = await self.select_agent(message) if self.finish_condition(message): return message message = await self.actions(message) return message if __name__ == '__main__': from lagent.llms import GPTAPI class ActionCall(BaseModel): name: str = Field(description='调用的函数名称') parameters: Dict = Field(description='调用函数的参数') class ActionFormat(BaseModel): thought_process: str = Field( description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') class FinishFormat(BaseModel): thought_process: str = Field( description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。') conclusion: str = Field(description='总结当前的搜索结果,回答问题。') prompt_template = PromptTemplate(select_action_template) output_format = JSONParser( output_format_template, function_format=ActionFormat, finish_format=FinishFormat) llm = dict( type=GPTAPI, model_type='gpt-4o-2024-05-13', key=None, max_new_tokens=4096, proxies=dict(), retry=1000) agent = ReAct( llm=llm, template=prompt_template, output_format=output_format, aggregator=dict(type='DefaultAggregator'), actions=[dict(type='PythonInterpreter')], ) response = agent( AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) print(response) response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) print(response)