anumaurya114exp commited on
Commit
37bd4dd
·
1 Parent(s): 0145029

Overwrite queryHelper with queryHelper2

Browse files
Files changed (7) hide show
  1. app.py +284 -203
  2. config.py +2 -0
  3. configProd.py +19 -0
  4. constants.py +33 -0
  5. gptManager.py +58 -0
  6. requirements.txt +3 -2
  7. utils.py +111 -0
app.py CHANGED
@@ -1,222 +1,258 @@
1
- # !pip install gradio
2
- # !pip install openai
3
- # import openai
4
-
5
- import gradio
6
- import pandas as pd
7
- import psycopg2
8
-
9
  import pandas as pd
10
- import openai
11
-
12
- import sqlite3
13
  import psycopg2
14
  import time
15
  import gradio as gr
16
  import sqlparse
 
17
  import os
 
 
18
 
19
- #EA_key
20
- openai.api_key = os.getenv("api_key")
 
 
 
21
 
22
  pd.set_option('display.max_columns', None)
23
  pd.set_option('display.max_rows', None)
24
 
25
- #database credential
26
- db_name = os.getenv("db_name")
27
- user_db = os.getenv("user_db")
28
- pwd_db = os.getenv("pwd_db")
29
- host_db = os.getenv("host_db")
30
- port_db = os.getenv("port_db")
31
-
32
- conn = psycopg2.connect(database=db_name, user = user_db, password = pwd_db, host = host_db, port = port_db)
33
-
34
- # sql="select master_customer_id, c.gender,c.city_name,c.state_name, c.zip_code,product_name,department,class,category,d.date_value,s.city_name as store_city,s.state_name as store_state,s.zip_code as store_zip,s.store_name,s.opened_dt,s.closed_dt, f.transaction_amt,ch.type from oyster_demo.tbl_d_customer c,oyster_demo.tbl_d_product p,oyster_demo.tbl_f_sales f,oyster_demo.tbl_d_date d, oyster_demo.tbl_d_store s,oyster_demo.tbl_d_channel ch where p.product_id=f.product_id and c.customer_id=f.customer_id and d.date_id=f.date_id and s.store_id=f.store_id and ch.channel_id=f.channel_id"
35
- sql2="""select * from lpdatamart.tbl_d_customer limit 10000"""
36
- sql3="""select * from lpdatamart.tbl_d_product limit 1000"""
37
- sql4="""select * from lpdatamart.tbl_f_sales limit 10000"""
38
- # sql5="""select * from lpdatamart.tbl_d_time limit 10000"""
39
- sql6="""select * from lpdatamart.tbl_d_store limit 10000"""
40
- sql7="""select * from lpdatamart.tbl_d_channel limit 10000"""
41
- sql8="""select * from lpdatamart.tbl_d_lineaction_code limit 10000"""
42
- sql9 = """select * from lpdatamart.tbl_d_calendar limit 10000"""
43
-
44
- df_customer = pd.read_sql_query(sql2, con=conn)
45
- df_product = pd.read_sql_query(sql3, con=conn)
46
- df_sales = pd.read_sql_query(sql4, con=conn)
47
- # df_time = pd.read_sql_query(sql5, con=conn)
48
- df_store = pd.read_sql_query(sql6, con=conn)
49
- df_channel = pd.read_sql_query(sql7, con=conn)
50
- df_lineaction = pd.read_sql_query(sql8, con=conn)
51
- df_calendar = pd.read_sql_query(sql9, con=conn)
52
-
53
-
54
- conn.close()
55
- df_customer.head(2)
56
-
57
- customer_col=['customer_id','customer_type', 'first_name', 'middle_name', 'household_name', 'last_name', 'personal_email', 'city', 'state', 'zip_code', 'address1', 'country', 'gender', 'phone_number', 'reward_number']
58
- product_col=['product_id', 'product_name', 'product_price', 'department', 'class', 'discount', 'category', 'department_desc', 'department_type', 'product_type', 'manufacturer', 'color']
59
- sales_col = ['store_id', 'customer_id', 'channel_id', 'product_id', 'time_id', 'date_id','order_id', 'line_action', 'discount_amount', 'shipping_amount','transaction_date', 'transaction_amount', 'transaction_type', 'qty_sold']
60
- # time_col = ['time_id', 'hour', 'minute', 'second', 'am_pm']
61
- store_col = ['store_id', 'store_number', 'store_name', 'store_designation', 'store_longitude', 'store_latitude', 'store_manager_name', 'zip_code', 'state_code', 'city', 'street_number', 'street_name', 'store_region', 'store_type', 'address1','sublocationcode', 'channel', 'company_flag', 'kiosk_physical_store', 'sublocation_code']
62
- channel_col = ['channel_id', 'channel_name', 'channel_code']
63
- lineaction_col = ['line_action_code', 'line_action_code_desc', 'load_date', 'catgory', 'sales_type']
64
- calendar_col = ['date_id','calendar_date','calendar_month','day_of_week','calendar_week_number','calendar_month_number','calendar_quarter_number','day_of_month','day_of_quarter','day_of_the_year','us_holiday','lp_holiday','work_day','year','ad_week','ad_week_year','ad_month','lp_day','lp_week','lp_month','lp_year','lp_quarter','event_day']
65
-
66
- df_customer=df_customer[customer_col]
67
- df_product=df_product[product_col]
68
- df_sales=df_sales[sales_col]
69
- # df_time = df_time[time_col]
70
- df_store = df_store[store_col]
71
- df_channel = df_channel[channel_col]
72
- df_lineaction = df_lineaction[lineaction_col]
73
- df_calendar = df_calendar[calendar_col]
74
-
75
- # df = pd.read_csv('/content/drive/MyDrive/tbl_m_querygen.csv')
76
-
77
-
78
- import sqlite3
79
- import openai
80
-
81
- # Connect to SQLite database
82
- conn1 = sqlite3.connect('chatgpt.db')
83
- cursor1 = conn1.cursor()
84
-
85
- # Connect to SQLite database
86
- conn2 = sqlite3.connect('chatgpt.db')
87
- cursor2 = conn2.cursor()
88
-
89
- # Connect to SQLite database
90
- conn3 = sqlite3.connect('chatgpt.db')
91
- cursor3 = conn3.cursor()
92
-
93
- # Connect to SQLite database
94
- conn4 = sqlite3.connect('chatgpt.db')
95
- cursor4 = conn4.cursor()
96
-
97
- # Connect to SQLite database
98
- conn5 = sqlite3.connect('chatgpt.db')
99
- cursor5 = conn5.cursor()
100
-
101
- # Connect to SQLite database
102
- conn5 = sqlite3.connect('chatgpt.db')
103
- cursor5 = conn5.cursor()
104
-
105
- # Connect to SQLite database
106
- conn6 = sqlite3.connect('chatgpt.db')
107
- cursor6 = conn6.cursor()
108
-
109
- # Connect to SQLite database
110
- conn7 = sqlite3.connect('chatgpt.db')
111
- cursor7 = conn7.cursor()
112
-
113
- # Connect to SQLite database
114
- conn8 = sqlite3.connect('chatgpt.db')
115
- cursor8 = conn8.cursor()
116
-
117
-
118
- # openai.api_key = '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'
119
-
120
- # Insert DataFrame into SQLite database
121
- df_customer.to_sql('tbl_d_customer', conn1, if_exists='replace', index=False)
122
- df_product.to_sql('tbl_d_product', conn2, if_exists='replace', index=False)
123
- df_sales.to_sql('tbl_f_sales', conn3, if_exists='replace', index=False)
124
- # df_time.to_sql('tbl_d_time', conn4, if_exists='replace', index=False)
125
- df_store.to_sql('tbl_d_store', conn5, if_exists='replace', index=False)
126
- df_channel.to_sql('tbl_d_channel', conn6, if_exists='replace', index=False)
127
- df_lineaction.to_sql('tbl_d_lineaction_code', conn7, if_exists='replace', index=False)
128
- df_calendar.to_sql('tbl_d_calendar', conn8, if_exists ='replace',index=False)
129
-
130
- # Function to get table columns from SQLite database
131
- def get_table_columns(table_name1, table_name2):
132
- cursor1.execute("PRAGMA table_info({})".format(table_name1))
133
- columns1 = cursor1.fetchall()
134
- # print(columns)
135
-
136
- cursor2.execute("PRAGMA table_info({})".format(table_name2))
137
- columns2 = cursor2.fetchall()
138
-
139
- return [column[1] for column in columns1], [column[1] for column in columns2]
140
-
141
- table_name1 = 'tbl_d_customer'
142
- table_name2 = 'tbl_d_product'
143
- table_name3 = 'tbl_f_sales'
144
-
145
- # table_name4 = 'tbl_d_time'
146
- table_name5 = 'tbl_d_store'
147
- table_name6 = 'tbl_d_channel'
148
- table_name7 = 'tbl_d_lineaction_code'
149
- table_name8 = 'tbl_d_calendar'
150
-
151
- columns1,columns2 = get_table_columns(table_name1,table_name2)
152
-
153
-
154
 
