"""An agent designed to hold a conversation in addition to using tools.""" from __future__ import annotations import json from typing import Any, List, Optional, Sequence, Tuple from langchain.agents.agent import Agent from langchain.agents.conversational_chat.prompt import ( FORMAT_INSTRUCTIONS, PREFIX, SUFFIX, TEMPLATE_TOOL_RESPONSE, ) from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, ) from langchain.schema import ( AgentAction, AIMessage, BaseLanguageModel, BaseMessage, BaseOutputParser, HumanMessage, ) from langchain.tools.base import BaseTool class AgentOutputParser(BaseOutputParser): def get_format_instructions(self) -> str: return FORMAT_INSTRUCTIONS def parse(self, text: str) -> Any: cleaned_output = text.strip() if "```json" in cleaned_output: _, cleaned_output = cleaned_output.split("```json") if "```" in cleaned_output: cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() response = json.loads(cleaned_output) return {"action": response["action"], "action_input": response["action_input"]} class ConversationalChatAgent(Agent): """An agent designed to hold a conversation in addition to using tools.""" output_parser: BaseOutputParser @property def _agent_type(self) -> str: raise NotImplementedError @property def observation_prefix(self) -> str: """Prefix to append the observation with.""" return "Observation: " @property def llm_prefix(self) -> str: """Prefix to append the llm call with.""" return "Thought:" @classmethod def create_prompt( cls, tools: Sequence[BaseTool], system_message: str = PREFIX, human_message: str = SUFFIX, input_variables: Optional[List[str]] = None, output_parser: Optional[BaseOutputParser] = None, ) -> BasePromptTemplate: tool_strings = "\n".join( [f"> {tool.name}: {tool.description}" for tool in tools] ) tool_names = ", ".join([tool.name for tool in tools]) _output_parser = output_parser or AgentOutputParser() format_instructions = human_message.format( format_instructions=_output_parser.get_format_instructions() ) final_prompt = format_instructions.format( tool_names=tool_names, tools=tool_strings ) if input_variables is None: input_variables = ["input", "chat_history", "agent_scratchpad"] messages = [ SystemMessagePromptTemplate.from_template(system_message), MessagesPlaceholder(variable_name="chat_history"), HumanMessagePromptTemplate.from_template(final_prompt), MessagesPlaceholder(variable_name="agent_scratchpad"), ] return ChatPromptTemplate(input_variables=input_variables, messages=messages) def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]: try: response = self.output_parser.parse(llm_output) return response["action"], response["action_input"] except Exception: raise ValueError(f"Could not parse LLM output: {llm_output}") def _construct_scratchpad( self, intermediate_steps: List[Tuple[AgentAction, str]] ) -> List[BaseMessage]: """Construct the scratchpad that lets the agent continue its thought process.""" thoughts: List[BaseMessage] = [] for action, observation in intermediate_steps: thoughts.append(AIMessage(content=action.log)) human_message = HumanMessage( content=TEMPLATE_TOOL_RESPONSE.format(observation=observation) ) thoughts.append(human_message) return thoughts @classmethod def from_llm_and_tools( cls, llm: BaseLanguageModel, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, system_message: str = PREFIX, human_message: str = SUFFIX, input_variables: Optional[List[str]] = None, output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) _output_parser = output_parser or AgentOutputParser() prompt = cls.create_prompt( tools, system_message=system_message, human_message=human_message, input_variables=input_variables, output_parser=_output_parser, ) llm_chain = LLMChain( llm=llm, prompt=prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] return cls( llm_chain=llm_chain, allowed_tools=tool_names, output_parser=_output_parser, **kwargs, )