Khachatur Mirijanyan commited on
Commit
91cf0b9
1 Parent(s): ea56acd

First Chain Check

Browse files
Files changed (1) hide show
  1. app.py +55 -7
app.py CHANGED
@@ -4,6 +4,36 @@ from langchain.llms.openai import OpenAI
4
  from langchain.chat_models import ChatOpenAI
5
  from langchain.prompts.prompt import PromptTemplate
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def check_query(query):
9
  if query.startswith("### Query"):
@@ -22,18 +52,36 @@ def check_query(query):
22
  return 'error'
23
  return 'small'
24
 
 
 
 
 
 
 
 
 
 
 
25
  def answer_question(query):
 
26
  query_check = check_query(query)
27
- if isinstance(query_check, dict):
28
- return("BIG TABLE")
29
- if query_check == 'small':
30
- return('SMALL TABLE')
31
  if query_check == 'error':
32
  return('ERROR: Wrong format for getting the big db schema')
33
-
34
-
 
 
 
 
 
35
 
36
- return("DONE: " + query)
 
 
 
 
 
 
37
 
38
  if __name__ == "__main__":
39
  import gradio as gr
 
4
  from langchain.chat_models import ChatOpenAI
5
  from langchain.prompts.prompt import PromptTemplate
6
 
7
+ llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo", verbose=True)
8
+
9
+ DEFAULT_TABLES = [
10
+ 'Active Players',
11
+ 'Team_Per_Game_Statistics_2022_23',
12
+ 'Team_Totals_Statistics_2022_23',
13
+ 'Player_Total_Statistics_2022_23',
14
+ 'Player_Per_Game_Statistics_2022_23'
15
+ ]
16
+
17
+ def get_prompt():
18
+ _DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
19
+ Use the following format:
20
+
21
+ Question: "Question here"
22
+ SQLQuery: "SQL Query to run"
23
+ SQLResult: "Result of the SQLQuery"
24
+
25
+ Answer: "Final answer here"
26
+
27
+ Only use the following tables:
28
+
29
+ {table_info}
30
+
31
+ Question: {input}"""
32
+
33
+ PROMPT = PromptTemplate(
34
+ input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
35
+ )
36
+ return PROMPT
37
 
38
  def check_query(query):
39
  if query.startswith("### Query"):
 
52
  return 'error'
53
  return 'small'
54
 
55
+ def get_db(q, tables):
56
+ if len(tables) == 0:
57
+ db = SQLDatabase.from_uri("sqlite:///nba_small.db",
58
+ sample_rows_in_table_info=3)
59
+ else:
60
+ tables.extend(DEFAULT_TABLES)
61
+ db = SQLDatabase.from_uri("sqlite:///nba_small.db",
62
+ include_tables = tables
63
+ sample_rows_in_table_info=3)
64
+ return db
65
  def answer_question(query):
66
+ PROMPT = get_prompt()
67
  query_check = check_query(query)
 
 
 
 
68
  if query_check == 'error':
69
  return('ERROR: Wrong format for getting the big db schema')
70
+ if isinstance(query_check, dict):
71
+ q = query_check['q']
72
+ tables = query_check
73
+ if query_check == 'small':
74
+ q = query
75
+ tables = []
76
+ db = get_db(q, tables)
77
 
78
+ db_chain = SQLDatabaseChain.from_llm(llm, db,
79
+ prompt=PROMPT,
80
+ verbose=True,
81
+ return_intermediate_steps=True
82
+ )
83
+ result = db_chain(q)
84
+ return result['result']
85
 
86
  if __name__ == "__main__":
87
  import gradio as gr