Spaces:
Runtime error
Runtime error
File size: 3,794 Bytes
58d33f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.schema import AgentAction, AgentFinish
from langchain.agents.agent import AgentExecutor
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.input import get_color_mapping
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.base import BaseTool
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit
from langchain.prompts import PromptTemplate
from langchain.output_parsers.pydantic import SQLOutput
from langchain.output_parsers import PydanticOutputParser
class SQLAgentExecutor(AgentExecutor):
sqlPad = []
state = ""
"""A MRKL chain that uses SQL to store data."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
self.agent.prepare_for_new_call()
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the iterations the agent has gone through
iterations = 0
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
)
if isinstance(next_step_output, AgentFinish):
return self._return(next_step_output, intermediate_steps)
else:
agent_action, observation = next_step_output
if agent_action.tool == "query_sql_db":
self.sqlPad.append(agent_action.tool_input)
print(self.sqlPad)
intermediate_steps.append(next_step_output)
iterations += 1
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps)
class SQLZeroShotAgent(ZeroShotAgent):
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = SQL_PREFIX,
suffix: str = SQL_SUFFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
output_parser = PydanticOutputParser(pydantic_object=SQLOutput)
output_format_instructions = "{output_format_instructions}"
template = "\n\n".join([prefix,
tool_strings,
format_instructions,
output_format_instructions,
suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template,
output_parser=output_parser,
input_variables=input_variables,
partial_variables={"output_format_instructions": output_parser.get_format_instructions()})
|