Spaces:
Runtime error
Runtime error
import json | |
from typing import Any, List, Optional, Sequence, Tuple | |
from langchain.agents.agent import Agent | |
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.schema import AgentAction, BaseLanguageModel | |
from langchain.tools import BaseTool | |
FINAL_ANSWER_ACTION = "Final Answer:" | |
class ChatAgent(Agent): | |
def observation_prefix(self) -> str: | |
"""Prefix to append the observation with.""" | |
return "Observation: " | |
def llm_prefix(self) -> str: | |
"""Prefix to append the llm call with.""" | |
return "Thought:" | |
def _construct_scratchpad( | |
self, intermediate_steps: List[Tuple[AgentAction, str]] | |
) -> str: | |
agent_scratchpad = super()._construct_scratchpad(intermediate_steps) | |
if not isinstance(agent_scratchpad, str): | |
raise ValueError("agent_scratchpad should be of type string.") | |
if agent_scratchpad: | |
return ( | |
f"This was your previous work " | |
f"(but I haven't seen any of it! I only see what " | |
f"you return as final answer):\n{agent_scratchpad}" | |
) | |
else: | |
return agent_scratchpad | |
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: | |
if FINAL_ANSWER_ACTION in text: | |
return "Final Answer", text.split(FINAL_ANSWER_ACTION)[-1].strip() | |
try: | |
_, action, _ = text.split("```") | |
response = json.loads(action.strip()) | |
return response["action"], response["action_input"] | |
except Exception: | |
raise ValueError(f"Could not parse LLM output: {text}") | |
def _stop(self) -> List[str]: | |
return ["Observation:"] | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
input_variables: Optional[List[str]] = None, | |
) -> BasePromptTemplate: | |
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |
tool_names = ", ".join([tool.name for tool in tools]) | |
format_instructions = format_instructions.format(tool_names=tool_names) | |
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) | |
messages = [ | |
SystemMessagePromptTemplate.from_template(template), | |
HumanMessagePromptTemplate.from_template("{input}\n\n{agent_scratchpad}"), | |
] | |
if input_variables is None: | |
input_variables = ["input", "agent_scratchpad"] | |
return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |
def from_llm_and_tools( | |
cls, | |
llm: BaseLanguageModel, | |
tools: Sequence[BaseTool], | |
callback_manager: Optional[BaseCallbackManager] = None, | |
prefix: str = PREFIX, | |
suffix: str = SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
input_variables: Optional[List[str]] = None, | |
**kwargs: Any, | |
) -> Agent: | |
"""Construct an agent from an LLM and tools.""" | |
cls._validate_tools(tools) | |
prompt = cls.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
format_instructions=format_instructions, | |
input_variables=input_variables, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
callback_manager=callback_manager, | |
) | |
tool_names = [tool.name for tool in tools] | |
return cls(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) | |
def _agent_type(self) -> str: | |
raise ValueError | |