jfeng1115's picture
init commit
58d33f0
raw
history blame contribute delete
No virus
2.73 kB
import json
import re
from typing import Any, List, Tuple
from pydantic import BaseModel, ValidationError, Field
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
class SQLOutput(BaseModel):
sql_query: str = Field(description="sql query to get the final answer")
column_names: List[str] = Field(description="column names of the sql query output")
# query_result: List[Tuple[str]] = Field(description="the sql query's output, each tuple is a row of the output,"
# "should match eactly the last observation's data")
chart_type: str = Field(description="the best chart type to visualize the sql query output,"
"should be one of ['bar', 'line', 'pie', 'table'], "
"use line for timeseries data, "
"if there are more than 3 column names use table data, "
"use pie for percentage data")
class SQLThink(BaseModel):
thought: str = Field(description="think to get the final answer, you should always think about what to do")
clarification: str = Field(description="clarification question to the user if the analytics question is not clear")
plan: str = Field(description="plan to get the final answer, you should always plan before you take action")
class PydanticOutputParser(BaseOutputParser):
pydantic_object: Any
def parse(self, text: str) -> BaseModel:
try:
# Greedy search for 1st json candidate.
match = re.search(
"\{.*\}", text.strip(), re.MULTILINE | re.IGNORECASE | re.DOTALL
)
json_str = ""
if match:
json_str = match.group()
json_object = json.loads(json_str)
return self.pydantic_object.parse_obj(json_object)
except (json.JSONDecodeError, ValidationError) as e:
name = self.pydantic_object.__name__
msg = f"Failed to parse {name} from completion {text}. Got: {e}"
raise OutputParserException(msg)
def get_format_instructions(self) -> str:
schema = self.pydantic_object.schema()
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema)