jfeng1115's picture
init commit
58d33f0
raw
history blame contribute delete
No virus
3.79 kB
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.schema import AgentAction, AgentFinish
from langchain.agents.agent import AgentExecutor
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.input import get_color_mapping
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.base import BaseTool
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit
from langchain.prompts import PromptTemplate
from langchain.output_parsers.pydantic import SQLOutput
from langchain.output_parsers import PydanticOutputParser
class SQLAgentExecutor(AgentExecutor):
sqlPad = []
state = ""
"""A MRKL chain that uses SQL to store data."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
iterations = 0
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
return self._return(next_step_output, intermediate_steps)
else:
agent_action, observation = next_step_output
if agent_action.tool == "query_sql_db":
self.sqlPad.append(agent_action.tool_input)
print(self.sqlPad)
intermediate_steps.append(next_step_output)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps)
class SQLZeroShotAgent(ZeroShotAgent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
output_parser = PydanticOutputParser(pydantic_object=SQLOutput)
output_format_instructions = "{output_format_instructions}"
template = "\n\n".join([prefix,
tool_strings,
format_instructions,
output_format_instructions,
suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template,
output_parser=output_parser,
input_variables=input_variables,
partial_variables={"output_format_instructions": output_parser.get_format_instructions()})