anumaurya114exp commited on
Commit
1dda07c
·
1 Parent(s): 4df2beb

added bug fixes, small features : see description

Browse files

In run query: tell user that it is limit 5 there, also show total number of results (Like pandas)
bug fix: always add schema to table names in query
Improve loggings : log complete gpt response and separate sql.
Manage History without sending systemPrompt each time.

Files changed (5) hide show
  1. app.py +29 -125
  2. gptManager.py +18 -4
  3. persistStorage.py +18 -3
  4. queryHelperManager.py +95 -0
  5. utils.py +35 -2
app.py CHANGED
@@ -9,18 +9,18 @@ import os
9
  import warnings
10
 
11
 
12
- from persistStorage import saveLog
13
  from config import *
14
  from constants import *
15
  from utils import *
16
  from gptManager import ChatgptManager
17
- # from queryHelper import QueryHelper
18
 
19
  logsDir = os.getenv("HF_HOME", "/data")
20
 
21
 
22
  pd.set_option('display.max_columns', None)
23
- pd.set_option('display.max_rows', None)
24
 
25
  # Filter out all warning messages
26
  warnings.filterwarnings("ignore")
@@ -47,96 +47,6 @@ def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
47
  print(f"couldn't read table data. Table: {table}")
48
  return data
49
 
