Spaces:
Runtime error
Runtime error
from langchain.agents.mrkl.base import ZeroShotAgent | |
from langchain.schema import AgentAction, AgentFinish | |
from langchain.agents.agent import AgentExecutor | |
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |
from langchain.input import get_color_mapping | |
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS | |
from langchain.callbacks.base import BaseCallbackManager | |
from langchain.chains.llm import LLMChain | |
from langchain.llms.base import BaseLLM | |
from langchain.tools.base import BaseTool | |
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX | |
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit | |
from langchain.prompts import PromptTemplate | |
from langchain.output_parsers.pydantic import SQLOutput | |
from langchain.output_parsers import PydanticOutputParser | |
class SQLAgentExecutor(AgentExecutor): | |
sqlPad = [] | |
state = "" | |
"""A MRKL chain that uses SQL to store data.""" | |
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: | |
self.agent.prepare_for_new_call() | |
# Construct a mapping of tool name to tool for easy lookup | |
name_to_tool_map = {tool.name: tool for tool in self.tools} | |
# We construct a mapping from each tool to a color, used for logging. | |
color_mapping = get_color_mapping( | |
[tool.name for tool in self.tools], excluded_colors=["green"] | |
) | |
intermediate_steps: List[Tuple[AgentAction, str]] = [] | |
# Let's start tracking the iterations the agent has gone through | |
iterations = 0 | |
# We now enter the agent loop (until it returns something). | |
while self._should_continue(iterations): | |
next_step_output = self._take_next_step( | |
name_to_tool_map, color_mapping, inputs, intermediate_steps | |
) | |
if isinstance(next_step_output, AgentFinish): | |
return self._return(next_step_output, intermediate_steps) | |
else: | |
agent_action, observation = next_step_output | |
if agent_action.tool == "query_sql_db": | |
self.sqlPad.append(agent_action.tool_input) | |
print(self.sqlPad) | |
intermediate_steps.append(next_step_output) | |
iterations += 1 | |
output = self.agent.return_stopped_response( | |
self.early_stopping_method, intermediate_steps, **inputs | |
) | |
return self._return(output, intermediate_steps) | |
class SQLZeroShotAgent(ZeroShotAgent): | |
def create_prompt( | |
cls, | |
tools: Sequence[BaseTool], | |
prefix: str = SQL_PREFIX, | |
suffix: str = SQL_SUFFIX, | |
format_instructions: str = FORMAT_INSTRUCTIONS, | |
input_variables: Optional[List[str]] = None, | |
) -> PromptTemplate: | |
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) | |
tool_names = ", ".join([tool.name for tool in tools]) | |
format_instructions = format_instructions.format(tool_names=tool_names) | |
output_parser = PydanticOutputParser(pydantic_object=SQLOutput) | |
output_format_instructions = "{output_format_instructions}" | |
template = "\n\n".join([prefix, | |
tool_strings, | |
format_instructions, | |
output_format_instructions, | |
suffix]) | |
if input_variables is None: | |
input_variables = ["input", "agent_scratchpad"] | |
return PromptTemplate(template=template, | |
output_parser=output_parser, | |
input_variables=input_variables, | |
partial_variables={"output_format_instructions": output_parser.get_format_instructions()}) | |