jianuo's picture
first
09321b6
raw
history blame
No virus
16.5 kB
import importlib
from typing import Dict, List, Optional, Union
from .agent_types import AgentType
from .llm import LLM
from .output_parser import OutputParser, get_output_parser
from .output_wrapper import display
from .prompt import PromptGenerator, get_prompt_generator
from .retrieve import KnowledgeRetrieval, ToolRetrieval
from .tools import TOOL_INFO_LIST
class AgentExecutor:
def custom_run_init(self,
task: str,
remote: bool = False,
print_info: bool = False,
append_files: list = []) -> List[Dict]:
tool_list = self.retrieve_tools(task)
knowledge_list = self.get_knowledge(task)
self.prompt_generator.init_prompt(
task, tool_list, knowledge_list, append_files=append_files)
function_list = self.prompt_generator.get_function_list(tool_list)
llm_result, exec_result = '', ''
idx = 0
final_res = []
return tool_list, knowledge_list, function_list, llm_result, exec_result, idx, final_res, remote, print_info
def custom_gene_prompt(self, llm_result, exec_result, idx):
idx += 1
# generate prompt and call llm
llm_artifacts = self.prompt_generator.generate(
llm_result, exec_result)
return llm_artifacts, idx
def custom_parse_llm(self, llm_artifacts, llm_result, idx, final_res, remote, print_info):
if print_info:
print(f'|LLM inputs in round {idx}: {llm_artifacts}')
# parse and get tool name and arguments
try:
action, action_args = self.output_parser.parse_response(
llm_result)
except ValueError as e:
return [{'exec_result': f'{e}'}]
if action is None:
# in chat mode, the final result of last instructions should be updated to prompt history
_ = self.prompt_generator.generate(llm_result, '')
# for summarize
# display(llm_result, {}, idx, self.agent_type)
return [{'end_res': final_res}]
if action in self.available_tool_list:
action_args = self.parse_action_args(action_args)
tool = self.tool_list[action]
# TODO @wenmeng.zwm remove this hack logic for image generation
if action == 'image_gen' and self.seed:
action_args['seed'] = self.seed
try:
exec_result = tool(**action_args, remote=remote)
if print_info:
print(f'|exec_result: {exec_result}')
# parse exec result and store result to agent state
final_res.append(exec_result)
self.parse_exec_result(exec_result)
except Exception as e:
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
return [{'exec_result': exec_result}]
else:
exec_result = f"Unknown action: '{action}'. "
return [{'exec_result': exec_result}]
# display result
# display(llm_result, exec_result, idx, self.agent_type)
return [{'no_stop': {'llm_result': llm_result, 'exec_result': exec_result, 'idx': idx, 'final_res': final_res}}]
def __init__(self,
llm: LLM,
tool_cfg: Optional[Dict] = {},
agent_type: AgentType = AgentType.DEFAULT,
additional_tool_list: Optional[Dict] = {},
prompt_generator: Optional[PromptGenerator] = None,
output_parser: Optional[OutputParser] = None,
tool_retrieval: Optional[Union[bool, ToolRetrieval]] = True,
knowledge_retrieval: Optional[KnowledgeRetrieval] = None):
"""
the core class of ms agent. It is responsible for the interaction between user, llm and tools,
and return the execution result to user.
Args:
llm (LLM): llm model, can be load from local or a remote server.
tool_cfg (Optional[Dict]): cfg of default tools
agent_type (AgentType, optional): agent type. Defaults to AgentType.DEFAULT, decide which type of agent
reasoning type to use
additional_tool_list (Optional[Dict], optional): user-defined additional tool list. Defaults to {}.
prompt_generator (Optional[PromptGenerator], optional): this module is responsible for generating prompt
according to interaction result. Defaults to use MSPromptGenerator.
output_parser (Optional[OutputParser], optional): this module is responsible for parsing output of llm
to executable actions. Defaults to use MsOutputParser.
tool_retrieval (Optional[Union[bool, ToolRetrieval]], optional): Retrieve related tools by input task,
since most of the tools may be useless for LLM in specific task.
If it is bool type and is True, will use default tool_retrieval. Defaults to True.
knowledge_retrieval (Optional[KnowledgeRetrieval], optional): If user want to use extra knowledge,
this component can be used to retrieve related knowledge. Defaults to None.
"""
self.llm = llm
self.agent_type = agent_type
self.llm.set_agent_type(agent_type)
self.prompt_generator = prompt_generator or get_prompt_generator(
agent_type)
self.output_parser = output_parser or get_output_parser(agent_type)
self._init_tools(tool_cfg, additional_tool_list)
if isinstance(tool_retrieval, bool) and tool_retrieval:
tool_retrieval = ToolRetrieval()
self.tool_retrieval = tool_retrieval
if self.tool_retrieval:
self.tool_retrieval.construct(
[str(t) for t in self.tool_list.values()])
self.knowledge_retrieval = knowledge_retrieval
self.reset()
self.seed = None
def _init_tools(self,
tool_cfg: Dict = {},
additional_tool_list: Dict = {}):
"""init tool list of agent. We provide a default tool list, which is initialized by a cfg file.
user can also provide user-defined tools by additional_tool_list.
The key of additional_tool_list is tool name, and the value is corresponding object.
Args:
tool_cfg (Dict): default tool cfg.
additional_tool_list (Dict, optional): user-defined tools. Defaults to {}.
"""
self.tool_list = {}
tool_info_list = {**TOOL_INFO_LIST, **additional_tool_list}
# tools_module = importlib.import_module('modelscope_agent.tools')
from . import tools as tools_module
for tool_name in tool_cfg.keys():
if tool_cfg[tool_name].get('use', False):
assert tool_name in tool_info_list, f'Invalid tool name: {tool_name}, ' \
f'available ones are: {tool_info_list.keys()}'
tool_class_name = tool_info_list[tool_name]
tool_class = getattr(tools_module, tool_class_name)
tool_name = tool_class.name
self.tool_list[tool_name] = tool_class(tool_cfg)
self.tool_list = {**self.tool_list, **additional_tool_list}
# self.available_tool_list = deepcopy(self.tool_list)
self.set_available_tools(self.tool_list.keys())
def set_available_tools(self, available_tool_list):
# TODO @wenmeng.zwm refine tool init
for t in available_tool_list:
if t not in self.tool_list:
raise ValueError(
f'Unsupported tools found:{t}, please check, valid ones: {self.tool_list.keys()}'
)
self.available_tool_list = {
k: self.tool_list[k]
for k in available_tool_list
}
def retrieve_tools(self, query: str) -> List[str]:
"""retrieve tools given query
Args:
query (str): query
"""
if self.tool_retrieval:
retrieve_tools = self.tool_retrieval.retrieve(query)
self.set_available_tools(available_tool_list=retrieve_tools.keys())
return self.available_tool_list.values()
def get_knowledge(self, query: str) -> List[str]:
"""retrieve knowledge given query
Args:
query (str): query
"""
return self.knowledge_retrieval.retrieve(
query) if self.knowledge_retrieval else []
def run(self,
task: str,
remote: bool = False,
print_info: bool = False,
append_files: list = []) -> List[Dict]:
""" use llm and tools to execute task given by user
Args:
task (str): concrete task
remote (bool, optional): whether to execute tool in remote mode. Defaults to False.
print_info (bool, optional): whether to print prompt info. Defaults to False.
Returns:
List[Dict]: execute result. One task may need to interact with llm multiple times,
so a list of dict is returned. Each dict contains the result of one interaction.
"""
# retrieve tools
tool_list = self.retrieve_tools(task)
knowledge_list = self.get_knowledge(task)
self.prompt_generator.init_prompt(
task, tool_list, knowledge_list, append_files=append_files)
function_list = self.prompt_generator.get_function_list(tool_list)
llm_result, exec_result = '', ''
idx = 0
final_res = []
while True:
idx += 1
# generate prompt and call llm
llm_artifacts = self.prompt_generator.generate(
llm_result, exec_result)
try:
llm_result = self.llm.generate(llm_artifacts, function_list)
except RuntimeError as e:
return [{'exec_result': str(e)}]
if print_info:
print(f'|LLM inputs in round {idx}: {llm_artifacts}')
# parse and get tool name and arguments
try:
action, action_args = self.output_parser.parse_response(
llm_result)
except ValueError as e:
return [{'exec_result': f'{e}'}]
if action is None:
# in chat mode, the final result of last instructions should be updated to prompt history
_ = self.prompt_generator.generate(llm_result, '')
# for summarize
display(llm_result, {}, idx, self.agent_type)
return final_res
if action in self.available_tool_list:
action_args = self.parse_action_args(action_args)
tool = self.tool_list[action]
# TODO @wenmeng.zwm remove this hack logic for image generation
if action == 'image_gen' and self.seed:
action_args['seed'] = self.seed
try:
exec_result = tool(**action_args, remote=remote)
if print_info:
print(f'|exec_result: {exec_result}')
# parse exec result and store result to agent state
final_res.append(exec_result)
self.parse_exec_result(exec_result)
except Exception as e:
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
return [{'exec_result': exec_result}]
else:
exec_result = f"Unknown action: '{action}'. "
return [{'exec_result': exec_result}]
# display result
display(llm_result, exec_result, idx, self.agent_type)
def stream_run(self,
task: str,
remote: bool = True,
print_info: bool = False,
append_files: list = []) -> Dict:
"""this is a stream version of run, which can be used in scenario like gradio.
It will yield the result of each interaction, so that the caller can display the result
Args:
task (str): concrete task
remote (bool, optional): whether to execute tool in remote mode. Defaults to True.
print_info (bool, optional): whether to print prompt info. Defaults to False.
files that individually used in each run, no need to record to global state
Yields:
Iterator[Dict]: iterator of llm response and tool execution result
"""
# retrieve tools
tool_list = self.retrieve_tools(task)
knowledge_list = self.get_knowledge(task)
self.prompt_generator.init_prompt(
task,
tool_list,
knowledge_list,
append_files=append_files,
)
function_list = self.prompt_generator.get_function_list(tool_list)
llm_result, exec_result = '', ''
idx = 0
while True:
idx += 1
llm_artifacts = self.prompt_generator.generate(
llm_result, exec_result)
if print_info:
print(f'|LLM inputs in round {idx}:\n{llm_artifacts}')
llm_result = ''
try:
for s in self.llm.stream_generate(llm_artifacts,
function_list):
llm_result += s
yield {'llm_text': s}
except RuntimeError:
s = self.llm.generate(llm_artifacts)
llm_result += s
yield {'llm_text': s}
except Exception as e:
yield {'llm_text': str(e)}
# parse and get tool name and arguments
try:
action, action_args = self.output_parser.parse_response(
llm_result)
except ValueError as e:
yield {'exec_result': f'{e}'}
return
if action is None:
# in chat mode, the final result of last instructions should be updated to prompt history
_ = self.prompt_generator.generate(llm_result, '')
yield {'is_final': True}
return
if action in self.available_tool_list:
# yield observation to as end of action input symbol asap
yield {'llm_text': 'Observation: '}
action_args = self.parse_action_args(action_args)
tool = self.tool_list[action]
# TODO @wenmeng.zwm remove this hack logic for image generation
if action == 'image_gen' and self.seed:
action_args['seed'] = self.seed
try:
exec_result = tool(**action_args, remote=remote)
yield {'exec_result': exec_result}
# parse exec result and update state
self.parse_exec_result(exec_result)
except Exception as e:
exec_result = f'Action call error: {action}: {action_args}. \n Error message: {e}'
yield {'exec_result': exec_result}
self.prompt_generator.reset()
return
else:
exec_result = f"Unknown action: '{action}'. "
yield {'exec_result': exec_result}
self.prompt_generator.reset()
return
def reset(self):
"""
clear history and agent state
"""
self.prompt_generator.reset()
self.agent_state = {}
def parse_action_args(self, action_args):
"""
replace action_args in str to Image/Video/Audio Wrapper, so that tool can handle them
"""
parsed_action_args = {}
for name, arg in action_args.items():
try:
true_arg = self.agent_state.get(arg, arg)
except Exception as e:
print(f'Error when parsing action args: {e}, using fall back')
true_arg = arg
parsed_action_args[name] = true_arg
return parsed_action_args
def parse_exec_result(self, exec_result, *args, **kwargs):
"""
update exec result to agent state.
key is the str representation of the result.
"""
for k, v in exec_result.items():
self.agent_state[str(v)] = v