NBA-Stats / app.py
Khachatur Mirijanyan
Add query checker to db chain
b5dc7c4
raw
history blame
3.69 kB
from langchain import SQLDatabaseChain
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True)
DEFAULT_TABLES = [
'Active Players',
'Team_Per_Game_Statistics_2022_23',
'Team_Totals_Statistics_2022_23',
'Player_Total_Statistics_2022_23',
'Player_Per_Game_Statistics_2022_23'
]
def get_prompt():
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)
return PROMPT
def check_query(query):
if query.startswith("### Query"):
split = query.split('\n\n')
q_text = split[0]
t_text = split[1]
if t_text.startswith("### Tables"):
query_params = dict()
tables = t_text.split('\n')
query_params['tables'] = tables[1:]
query_params['q'] = q_text.split('\n')[1]
print(query_params)
return query_params
else:
return 'error'
return 'small'
def get_db(q, tables):
if len(tables) == 0:
db = SQLDatabase.from_uri("sqlite:///nba_small.db",
sample_rows_in_table_info=3)
else:
tables.extend(DEFAULT_TABLES)
db = SQLDatabase.from_uri("sqlite:///nba.db",
include_tables = tables,
sample_rows_in_table_info=3)
return db
def answer_question(query):
PROMPT = get_prompt()
query_check = check_query(query)
if query_check == 'error':
return('ERROR: Wrong format for getting the big db schema')
if isinstance(query_check, dict):
q = query_check['q']
tables = query_check['tables']
if query_check == 'small':
q = query
tables = []
db = get_db(q, tables)
db_chain = SQLDatabaseChain.from_llm(llm, db,
prompt=PROMPT,
verbose=True,
return_intermediate_steps=True,
use_query_checker=True
)
result = db_chain(q)
return result['result']
if __name__ == "__main__":
import gradio as gr
# print(answer_question("Who is Harry's Father"))
gr.Interface(
answer_question,
[
gr.inputs.Textbox(lines=10, label="Query"),
],
gr.outputs.Textbox(label="Response"),
title="Ask NBA Stats",
description=""" Ask NBA Stats is a tool that let's you ask a question with
the NBA SQL tables as a reference
Ask a simple question to use the small database
If you would like to access the large DB use format
### Query
single line query
### Tables
tables to access line by line
table1
table2"""
).launch()