Tools / database /utils /database.py
ZackBradshaw's picture
Upload folder using huggingface_hub
e67043b verified
import psycopg2
import pymysql
import json
import logging
import os
from enum import IntEnum
class DataType(IntEnum):
VALUE = 0
TIME = 1
CHAR = 2
AGGREGATE_CONSTRAINTS = {
DataType.VALUE.value: ["count", "max", "min", "avg", "sum"],
DataType.VALUE.CHAR: ["count", "max", "min"],
DataType.VALUE.TIME: ["count", "max", "min"],
}
def transfer_field_type(database_type, server):
data_type = list()
if server == "mysql":
data_type = [
[
"int",
"tinyint",
"smallint",
"mediumint",
"bigint",
"float",
"double",
"decimal",
],
["date", "time", "year", "datetime", "timestamp"],
]
database_type = database_type.lower().split("(")[0]
elif server == "postgresql":
data_type = [["integer", "numeric"], ["date"]]
if database_type in data_type[0]:
return DataType.VALUE.value
elif database_type in data_type[1]:
return DataType.TIME.value
else:
return DataType.CHAR.value
class DBArgs(object):
def __init__(self, dbtype, config, dbname=None):
self.dbtype = dbtype
if self.dbtype == "mysql":
self.host = config["host"]
self.port = config["port"]
self.user = config["user"]
self.password = config["password"]
self.dbname = dbname if dbname else config["dbname"]
self.driver = "com.mysql.jdbc.Driver"
self.jdbc = "jdbc:mysql://"
else:
self.host = config["host"]
self.port = config["port"]
self.user = config["user"]
self.password = config["password"]
self.dbname = dbname if dbname else config["dbname"]
self.driver = "org.postgresql.Driver"
self.jdbc = "jdbc:postgresql://"
class Database:
def __init__(self, args, timeout=-1):
self.args = args
self.conn = self.resetConn(timeout)
# self.schema = self.compute_table_schema()
def resetConn(self, timeout=-1):
if self.args.dbtype == "mysql":
conn = pymysql.connect(
host=self.args.host,
user=self.args.user,
passwd=self.args.password,
database=self.args.dbname,
port=int(self.args.port),
charset="utf8",
connect_timeout=timeout,
read_timeout=timeout,
write_timeout=timeout,
)
else:
if timeout > 0:
conn = psycopg2.connect(
database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port,
options="-c statement_timeout={}s".format(timeout),
)
else:
conn = psycopg2.connect(
database=self.args.dbname,
user=self.args.user,
password=self.args.password,
host=self.args.host,
port=self.args.port,
)
return conn
"""
def exec_fetch(self, statement, one=True):
cur = self.conn.cursor()
cur.execute(statement)
if one:
return cur.fetchone()
return cur.fetchall()
"""
def execute_sql(self, sql):
fail = 1
self.conn = self.resetConn()
cur = self.conn.cursor()
i = 0
cnt = 3 # retry times
while fail == 1 and i < cnt:
try:
fail = 0
cur.execute(sql)
except BaseException:
fail = 1
res = []
if fail == 0:
res = cur.fetchall()
i = i + 1
logging.debug(
"database {}, return flag {}, execute sql {}\n".format(
self.args.dbname, 1 - fail, sql
)
)
if fail == 1:
# raise RuntimeError("Database query failed")
print("SQL Execution Fatal!!")
return 0, ""
elif fail == 0:
# print("SQL Execution Succeed!!")
return 1, res
def pgsql_results(self, sql):
try:
# success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql(sql)
# print("pgsql_results", success, res)
if success == 1:
return res
else:
return "<fail>"
except Exception as error:
logging.error("pgsql_results Exception", error)
return "<fail>"
def pgsql_cost_estimation(self, sql):
try:
# success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql("explain (FORMAT JSON) " + sql)
if success == 1:
cost = res[0][0][0]["Plan"]["Total Cost"]
return cost
else:
logging.error("pgsql_cost_estimation Fails!")
return 0
except Exception as error:
logging.error("pgsql_cost_estimation Exception", error)
return 0
def pgsql_actual_time(self, sql):
try:
# success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql)
success, res = self.execute_sql("explain (FORMAT JSON, analyze) " + sql)
if success == 1:
cost = res[0][0][0]["Plan"]["Actual Total Time"]
return cost
else:
return -1
except Exception as error:
logging.error("pgsql_actual_time Exception", error)
return -1
def mysql_cost_estimation(self, sql):
try:
success, res = self.execute_sql("explain format=json " + sql)
if success == 1:
total_cost = self.get_mysql_total_cost(0, json.loads(res[0][0]))
return float(total_cost)
else:
return -1
except Exception as error:
logging.error("mysql_cost_estimation Exception", error)
return -1
def get_mysql_total_cost(self, total_cost, res):
if isinstance(res, dict):
if "query_cost" in res.keys():
total_cost += float(res["query_cost"])
else:
for key in res:
total_cost += self.get_mysql_total_cost(0, res[key])
elif isinstance(res, list):
for i in res:
total_cost += self.get_mysql_total_cost(0, i)
return total_cost
def get_tables(self):
if self.args.dbtype == "mysql":
return self.mysql_get_tables()
else:
return self.pgsql_get_tables()
# query cost estimated by the optimizer
def cost_estimation(self, sql):
if self.args.dbtype == "mysql":
return self.mysql_cost_estimation(sql)
else:
return self.pgsql_cost_estimation(sql)
def compute_table_schema(self):
"""
schema: {table_name: [field_name]}
:param cursor:
:return:
"""
if self.args.dbtype == "postgresql":
# cur_path = os.path.abspath('.')
# tpath = cur_path + '/sampled_data/'+dbname+'/schema'
sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
success, res = self.execute_sql(sql)
# print("======== tables", res)
if success == 1:
tables = res
schema = {}
for table_info in tables:
table_name = table_info[0]
sql = (
"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '"
+ table_name
+ "';"
)
success, res = self.execute_sql(sql)
# print("======== table columns", res)
columns = res
schema[table_name] = []
for col in columns:
"""compute the distinct value ratio of the column
if transfer_field_type(col[1], self.args.dbtype) == DataType.VALUE.value:
sql = 'SELECT count({}) FROM {};'.format(col[0], table_name)
success, res = self.execute_sql(sql)
print("======== column rows", res)
num = res
if num[0][0] != 0:
schema[table_name].append(col[0])
"""
# schema[table_name].append("column {} is of {} type".format(col[0], col[1]))
schema[table_name].append("{}".format(col[0]))
"""
with open(tpath, 'w') as f:
f.write(str(schema))
"""
# print(schema)
return schema
else:
logging.error("pgsql_cost_estimation Fails!")
return 0
def simulate_index(self, index):
# table_name = index.table()
statement = "SELECT * FROM hypopg_create_index(E'{}');".format(index)
result = self.execute_sql(statement)
return result
def drop_simulated_index(self, oid):
statement = f"select * from hypopg_drop_index({oid})"
result = self.execute_sql(statement)
assert result[0] is True, f"Could not drop simulated index with oid = {oid}."