File size: 3,695 Bytes
f8b0fab
 
 
 
 
 
91cf0b9
 
 
 
 
ecd6946
df399f2
ecd6946
91cf0b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f422003
91cf0b9
 
 
 
 
 
f8b0fab
9eb98fa
 
 
 
 
 
 
b75450f
9eb98fa
b75450f
 
 
 
9eb98fa
b75450f
 
9eb98fa
91cf0b9
 
 
611b5d4
91cf0b9
 
13fcfb3
6ea428d
611b5d4
91cf0b9
f8b0fab
91cf0b9
9eb98fa
 
 
91cf0b9
 
13fcfb3
91cf0b9
 
 
 
9eb98fa
91cf0b9
 
 
b5dc7c4
6ed7387
91cf0b9
 
 
f8b0fab
 
 
 
 
 
 
 
0f27ac5
f8b0fab
23496d5
0419f76
22c6298
 
0f9e25e
0f27ac5
 
0419f76
0f9e25e
22c6298
0f9e25e
 
22c6298
0f9e25e
 
 
f8b0fab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from langchain import SQLDatabaseChain
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate

llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True)

DEFAULT_TABLES = [
    'Active Players',
    'Team_Per_Game_Statistics_2022_23',
    "Team_Totals_Statistics_2022_23",
    "Player_Total_Statistics_2022_23",
    "Player_Per_Game_Statistics_2022_23"
]

def get_prompt():
    _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
    Use the following format:

    Question: "Question here"
    SQLQuery: "SQL Query to run"
    SQLResult: "Result of the SQLQuery"

    Answer: "Final answer here"

    Only use the following tables:

    {table_info}

    Question: {input}"""

    PROMPT = PromptTemplate(
        input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
    )
    return PROMPT

def check_query(query):
    if query.startswith("### Query"):
        split = query.split('\n\n')
        q_text = split[0]
        t_text = split[1]

        if t_text.startswith("### Tables"):
            query_params = dict()
            tables = t_text.split('\n')
            query_params['tables'] = tables[1:]
            query_params['q'] = q_text.split('\n')[1]
            print(query_params)
            return query_params
        else:
            return 'error'
    return 'small'

def get_db(q, tables):
    if len(tables) == 0:
        db = SQLDatabase.from_uri("sqlite:///nba_small.db",
                                  sample_rows_in_table_info=2)
    else:
        tables.extend(DEFAULT_TABLES)
        db = SQLDatabase.from_uri("sqlite:///nba.db",
                                  include_tables = tables,
                                  sample_rows_in_table_info=2)
    return db
def answer_question(query):
    PROMPT = get_prompt()
    query_check = check_query(query)
    if query_check == 'error':
        return('ERROR: Wrong format for getting the big db schema')
    if isinstance(query_check, dict):
        q = query_check['q']
        tables = query_check['tables']
    if query_check == 'small':
        q = query
        tables = []
    db = get_db(q, tables)

    db_chain = SQLDatabaseChain.from_llm(llm, db, 
                                         prompt=PROMPT, 
                                         verbose=True, 
                                         return_intermediate_steps=True,
                                        #  use_query_checker=True
                                         )
    result = db_chain(q)
    return result['result']

if __name__ == "__main__":
    import gradio as gr
    # print(answer_question("Who is Harry's Father"))
    
    gr.Interface(
        answer_question,
        [
          gr.inputs.Textbox(lines=10, label="Query"),
        ],
        gr.outputs.Textbox(label="Response"),
        title="Ask NBA Stats",
        description=""" Ask NBA Stats is a tool that let's you ask a question with
                        the NBA SQL tables as a reference
                        
                        Ask a simple question to use the small database

                        If you would like to access the large DB use format
                        
                        ### Query
                        single line query
                        
                        ### Tables
                        tables to access line by line
                        table1
                        table2"""
    ).launch()