155
- # Function to generate SQL query from input text using ChatGPT
156
- def generate_sql_query(text):
157
- # prompt = """You are a ChatGPT language model that can generate SQL queries. Please provide a natural language input text, and I will generate the corresponding SQL query and Answer the provided question if possible for you.The table name is {} and the following data:\n {} and corresponding columns are {}.\nInput: {}\nSQL Query:""".format(table_name,read_csv, columns,text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- messages.append({"role": "user", "content": text})
160
- # print(prompt)
161
- request = openai.ChatCompletion.create(
162
- model="gpt-4",
163
- messages=messages
164
- )
165
- print(request)
166
- sql_query = request['choices'][0]['message']['content']
167
- return sql_query
168
-
169
- text = "for female customer who did a transaction of more than 100 dollars in year 2020 please write sql query ?"
170
-
171
- schema_name = 'lpdatamart'
172
- prompt = """Given an input text, and You will generate the corresponding SQL query. The schema name is {}. The first table name is {} and the following data:\n {}. The second table name is {} and the following data for second table:\n {}. The third table name is {} and the following data for third table:\n {}. The fourth table name is {} and the following data for fourth table:\n {}. The fifth table name is {} and the following data for fifth table:\n {}. The sixth table name is {} and the following data for sixth table:\n {}. The seventh table name is {} and the following data for seventh table:\n {} \n""".format(schema_name,table_name1,df_customer.loc[:5], table_name2, df_product.loc[:5], table_name3, df_sales.loc[:5], table_name5, df_store.loc[:5], table_name6, df_channel.loc[:5],table_name7, df_lineaction.loc[:5], table_name8, df_calendar.loc[:5])
173
- messages = [{"role": "system", "content": prompt}]
174
-
175
- sql_query=generate_sql_query(text)
176
- print("Generated SQL query: ",sql_query)
177
-
178
- # prompt = """Given an input text, and You will generate the corresponding SQL query. The first table name is {} and the following data:\n {}. The second table name is {} and the following data for second table:\n {}. The third table name is {} and the following data for third table:\n {}.\n""".format(table_name1,df2.loc[:5], table_name2, df3.loc[:5], table_name3, df4.loc[:5])
179
- prompt = """Given an input text, and You will generate the corresponding SQL query. The schema name is {}. The first table name is {} and the following data:\n {}. The second table name is {} and the following data for second table:\n {}. The third table name is {} and the following data for third table:\n {}. The fourth table name is {} and the following data for fourth table:\n {}. The fifth table name is {} and the following data for fifth table:\n {}. The sixth table name is {} and the following data for sixth table:\n {}. The seventh table name is {} and the following data for seventh table:\n {} \n""".format(schema_name,table_name1,df_customer.loc[:5], table_name2, df_product.loc[:5], table_name3, df_sales.loc[:5], table_name5, df_store.loc[:5], table_name6, df_channel.loc[:5],table_name7, df_lineaction.loc[:5], table_name8, df_calendar.loc[:5])
180
- messages = [{"role": "system", "content": prompt}]
181
-
182
- import time
183
- import gradio as gr
184
- def CustomChatGPT(user_inp):
185
- messages.append({"role": "user", "content": user_inp})
186
- response = openai.ChatCompletion.create(
187
- model = "gpt-4",
188
- messages = messages
189
- )
190
- ChatGPT_reply = response["choices"][0]["message"]["content"]
191
- messages.append({"role": "assistant", "content": ChatGPT_reply})
192
- return ChatGPT_reply
193
-
194
- def respond(message, chat_history):
195
- bot_message = CustomChatGPT(message)
196
- chat_history.append((message, bot_message))
197
- time.sleep(2)
198
- return "", chat_history
199
-
200
- # to test the generated sql query
201
- def test_Sql(sql):
202
  sql=sql.replace(';', '')
