QueryHelper / utils.py
anumaurya114exp's picture
added keep alive args
4cb09a6
raw
history blame
5.37 kB
import psycopg2
import re
import pandas as pd
from persistStorage import retrieveTablesDataFromLocalDb, saveTablesDataToLocalDB
class DataWrapper:
def __init__(self, data):
if isinstance(data, list):
emptyDict = {dataKey:None for dataKey in data}
self.__dict__.update(emptyDict)
elif isinstance(data, dict):
self.__dict__.update(data)
def addKey(self, key, val=None):
self.__dict__.update({key:val})
def __repr__(self):
return self.__dict__.__repr__()
class MetaDataLayout:
def __init__(self, schemaName, allTablesAndCols):
self.schemaName = schemaName
self.datalayout = {
"schema": self.schemaName,
"selectedTables":{},
"allTables":allTablesAndCols
}
def setSelection(self, tablesAndCols):
"""
tablesAndCols : {"table1":["col1", "col2"], "table1":["cola","colb"]}
"""
datalayout = self.datalayout
for table in tablesAndCols:
if table in datalayout['allTables'].keys():
datalayout['selectedTables'][table] = tablesAndCols[table]
else:
print(f"Table {table} doesn't exists in the schema")
self.datalayout = datalayout
def resetSelection(self):
datalayout = self.datalayout
datalayout['selectedTables'] = {}
self.datalayout = datalayout
def getSelectedTablesAndCols(self):
return self.datalayout['selectedTables']
def getAllTablesCols(self):
return self.datalayout['allTables']
class DbEngine:
def __init__(self, dbCreds):
self.dbCreds = dbCreds
self._connection = None
def connect(self):
dbCreds = self.dbCreds
keepaliveKwargs = {
"keepalives": 1,
"keepalives_idle": 100,
"keepalives_interval": 5,
"keepalives_count": 5,
}
if self._connection is None or self._connection.closed != 0:
self._connection = psycopg2.connect(database=dbCreds.database, user = dbCreds.user,
password = dbCreds.password, host = dbCreds.host,
port = dbCreds.port, **keepaliveKwargs)
def getConnection(self):
if self._connection is None or self._connection.closed != 0:
self.connect()
return self._connection
def disconnect(self):
if self._connection is not None and self._connection.closed == 0:
self._connection.close()
def execute_query(self, query):
try:
self.connect()
with self._connection.cursor() as cursor:
cursor.execute(query)
result = cursor.fetchall()
except Exception as e:
raise Exception(e)
return result
def executeQuery(dbEngine, query):
result = dbEngine.execute_query(query)
return result
def executeColumnsQuery(dbEngine, columnQuery):
dbEngine.connect()
with dbEngine._connection.cursor() as cursor:
cursor.execute(columnQuery)
columns = [desc[0] for desc in cursor.description]
return columns
def closeDbEngine(dbEngine):
dbEngine.disconnect()
def getAllTablesInfo(dbEngine, schemaName):
tablesAndCols = {}
allTablesQuery = f"""SELECT table_name FROM information_schema.tables
WHERE table_schema = '{schemaName}'"""
tables = executeQuery(dbEngine, allTablesQuery)
for table in tables:
tableName = table[0]
columnsQuery = f"""Select * FROM {schemaName}.{tableName} LIMIT 0"""
columns = executeColumnsQuery(dbEngine, columnsQuery)
tablesAndCols[tableName] = columns
return tablesAndCols
def getSampleDataForTablesAndCols(dbEngine, schemaName, tablesAndCols, maxRows):
data = retrieveTablesDataFromLocalDb(list(tablesAndCols.keys()))
if data!={}:
return data
dbEngine.connect()
conn = dbEngine.getConnection()
print("Didn't find any cache/valid cache.")
print("Getting data from aws redshift")
for table in tablesAndCols.keys():
try:
sqlQuery = f"""select * from {schemaName}.{table} limit {maxRows}"""
data[table] = pd.read_sql_query(sqlQuery, con=conn)
except:
print(f"couldn't read table data. Table: {table}")
data[table] = pd.DataFrame({})
saveTablesDataToLocalDB(data)
return data
# Function to test the generated sql query
def isDataQuery(sql_query):
upper_query = sql_query.upper()
dml_keywords = ['INSERT', 'UPDATE', 'DELETE', 'MERGE']
for keyword in dml_keywords:
if re.search(fr'\b{keyword}\b', upper_query):
return False # Found a DML keyword, indicating modification
# If no DML keywords are found, it's likely a data query
return True
def extractSqlFromGptResponse(gptReponse):
sqlPattern = re.compile(r"```sql\n(.*?)```", re.DOTALL)
# Find the match in the text
match = re.search(sqlPattern, gptReponse)
# Extract the SQL query if a match is found
if match:
sqlQuery = match.group(1)
return sqlQuery
else:
return ""
def addSchemaToTableInSQL(sqlQuery, schemaName, tablesList):
for table in tablesList:
pattern = re.compile(rf'(?<!\S){re.escape(table)}(?!\S)', re.IGNORECASE)
replacement = f'{schemaName}.{table}'
sqlQuery = re.sub(pattern, replacement, sqlQuery)
return sqlQuery
def preProcessGptQueryReponse(gptResponse, metadataLayout: MetaDataLayout):
schemaName = metadataLayout.schemaName
tablesList = metadataLayout.getAllTablesCols().keys()
gptResponse = addSchemaToTableInSQL(gptResponse, schemaName=schemaName, tablesList=tablesList)
return gptResponse