50
- class QueryHelper:
51
- def __init__(self, gptInstance, dbEngine, schemaName,
52
- platform, metadataLayout, sampleDataRows,
53
- gptSampleRows, getSampleDataForTablesAndCols):
54
- self.gptInstance = gptInstance
55
- self.schemaName = schemaName
56
- self.platform = platform
57
- self.metadataLayout = metadataLayout
58
- self.sampleDataRows = sampleDataRows
59
- self.gptSampleRows = gptSampleRows
60
- self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
61
- self.dbEngine = dbEngine
62
- self._onMetadataChange()
63
-
64
- def _onMetadataChange(self):
65
- metadataLayout = self.metadataLayout
66
- sampleDataRows = self.sampleDataRows
67
- dbEngine = self.dbEngine
68
- schemaName = self.schemaName
69
-
70
- selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
71
- self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
72
- tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
73
-
74
- def getMetadata(self):
75
- return self.metadataLayout
76
-
77
- def updateMetadata(self, metadataLayout):
78
- self.metadataLayout = metadataLayout
79
- self._onMetadataChange()
80
-
81
- def modifySqlQueryEnteredByUser(self, userSqlQuery):
82
- platform = self.platform
83
- userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
84
- systemPrompt = ""
85
- modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
86
- return modifiedSql
87
-
88
- def filteredSampleDataForProspects(self, prospectTablesAndCols):
89
- sampleData = self.sampleData
90
- filteredData = {}
91
- for table in prospectTablesAndCols.keys():
92
- # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
93
- #take all columns of prospects
94
- filteredData[table] = sampleData[table]
95
- return filteredData
96
-
97
- def getQueryForUserInput(self, userInput):
98
- gptSampleRows = self.gptSampleRows
99
- selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
100
- prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols)
101
- print("getting prospects", prospectTablesAndCols)
102
- prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
103
- systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
104
- queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration)
105
- return queryByGpt
106
-
107
- def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols):
108
- schemaName = self.schemaName
109
-
110
- systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
111
- prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns)
112
- prospectTablesAndCols = {}
113
- for table in selectedTablesAndCols.keys():
114
- if table in prospectiveTablesColsText:
115
- prospectTablesAndCols[table] = []
116
- for column in selectedTablesAndCols[table]:
117
- if column in prospectiveTablesColsText:
118
- prospectTablesAndCols[table].append(column)
119
- return prospectTablesAndCols
120
-
121
- def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
122
- schemaName = self.schemaName
123
- platform = self.platform
124
- prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data"""
125
- for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
126
- prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
127
- prompt += "XXXX"
128
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
129
-
130
- def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
131
- schemaName = self.schemaName
132
- platform = self.platform
133
-
134
- prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
135
- for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
136
- prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
137
- prompt += "XXXX"
138
- return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
139
-
140
 
141
  openAIClient = OpenAI(api_key=OPENAI_API_KEY)
142
  gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL)
@@ -155,49 +65,51 @@ def checkAuth(username, password):
155
  return False
156
 
157
 
 
158
  # Function to save history of chat
159
  def respond(message, chatHistory):
160
  """gpt response handler for gradio ui"""
161
  global queryHelper
162
  try:
163
- botMessage = queryHelper.getQueryForUserInput(message)
164
  except Exception as e:
165
  errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message}
166
  saveLog(errorMessage, 'error')
167
-
168
- logMessage = {"userInput":message, "queryGenerated":botMessage}
169
  saveLog(logMessage)
170
  chatHistory.append((message, botMessage))
171
  time.sleep(2)
172
  return "", chatHistory
173
 
174
- # Function to test the generated sql query
175
- def isDataQuery(sql_query):
176
- upper_query = sql_query.upper()
177
-
178
- dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
179
- for keyword in dml_keywords:
180
- if re.search(fr'\b{keyword}\b', upper_query):
181
- return False # Found a DML keyword, indicating modification
182
-
183
- # If no DML keywords are found, it's likely a data query
184
- return True
185
-
186
- def testSQL(sql):
187
- global dbEngine, queryHelper
188
 
 
189
  sql=sql.replace(';', '')
 
190
  if ('limit' in sql[-15:].lower())==False:
191
  sql = sql + ' ' + 'limit 5'
 
 
 
192
  sql = str(sql)
193
  sql = sqlparse.format(sql, reindent=True, keyword_case='upper')
194
- print(sql)
 
 
 
 
 
195
  if not isDataQuery(sql):
196
  return "Sorry not allowed to run. As the query modifies the data."
197
  try:
198
  conn = dbEngine.connection
199
- df = pd.read_sql_query(sql, con=conn)
200
- return pd.DataFrame(df)
201
  except Exception as e:
202
  errorMessage = {"function":"testSQL","error":str(e), "userInput":sql}
203
  saveLog(errorMessage, 'error')
@@ -206,6 +118,8 @@ def testSQL(sql):
206
 
207
  prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}."
208
  modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt)
 
 
209
  return f"The query you entered throws some error. Here is modified version. Please try this.\n {modifiedSql}"
210
 
211
 
@@ -246,7 +160,7 @@ def onSelectedColumnsChange(*tableBoxes):
246
 
247
  def onResetToDefaultSelection():
248
  global queryHelper
249
- tablesSelected = list(DefaultTablesAndCols.keys())
250
  tableBoxes = []
251
  allTablesList = list(metadataLayout.getAllTablesCols().keys())
252
  for i in range(len(allTablesList)):
@@ -256,21 +170,11 @@ def onResetToDefaultSelection():
256
  tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=False, label=f"{allTablesList[i]}"))
257
 
258
  metadataLayout.resetSelection()
259
- metadataLayout.setSelection(DefaultTablesAndCols)
260
  queryHelper.updateMetadata(metadataLayout)
261
 
262
  return tableBoxes
263
 
264
- def getAllLogFilesPaths():
265
- global logsDir
266
- # Save processed data to temporary file
267
-
268
- logFiles = [file for file in os.listdir(logsDir) if 'log' in file.lower()]
269
- print(logFiles,"avaiable logs")
270
-
271
- downloadableFilesPaths = [os.path.join(os.path.abspath(logsDir), logFilePath) for logFilePath in logFiles]
272
- return downloadableFilesPaths
273
-
274
  def onSyncLogsWithDataDir():
275
  downloadableFilesPaths = getAllLogFilesPaths()
276
  fileComponent = gr.File(downloadableFilesPaths, file_count='multiple')
 
9
  import warnings
10
 
11
 
12
+ from persistStorage import saveLog, getAllLogFilesPaths
13
  from config import *
14
  from constants import *
15
  from utils import *
16
  from gptManager import ChatgptManager
17
+ from queryHelperManager import QueryHelper
18
 
19
  logsDir = os.getenv("HF_HOME", "/data")
20
 
21
 
22
  pd.set_option('display.max_columns', None)
23
+ pd.set_option('display.max_rows', 10)
24
 
25
  # Filter out all warning messages
26
  warnings.filterwarnings("ignore")
 
47
  print(f"couldn't read table data. Table: {table}")
48
  return data
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  openAIClient = OpenAI(api_key=OPENAI_API_KEY)
52
  gptInstance = ChatgptManager(openAIClient, model=GPT_MODEL)
 
65
  return False
66
 
67
 
68
+
69
  # Function to save history of chat
70
  def respond(message, chatHistory):
71
  """gpt response handler for gradio ui"""
72
  global queryHelper
73
  try:
74
+ botMessage, prospectTablesAndCols = queryHelper.getQueryForUserInput(message, chatHistory)
75
  except Exception as e:
76
  errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message}
77
  saveLog(errorMessage, 'error')
78
+ queryGenerated = extractSqlFromGptResponse(botMessage)
79
+ logMessage = {"userInput":message, "tablesColsSelectedByGpt":str(prospectTablesAndCols) , "queryGenerated":queryGenerated, "completeGptResponse":botMessage}
80
  saveLog(logMessage)
81
  chatHistory.append((message, botMessage))
82
  time.sleep(2)
83
  return "", chatHistory
84
 
85
+ def preProcessGptQueryReponse(gptResponse, metadataLayout: MetaDataLayout):
86
+ schemaName = metadataLayout.schemaName
87
+ tablesList = metadataLayout.getAllTablesCols().keys()
88
+ gptResponse = addSchemaToTableInSQL(gptResponse, schemaName=schemaName, tablesList=tablesList)
89
+ return gptResponse
 
 
 
 
 
 
 
 
 
90
 
91
+ def preProcessSQL(sql):
92
  sql=sql.replace(';', '')
93
+ disclaimerOutputStripping = ""
94
  if ('limit' in sql[-15:].lower())==False:
95
  sql = sql + ' ' + 'limit 5'
96
+ disclaimerOutputStripping = """Results are stripped to show only top 5 rows.
97
+ Please add your custom limit to get extend result.
98
+ eg\n select * from schema.table limit 20\n\n"""
99
  sql = str(sql)
100
  sql = sqlparse.format(sql, reindent=True, keyword_case='upper')
101
+ return sql, disclaimerOutputStripping
102
+
103
+ def testSQL(sql):
104
+ global dbEngine, queryHelper
105
+
106
+ sql, disclaimerOutputStripping = preProcessSQL(sql=sql)
107
  if not isDataQuery(sql):
108
  return "Sorry not allowed to run. As the query modifies the data."
109
  try:
110
  conn = dbEngine.connection
111
+ df = pd.read_sql_query(sql, con=conn)
112
+ return disclaimerOutputStripping + str(pd.DataFrame(df))
113
  except Exception as e:
114
  errorMessage = {"function":"testSQL","error":str(e), "userInput":sql}
115
  saveLog(errorMessage, 'error')
 
118
 
119
  prompt = f"Please correct the following sql query, also it has to be run on {PLATFORM}. sql query is \n {sql}. the error occured is {str(e)}."
120
  modifiedSql = queryHelper.modifySqlQueryEnteredByUser(prompt)
121
+ logMessage = {"function":"queryHelper.modifySqlQueryEnteredByUser", "sqlQuery":sql, "modifiedSQLQuery":modifiedSql}
122
+ saveLog(logMessage, 'info')
123
  return f"The query you entered throws some error. Here is modified version. Please try this.\n {modifiedSql}"
124
 
125
 
 
160
 
161
  def onResetToDefaultSelection():
162
  global queryHelper
163
+ tablesSelected = list(DEFAULT_TABLES_COLS.keys())
164
  tableBoxes = []
165
  allTablesList = list(metadataLayout.getAllTablesCols().keys())
166
  for i in range(len(allTablesList)):
 
170
  tableBoxes.append(gr.Textbox(f"Textbox {allTablesList[i]}", visible=False, label=f"{allTablesList[i]}"))
171
 
172
  metadataLayout.resetSelection()
173
+ metadataLayout.setSelection(DEFAULT_TABLES_COLS)
174
  queryHelper.updateMetadata(metadataLayout)
175
 
176
  return tableBoxes
177
 
 
 
 
 
 
 
 
 
 
 
178
  def onSyncLogsWithDataDir():
179
  downloadableFilesPaths = getAllLogFilesPaths()
180
  fileComponent = gr.File(downloadableFilesPaths, file_count='multiple')
gptManager.py CHANGED
@@ -5,20 +5,34 @@ class ChatgptManager:
5
  self.client = openAIClient
6
  self.tokenLimit = tokenLimit
7
  self.model = model
 
 
 
 
 
 
 
 
 
 
8
 
9
- def getResponseForUserInput(self, userInput, systemPrompt):
10
- self.messages = []
11
  newMessage = {"role":"system", "content":systemPrompt}
12
  if not self.isTokeLimitExceeding(newMessage):
13
  self.messages.append(newMessage)
14
  else:
15
- raise ValueError("System Prompt Too long.")
 
 
16
 
17
  userMessage = {"role":"user", "content":userInput}
18
  if not self.isTokeLimitExceeding(userMessage):
19
  self.messages.append(userMessage)
20
  else:
21
- raise ValueError("Token Limit exceeding. With user input")
 
 
22
 
23
  # completion = self.client.chat.completions.create(
24
  # model="gpt-3.5-turbo-1106",
 
5
  self.client = openAIClient
6
  self.tokenLimit = tokenLimit
7
  self.model = model
8
+
9
+ def _chatHistoryToGptMessages(self, chatHistory=[]):
10
+ messages = []
11
+ for i in range(len(chatHistory)):
12
+ if i%2==0:
13
+ message = {"role":"user", "content":chatHistory[i]}
14
+ else:
15
+ message = {"role":"assistant", "content": chatHistory[i]}
16
+ messages.append(message)
17
+ return messages
18
 
19
+ def getResponseForUserInput(self, userInput, systemPrompt, chatHistory=[]):
20
+ self.messages = self._chatHistoryToGptMessages(chatHistory[:])
21
  newMessage = {"role":"system", "content":systemPrompt}
22
  if not self.isTokeLimitExceeding(newMessage):
23
  self.messages.append(newMessage)
24
  else:
25
+ if chatHistory==[]:
26
+ raise ValueError("System Prompt Too long.")
27
+ return self.getResponseForUserInput(userInput=userInput, systemPrompt=systemPrompt)
28
 
29
  userMessage = {"role":"user", "content":userInput}
30
  if not self.isTokeLimitExceeding(userMessage):
31
  self.messages.append(userMessage)
32
  else:
33
+ if chatHistory==[]:
34
+ raise ValueError("Token Limit exceeding. With user input")
35
+ return self.getResponseForUserInput(userInput=userInput, systemPrompt=systemPrompt)
36
 
37
  # completion = self.client.chat.completions.create(
38
  # model="gpt-3.5-turbo-1106",
persistStorage.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  from config import HUGGING_FACE_TOKEN
7
  import csv
8
 
9
- logs_dir = os.getenv("HF_HOME", "/data")
10
 
11
  # # Create a new file
12
  # with open(os.path.join(data_dir, "my_data.txt"), "a") as f:
@@ -34,9 +34,24 @@ def append_dict_to_csv(file_path, row_data):
34
  csv_writer.writerow(row_data)
35
 
36
  def saveLog(message, level='info') -> None:
37
- global logs_dir
 
 
 
38
  current_time = datetime.now(TIMEZONE_OBJ)
39
  message = str(message)
40
- log_file_path = os.path.join(logs_dir, f"{current_time.strftime('%Y-%m')}-log.csv")
41
  data_dict = {"time":str(current_time), "level": level, "message": message}
42
  append_dict_to_csv(log_file_path, data_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from config import HUGGING_FACE_TOKEN
7
  import csv
8
 
9
+ logsDir = os.getenv("HF_HOME", "/data")
10
 
11
  # # Create a new file
12
  # with open(os.path.join(data_dir, "my_data.txt"), "a") as f:
 
34
  csv_writer.writerow(row_data)
35
 
36
  def saveLog(message, level='info') -> None:
37
+ global logsDir
38
+ if not os.path.isdir(logsDir):
39
+ print("Log directory/Data Directory not available.")
40
+ return
41
  current_time = datetime.now(TIMEZONE_OBJ)
42
  message = str(message)
43
+ log_file_path = os.path.join(logsDir, f"{current_time.strftime('%Y-%m')}-log.csv")
44
  data_dict = {"time":str(current_time), "level": level, "message": message}
45
  append_dict_to_csv(log_file_path, data_dict)
46
+
47
+ def getAllLogFilesPaths():
48
+ global logsDir
49
+ # Save processed data to temporary file
50
+ if not os.path.isdir(logsDir):
51
+ print("Log directory/Data Directory not available.")
52
+ return []
53
+ logFiles = [file for file in os.listdir(logsDir) if 'log' in file.lower()]
54
+ print(logFiles,"avaiable logs")
55
+
56
+ downloadableFilesPaths = [os.path.join(os.path.abspath(logsDir), logFilePath) for logFilePath in logFiles]
57
+ return downloadableFilesPaths
queryHelperManager.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gptManager import ChatgptManager
2
+ from utils import MetaDataLayout
3
+
4
+ class QueryHelper:
5
+ def __init__(self, gptInstance: ChatgptManager, dbEngine, schemaName,
6
+ platform, metadataLayout: MetaDataLayout, sampleDataRows,
7
+ gptSampleRows, getSampleDataForTablesAndCols):
8
+ self.gptInstance = gptInstance
9
+ self.schemaName = schemaName
10
+ self.platform = platform
11
+ self.metadataLayout = metadataLayout
12
+ self.sampleDataRows = sampleDataRows
13
+ self.gptSampleRows = gptSampleRows
14
+ self.getSampleDataForTablesAndCols = getSampleDataForTablesAndCols
15
+ self.dbEngine = dbEngine
16
+ self._onMetadataChange()
17
+
18
+ def _onMetadataChange(self):
19
+ metadataLayout = self.metadataLayout
20
+ sampleDataRows = self.sampleDataRows
21
+ dbEngine = self.dbEngine
22
+ schemaName = self.schemaName
23
+
24
+ selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols()
25
+ self.sampleData = self.getSampleDataForTablesAndCols(dbEngine=dbEngine,schemaName=schemaName,
26
+ tablesAndCols=selectedTablesAndCols, maxRows=sampleDataRows)
27
+
28
+ def getMetadata(self) -> MetaDataLayout :
29
+ return self.metadataLayout
30
+
31
+ def updateMetadata(self, metadataLayout):
32
+ self.metadataLayout = metadataLayout
33
+ self._onMetadataChange()
34
+
35
+ def modifySqlQueryEnteredByUser(self, userSqlQuery):
36
+ platform = self.platform
37
+ userPrompt = f"Please correct the following sql query, also it has to be run on {platform}. sql query is \n {userSqlQuery}."
38
+ systemPrompt = ""
39
+ modifiedSql = self.gptInstance.getResponseForUserInput(userPrompt, systemPrompt)
40
+ return modifiedSql
41
+
42
+ def filteredSampleDataForProspects(self, prospectTablesAndCols):
43
+ sampleData = self.sampleData
44
+ filteredData = {}
45
+ for table in prospectTablesAndCols.keys():
46
+ # filteredData[table] = sampleData[table][prospectTablesAndCols[table]]
47
+ #take all columns of prospects
48
+ filteredData[table] = sampleData[table]
49
+ return filteredData
50
+
51
+ def getQueryForUserInput(self, userInput, chatHistory=[]):
52
+ gptSampleRows = self.gptSampleRows
53
+ selectedTablesAndCols = self.metadataLayout.getSelectedTablesAndCols()
54
+ prospectTablesAndCols = self.getProspectiveTablesAndCols(userInput, selectedTablesAndCols, chatHistory)
55
+ print("getting prospects", prospectTablesAndCols)
56
+ prospectTablesData = self.filteredSampleDataForProspects(prospectTablesAndCols)
57
+ systemPromptForQueryGeneration = self.getSystemPromptForQueryGeneration(prospectTablesData, gptSampleRows=gptSampleRows)
58
+ queryByGpt = self.gptInstance.getResponseForUserInput(userInput, systemPromptForQueryGeneration, chatHistory)
59
+
60
+ queryByGpt = preProcessGptQueryReponse(queryByGpt, metadataLayout=metadataLayout)
61
+ return queryByGpt, prospectTablesAndCols
62
+
63
+ def getProspectiveTablesAndCols(self, userInput, selectedTablesAndCols, chatHistory=[]):
64
+ schemaName = self.schemaName
65
+
66
+ systemPromptForProspectColumns = self.getSystemPromptForProspectColumns(selectedTablesAndCols)
67
+ prospectiveTablesColsText = self.gptInstance.getResponseForUserInput(userInput, systemPromptForProspectColumns, chatHistory)
68
+ prospectTablesAndCols = {}
69
+ for table in selectedTablesAndCols.keys():
70
+ if table in prospectiveTablesColsText:
71
+ prospectTablesAndCols[table] = []
72
+ for column in selectedTablesAndCols[table]:
73
+ if column in prospectiveTablesColsText:
74
+ prospectTablesAndCols[table].append(column)
75
+ return prospectTablesAndCols
76
+
77
+ def getSystemPromptForQueryGeneration(self, prospectTablesData, gptSampleRows):
78
+ schemaName = self.schemaName
79
+ platform = self.platform
80
+ prompt = f"""Given an input text, generate the corresponding SQL query for given details. Schema Name is {schemaName}. And sql platform is {platform}.\n following is sample data"""
81
+ for idx, tableName in enumerate(prospectTablesData.keys(), start=1):
82
+ prompt += f"table name is {tableName}, table data is {prospectTablesData[tableName].head(gptSampleRows)}"
83
+ prompt += "XXXX"
84
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
85
+
86
+ def getSystemPromptForProspectColumns(self, selectedTablesAndCols):
87
+ schemaName = self.schemaName
88
+ platform = self.platform
89
+
90
+ prompt = f"""Given an input text, User wants to know which all tables and columns would be possibily to have the desired data. Output them as json. Schema Name is {schemaName}. And sql platform is {platform}.\n"""
91
+ for idx, tableName in enumerate(selectedTablesAndCols.keys(), start=1):
92
+ prompt += f"table name {tableName} {', '.join(selectedTablesAndCols[tableName])}"
93
+ prompt += "XXXX"
94
+ return prompt.replace("\n"," ").replace("\\"," ").replace(" "," ").replace("XXXX", " ")
95
+
utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import psycopg2
2
-
3
 
4
  class DataWrapper:
5
  def __init__(self, data):
@@ -108,4 +108,37 @@ def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
108
  data[table] = pd.read_sql_query(sqlQuery, con=conn)
109
  except:
110
  print(f"couldn't read table data. Table: {table}")
111
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import psycopg2
2
+ import re
3
 
4
  class DataWrapper:
5
  def __init__(self, data):
 
108
  data[table] = pd.read_sql_query(sqlQuery, con=conn)
109
  except:
110
  print(f"couldn't read table data. Table: {table}")
111
+ return data
112
+
113
+ # Function to test the generated sql query
114
+ def isDataQuery(sql_query):
115
+ upper_query = sql_query.upper()
116
+
117
+ dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
118
+ for keyword in dml_keywords:
119
+ if re.search(fr'\b{keyword}\b', upper_query):
120
+ return False # Found a DML keyword, indicating modification
121
+
122
+ # If no DML keywords are found, it's likely a data query
123
+ return True
124
+
125
+ def extractSqlFromGptResponse(gptReponse):
126
+ sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL)
127
+
128
+ # Find the match in the text
129
+ match = re.search(sqlPattern, gptReponse)
130
+
131
+ # Extract the SQL query if a match is found
132
+ if match:
133
+ sqlQuery = match.group(1)
134
+ return sqlQuery
135
+ else:
136
+ return ""
137
+
138
+ def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList):
139
+
140
+ for table in tablesList:
141
+ pattern = re.compile(rf'(?<![a-zA-Z0-9_]){re.escape(table)}(?![a-zA-Z0-9_])', re.IGNORECASE)
142
+ replacement = f'{schemaName}.{table}'
143
+ sqlQuery = re.sub(pattern, replacement, sqlQuery)
144
+ return sqlQuery