203
- sql = sql + ' ' + 'limit 5'
 
204
  sql = str(sql)
205
  sql = sqlparse.format(sql, reindent=True, keyword_case='upper')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- conn = psycopg2.connect(database=db_name, user = user_db, password = pwd_db, host = host_db, port = port_db)
208
- df = pd.read_sql_query(sql, con=conn)
209
- conn.close()
210
- return pd.DataFrame(df)
211
-
212
- admin = os.getenv("admin")
213
- paswd = os.getenv("paswd")
214
 
215
- def same_auth(username, password):
216
- if username == admin and password == paswd:
217
- return 1
218
 
219
  with gr.Blocks() as demo:
 
 
220
  with gr.Tab("Query Helper"):
221
  gr.Markdown("""<h1><center> Query Helper</center></h1>""")
222
  chatbot = gr.Chatbot()
@@ -224,12 +260,57 @@ with gr.Blocks() as demo:
224
  clear = gr.ClearButton([msg, chatbot])
225
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
226
 
 
227
  with gr.Tab("Run Query"):
228
- # gr.Markdown("""<h1><center> Run Query </center></h1>""")
229
  text_input = gr.Textbox(label = 'Input SQL Query', placeholder="Write your SQL query here ...")
230
  text_output = gr.Textbox(label = 'Result')
231
  text_button = gr.Button("RUN QUERY")
232
  clear = gr.ClearButton([text_input, text_output])
