File size: 3,695 Bytes
f8b0fab 91cf0b9 ecd6946 df399f2 ecd6946 91cf0b9 f422003 91cf0b9 f8b0fab 9eb98fa b75450f 9eb98fa b75450f 9eb98fa b75450f 9eb98fa 91cf0b9 611b5d4 91cf0b9 13fcfb3 6ea428d 611b5d4 91cf0b9 f8b0fab 91cf0b9 9eb98fa 91cf0b9 13fcfb3 91cf0b9 9eb98fa 91cf0b9 b5dc7c4 6ed7387 91cf0b9 f8b0fab 0f27ac5 f8b0fab 23496d5 0419f76 22c6298 0f9e25e 0f27ac5 0419f76 0f9e25e 22c6298 0f9e25e 22c6298 0f9e25e f8b0fab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from langchain import SQLDatabaseChain
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate
llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True)
DEFAULT_TABLES = [
'Active Players',
'Team_Per_Game_Statistics_2022_23',
"Team_Totals_Statistics_2022_23",
"Player_Total_Statistics_2022_23",
"Player_Per_Game_Statistics_2022_23"
]
def get_prompt():
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:
Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"
Only use the following tables:
{table_info}
Question: {input}"""
PROMPT = PromptTemplate(
input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)
return PROMPT
def check_query(query):
if query.startswith("### Query"):
split = query.split('\n\n')
q_text = split[0]
t_text = split[1]
if t_text.startswith("### Tables"):
query_params = dict()
tables = t_text.split('\n')
query_params['tables'] = tables[1:]
query_params['q'] = q_text.split('\n')[1]
print(query_params)
return query_params
else:
return 'error'
return 'small'
def get_db(q, tables):
if len(tables) == 0:
db = SQLDatabase.from_uri("sqlite:///nba_small.db",
sample_rows_in_table_info=2)
else:
tables.extend(DEFAULT_TABLES)
db = SQLDatabase.from_uri("sqlite:///nba.db",
include_tables = tables,
sample_rows_in_table_info=2)
return db
def answer_question(query):
PROMPT = get_prompt()
query_check = check_query(query)
if query_check == 'error':
return('ERROR: Wrong format for getting the big db schema')
if isinstance(query_check, dict):
q = query_check['q']
tables = query_check['tables']
if query_check == 'small':
q = query
tables = []
db = get_db(q, tables)
db_chain = SQLDatabaseChain.from_llm(llm, db,
prompt=PROMPT,
verbose=True,
return_intermediate_steps=True,
# use_query_checker=True
)
result = db_chain(q)
return result['result']
if __name__ == "__main__":
import gradio as gr
# print(answer_question("Who is Harry's Father"))
gr.Interface(
answer_question,
[
gr.inputs.Textbox(lines=10, label="Query"),
],
gr.outputs.Textbox(label="Response"),
title="Ask NBA Stats",
description=""" Ask NBA Stats is a tool that let's you ask a question with
the NBA SQL tables as a reference
Ask a simple question to use the small database
If you would like to access the large DB use format
### Query
single line query
### Tables
tables to access line by line
table1
table2"""
).launch() |