Spaces:
Runtime error
Runtime error
| from openai import OpenAI | |
| import pandas as pd | |
| import psycopg2 | |
| import time | |
| import gradio as gr | |
| import sqlparse | |
| import re | |
| import os | |
| import warnings | |
| from persistStorage import saveLog, getAllLogFilesPaths, getNewCsvFilePath, removeAllCsvFiles | |
| from config import * | |
| from constants import * | |
| from utils import * | |
| from gptManager import ChatgptManager | |
| from queryHelperManager import QueryHelper | |
| from queryHelperManagerCoT import QueryHelperChainOfThought | |
| pd.set_option('display.max_columns', None) | |
| pd.set_option('display.max_rows', None) | |
| # Filter out all warning messages | |
| warnings.filterwarnings("ignore") | |
| dbCreds = DataWrapper(DB_CREDS_DATA) | |
| dbEngine = DbEngine(dbCreds) | |
| tablesAndCols = getAllTablesInfo(dbEngine, SCHEMA_NAME) | |
| ##ToDo Resolve it and remove ittablesAndCols not getting flags table. | |
| tablesAndCols['tbl_d_product_style_flags'] = ["product_id", "contemp_style_flag", "trad_style_flag", "country_style_flag", "trans_style_flag", | |
| "mc_style_flag", "farm_style_flag", "wi_style_flag","iron_style_flag","crystal_style_flag","coast_style_flag","rustic_style_flag","ind_style_flag", | |
| "glam_style_flag","ac_style_flag","kids_style_flag","asian_style_flag","tiff_style_flag","trop_style_flag","um_style_flag", | |
| "sw_style_flag", "themed_style_flag", "west_style_flag", "style", "sku#"] | |
| metadataLayout = MetaDataLayout(schemaName=SCHEMA_NAME, allTablesAndCols=tablesAndCols) | |
| metadataLayout.setSelection(DEFAULT_TABLES_COLS) | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| openAIClient = OpenAI(api_key=OPENAI_API_KEY) | |
| gptInstanceForTableCols = ChatgptManager(openAIClient, model=GPT_MODEL) | |
| gptInstanceForQuery = ChatgptManager(openAIClient, model=GPT_MODEL) | |
| queryHelper = QueryHelper(gptInstanceForTableCols=gptInstanceForTableCols, | |
| gptInstanceForQuery=gptInstanceForQuery, | |
| schemaName=SCHEMA_NAME,platform=PLATFORM, | |
| metadataLayout=metadataLayout, | |
| sampleDataRows=SAMPLE_ROW_MAX, | |
| gptSampleRows=GPT_SAMPLE_ROWS, | |
| dbEngine=dbEngine, | |
| getSampleDataForTablesAndCols=getSampleDataForTablesAndCols) | |
| openAIClient2 = OpenAI(api_key=OPENAI_API_KEY) | |
| gptInstanceForCoT = ChatgptManager(openAIClient2, model=GPT_MODEL) | |
| queryHelperCot = QueryHelperChainOfThought(gptInstanceForCoT=gptInstanceForCoT, | |
| schemaName=SCHEMA_NAME,platform=PLATFORM, | |
| metadataLayout=metadataLayout, | |
| sampleDataRows=SAMPLE_ROW_MAX, | |
| gptSampleRows=GPT_SAMPLE_ROWS, | |
| dbEngine=dbEngine, | |
| getSampleDataForTablesAndCols=getSampleDataForTablesAndCols) | |
| def checkAuth(username, password): | |
| global ADMIN, PASSWD | |
| if username == ADMIN and password == PASSWD: | |
| return True | |
| return False | |
| # Function to save history of chat | |
| def respond(message, chatHistory): | |
| """gpt response handler for gradio ui""" | |
| global queryHelper | |
| try: | |
| botMessage = queryHelper.getQueryForUserInput(message) | |
| except Exception as e: | |
| errorMessage = {"function":"queryHelper.getQueryForUserInput","error":str(e), "userInput":message} | |
| saveLog(errorMessage, 'error') | |
| raise ValueError(str(e)) | |
| queryGenerated = extractSqlFromGptResponse(botMessage) | |
| logMessage = {"userInput":message, "queryGenerated":queryGenerated, "completeGptResponse":botMessage, "function":"queryHelper.getQueryForUserInput"} | |
| saveLog(logMessage) | |
| chatHistory.append((message, botMessage)) | |
| return "", chatHistory | |
| # Function to save history of chat | |
| def respondCoT(message, chatHistory): | |
| """gpt response handler for gradio ui""" | |
| global queryHelperCot | |
| try: | |
| botMessage = queryHelperCot.getQueryForUserInputCoT(message) | |
| except Exception as e: | |
| errorMessage = {"function":"queryHelperCot.getQueryForUserInput","error":str(e), "userInput":message} | |
| saveLog(errorMessage, 'error') | |
| raise ValueError(str(e)) | |
| logMessage = {"userInput":message, "completeGptResponse":botMessage, "function":"queryHelperCot.getQueryForUserInputCoT"} | |
| saveLog(logMessage) | |
| chatHistory.append((message, botMessage)) | |
| return "", chatHistory | |
| def preProcessSQL(sql): | |
| sql=sql.replace(';', '') | |
| disclaimerOutputStripping = "" | |
| if ('limit' in sql[-15:].lower())==False: | |
| sql = sql + ' ' + 'limit 5' | |
| disclaimerOutputStripping = """Results are stripped to show only top 5 rows. | |
| Please add your custom limit to get extend result. | |
| eg\n select * from schema.table limit 20\n\n""" | |
| sql = sqlparse.format(sql, reindent=True, keyword_case='upper') | |
| return sql, disclaimerOutputStripping | |
| def getResultFilePath(): | |
| global dbEngine, queryHelper | |
| sql, disclaimerOutputStripping = preProcessSQL(sql=sql) | |
| if not isDataQuery(sql): | |
| return "Sorry not allowed to run. As the query modifies the data." | |
| try: | |
| dbEngine2 = DbEngine(dbCreds) | |
| dbEngine2.connect() | |
| conn = dbEngine2.getConnection() | |
| df = pd.read_sql_query(sql, con=conn) | |
| dbEngine2.disconnect() | |
| # return disclaimerOutputStripping + str(pd.DataFrame(df)) | |
| except Exception as e: | |
| # errorMessage = {"function":"testSQL","error":str(e), "userInput":sql} | |
| # saveLog(errorMessage, 'error') | |
| dbEngine2.disconnect() | |
| df = pd.DataFrame() | |
| # print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") | |
| removeAllCsvFiles() | |
| csvFilePath = getNewCsvFilePath() | |
| df.to_csv(csvFilePath, index=False) | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(csvFilePath) | |
| return fileComponent | |
| def testSQL(sql): | |
| global dbEngine, queryHelper | |
| sql, disclaimerOutputStripping = preProcessSQL(sql=sql) | |
| if not isDataQuery(sql): | |
| return "Sorry not allowed to run. As the query modifies the data." | |
| try: | |
| dbEngine2 = DbEngine(dbCreds) | |
| dbEngine2.connect() | |
| conn = dbEngine2.getConnection() | |
| df = pd.read_sql_query(sql, con=conn) | |
| dbEngine2.disconnect() | |
| return disclaimerOutputStripping + str(pd.DataFrame(df)) | |
| except Exception as e: | |
| errorMessage = {"function":"testSQL","error":str(e), "userInput":sql} | |
| saveLog(errorMessage, 'error') | |
| dbEngine2.disconnect() | |
| print(f"Error occured during running the query {sql}.\n and the error is {str(e)}") | |
| return f"The query you entered throws some error. Here is the error.\n {str(e)}" | |
| def onSelectedTablesChange(tablesSelected): | |
| #Updates tables visible and allow selecting columns for them | |
| global queryHelper | |
| print(f"Selected tables : {tablesSelected}") | |
| metadataLayout = queryHelper.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in tablesSelected: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| else: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=False,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| return tableBoxes | |
| def onSelectedColumnsChange(*tableBoxes): | |
| #update selection of columns and tables (include new tables and cols in gpts context) | |
| global queryHelper | |
| metadataLayout = queryHelper.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tablesAndCols = {} | |
| result = '' | |
| print("Getting selected tables and columns from gradio") | |
| for tableBox, table in zip(tableBoxes, allTablesList): | |
| if isinstance(tableBox, list): | |
| if len(tableBox)!=0: | |
| tablesAndCols[table] = tableBox | |
| else: | |
| pass | |
| metadataLayout.setSelection(tablesAndCols=tablesAndCols) | |
| print("metadata updated") | |
| print("Updating queryHelper state, and sample data") | |
| queryHelper.updateMetadata(metadataLayout) | |
| return "Columns udpated" | |
| def onResetToDefaultSelection(): | |
| global queryHelper | |
| metadataLayout = queryHelper.getMetadata() | |
| metadataLayout.setSelection(tablesAndCols=tablesAndCols) | |
| queryHelper.updateMetadata(metadataLayout) | |
| metadataLayout = queryHelper.getMetadata() | |
| allTablesAndCols = metadataLayout.getAllTablesCols() | |
| selectedTablesAndCols = metadataLayout.getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in selectedTablesAndCols.keys(): | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| else: | |
| dd = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=False,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| tableBoxes.append(dd) | |
| return tableBoxes | |
| def onSyncLogsWithDataDir(): | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(downloadableFilesPaths, file_count='multiple') | |
| return fileComponent | |
| with gr.Blocks() as demo: | |
| # screen 1 : Chatbot for question answering to generate sql query from user input in english | |
| with gr.Tab("Query Helper"): | |
| gr.Markdown("""<h1><center> Query Helper</center></h1>""") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| clear = gr.ClearButton([msg, chatbot]) | |
| msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
| with gr.Tab("Query Helper CoT"): | |
| gr.Markdown("""<h1><center> Query Helper CoT</center></h1>""") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox() | |
| clear = gr.ClearButton([msg, chatbot]) | |
| msg.submit(respondCoT, [msg, chatbot], [msg, chatbot]) | |
| # screen 2 : To run sql query against database | |
| with gr.Tab("Run Query"): | |
| gr.Markdown("""<h1><center> Run Query </center></h1>""") | |
| text_input = gr.Textbox(label = 'Input SQL Query', placeholder="Write your SQL query here ...") | |
| text_output = gr.Textbox(label = 'Result') | |
| text_button = gr.Button("RUN QUERY") | |
| clear = gr.ClearButton([text_input, text_output]) | |
| text_button.click(testSQL, inputs=text_input, outputs=text_output) | |
| csvFileComponent = gr.File([], file_count='multiple') | |
| downloadCsv = gr.Button("Get result as csv") | |
| downloadCsv.click(getResultFilePath, inputs=text_input, outputs=csvFileComponent) | |
| # screen 3 : To set creds, schema, tables and columns | |
| with gr.Tab("Setup"): | |
| gr.Markdown("""<h1><center> Run Query </center></h1>""") | |
| text_input = gr.Textbox(label = 'schema name', value= SCHEMA_NAME) | |
| allTablesAndCols = queryHelper.getMetadata().getAllTablesCols() | |
| selectedTablesAndCols = queryHelper.getMetadata().getSelectedTablesAndCols() | |
| allTablesList = list(allTablesAndCols.keys()) | |
| selectedTablesList = list(selectedTablesAndCols.keys()) | |
| dropDown = gr.Dropdown( | |
| allTablesList, value=selectedTablesList, multiselect=True, label="Selected Tables", info="Select Tables from available tables of the schema" | |
| ) | |
| refreshTables = gr.Button("Refresh selected tables") | |
| tableBoxes = [] | |
| for i in range(len(allTablesList)): | |
| if allTablesList[i] in selectedTablesList: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]],visible=True,value=selectedTablesAndCols.get(allTablesList[i],None), multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| else: | |
| columnsDropDown = gr.Dropdown( | |
| allTablesAndCols[allTablesList[i]], visible=False, value=None, multiselect=True, label=allTablesList[i], info="Select columns of a table" | |
| ) | |
| #tableBoxes[allTables[i]] = columnsDropDown | |
| tableBoxes.append(columnsDropDown) | |
| refreshTables.click(onSelectedTablesChange, inputs=dropDown, outputs=tableBoxes) | |
| columnsTextBox = gr.Textbox(label = 'Result') | |
| refreshColumns = gr.Button("Refresh selected columns and Reload Data") | |
| refreshColumns.click(onSelectedColumnsChange, inputs=tableBoxes, outputs=columnsTextBox) | |
| resetToDefaultSelection = gr.Button("Reset to Default") | |
| resetToDefaultSelection.click(onResetToDefaultSelection, inputs=None, outputs=tableBoxes) | |
| #screen 4 for downloading logs | |
| with gr.Tab("Log-files"): | |
| downloadableFilesPaths = getAllLogFilesPaths() | |
| fileComponent = gr.File(downloadableFilesPaths, file_count='multiple') | |
| refreshLogs = gr.Button("Sync Log files from /data") | |
| refreshLogs.click(onSyncLogsWithDataDir, inputs=None, outputs=fileComponent) | |
| demo.launch(share=True, debug=True, ssl_verify=False, auth=checkAuth) |