jfeng1115's picture
init commit
58d33f0
import json
import re
from typing import Any, List, Tuple
from pydantic import BaseModel, ValidationError, Field
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
class SQLOutput(BaseModel):
sql_query: str = Field(description="sql query to get the final answer")
column_names: List[str] = Field(description="column names of the sql query output")
# query_result: List[Tuple[str]] = Field(description="the sql query's output, each tuple is a row of the output,"
# "should match eactly the last observation's data")
chart_type: str = Field(description="the best chart type to visualize the sql query output,"
"should be one of ['bar', 'line', 'pie', 'table'], "
"use line for timeseries data, "
"if there are more than 3 column names use table data, "
"use pie for percentage data")
class SQLThink(BaseModel):
thought: str = Field(description="think to get the final answer, you should always think about what to do")
clarification: str = Field(description="clarification question to the user if the analytics question is not clear")
plan: str = Field(description="plan to get the final answer, you should always plan before you take action")
class PydanticOutputParser(BaseOutputParser):
pydantic_object: Any
def parse(self, text: str) -> BaseModel:
try:
# Greedy search for 1st json candidate.
match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
json_object = json.loads(json_str)
return self.pydantic_object.parse_obj(json_object)
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise OutputParserException(msg)
def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)