Spaces:
Runtime error
Runtime error
"""An agent designed to hold a conversation in addition to using tools.""" | |
from __future__ import annotations | |
import json | |
from typing import Any, List, Optional, Sequence, Tuple | |
from langchain.agents.agent import Agent | |
from langchain.agents.conversational_chat.prompt import ( | |
FORMAT_INSTRUCTIONS, | |
PREFIX, | |
SUFFIX, | |
TEMPLATE_TOOL_RESPONSE, | |
) | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.chains import LLMChain | |
from langchain.prompts.base import BasePromptTemplate | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
HumanMessagePromptTemplate, | |
MessagesPlaceholder, | |
SystemMessagePromptTemplate, | |
) | |
from langchain.schema import ( | |
AgentAction, | |
AIMessage, | |
BaseLanguageModel, | |
BaseMessage, | |
BaseOutputParser, | |
HumanMessage, | |
) | |
from langchain.tools.base import BaseTool | |
class AgentOutputParser(BaseOutputParser): | |
def get_format_instructions(self) -> str: | |
return FORMAT_INSTRUCTIONS | |
def parse(self, text: str) -> Any: | |
cleaned_output = text.strip() | |
if "```json" in cleaned_output: | |
_, cleaned_output = cleaned_output.split("```json") | |
if "```" in cleaned_output: | |
cleaned_output, _ = cleaned_output.split("```") | |
if cleaned_output.startswith("```json"): | |
cleaned_output = cleaned_output[len("```json") :] | |
if cleaned_output.startswith("```"): | |
cleaned_output = cleaned_output[len("```") :] | |
if cleaned_output.endswith("```"): | |
cleaned_output = cleaned_output[: -len("```")] | |
cleaned_output = cleaned_output.strip() | |
response = json.loads(cleaned_output) | |
return {"action": response["action"], "action_input": response["action_input"]} | |
class ConversationalChatAgent(Agent): | |
"""An agent designed to hold a conversation in addition to using tools.""" | |
output_parser: BaseOutputParser | |
def _agent_type(self) -> str: | |
raise NotImplementedError | |
def observation_prefix(self) -> str: | |
"""Prefix to append the observation with.""" | |
return "Observation: " | |
def llm_prefix(self) -> str: | |
"""Prefix to append the llm call with.""" | |
return "Thought:" | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
system_message: str = PREFIX, | |
human_message: str = SUFFIX, | |
input_variables: Optional[List[str]] = None, | |
output_parser: Optional[BaseOutputParser] = None, | |
) -> BasePromptTemplate: | |
tool_strings = "\n".join( | |
[f"> {tool.name}: {tool.description}" for tool in tools] | |
) | |
tool_names = ", ".join([tool.name for tool in tools]) | |
_output_parser = output_parser or AgentOutputParser() | |
format_instructions = human_message.format( | |
format_instructions=_output_parser.get_format_instructions() | |
) | |
final_prompt = format_instructions.format( | |
tool_names=tool_names, tools=tool_strings | |
) | |
if input_variables is None: | |
input_variables = ["input", "chat_history", "agent_scratchpad"] | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_message), | |
MessagesPlaceholder(variable_name="chat_history"), | |
HumanMessagePromptTemplate.from_template(final_prompt), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
] | |
return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |
def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]: | |
try: | |
response = self.output_parser.parse(llm_output) | |
return response["action"], response["action_input"] | |
except Exception: | |
raise ValueError(f"Could not parse LLM output: {llm_output}") | |
def _construct_scratchpad( | |
self, intermediate_steps: List[Tuple[AgentAction, str]] | |
) -> List[BaseMessage]: | |
"""Construct the scratchpad that lets the agent continue its thought process.""" | |
thoughts: List[BaseMessage] = [] | |
for action, observation in intermediate_steps: | |
thoughts.append(AIMessage(content=action.log)) | |
human_message = HumanMessage( | |
content=TEMPLATE_TOOL_RESPONSE.format(observation=observation) | |
) | |
thoughts.append(human_message) | |
return thoughts | |
def from_llm_and_tools( | |
cls, | |
llm: BaseLanguageModel, | |
tools: Sequence[BaseTool], | |
callback_manager: Optional[BaseCallbackManager] = None, | |
system_message: str = PREFIX, | |
human_message: str = SUFFIX, | |
input_variables: Optional[List[str]] = None, | |
output_parser: Optional[BaseOutputParser] = None, | |
**kwargs: Any, | |
) -> Agent: | |
"""Construct an agent from an LLM and tools.""" | |
cls._validate_tools(tools) | |
_output_parser = output_parser or AgentOutputParser() | |
prompt = cls.create_prompt( | |
tools, | |
system_message=system_message, | |
human_message=human_message, | |
input_variables=input_variables, | |
output_parser=_output_parser, | |
) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt, | |
callback_manager=callback_manager, | |
) | |
tool_names = [tool.name for tool in tools] | |
return cls( | |
llm_chain=llm_chain, | |
allowed_tools=tool_names, | |
output_parser=_output_parser, | |
**kwargs, | |
) | |