AgentFlows / ReAct.py
3Represents's picture
Add ReAct
951d8ce
import copy
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import hydra
from pydantic import root_validator
from langchain import LLMChain, PromptTemplate
from langchain.agents import AgentExecutor, BaseMultiActionAgent, ZeroShotAgent
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
AgentAction,
AgentFinish,
OutputParserException,
)
from flows.base_flows import Flow, CompositeFlow, GenericLCTool
from flows.messages import OutputMessage, UpdateMessage_Generic
from flows.utils.caching_utils import flow_run_cache
class GenericZeroShotAgent(ZeroShotAgent):
@classmethod
def create_prompt(
cls,
tools: Dict[str, Flow],
prefix: str = PREFIX,
suffix: str = SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
suffix: String to put after the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
# tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
# tool_names = ", ".join([tool.name for tool in tools])
tool_strings = "\n".join([f"{tool_name}: {tool.flow_config['description']}" for tool_name, tool in tools.items()])
tool_names = ", ".join(tools.keys())
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
class GenericAgentExecutor(AgentExecutor):
tools: Dict[str, Flow]
@root_validator()
def validate_tools(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
agent = values["agent"]
tools = values["tools"]
allowed_tools = agent.get_allowed_tools()
if allowed_tools is not None:
if set(allowed_tools) != set(tools.keys()):
raise ValueError(
f"Allowed tools ({allowed_tools}) different than "
f"provided tools ({tools.keys()})"
)
return values
@root_validator()
def validate_return_direct_tool(cls, values: Dict) -> Dict:
"""Validate that tools are compatible with agent."""
agent = values["agent"]
tools = values["tools"]
if isinstance(agent, BaseMultiActionAgent):
for tool in tools:
if tool.flow_config["return_direct"]:
raise ValueError(
"Tools that have `return_direct=True` are not allowed "
"in multi-action agents"
)
return values
def _get_tool_return(
self, next_step_output: Tuple[AgentAction, str]
) -> Optional[AgentFinish]:
"""Check if the tool is a returning tool."""
agent_action, observation = next_step_output
# name_to_tool_map = {tool.name: tool for tool in self.tools}
# Invalid tools won't be in the map, so we return False.
if agent_action.tool in self.tools:
if self.tools[agent_action.tool].flow_config["return_direct"]:
return AgentFinish(
{self.agent.return_values[0]: observation},
"",
)
return None
class ReActFlow(CompositeFlow):
EXCEPTION_FLOW_CONFIG = {
"_target_": "flows.base_flows.GenericLCTool.instantiate_from_config",
"config": {
"name": "_Exception",
"description": "Exception tool",
"tool_type": "exception",
"input_keys": ["query"],
"output_keys": ["raw_response"],
"verbose": False,
"clear_flow_namespace_on_run_end": False,
"input_data_transformations": [],
"output_data_transformations": [],
"keep_raw_response": True
}
}
INVALID_FLOW_CONFIG = {
"_target_": "flows.base_flows.GenericLCTool.instantiate_from_config",
"config": {
"name": "invalid_tool",
"description": "Called when tool name is invalid.",
"tool_type": "invalid",
"input_keys": ["tool_name"],
"output_keys": ["raw_response"],
"verbose": False,
"clear_flow_namespace_on_run_end": False,
"input_data_transformations": [],
"output_data_transformations": [],
"keep_raw_response": True
}
}
SUPPORTS_CACHING: bool = True
api_keys: Dict[str, str]
backend: GenericAgentExecutor
react_prompt_template: PromptTemplate
exception_flow: GenericLCTool
invalid_flow: GenericLCTool
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_keys = None
self.backend = None
self.react_prompt_template = GenericZeroShotAgent.create_prompt(
tools=self.subflows,
**self.flow_config.get("prompt_config", {})
)
self._set_up_necessary_subflows()
def set_up_flow_state(self):
super().set_up_flow_state()
self.flow_state["intermediate_steps"]: List[Tuple[AgentAction, str]] = []
def _set_up_necessary_subflows(self):
self.exception_flow = hydra.utils.instantiate(
self.EXCEPTION_FLOW_CONFIG, _convert_="partial", _recursive_=False
)
self.invalid_flow = hydra.utils.instantiate(
self.INVALID_FLOW_CONFIG, _convert_="partial", _recursive_=False
)
def _get_prompt_message(self, input_data: Dict[str, Any]) -> str:
data = copy.deepcopy(input_data)
data["agent_scratchpad"] = "{agent_scratchpad}" # dummy value for agent scratchpad
return self.react_prompt_template.format(**data)
@staticmethod
def get_raw_response(output: OutputMessage) -> str:
key = output.data["output_keys"][0]
return output.data["output_data"]["raw_response"][key]
def _take_next_step(
self,
# name_to_tool_map: Dict[str, BaseTool],
# color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
# run_manager: Optional[CallbackManagerForChainRun] = None,
# input_data: Dict[str, Any],
private_keys: Optional[List[str]] = [],
keys_to_ignore_for_hash: Optional[List[str]] = []
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
try:
# Call the LLM to see what to do.
output = self.backend.agent.plan(
intermediate_steps,
# callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.backend.handle_parsing_errors, bool):
raise_error = not self.backend.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise e
text = str(e)
if isinstance(self.backend.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.backend.handle_parsing_errors, str):
observation = self.backend.handle_parsing_errors
elif callable(self.backend.handle_parsing_errors):
observation = self.backend.handle_parsing_errors(e)
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text)
# if run_manager:
# run_manager.on_agent_action(output, color="green")
# tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
# observation = ExceptionTool().run(
# output.tool_input,
# verbose=self.verbose,
# color=None,
# callbacks=run_manager.get_child() if run_manager else None,
# **tool_run_kwargs,
# )
self._state_update_dict({"query": output.tool_input})
tool_output = self._call_flow_from_state(
self.exception_flow,
private_keys=private_keys,
keys_to_ignore_for_hash=keys_to_ignore_for_hash,
search_class_namespace_for_inputs=False
)
observation = self.get_raw_response(tool_output)
return [(output, observation)]
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
actions: List[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
result = []
for agent_action in actions:
# if run_manager:
# run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in self.subflows:
tool = self.subflows[agent_action.tool]
if isinstance(agent_action.tool_input, dict):
self._state_update_dict(agent_action.tool_input)
else:
self._state_update_dict({tool.flow_config["input_keys"][0]:agent_action.tool_input})
tool_output = self._call_flow_from_state(
tool,
private_keys=private_keys,
keys_to_ignore_for_hash=keys_to_ignore_for_hash,
search_class_namespace_for_inputs=False
)
observation = self.get_raw_response(tool_output)
# return_direct = tool.return_direct
# color = color_mapping[agent_action.tool]
# tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
# if return_direct:
# tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
# observation = tool.run(
# agent_action.tool_input,
# verbose=self.verbose,
# color=color,
# callbacks=run_manager.get_child() if run_manager else None,
# **tool_run_kwargs,
# )
else:
# tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs()
# observation = InvalidTool().run(
# agent_action.tool,
# verbose=self.verbose,
# color=None,
# callbacks=run_manager.get_child() if run_manager else None,
# **tool_run_kwargs,
# )
self._state_update_dict({"tool_name": agent_action.tool})
tool_output = self._call_flow_from_state(
self.invalid_flow,
private_keys=private_keys,
keys_to_ignore_for_hash=keys_to_ignore_for_hash,
search_class_namespace_for_inputs=False
)
observation = self.get_raw_response(tool_output)
result.append((agent_action, observation))
return result
def _run(
self,
input_data: Dict[str, Any],
private_keys: Optional[List[str]] = [],
keys_to_ignore_for_hash: Optional[List[str]] = []
) -> str:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
# name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
# color_mapping = get_color_mapping(
# [tool.name for tool in self.tools], excluded_colors=["green", "red"]
# )
self.flow_state["intermediate_steps"] = []
intermediate_steps = self.flow_state["intermediate_steps"]
# Let's start tracking the number of iterations and time elapsed
iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self.backend._should_continue(iterations, time_elapsed):
# next_step_output = self._take_next_step(
# name_to_tool_map,
# color_mapping,
# inputs,
# intermediate_steps,
# run_manager=run_manager,
# )
next_step_output = self._take_next_step(
input_data,
intermediate_steps,
private_keys,
keys_to_ignore_for_hash
)
if isinstance(next_step_output, AgentFinish):
# TODO: f"{self.backend.agent.llm_prefix} {next_step_output.log}"
return next_step_output.return_values["output"]
intermediate_steps.extend(next_step_output)
for act, obs in next_step_output:
pass # TODO
# f"{self.backend.agent.llm_prefix} {act.log}"
# f"{self.backend.agent.observation_prefix}{obs}"
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self.backend._get_tool_return(next_step_action)
if tool_return is not None:
# same as the observation
return tool_return.return_values["output"]
iterations += 1
time_elapsed = time.time() - start_time
output = self.backend.agent.return_stopped_response(
self.backend.early_stopping_method, intermediate_steps, **input_data
)
return output.return_values["output"]
@flow_run_cache()
def run(
self,
input_data: Dict[str, Any],
private_keys: Optional[List[str]] = [],
keys_to_ignore_for_hash: Optional[List[str]] = []
) -> Dict[str, Any]:
self.api_keys = input_data["api_keys"]
del input_data["api_keys"]
llm = ChatOpenAI(
model_name=self.flow_config["model_name"],
openai_api_key=self.api_keys["openai"],
**self.flow_config["generation_parameters"],
)
llm_chain = LLMChain(llm=llm, prompt=self.react_prompt_template)
agent = GenericZeroShotAgent(llm_chain=llm_chain, allowed_tools=list(self.subflows.keys()))
self.backend = GenericAgentExecutor.from_agent_and_tools(
agent=agent,
tools=self.subflows,
max_iterations=self.flow_config.get("max_iterations", 15),
max_execution_time=self.flow_config.get("max_execution_time")
)
data = {k: input_data[k] for k in self.get_input_keys(input_data)}
# TODO
# prompt = UpdateMessage_Generic(
# created_by=self.flow_config["name"],
# updated_flow=self.flow_config["name"],
# content=self._get_prompt_message(data)
# )
# self._log_message(prompt)
output = self._run(data, private_keys, keys_to_ignore_for_hash)
return {input_data["output_keys"][0]: output}