Spaces:
Running
Running
import json | |
from typing import Dict, Iterator, List, Optional | |
from agent.actions.base import Action | |
from agent.tools import call_plugin | |
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" | |
PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools: | |
{tool_descs} | |
Use the following format: | |
Question: the input question you must answer | |
Thought: you should always think about what to do | |
Action: the action to take, should be one of [{tool_names}] | |
Action Input: the input to the action | |
Observation: the result of the action | |
(this Thought/Action/Action Input/Observation can be repeated zero or more times) | |
Thought: I now know the final answer | |
Final Answer: the final answer to the original input question | |
Begin! | |
Question: {query}""" | |
def _build_react_instruction(query: str, functions: List[Dict]): | |
tool_descs = [] | |
tool_names = [] | |
for info in functions: | |
tool_descs.append( | |
TOOL_DESC.format( | |
name_for_model=info['name_for_model'], | |
name_for_human=info['name_for_human'], | |
description_for_model=info['description_for_model'], | |
parameters=json.dumps(info['parameters'], ensure_ascii=False), | |
)) | |
tool_names.append(info['name_for_model']) | |
tool_descs = '\n\n'.join(tool_descs) | |
tool_names = ','.join(tool_names) | |
prompt = PROMPT_REACT.format(tool_descs=tool_descs, | |
tool_names=tool_names, | |
query=query) | |
return prompt | |
def _parse_last_action(text): | |
plugin_name, plugin_args = '', '' | |
i = text.rfind('\nAction:') | |
j = text.rfind('\nAction Input:') | |
k = text.rfind('\nObservation:') | |
if 0 <= i < j: # If the text has `Action` and `Action input`, | |
if k < j: # but does not contain `Observation`, | |
# then it is likely that `Observation` is ommited by the LLM, | |
# because the output text may have discarded the stop word. | |
text = text.rstrip() + '\nObservation:' # Add it back. | |
k = text.rfind('\nObservation:') | |
plugin_name = text[i + len('\nAction:'):j].strip() | |
plugin_args = text[j + len('\nAction Input:'):k].strip() | |
text = text[:k] # Discard '\nObservation:'. | |
return plugin_name, plugin_args, text | |
# TODO: When to put an parameter (such as history) in __init__()? When to put it in run()? | |
class ReAct(Action): | |
def _run(self, | |
user_request, | |
functions: List[Dict] = None, | |
history: Optional[List[Dict]] = None, | |
lang: str = 'en') -> Iterator[str]: | |
functions = functions or [] | |
prompt = _build_react_instruction(user_request, functions) | |
messages = [] | |
if history: | |
assert history[-1][ | |
'role'] != 'user', 'The history should not include the latest user query.' | |
messages.extend(history) | |
messages.append({'role': 'user', 'content': prompt}) | |
max_turn = 5 | |
while True and max_turn > 0: | |
max_turn -= 1 | |
output = self.llm.chat( | |
messages=messages, | |
stream=False, # TODO: | |
stop=['Observation:', 'Observation:\n'], | |
) | |
action, action_input, output = _parse_last_action(output) | |
if messages[-1]['content'].endswith('\nThought:'): | |
if not output.startswith(' '): | |
output = ' ' + output | |
else: | |
if not output.startswith('\n'): | |
output = '\n' + output | |
yield output | |
if action: | |
observation = call_plugin(action, action_input) | |
observation = f'\nObservation: {observation}\nThought:' | |
yield observation | |
messages[-1]['content'] += output + observation | |
else: | |
break | |