Zulelee's picture
Upload 254 files
5e9cd1d verified
raw
history blame
No virus
8.23 kB
"""
This file is a modified version for ChatGLM3-6B the original glm3_agent.py file from the langchain repo.
"""
from __future__ import annotations
import json
import logging
from typing import Any, List, Sequence, Tuple, Optional, Union
from pydantic.schema import model_schema
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents.agent import Agent
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.agents.agent import AgentOutputParser
from langchain.output_parsers import OutputFixingParser
from langchain.pydantic_v1 import Field
from langchain.schema import AgentAction, AgentFinish, OutputParserException, BasePromptTemplate
from langchain.agents.agent import AgentExecutor
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools.base import BaseTool
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
logger = logging.getLogger(__name__)
class StructuredChatOutputParserWithRetries(AgentOutputParser):
"""Output parser with retries for the structured chat agent."""
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
"""The base parser to use."""
output_fixing_parser: Optional[OutputFixingParser] = None
"""The output fixing parser to use."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
special_tokens = ["Action:", "<|observation|>"]
first_index = min([text.find(token) if token in text else len(text) for token in special_tokens])
text = text[:first_index]
if "tool_call" in text:
action_end = text.find("```")
action = text[:action_end].strip()
params_str_start = text.find("(") + 1
params_str_end = text.rfind(")")
params_str = text[params_str_start:params_str_end]
params_pairs = [param.split("=") for param in params_str.split(",") if "=" in param]
params = {pair[0].strip(): pair[1].strip().strip("'\"") for pair in params_pairs}
action_json = {
"action": action,
"action_input": params
}
else:
action_json = {
"action": "Final Answer",
"action_input": text
}
action_str = f"""
Action:
```
{json.dumps(action_json, ensure_ascii=False)}
```"""
try:
if self.output_fixing_parser is not None:
parsed_obj: Union[
AgentAction, AgentFinish
] = self.output_fixing_parser.parse(action_str)
else:
parsed_obj = self.base_parser.parse(action_str)
return parsed_obj
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "structured_chat_ChatGLM3_6b_with_retries"
class StructuredGLM3ChatAgent(Agent):
"""Structured Chat Agent."""
output_parser: AgentOutputParser = Field(
default_factory=StructuredChatOutputParserWithRetries
)
"""Output parser for the agent."""
@property
def observation_prefix(self) -> str:
"""Prefix to append the ChatGLM3-6B observation with."""
return "Observation:"
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought:"
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
@classmethod
def _get_default_output_parser(
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
) -> AgentOutputParser:
return StructuredChatOutputParserWithRetries(llm=llm)
@property
def _stop(self) -> List[str]:
return ["<|observation|>"]
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prompt: str = None,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
) -> BasePromptTemplate:
tools_json = []
tool_names = []
for tool in tools:
tool_schema = model_schema(tool.args_schema) if tool.args_schema else {}
simplified_config_langchain = {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema.get("properties", {})
}
tools_json.append(simplified_config_langchain)
tool_names.append(tool.name)
formatted_tools = "\n".join([
f"{tool['name']}: {tool['description']}, args: {tool['parameters']}"
for tool in tools_json
])
formatted_tools = formatted_tools.replace("'", "\\'").replace("{", "{{").replace("}", "}}")
template = prompt.format(tool_names=tool_names,
tools=formatted_tools,
history="None",
input="{input}",
agent_scratchpad="{agent_scratchpad}")
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
_memory_prompts = memory_prompts or []
messages = [
SystemMessagePromptTemplate.from_template(template),
*_memory_prompts,
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: str = None,
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
input_variables: Optional[List[str]] = None,
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prompt=prompt,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)
@property
def _agent_type(self) -> str:
raise ValueError
def initialize_glm3_agent(
tools: Sequence[BaseTool],
llm: BaseLanguageModel,
prompt: str = None,
memory: Optional[ConversationBufferWindowMemory] = None,
agent_kwargs: Optional[dict] = None,
*,
tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AgentExecutor:
tags_ = list(tags) if tags else []
agent_kwargs = agent_kwargs or {}
agent_obj = StructuredGLM3ChatAgent.from_llm_and_tools(
llm=llm,
tools=tools,
prompt=prompt,
**agent_kwargs
)
return AgentExecutor.from_agent_and_tools(
agent=agent_obj,
tools=tools,
memory=memory,
tags=tags_,
**kwargs,
)