File size: 3,794 Bytes
58d33f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.schema import AgentAction, AgentFinish
from langchain.agents.agent import AgentExecutor
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.input import get_color_mapping
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.tools.base import BaseTool
from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX
from langchain.agents.agent_toolkits.sql.toolkit import SimpleSQLDatabaseToolkit
from langchain.prompts import PromptTemplate
from langchain.output_parsers.pydantic import SQLOutput
from langchain.output_parsers import PydanticOutputParser


class SQLAgentExecutor(AgentExecutor):
    sqlPad = []
    state = ""
    """A MRKL chain that uses SQL to store data."""

    def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        self.agent.prepare_for_new_call()
        # Construct a mapping of tool name to tool for easy lookup
        name_to_tool_map = {tool.name: tool for tool in self.tools}
        # We construct a mapping from each tool to a color, used for logging.
        color_mapping = get_color_mapping(
            [tool.name for tool in self.tools], excluded_colors=["green"]
        )
        intermediate_steps: List[Tuple[AgentAction, str]] = []
        # Let's start tracking the iterations the agent has gone through
        iterations = 0
        # We now enter the agent loop (until it returns something).
        while self._should_continue(iterations):
            next_step_output = self._take_next_step(
                name_to_tool_map, color_mapping, inputs, intermediate_steps
            )
            if isinstance(next_step_output, AgentFinish):
                return self._return(next_step_output, intermediate_steps)
            else:
                agent_action, observation = next_step_output
                if agent_action.tool == "query_sql_db":
                    self.sqlPad.append(agent_action.tool_input)
                    print(self.sqlPad)

            intermediate_steps.append(next_step_output)
            iterations += 1
        output = self.agent.return_stopped_response(
            self.early_stopping_method, intermediate_steps, **inputs
        )
        return self._return(output, intermediate_steps)


class SQLZeroShotAgent(ZeroShotAgent):
    @classmethod
    def create_prompt(
        cls,
        tools: Sequence[BaseTool],
        prefix: str = SQL_PREFIX,
        suffix: str = SQL_SUFFIX,
        format_instructions: str = FORMAT_INSTRUCTIONS,
        input_variables: Optional[List[str]] = None,
    ) -> PromptTemplate:
        tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
        tool_names = ", ".join([tool.name for tool in tools])
        format_instructions = format_instructions.format(tool_names=tool_names)
        output_parser = PydanticOutputParser(pydantic_object=SQLOutput)
        output_format_instructions = "{output_format_instructions}"
        template = "\n\n".join([prefix,
                                tool_strings,
                                format_instructions,
                                output_format_instructions,
                                suffix])
        if input_variables is None:
            input_variables = ["input", "agent_scratchpad"]

        return PromptTemplate(template=template,
                              output_parser=output_parser,
                              input_variables=input_variables,
                              partial_variables={"output_format_instructions": output_parser.get_format_instructions()})