233
- text_button.click(test_Sql, inputs=text_input, outputs=text_output)
234
-
235
- demo.launch(share=True, auth=same_auth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
 
 
 
 
 
 
 
2
  import pandas as pd
 
 
 
3
  import psycopg2
4
  import time
5
  import gradio as gr
6
  import sqlparse
7
+ import re
8
  import os
9
+ import warnings
10
+
11
 
12
+ from config import *
13
+ from constants import *
14
+ from utils import *
15
+ from gptManager import ChatgptManager
16
+ # from queryHelper import QueryHelper
17
 
18
  pd.set_option('display.max_columns', None)
19
  pd.set_option('display.max_rows', None)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Filter out all warning messages
23
+ warnings.filterwarnings("ignore")
24
+
25
+ dbCreds = DataWrapper(DB_CREDS_DATA)
26
+ dbEngine = DbEngine(dbCreds)
27
+ dbEngine.connect()
28
+
29
+ tablesAndCols = getAllTablesInfo(dbEngine, SCHEMA_NAME)
30
+ metadataLayout = MetaDataLayout(schemaName=SCHEMA_NAME, allTablesAndCols=tablesAndCols)
31
+ metadataLayout.setSelection(DEFAULT_TABLES_COLS)
32
+
33
+ selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
34
+
35
+ def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
36
+ data = {}
37
+ conn = dbEngine.connection
38
+ for table in tablesAndCols.keys():
39
+ try:
40
+ sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}"""
41
+ data[table] = pd.read_sql_query(sqlQuery, con=conn)
42
+ except Exception as e:
43
+ print(e)
44
+ print(f"couldn't read table data. Table: {table}")
45
+ return data
46
+
47
+ class QueryHelper:
48
+ def __init__(self, gptInstance, dbEngine, schemaName,
49
+ platform, metadataLayout, sampleDataRows,
50
+ gptSampleRows, getSampleDataForTablesAndCols):
51
+ self.gptInstance = gptInstance
52
+ self.schemaName = schemaName
53
+ self.platform = platform
54
+ self.metadataLayout = metadataLayout
55
+ self.sampleDataRows = sampleDataRows
56
+ self.gptSampleRows = gptSampleRows
57
+ self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
58
+ self.dbEngine = dbEngine
59
+ self._onMetadataChange()
60
+
61
+ def _onMetadataChange(self):
62
+ metadataLayout = self.metadataLayout
63
+ sampleDataRows = self.sampleDataRows
64
+ dbEngine = self.dbEngine
65
+ schemaName = self.schemaName
66
+
67
+ selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
68
+ self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
69
+ tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
70
+
71
+ def getMetadata(self):
72
+ return self.metadataLayout
73
+
74
+ def updateMetadata(self, metadataLayout):
75
+ self.metadataLayout = metadataLayout
76
+ self._onMetadataChange()
77
+
78
+ def modifySqlQueryEnteredByUser(self, userSqlQuery):
79
+ platform = self.platform
80
+ userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
81
+ systemPrompt = ""
82
+ modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
83
+ return modifiedSql
84
+
85
+ def filteredSampleDataForProspects(self, prospectTablesAndCols):
86
+ sampleData = self.sampleData
87
+ filteredData = {}
88
+ for table in prospectTablesAndCols.keys():
89
+ # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
90
+ #take all columns of prospects
91
+ filteredData[table] = sampleData[table]
92
+ return filteredData
93
+
94
+ def getQueryForUserInput(self, userInput):
95
+ gptSampleRows = self.gptSampleRows
96
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
97
+ prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols)
98
+ print("getting prospects", prospectTablesAndCols)
99
+ prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
100
+ systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
101
+ queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration)
102
+ return queryByGpt
103
+
104
+ def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols):
105
+ schemaName = self.schemaName
106
+
107
+ systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
108
+ prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns)
109
+ prospectTablesAndCols = {}
110
+ for table in selectedTablesAndCols.keys():
111
+ if table in prospectiveTablesColsText:
112
+ prospectTablesAndCols[table] = []
113
+ for column in selectedTablesAndCols[table]:
114
+ if column in prospectiveTablesColsText:
115
+ prospectTablesAndCols[table].append(column)
116
+ return prospectTablesAndCols
117
+
118
+ def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
119
+ schemaName = self.schemaName
120
+ platform = self.platform
121
+ prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data"""
122
+ for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
123
+ prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
124
+ prompt += "XXXX"
125
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
126
+
127
+ def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
128
+ schemaName = self.schemaName
129
+ platform = self.platform
130
+
131
+ prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
132
+ for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
133
+ prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
134
+ prompt += "XXXX"
135
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
136
+
137
+
138
+ openAIClient = OpenAI(api_key=OPENAI_API_KEY)
139
+ gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL)
140
+ queryHelper = QueryHelper(gptInstance=gptInstance,
141
+ schemaName=SCHEMA_NAME,platform=PLATFORM,
142
+ metadataLayout=metadataLayout,
143
+ sampleDataRows=SAMPLE_ROW_MAX,
144
+ gptSampleRows=GPT_SAMPLE_ROWS,
145
+ dbEngine=dbEngine,
146
+ getSampleDataForTablesAndCols=getSampleDataForTablesAndCols)
147
+
148
+ def checkAuth(username, password):
149
+ global ADMIN, PASSWD
150
+ if username == ADMIN and password == PASSWD:
151
+ return True
152
+ return False
153
+
154
+
155
+ # Function to save history of chat
156
+ def respond(message, chatHistory):
157
+ """gpt response handler for gradio ui"""
158
+ global queryHelper
159
+ botMessage = queryHelper.getQueryForUserInput(message)
160
+ chatHistory.append((message, botMessage))
161
+ time.sleep(2)
162
+ return "", chatHistory
163
+
164
+ # Function to test the generated sql query
165
+ def isDataQuery(sql_query):
166
+ upper_query = sql_query.upper()
167
+
168
+ dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
169
+ for keyword in dml_keywords:
170
+ if re.search(fr'\b{keyword}\b', upper_query):
171
+ return False # Found a DML keyword, indicating modification
172
+
173
+ # If no DML keywords are found, it's likely a data query
174
+ return True
175
+
176
+ def testSQL(sql):
177
+ global dbEngine, queryHelper
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  sql=sql.replace(';', '')
180
+ if ('limit' in sql[-15:].lower())==False:
181
+ sql = sql + ' ' + 'limit 5'
182
  sql = str(sql)
