ff_li
目录调整
f67d239
import json
from typing import Dict, Iterator, List, Optional
from agent.actions.base import Action
from agent.tools import call_plugin
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:
{tool_descs}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!
Question: {query}"""
def _build_react_instruction(query: str, functions: List[Dict]):
tool_descs = []
tool_names = []
for info in functions:
tool_descs.append(
TOOL_DESC.format(
name_for_model=info['name_for_model'],
name_for_human=info['name_for_human'],
description_for_model=info['description_for_model'],
parameters=json.dumps(info['parameters'], ensure_ascii=False),
))
tool_names.append(info['name_for_model'])
tool_descs = '\n\n'.join(tool_descs)
tool_names = ','.join(tool_names)
prompt = PROMPT_REACT.format(tool_descs=tool_descs,
tool_names=tool_names,
query=query)
return prompt
def _parse_last_action(text):
plugin_name, plugin_args = '', ''
i = text.rfind('\nAction:')
j = text.rfind('\nAction Input:')
k = text.rfind('\nObservation:')
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is ommited by the LLM,
# because the output text may have discarded the stop word.
text = text.rstrip() + '\nObservation:' # Add it back.
k = text.rfind('\nObservation:')
plugin_name = text[i + len('\nAction:'):j].strip()
plugin_args = text[j + len('\nAction Input:'):k].strip()
text = text[:k] # Discard '\nObservation:'.
return plugin_name, plugin_args, text
# TODO: When to put an parameter (such as history) in __init__()? When to put it in run()?
class ReAct(Action):
def _run(self,
user_request,
functions: List[Dict] = None,
history: Optional[List[Dict]] = None,
lang: str = 'en') -> Iterator[str]:
functions = functions or []
prompt = _build_react_instruction(user_request, functions)
messages = []
if history:
assert history[-1][
'role'] != 'user', 'The history should not include the latest user query.'
messages.extend(history)
messages.append({'role': 'user', 'content': prompt})
max_turn = 5
while True and max_turn > 0:
max_turn -= 1
output = self.llm.chat(
messages=messages,
stream=False, # TODO:
stop=['Observation:', 'Observation:\n'],
)
action, action_input, output = _parse_last_action(output)
if messages[-1]['content'].endswith('\nThought:'):
if not output.startswith(' '):
output = ' ' + output
else:
if not output.startswith('\n'):
output = '\n' + output
yield output
if action:
observation = call_plugin(action, action_input)
observation = f'\nObservation: {observation}\nThought:'
yield observation
messages[-1]['content'] += output + observation
else:
break