Spaces:
Runtime error
Runtime error
"""Attempt to implement MRKL systems as described in arxiv.org/pdf/2205.00445.pdf.""" | |
from __future__ import annotations | |
import re | |
from typing import Any, Callable, List, NamedTuple, Optional, Sequence, Tuple | |
from langchain.agents.agent import Agent, AgentExecutor | |
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX | |
from langchain.agents.tools import Tool | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.chains import LLMChain | |
from langchain.llms.base import BaseLLM | |
from langchain.prompts import PromptTemplate | |
from langchain.tools.base import BaseTool | |
FINAL_ANSWER_ACTION = "Final Answer:" | |
class ChainConfig(NamedTuple): | |
"""Configuration for chain to use in MRKL system. | |
Args: | |
action_name: Name of the action. | |
action: Action function to call. | |
action_description: Description of the action. | |
""" | |
action_name: str | |
action: Callable | |
action_description: str | |
def get_action_and_input(llm_output: str) -> Tuple[str, str]: | |
"""Parse out the action and input from the LLM output. | |
Note: if you're specifying a custom prompt for the ZeroShotAgent, | |
you will need to ensure that it meets the following Regex requirements. | |
The string starting with "Action:" and the following string starting | |
with "Action Input:" should be separated by a newline. | |
""" | |
if FINAL_ANSWER_ACTION in llm_output: | |
return "Final Answer", llm_output.split(FINAL_ANSWER_ACTION)[-1].strip() | |
regex = r"Action: (.*?)[\n]*Action Input: (.*)" | |
match = re.search(regex, llm_output, re.DOTALL) | |
if not match: | |
raise ValueError(f"Could not parse LLM output: `{llm_output}`") | |
action = match.group(1).strip() | |
action_input = match.group(2) | |
return action, action_input.strip(" ").strip('"') | |
class ZeroShotAgent(Agent): | |
"""Agent for the MRKL chain.""" | |
def _agent_type(self) -> str: | |
"""Return Identifier of agent type.""" | |
return "zero-shot-react-description" | |
def observation_prefix(self) -> str: | |
"""Prefix to append the observation with.""" | |
return "Observation: " | |
def llm_prefix(self) -> str: | |
"""Prefix to append the llm call with.""" | |
return "Thought:" | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
input_variables: Optional[List[str]] = None, | |
) -> PromptTemplate: | |
"""Create prompt in the style of the zero shot agent. | |
Args: | |
tools: List of tools the agent will have access to, used to format the | |
prompt. | |
prefix: String to put before the list of tools. | |
suffix: String to put after the list of tools. | |
input_variables: List of input variables the final prompt will expect. | |
Returns: | |
A PromptTemplate with the template assembled from the pieces here. | |
""" | |
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |
tool_names = ", ".join([tool.name for tool in tools]) | |
format_instructions = format_instructions.format(tool_names=tool_names) | |
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |
if input_variables is None: | |
input_variables = ["input", "agent_scratchpad"] | |
return PromptTemplate(template=template, input_variables=input_variables) | |
def from_llm_and_tools( | |
cls, | |
llm: BaseLLM, | |
tools: Sequence[BaseTool], | |
callback_manager: Optional[BaseCallbackManager] = None, | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
input_variables: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Agent: | |
"""Construct an agent from an LLM and tools.""" | |
cls._validate_tools(tools) | |
prompt = cls.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
format_instructions=format_instructions, | |
input_variables=input_variables, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
callback_manager=callback_manager, | |
) | |
tool_names = [tool.name for tool in tools] | |
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) | |
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: | |
for tool in tools: | |
if tool.description is None: | |
raise ValueError( | |
f"Got a tool {tool.name} without a description. For this agent, " | |
f"a description must always be provided." | |
) | |
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: | |
return get_action_and_input(text) | |
class MRKLChain(AgentExecutor): | |
"""Chain that implements the MRKL system. | |
Example: | |
.. code-block:: python | |
from langchain import OpenAI, MRKLChain | |
from langchain.chains.mrkl.base import ChainConfig | |
llm = OpenAI(temperature=0) | |
prompt = PromptTemplate(...) | |
chains = [...] | |
mrkl = MRKLChain.from_chains(llm=llm, prompt=prompt) | |
""" | |
def from_chains( | |
cls, llm: BaseLLM, chains: List[ChainConfig], **kwargs: Any | |
) -> AgentExecutor: | |
"""User friendly way to initialize the MRKL chain. | |
This is intended to be an easy way to get up and running with the | |
MRKL chain. | |
Args: | |
llm: The LLM to use as the agent LLM. | |
chains: The chains the MRKL system has access to. | |
**kwargs: parameters to be passed to initialization. | |
Returns: | |
An initialized MRKL chain. | |
Example: | |
.. code-block:: python | |
from langchain import LLMMathChain, OpenAI, SerpAPIWrapper, MRKLChain | |
from langchain.chains.mrkl.base import ChainConfig | |
llm = OpenAI(temperature=0) | |
search = SerpAPIWrapper() | |
llm_math_chain = LLMMathChain(llm=llm) | |
chains = [ | |
ChainConfig( | |
action_name = "Search", | |
action=search.search, | |
action_description="useful for searching" | |
), | |
ChainConfig( | |
action_name="Calculator", | |
action=llm_math_chain.run, | |
action_description="useful for doing math" | |
) | |
] | |
mrkl = MRKLChain.from_chains(llm, chains) | |
""" | |
tools = [ | |
Tool( | |
name=c.action_name, | |
func=c.action, | |
description=c.action_description, | |
) | |
for c in chains | |
] | |
agent = ZeroShotAgent.from_llm_and_tools(llm, tools) | |
return cls(agent=agent, tools=tools, **kwargs) | |