183
  sql = sqlparse.format(sql, reindent=True, keyword_case='upper')
184
+ print(sql)
185
+ if not isDataQuery(sql):
186
+ return "Sorry not allowed to run. As the query modifies the data."
187
+ try:
188
+ conn = dbEngine.connection
189
+ df = pd.read_sql_query(sql, con=conn)
190
+ return pd.DataFrame(df)
191
+ except Exception as e:
192
+ print(f"Error occured during running the query {sql}.\n and the error is {str(e)}")
193
+
194
+ prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}."
195
+ modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt)
196
+ return f"The query you entered throws some error. Here is modified version. Please try this.\n {modifiedSql}"
197
+
198
+
199
+ def onSelectedTablesChange(tablesSelected):
200
+ #Updates tables visible and allow selecting columns for them
201
+ global queryHelper
202
+ print(f"Selected tables : {tablesSelected}")
203
+ metadataLayout = queryHelper.getMetadata()
204
+ allTables = list(metadataLayout.getAllTablesCols())
205
+ tableBoxes = []
206
+ for i in range(len(allTables)):
207
+ if allTables[i] in tablesSelected:
208
+ tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=True, label=f"{allTables[i]}"))
209
+ else:
210
+ tableBoxes.append(gr.Textbox(f"Textbox {allTables[i]}", visible=False, label=f"{allTables[i]}"))
211
+ return tableBoxes
212
+
213
+ def onSelectedColumnsChange(*tableBoxes):
214
+ #update selection of columns and tables (include new tables and cols in gpts context)
215
+ global queryHelper
216
+ metadataLayout = queryHelper.getMetadata()
217
+ allTablesList = list(metadataLayout.getAllTablesCols().keys())
218
+ tablesAndCols = {}
219
+ result = ''
220
+ print("Getting selected tables and columns from gradio")
221
+ for tableBox, table in zip(tableBoxes, allTablesList):
222
+ if isinstance(tableBox, list):
223
+ if len(tableBox)!=0:
224
+ tablesAndCols[table] = tableBox
225
+ else:
226
+ pass
227
+
228
+ metadataLayout.setSelection(tablesAndCols=tablesAndCols)
229
+ print("metadata updated")
230
+ print("Updating queryHelper state, and sample data")
231
+ queryHelper.updateMetadata(metadataLayout)
232
+ return "Columns udpated"
233
+
234
+ def onResetToDefaultSelection():
235
+ global queryHelper
236
+ tablesSelected = list(DefaultTablesAndCols.keys())
237
+ tableBoxes = []
238
+ allTablesList = list(metadataLayout.getAllTablesCols().keys())
239
+ for i in range(len(allTablesList)):
240
+ if allTablesList[i] in tablesSelected:
241
+ tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=True, label=f"{allTablesList[i]}"))
242
+ else:
243
+ tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=False, label=f"{allTablesList[i]}"))
244
+
245
+ metadataLayout.resetSelection()
246
+ metadataLayout.setSelection(DefaultTablesAndCols)
247
+ queryHelper.updateMetadata(metadataLayout)
248
+
249
+ return tableBoxes
250
 
 
 
 
 
 
 
 
251
 
 
 
 
252
 
