Spaces:
Runtime error
Runtime error
| import re | |
| from typing import Any, List, Optional, Sequence, Tuple | |
| from langchain_core.agents import AgentAction | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.prompts import BasePromptTemplate | |
| from langchain_core.prompts.chat import ( | |
| ChatPromptTemplate, | |
| HumanMessagePromptTemplate, | |
| SystemMessagePromptTemplate, | |
| ) | |
| from langchain_core.pydantic_v1 import Field | |
| from langchain.agents.agent import Agent, AgentOutputParser | |
| from langchain.agents.structured_chat.output_parser import ( | |
| StructuredChatOutputParserWithRetries, | |
| ) | |
| from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX | |
| from langchain.callbacks.base import BaseCallbackManager | |
| from langchain.chains.llm import LLMChain | |
| from langchain.tools import BaseTool | |
| HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}" | |
| class StructuredChatAgent(Agent): | |
| """Structured Chat Agent.""" | |
| output_parser: AgentOutputParser = Field( | |
| default_factory=StructuredChatOutputParserWithRetries | |
| ) | |
| """Output parser for the agent.""" | |
| def observation_prefix(self) -> str: | |
| """Prefix to append the observation with.""" | |
| return "Observation: " | |
| def llm_prefix(self) -> str: | |
| """Prefix to append the llm call with.""" | |
| return "Thought:" | |
| def _construct_scratchpad( | |
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |
| ) -> str: | |
| agent_scratchpad = super()._construct_scratchpad(intermediate_steps) | |
| if not isinstance(agent_scratchpad, str): | |
| raise ValueError("agent_scratchpad should be of type string.") | |
| if agent_scratchpad: | |
| return ( | |
| f"This was your previous work " | |
| f"(but I haven't seen any of it! I only see what " | |
| f"you return as final answer):\n{agent_scratchpad}" | |
| ) | |
| else: | |
| return agent_scratchpad | |
| def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: | |
| pass | |
| def _get_default_output_parser( | |
| cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any | |
| ) -> AgentOutputParser: | |
| return StructuredChatOutputParserWithRetries.from_llm(llm=llm) | |
| def _stop(self) -> List[str]: | |
| return ["Observation:"] | |
| def create_prompt( | |
| cls, | |
| tools: Sequence[BaseTool], | |
| prefix: str = PREFIX, | |
| suffix: str = SUFFIX, | |
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |
| format_instructions: str = FORMAT_INSTRUCTIONS, | |
| input_variables: Optional[List[str]] = None, | |
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
| ) -> BasePromptTemplate: | |
| tool_strings = [] | |
| for tool in tools: | |
| args_schema = re.sub("}", "}}", re.sub("{", "{{", str(tool.args))) | |
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |
| formatted_tools = "\n".join(tool_strings) | |
| tool_names = ", ".join([tool.name for tool in tools]) | |
| format_instructions = format_instructions.format(tool_names=tool_names) | |
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |
| if input_variables is None: | |
| input_variables = ["input", "agent_scratchpad"] | |
| _memory_prompts = memory_prompts or [] | |
| messages = [ | |
| SystemMessagePromptTemplate.from_template(template), | |
| *_memory_prompts, | |
| HumanMessagePromptTemplate.from_template(human_message_template), | |
| ] | |
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |
| def from_llm_and_tools( | |
| cls, | |
| llm: BaseLanguageModel, | |
| tools: Sequence[BaseTool], | |
| callback_manager: Optional[BaseCallbackManager] = None, | |
| output_parser: Optional[AgentOutputParser] = None, | |
| prefix: str = PREFIX, | |
| suffix: str = SUFFIX, | |
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |
| format_instructions: str = FORMAT_INSTRUCTIONS, | |
| input_variables: Optional[List[str]] = None, | |
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |
| **kwargs: Any, | |
| ) -> Agent: | |
| """Construct an agent from an LLM and tools.""" | |
| cls._validate_tools(tools) | |
| prompt = cls.create_prompt( | |
| tools, | |
| prefix=prefix, | |
| suffix=suffix, | |
| human_message_template=human_message_template, | |
| format_instructions=format_instructions, | |
| input_variables=input_variables, | |
| memory_prompts=memory_prompts, | |
| ) | |
| llm_chain = LLMChain( | |
| llm=llm, | |
| prompt=prompt, | |
| callback_manager=callback_manager, | |
| ) | |
| tool_names = [tool.name for tool in tools] | |
| _output_parser = output_parser or cls._get_default_output_parser(llm=llm) | |
| return cls( | |
| llm_chain=llm_chain, | |
| allowed_tools=tool_names, | |
| output_parser=_output_parser, | |
| **kwargs, | |
| ) | |
| def _agent_type(self) -> str: | |
| raise ValueError | |