import copy import time from typing import Any, Dict, List, Optional, Tuple, Union import hydra from pydantic import root_validator from langchain import LLMChain, PromptTemplate from langchain.agents import AgentExecutor, BaseMultiActionAgent, ZeroShotAgent from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.chat_models import ChatOpenAI from langchain.schema import ( AgentAction, AgentFinish, OutputParserException, ) from flows.base_flows import Flow, CompositeFlow, GenericLCTool from flows.messages import OutputMessage, UpdateMessage_Generic from flows.utils.caching_utils import flow_run_cache class GenericZeroShotAgent(ZeroShotAgent): @classmethod def create_prompt( cls, tools: Dict[str, Flow], prefix: str = PREFIX, suffix: str = SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. Args: tools: List of tools the agent will have access to, used to format the prompt. prefix: String to put before the list of tools. suffix: String to put after the list of tools. input_variables: List of input variables the final prompt will expect. Returns: A PromptTemplate with the template assembled from the pieces here. """ # tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) # tool_names = ", ".join([tool.name for tool in tools]) tool_strings = "\n".join([f"{tool_name}: {tool.flow_config['description']}" for tool_name, tool in tools.items()]) tool_names = ", ".join(tools.keys()) format_instructions = format_instructions.format(tool_names=tool_names) template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) if input_variables is None: input_variables = ["input", "agent_scratchpad"] return PromptTemplate(template=template, input_variables=input_variables) class GenericAgentExecutor(AgentExecutor): tools: Dict[str, Flow] @root_validator() def validate_tools(cls, values: Dict) -> Dict: """Validate that tools are compatible with agent.""" agent = values["agent"] tools = values["tools"] allowed_tools = agent.get_allowed_tools() if allowed_tools is not None: if set(allowed_tools) != set(tools.keys()): raise ValueError( f"Allowed tools ({allowed_tools}) different than " f"provided tools ({tools.keys()})" ) return values @root_validator() def validate_return_direct_tool(cls, values: Dict) -> Dict: """Validate that tools are compatible with agent.""" agent = values["agent"] tools = values["tools"] if isinstance(agent, BaseMultiActionAgent): for tool in tools: if tool.flow_config["return_direct"]: raise ValueError( "Tools that have `return_direct=True` are not allowed " "in multi-action agents" ) return values def _get_tool_return( self, next_step_output: Tuple[AgentAction, str] ) -> Optional[AgentFinish]: """Check if the tool is a returning tool.""" agent_action, observation = next_step_output # name_to_tool_map = {tool.name: tool for tool in self.tools} # Invalid tools won't be in the map, so we return False. if agent_action.tool in self.tools: if self.tools[agent_action.tool].flow_config["return_direct"]: return AgentFinish( {self.agent.return_values[0]: observation}, "", ) return None class ReActFlow(CompositeFlow): EXCEPTION_FLOW_CONFIG = { "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", "config": { "name": "_Exception", "description": "Exception tool", "tool_type": "exception", "input_keys": ["query"], "output_keys": ["raw_response"], "verbose": False, "clear_flow_namespace_on_run_end": False, "input_data_transformations": [], "output_data_transformations": [], "keep_raw_response": True } } INVALID_FLOW_CONFIG = { "_target_": "flows.base_flows.GenericLCTool.instantiate_from_config", "config": { "name": "invalid_tool", "description": "Called when tool name is invalid.", "tool_type": "invalid", "input_keys": ["tool_name"], "output_keys": ["raw_response"], "verbose": False, "clear_flow_namespace_on_run_end": False, "input_data_transformations": [], "output_data_transformations": [], "keep_raw_response": True } } SUPPORTS_CACHING: bool = True api_keys: Dict[str, str] backend: GenericAgentExecutor react_prompt_template: PromptTemplate exception_flow: GenericLCTool invalid_flow: GenericLCTool def __init__(self, **kwargs): super().__init__(**kwargs) self.api_keys = None self.backend = None self.react_prompt_template = GenericZeroShotAgent.create_prompt( tools=self.subflows, **self.flow_config.get("prompt_config", {}) ) self._set_up_necessary_subflows() def set_up_flow_state(self): super().set_up_flow_state() self.flow_state["intermediate_steps"]: List[Tuple[AgentAction, str]] = [] def _set_up_necessary_subflows(self): self.exception_flow = hydra.utils.instantiate( self.EXCEPTION_FLOW_CONFIG, _convert_="partial", _recursive_=False ) self.invalid_flow = hydra.utils.instantiate( self.INVALID_FLOW_CONFIG, _convert_="partial", _recursive_=False ) def _get_prompt_message(self, input_data: Dict[str, Any]) -> str: data = copy.deepcopy(input_data) data["agent_scratchpad"] = "{agent_scratchpad}" # dummy value for agent scratchpad return self.react_prompt_template.format(**data) @staticmethod def get_raw_response(output: OutputMessage) -> str: key = output.data["output_keys"][0] return output.data["output_data"]["raw_response"][key] def _take_next_step( self, # name_to_tool_map: Dict[str, BaseTool], # color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], # run_manager: Optional[CallbackManagerForChainRun] = None, # input_data: Dict[str, Any], private_keys: Optional[List[str]] = [], keys_to_ignore_for_hash: Optional[List[str]] = [] ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. """ try: # Call the LLM to see what to do. output = self.backend.agent.plan( intermediate_steps, # callbacks=run_manager.get_child() if run_manager else None, **inputs, ) except OutputParserException as e: if isinstance(self.backend.handle_parsing_errors, bool): raise_error = not self.backend.handle_parsing_errors else: raise_error = False if raise_error: raise e text = str(e) if isinstance(self.backend.handle_parsing_errors, bool): if e.send_to_llm: observation = str(e.observation) text = str(e.llm_output) else: observation = "Invalid or incomplete response" elif isinstance(self.backend.handle_parsing_errors, str): observation = self.backend.handle_parsing_errors elif callable(self.backend.handle_parsing_errors): observation = self.backend.handle_parsing_errors(e) else: raise ValueError("Got unexpected type of `handle_parsing_errors`") output = AgentAction("_Exception", observation, text) # if run_manager: # run_manager.on_agent_action(output, color="green") # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs() # observation = ExceptionTool().run( # output.tool_input, # verbose=self.verbose, # color=None, # callbacks=run_manager.get_child() if run_manager else None, # **tool_run_kwargs, # ) self._state_update_dict({"query": output.tool_input}) tool_output = self._call_flow_from_state( self.exception_flow, private_keys=private_keys, keys_to_ignore_for_hash=keys_to_ignore_for_hash, search_class_namespace_for_inputs=False ) observation = self.get_raw_response(tool_output) return [(output, observation)] # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output actions: List[AgentAction] if isinstance(output, AgentAction): actions = [output] else: actions = output result = [] for agent_action in actions: # if run_manager: # run_manager.on_agent_action(agent_action, color="green") # Otherwise we lookup the tool if agent_action.tool in self.subflows: tool = self.subflows[agent_action.tool] if isinstance(agent_action.tool_input, dict): self._state_update_dict(agent_action.tool_input) else: self._state_update_dict({tool.flow_config["input_keys"][0]:agent_action.tool_input}) tool_output = self._call_flow_from_state( tool, private_keys=private_keys, keys_to_ignore_for_hash=keys_to_ignore_for_hash, search_class_namespace_for_inputs=False ) observation = self.get_raw_response(tool_output) # return_direct = tool.return_direct # color = color_mapping[agent_action.tool] # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs() # if return_direct: # tool_run_kwargs["llm_prefix"] = "" # We then call the tool on the tool input to get an observation # observation = tool.run( # agent_action.tool_input, # verbose=self.verbose, # color=color, # callbacks=run_manager.get_child() if run_manager else None, # **tool_run_kwargs, # ) else: # tool_run_kwargs = self.backend.agent.tool_run_logging_kwargs() # observation = InvalidTool().run( # agent_action.tool, # verbose=self.verbose, # color=None, # callbacks=run_manager.get_child() if run_manager else None, # **tool_run_kwargs, # ) self._state_update_dict({"tool_name": agent_action.tool}) tool_output = self._call_flow_from_state( self.invalid_flow, private_keys=private_keys, keys_to_ignore_for_hash=keys_to_ignore_for_hash, search_class_namespace_for_inputs=False ) observation = self.get_raw_response(tool_output) result.append((agent_action, observation)) return result def _run( self, input_data: Dict[str, Any], private_keys: Optional[List[str]] = [], keys_to_ignore_for_hash: Optional[List[str]] = [] ) -> str: """Run text through and get agent response.""" # Construct a mapping of tool name to tool for easy lookup # name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. # color_mapping = get_color_mapping( # [tool.name for tool in self.tools], excluded_colors=["green", "red"] # ) self.flow_state["intermediate_steps"] = [] intermediate_steps = self.flow_state["intermediate_steps"] # Let's start tracking the number of iterations and time elapsed iterations = 0 time_elapsed = 0.0 start_time = time.time() # We now enter the agent loop (until it returns something). while self.backend._should_continue(iterations, time_elapsed): # next_step_output = self._take_next_step( # name_to_tool_map, # color_mapping, # inputs, # intermediate_steps, # run_manager=run_manager, # ) next_step_output = self._take_next_step( input_data, intermediate_steps, private_keys, keys_to_ignore_for_hash ) if isinstance(next_step_output, AgentFinish): # TODO: f"{self.backend.agent.llm_prefix} {next_step_output.log}" return next_step_output.return_values["output"] intermediate_steps.extend(next_step_output) for act, obs in next_step_output: pass # TODO # f"{self.backend.agent.llm_prefix} {act.log}" # f"{self.backend.agent.observation_prefix}{obs}" if len(next_step_output) == 1: next_step_action = next_step_output[0] # See if tool should return directly tool_return = self.backend._get_tool_return(next_step_action) if tool_return is not None: # same as the observation return tool_return.return_values["output"] iterations += 1 time_elapsed = time.time() - start_time output = self.backend.agent.return_stopped_response( self.backend.early_stopping_method, intermediate_steps, **input_data ) return output.return_values["output"] @flow_run_cache() def run( self, input_data: Dict[str, Any], private_keys: Optional[List[str]] = [], keys_to_ignore_for_hash: Optional[List[str]] = [] ) -> Dict[str, Any]: self.api_keys = input_data["api_keys"] del input_data["api_keys"] llm = ChatOpenAI( model_name=self.flow_config["model_name"], openai_api_key=self.api_keys["openai"], **self.flow_config["generation_parameters"], ) llm_chain = LLMChain(llm=llm, prompt=self.react_prompt_template) agent = GenericZeroShotAgent(llm_chain=llm_chain, allowed_tools=list(self.subflows.keys())) self.backend = GenericAgentExecutor.from_agent_and_tools( agent=agent, tools=self.subflows, max_iterations=self.flow_config.get("max_iterations", 15), max_execution_time=self.flow_config.get("max_execution_time") ) data = {k: input_data[k] for k in self.get_input_keys(input_data)} # TODO # prompt = UpdateMessage_Generic( # created_by=self.flow_config["name"], # updated_flow=self.flow_config["name"], # content=self._get_prompt_message(data) # ) # self._log_message(prompt) output = self._run(data, private_keys, keys_to_ignore_for_hash) return {input_data["output_keys"][0]: output}