jfeng1115's picture
init commit
58d33f0
raw
history blame
5.59 kB
"""An agent designed to hold a conversation in addition to using tools."""
from __future__ import annotations
import json
from typing import Any, List, Optional, Sequence, Tuple
from langchain.agents.agent import Agent
from langchain.agents.conversational_chat.prompt import (
FORMAT_INSTRUCTIONS,
PREFIX,
SUFFIX,
TEMPLATE_TOOL_RESPONSE,
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.schema import (
AgentAction,
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,
)
from langchain.tools.base import BaseTool
class AgentOutputParser(BaseOutputParser):
def get_format_instructions(self) -> str:
return FORMAT_INSTRUCTIONS
def parse(self, text: str) -> Any:
cleaned_output = text.strip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
if "```" in cleaned_output:
cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
response = json.loads(cleaned_output)
return {"action": response["action"], "action_input": response["action_input"]}
class ConversationalChatAgent(Agent):
"""An agent designed to hold a conversation in addition to using tools."""
output_parser: BaseOutputParser
@property
def _agent_type(self) -> str:
raise NotImplementedError
@property
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
return "Observation: "
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought:"
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
system_message: str = PREFIX,
human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None,
output_parser: Optional[BaseOutputParser] = None,
) -> BasePromptTemplate:
tool_strings = "\n".join(
[f"> {tool.name}: {tool.description}" for tool in tools]
)
tool_names = ", ".join([tool.name for tool in tools])
_output_parser = output_parser or AgentOutputParser()
format_instructions = human_message.format(
format_instructions=_output_parser.get_format_instructions()
)
final_prompt = format_instructions.format(
tool_names=tool_names, tools=tool_strings
)
if input_variables is None:
input_variables = ["input", "chat_history", "agent_scratchpad"]
messages = [
SystemMessagePromptTemplate.from_template(system_message),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(final_prompt),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]:
try:
response = self.output_parser.parse(llm_output)
return response["action"], response["action_input"]
except Exception:
raise ValueError(f"Could not parse LLM output: {llm_output}")
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = []
for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage(
content=TEMPLATE_TOOL_RESPONSE.format(observation=observation)
)
thoughts.append(human_message)
return thoughts
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
system_message: str = PREFIX,
human_message: str = SUFFIX,
input_variables: Optional[List[str]] = None,
output_parser: Optional[BaseOutputParser] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
_output_parser = output_parser or AgentOutputParser()
prompt = cls.create_prompt(
tools,
system_message=system_message,
human_message=human_message,
input_variables=input_variables,
output_parser=_output_parser,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)