from langchain.agents.mrkl.base import ZeroShotAgent from langchain.schema import AgentAction, AgentFinish from langchain.agents.agent import AgentExecutor from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from langchain.input import get_color_mapping from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain.tools.base import BaseTool from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit from langchain.prompts import PromptTemplate from langchain.output_parsers.pydantic import SQLOutput from langchain.output_parsers import PydanticOutputParser class SQLAgentExecutor(AgentExecutor): sqlPad = [] state = "" """A MRKL chain that uses SQL to store data.""" def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: self.agent.prepare_for_new_call() # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} # We construct a mapping from each tool to a color, used for logging. color_mapping = get_color_mapping( [tool.name for tool in self.tools], excluded_colors=["green"] ) intermediate_steps: List[Tuple[AgentAction, str]] = [] # Let's start tracking the iterations the agent has gone through iterations = 0 # We now enter the agent loop (until it returns something). while self._should_continue(iterations): next_step_output = self._take_next_step( name_to_tool_map, color_mapping, inputs, intermediate_steps ) if isinstance(next_step_output, AgentFinish): return self._return(next_step_output, intermediate_steps) else: agent_action, observation = next_step_output if agent_action.tool == "query_sql_db": self.sqlPad.append(agent_action.tool_input) print(self.sqlPad) intermediate_steps.append(next_step_output) iterations += 1 output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) return self._return(output, intermediate_steps) class SQLZeroShotAgent(ZeroShotAgent): @classmethod def create_prompt( cls, tools: Sequence[BaseTool], prefix: str = SQL_PREFIX, suffix: str = SQL_SUFFIX, format_instructions: str = FORMAT_INSTRUCTIONS, input_variables: Optional[List[str]] = None, ) -> PromptTemplate: tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) tool_names = ", ".join([tool.name for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names) output_parser = PydanticOutputParser(pydantic_object=SQLOutput) output_format_instructions = "{output_format_instructions}" template = "\n\n".join([prefix, tool_strings, format_instructions, output_format_instructions, suffix]) if input_variables is None: input_variables = ["input", "agent_scratchpad"] return PromptTemplate(template=template, output_parser=output_parser, input_variables=input_variables, partial_variables={"output_format_instructions": output_parser.get_format_instructions()})