from langchain import SQLDatabaseChain from langchain.sql_database import SQLDatabase from langchain.llms.openai import OpenAI from langchain.chat_models import ChatOpenAI from langchain.prompts.prompt import PromptTemplate llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True) DEFAULT_TABLES = [ 'Active Players', 'Team_Per_Game_Statistics_2022_23', "Team_Totals_Statistics_2022_23", "Player_Total_Statistics_2022_23", "Player_Per_Game_Statistics_2022_23" ] def get_prompt(): _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Use the following format: Question: "Question here" SQLQuery: "SQL Query to run" SQLResult: "Result of the SQLQuery" Answer: "Final answer here" Only use the following tables: {table_info} Question: {input}""" PROMPT = PromptTemplate( input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE ) return PROMPT def check_query(query): if query.startswith("### Query"): split = query.split('\n\n') q_text = split[0] t_text = split[1] if t_text.startswith("### Tables"): query_params = dict() tables = t_text.split('\n') query_params['tables'] = tables[1:] query_params['q'] = q_text.split('\n')[1] print(query_params) return query_params else: return 'error' return 'small' def get_db(q, tables): if len(tables) == 0: db = SQLDatabase.from_uri("sqlite:///nba_small.db", sample_rows_in_table_info=2) else: tables.extend(DEFAULT_TABLES) db = SQLDatabase.from_uri("sqlite:///nba.db", include_tables = tables, sample_rows_in_table_info=2) return db def answer_question(query): PROMPT = get_prompt() query_check = check_query(query) if query_check == 'error': return('ERROR: Wrong format for getting the big db schema') if isinstance(query_check, dict): q = query_check['q'] tables = query_check['tables'] if query_check == 'small': q = query tables = [] db = get_db(q, tables) db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, return_intermediate_steps=True, # use_query_checker=True ) result = db_chain(q) return result['result'] if __name__ == "__main__": import gradio as gr # print(answer_question("Who is Harry's Father")) gr.Interface( answer_question, [ gr.inputs.Textbox(lines=10, label="Query"), ], gr.outputs.Textbox(label="Response"), title="Ask NBA Stats", description=""" Ask NBA Stats is a tool that let's you ask a question with the NBA SQL tables as a reference Ask a simple question to use the small database If you would like to access the large DB use format ### Query single line query ### Tables tables to access line by line table1 table2""" ).launch()