import json import re from typing import Any, List, Tuple from pydantic import BaseModel, ValidationError, Field from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS from langchain.schema import BaseOutputParser, OutputParserException class SQLOutput(BaseModel): sql_query: str = Field(description="sql query to get the final answer") column_names: List[str] = Field(description="column names of the sql query output") # query_result: List[Tuple[str]] = Field(description="the sql query's output, each tuple is a row of the output," # "should match eactly the last observation's data") chart_type: str = Field(description="the best chart type to visualize the sql query output," "should be one of ['bar', 'line', 'pie', 'table'], " "use line for timeseries data, " "if there are more than 3 column names use table data, " "use pie for percentage data") class SQLThink(BaseModel): thought: str = Field(description="think to get the final answer, you should always think about what to do") clarification: str = Field(description="clarification question to the user if the analytics question is not clear") plan: str = Field(description="plan to get the final answer, you should always plan before you take action") class PydanticOutputParser(BaseOutputParser): pydantic_object: Any def parse(self, text: str) -> BaseModel: try: # Greedy search for 1st json candidate. match = re.search( "\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL ) json_str = "" if match: json_str = match.group() json_object = json.loads(json_str) return self.pydantic_object.parse_obj(json_object) except (json.JSONDecodeError, ValidationError) as e: name = self.pydantic_object.__name__ msg = f"Failed to parse {name} from completion {text}. Got: {e}" raise OutputParserException(msg) def get_format_instructions(self) -> str: schema = self.pydantic_object.schema() # Remove extraneous fields. reduced_schema = schema if "title" in reduced_schema: del reduced_schema["title"] if "type" in reduced_schema: del reduced_schema["type"] # Ensure json in context is well-formed with double quotes. schema = json.dumps(reduced_schema) return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)