253
  with gr.Blocks() as demo:
254
+
255
+ # screen 1 : Chatbot for question answering to generate sql query from user input in english
256
  with gr.Tab("Query Helper"):
257
  gr.Markdown("""<h1><center> Query Helper</center></h1>""")
258
  chatbot = gr.Chatbot()
 
260
  clear = gr.ClearButton([msg, chatbot])
261
  msg.submit(respond, [msg, chatbot], [msg, chatbot])
262
 
263
+ # screen 2 : To run sql query against database
264
  with gr.Tab("Run Query"):
265
+ gr.Markdown("""<h1><center> Run Query </center></h1>""")
266
  text_input = gr.Textbox(label = 'Input SQL Query', placeholder="Write your SQL query here ...")
267
  text_output = gr.Textbox(label = 'Result')
268
  text_button = gr.Button("RUN QUERY")
269
  clear = gr.ClearButton([text_input, text_output])
270
+ text_button.click(testSQL, inputs=text_input, outputs=text_output)
271
+ # screen 3 : To set creds, schema, tables and columns
272
+ with gr.Tab("Setup"):
273
+ gr.Markdown("""<h1><center> Run Query </center></h1>""")
274
+ text_input = gr.Textbox(label = 'schema name', value= SCHEMA_NAME)
275
+ allTablesAndCols = queryHelper.getMetadata().getAllTablesCols()
276
+ selectedTablesAndCols = queryHelper.getMetadata().getSelectedTablesAndCols()
277
+ allTablesList = list(allTablesAndCols.keys())
278
+ selectedTablesList = list(selectedTablesAndCols.keys())
279
+
280
+ dropDown = gr.Dropdown(
281
+ allTablesList, value=selectedTablesList, multiselect=True, label="Selected Tables", info="Select Tables from available tables of the schema"
282
+ )
283
+
284
+ refreshTables = gr.Button("Refresh selected tables")
285
+
286
+ tableBoxes = []
287
+
288
+ for i in range(len(allTablesList)):
289
+ if allTablesList[i] in selectedTablesList:
290
+ columnsDropDown = gr.Dropdown(
291
+ allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table"
292
+ )
293
+ #tableBoxes[allTables[i]] = columnsDropDown
294
+ tableBoxes.append(columnsDropDown)
295
+ else:
296
+ columnsDropDown = gr.Dropdown(
297
+ allTablesAndCols[allTablesList[i]], visible=False, value=None, multiselect=True, label=allTablesList[i], info="Select columns of a table"
298
+ )
299
+ #tableBoxes[allTables[i]] = columnsDropDown
300
+ tableBoxes.append(columnsDropDown)
301
+
302
+ refreshTables.click(onSelectedTablesChange, inputs=dropDown, outputs=tableBoxes)
303
+
304
+
305
+
306
+
307
+ columnsTextBox = gr.Textbox(label = 'Result')
308
+ refreshColumns = gr.Button("Refresh selected columns and Reload Data")
309
+ refreshColumns.click(onSelectedColumnsChange, inputs=tableBoxes, outputs=columnsTextBox)
310
+
311
+ resetToDefaultSelection = gr.Button("Reset to Default")
312
+ resetToDefaultSelection.click(onResetToDefaultSelection, inputs=None, outputs=tableBoxes)
313
+
314
+
315
+ demo.launch(share=True, debug=True, ssl_verify=False, auth=checkAuth)
316
+ dbEngine.connect()
config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from configLocal import *
2
+ from configProd import *
configProd.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ OPENAI_API_KEY = os.getenv("api_key")
3
+
4
+ #database credential
5
+ dbName = os.getenv("db_name")
6
+ userDB = os.getenv("user_db")
7
+ pwdDB = os.getenv("pwd_db")
8
+ host = os.getenv("host_db")
9
+ port = os.getenv("port_db")
10
+ GPT_MODEL = "gpt-4"
11
+ # GPT_MODEL = "gpt-3.5-turbo-1106"
12
+
13
+
14
+
15
+ #gradio login
16
+ ADMIN = os.getenv("admin")
17
+ PASSWD = os.getenv("paswd")
18
+
19
+ DB_CREDS_DATA = ({"database":dbName, "user":userDB, "password":pwdDB, "host":host, "port":port})
constants.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = ["SCHEMA_NAME", "GPT_SAMPLE_ROWS", "PLATFORM", "SAMPLE_ROW_MAX", "DEFAULT_TABLES_COLS", "QUERY_TIMEOUT"]
2
+
3
+ #Constants
4
+ SCHEMA_NAME = "lpdatamart"
5
+ GPT_SAMPLE_ROWS = 5
6
+ PLATFORM = "Amazon Redshift"
7
+ SAMPLE_ROW_MAX = 50
8
+ QUERY_TIMEOUT = 20 #timeout in seconds
9
+
10
+
11
+ # list down the desired column
12
+
13
+ customer_col=['customer_id','customer_type', 'first_name', 'middle_name', 'household_name', 'last_name', 'personal_email', 'city', 'state', 'zip_code', 'address1', 'country', 'gender', 'phone_number', 'reward_number']
14
+ product_col=['product_id', 'product_name', 'product_price', 'department', 'class', 'discount', 'category', 'department_desc', 'department_type', 'product_type', 'manufacturer', 'color']
15
+ sales_col = ['store_id', 'customer_id', 'channel_id', 'product_id', 'time_id', 'date_id','order_id', 'line_action', 'discount_amount', 'shipping_amount','transaction_date', 'transaction_amount', 'transaction_type', 'qty_sold']
16
+ store_col = ['store_id', 'store_number', 'store_name', 'store_designation', 'store_longitude', 'store_latitude', 'store_manager_name', 'zip_code', 'state_code', 'city', 'street_number', 'street_name', 'store_region', 'store_type', 'address1','sublocationcode', 'channel', 'company_flag', 'kiosk_physical_store', 'sublocation_code']
17
+ channel_col = ['channel_id', 'channel_name', 'channel_code']
18
+ lineaction_col = ['line_action_code', 'line_action_code_desc', 'load_date', 'catgory', 'sales_type']
19
+ calendar_col = ['date_id','calendar_date','calendar_month','day_of_week','calendar_week_number','calendar_month_number','calendar_quarter_number','day_of_month','day_of_quarter','day_of_the_year','us_holiday','lp_holiday','work_day','year','ad_week','ad_week_year','ad_month','lp_day','lp_week','lp_month','lp_year','lp_quarter','event_day']
20
+
21
+ browse_col = ['cookie_id', 'session_id', 'customer_id', 'email_key', 'reward_number', 'date_id', 'time_id', 'category_id', 'browse_action_id', 'product_id', 'style_id', 'order_id']
22
+ time_col = ['time_id', 'time_of_day']
23
+ browse_action_col = ["browse_action_id", "browse_action"]
24
+ browse_category_col = ['category_id', 'category_code', 'category']
25
+ style_col = ["sku", "style", "source_file", "load_date"]
26
+ email_col = ['event_id', 'customer_id', 'time_id', 'date_id', 'email_key']
27
+ event_col = ['event_id', 'event_type', 'event_description', 'event_detail', 'start_date', 'end_date', 'event_code', 'event_category']
28
+
29
+
30
+ DEFAULT_TABLES_COLS = {"tbl_d_customer":customer_col, "tbl_d_product":product_col, "tbl_f_sales":sales_col,
31
+ "tbl_d_store":store_col, "tbl_d_channel":channel_col, "tbl_d_lineaction_code":lineaction_col,
32
+ "tbl_d_calendar":calendar_col, 'tbl_f_browse':browse_col, 'tbl_d_time': time_col, 'tbl_d_browse_action': browse_action_col,
33
+ 'tbl_d_browse_category':browse_category_col, 'tbl_d_style':style_col, 'tbl_f_emailing': email_col, 'tbl_d_event':event_col}
gptManager.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ class ChatgptManager:
4
+ def __init__(self, openAIClient, model="gpt-3.5-turbo-1106", tokenLimit=8000):
5
+ self.client = openAIClient
6
+ self.tokenLimit = tokenLimit
7
+ self.model = model
8
+
9
+ def getResponseForUserInput(self, userInput, systemPrompt):
10
+ self.messages = []
11
+ newMessage = {"role":"system", "content":systemPrompt}
12
+ if not self.isTokeLimitExceeding(newMessage):
13
+ self.messages.append(newMessage)
14
+ else:
15
+ raise ValueError("System Prompt Too long.")
16
+
17
+ userMessage = {"role":"user", "content":userInput}
18
+ if not self.isTokeLimitExceeding(userMessage):
19
+ self.messages.append(userMessage)
20
+ else:
21
+ raise ValueError("Token Limit exceeding. With user input")
22
+
23
+ # completion = self.client.chat.completions.create(
24
+ # model="gpt-3.5-turbo-1106",
25
+ # messages=self.messages,
26
+ # temperature=0,
27
+ # )
28
+
29
+ completion = self.client.chat.completions.create(
30
+ model=self.model,
31
+ messages=self.messages,
32
+ temperature=0,
33
+ )
34
+
35
+ gptResponse = completion.choices[0].message.content
36
+
37
+ self.messages.append({"role": "assistant", "content": gptResponse})
38
+ return gptResponse
39
+
40
+ def isTokeLimitExceeding(self, newMessage=None, truncate=True, throwError=True):
41
+ if self.getTokenCount(newMessage=newMessage) > self.tokenLimit:
42
+ return True
43
+ return False
44
+
45
+
46
+ def getTokenCount(self, newMessage=None):
47
+ """Token count including new Message"""
48
+
49
+ def getWordsCount(text):
50
+ return len(re.findall(r'\b\w+\b', text))
51
+
52
+ messages = self.messages[:]
53
+ if newMessage!=None:
54
+ messages.append(newMessage)
55
+
56
+ combinedContent = " ".join(msg["content"] for msg in messages)
57
+ currentTokensInMessages = getWordsCount(combinedContent)
58
+ return currentTokensInMessages
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  pandas
2
  psycopg2
3
- openai
4
- sqlparse
 
 
1
  pandas
2
  psycopg2
3
+ openai==1.3.5
4
+ sqlparse
5
+ gradio==3.50.1
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import psycopg2
2
+
3
+
4
+ class DataWrapper:
5
+ def __init__(self, data):
6
+ if isinstance(data, list):
7
+ emptyDict = {dataKey:None for dataKey in data}
8
+ self.__dict__.update(emptyDict)
9
+ elif isinstance(data, dict):
10
+ self.__dict__.update(data)
11
+
12
+ def addKey(self, key, val=None):
13
+ self.__dict__.update({key:val})
14
+
15
+ def __repr__(self):
16
+ return self.__dict__.__repr__()
17
+
18
+ class MetaDataLayout:
19
+ def __init__(self, schemaName, allTablesAndCols):
20
+ self.schemaName = schemaName
21
+ self.datalayout = {
22
+ "schema": self.schemaName,
23
+ "selectedTables":{},
24
+ "allTables":allTablesAndCols
25
+ }
26
+
27
+ def setSelection(self, tablesAndCols):
28
+ """
29
+ tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]}
30
+ """
31
+ datalayout = self.datalayout
32
+ for table in tablesAndCols:
33
+ if table in datalayout['allTables'].keys():
34
+ datalayout['selectedTables'][table] = tablesAndCols[table]
35
+ else:
36
+ print(f"Table {table} doesn't exists in the schema")
37
+ self.datalayout = datalayout
38
+
39
+ def resetSelection(self):
40
+ datalayout = self.datalayout
41
+ datalayout['selectedTables'] = {}
42
+ self.datalayout = datalayout
43
+
44
+ def getSelectedTablesAndCols(self):
45
+ return self.datalayout['selectedTables']
46
+
47
+ def getAllTablesCols(self):
48
+ return self.datalayout['allTables']
49
+
50
+
51
+
52
+
53
+ class DbEngine:
54
+ def __init__(self, dbCreds):
55
+ self.dbCreds = dbCreds
56
+ self.connection = None
57
+
58
+ def connect(self):
59
+ dbCreds = self.dbCreds
60
+ if self.connection is None or self.connection.closed != 0:
61
+ self.connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user,
62
+ password = dbCreds.password, host = dbCreds.host,
63
+ port = dbCreds.port)
64
+
65
+ def disconnect(self):
66
+ if self.connection is not None and self.connection.closed == 0:
67
+ self.connection.close()
68
+
69
+ def execute_query(self, query):
70
+ with self.connection.cursor() as cursor:
71
+ cursor.execute(query)
72
+ result = cursor.fetchall()
73
+ return result
74
+
75
+
76
+ def executeQuery(dbEngine, query):
77
+ result = dbEngine.execute_query(query)
78
+ return result
79
+
80
+ def executeColumnsQuery(dbEngine, columnQuery):
81
+ with dbEngine.connection.cursor() as cursor:
82
+ cursor.execute(columnQuery)
83
+ columns = [desc[0] for desc in cursor.description]
84
+ return columns
85
+
86
+ def closeDbEngine(dbEngine):
87
+ dbEngine.disconnect()
88
+
89
+ def getAllTablesInfo(dbEngine, schemaName):
90
+ tablesAndCols = {}
91
+ allTablesQuery = f"""SELECT table_name FROM information_schema.tables
92
+ WHERE table_schema = '{schemaName}'"""
93
+ tables = executeQuery(dbEngine, allTablesQuery)
94
+ for table in tables:
95
+ tableName = table[0]
96
+ columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0"""
97
+ columns = executeColumnsQuery(dbEngine, columnsQuery)
98
+ tablesAndCols[tableName] = columns
99
+ return tablesAndCols
100
+
101
+ def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
102
+
103
+ data = {}
104
+ conn = dbEngine.connection
105
+ for table in tablesAndCols.keys():
106
+ try:
107
+ sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}"""
108
+ data[table] = pd.read_sql_query(sqlQuery, con=conn)
109
+ except:
110
+ print(f"couldn't read table data. Table: {table}")
111
+ return data