import psycopg2 import pymysql import json import logging import os from enum import IntEnum class DataType(IntEnum): VALUE = 0 TIME = 1 CHAR = 2 AGGREGATE_CONSTRAINTS = { DataType.VALUE.value: ["count", "max", "min", "avg", "sum"], DataType.VALUE.CHAR: ["count", "max", "min"], DataType.VALUE.TIME: ["count", "max", "min"], } def transfer_field_type(database_type, server): data_type = list() if server == "mysql": data_type = [ [ "int", "tinyint", "smallint", "mediumint", "bigint", "float", "double", "decimal", ], ["date", "time", "year", "datetime", "timestamp"], ] database_type = database_type.lower().split("(")[0] elif server == "postgresql": data_type = [["integer", "numeric"], ["date"]] if database_type in data_type[0]: return DataType.VALUE.value elif database_type in data_type[1]: return DataType.TIME.value else: return DataType.CHAR.value class DBArgs(object): def __init__(self, dbtype, config, dbname=None): self.dbtype = dbtype if self.dbtype == "mysql": self.host = config["host"] self.port = config["port"] self.user = config["user"] self.password = config["password"] self.dbname = dbname if dbname else config["dbname"] self.driver = "com.mysql.jdbc.Driver" self.jdbc = "jdbc:mysql://" else: self.host = config["host"] self.port = config["port"] self.user = config["user"] self.password = config["password"] self.dbname = dbname if dbname else config["dbname"] self.driver = "org.postgresql.Driver" self.jdbc = "jdbc:postgresql://" class Database: def __init__(self, args, timeout=-1): self.args = args self.conn = self.resetConn(timeout) # self.schema = self.compute_table_schema() def resetConn(self, timeout=-1): if self.args.dbtype == "mysql": conn = pymysql.connect( host=self.args.host, user=self.args.user, passwd=self.args.password, database=self.args.dbname, port=int(self.args.port), charset="utf8", connect_timeout=timeout, read_timeout=timeout, write_timeout=timeout, ) else: if timeout > 0: conn = psycopg2.connect( database=self.args.dbname, user=self.args.user, password=self.args.password, host=self.args.host, port=self.args.port, options="-c statement_timeout={}s".format(timeout), ) else: conn = psycopg2.connect( database=self.args.dbname, user=self.args.user, password=self.args.password, host=self.args.host, port=self.args.port, ) return conn """ def exec_fetch(self, statement, one=True): cur = self.conn.cursor() cur.execute(statement) if one: return cur.fetchone() return cur.fetchall() """ def execute_sql(self, sql): fail = 1 self.conn = self.resetConn() cur = self.conn.cursor() i = 0 cnt = 3 # retry times while fail == 1 and i < cnt: try: fail = 0 cur.execute(sql) except BaseException: fail = 1 res = [] if fail == 0: res = cur.fetchall() i = i + 1 logging.debug( "database {}, return flag {}, execute sql {}\n".format( self.args.dbname, 1 - fail, sql ) ) if fail == 1: # raise RuntimeError("Database query failed") print("SQL Execution Fatal!!") return 0, "" elif fail == 0: # print("SQL Execution Succeed!!") return 1, res def pgsql_results(self, sql): try: # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) success, res = self.execute_sql(sql) # print("pgsql_results", success, res) if success == 1: return res else: return "" except Exception as error: logging.error("pgsql_results Exception", error) return "" def pgsql_cost_estimation(self, sql): try: # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) success, res = self.execute_sql("explain (FORMAT JSON) " + sql) if success == 1: cost = res[0][0][0]["Plan"]["Total Cost"] return cost else: logging.error("pgsql_cost_estimation Fails!") return 0 except Exception as error: logging.error("pgsql_cost_estimation Exception", error) return 0 def pgsql_actual_time(self, sql): try: # success, res = self.execute_sql('explain (FORMAT JSON, analyze) ' + sql) success, res = self.execute_sql("explain (FORMAT JSON, analyze) " + sql) if success == 1: cost = res[0][0][0]["Plan"]["Actual Total Time"] return cost else: return -1 except Exception as error: logging.error("pgsql_actual_time Exception", error) return -1 def mysql_cost_estimation(self, sql): try: success, res = self.execute_sql("explain format=json " + sql) if success == 1: total_cost = self.get_mysql_total_cost(0, json.loads(res[0][0])) return float(total_cost) else: return -1 except Exception as error: logging.error("mysql_cost_estimation Exception", error) return -1 def get_mysql_total_cost(self, total_cost, res): if isinstance(res, dict): if "query_cost" in res.keys(): total_cost += float(res["query_cost"]) else: for key in res: total_cost += self.get_mysql_total_cost(0, res[key]) elif isinstance(res, list): for i in res: total_cost += self.get_mysql_total_cost(0, i) return total_cost def get_tables(self): if self.args.dbtype == "mysql": return self.mysql_get_tables() else: return self.pgsql_get_tables() # query cost estimated by the optimizer def cost_estimation(self, sql): if self.args.dbtype == "mysql": return self.mysql_cost_estimation(sql) else: return self.pgsql_cost_estimation(sql) def compute_table_schema(self): """ schema: {table_name: [field_name]} :param cursor: :return: """ if self.args.dbtype == "postgresql": # cur_path = os.path.abspath('.') # tpath = cur_path + '/sampled_data/'+dbname+'/schema' sql = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" success, res = self.execute_sql(sql) # print("======== tables", res) if success == 1: tables = res schema = {} for table_info in tables: table_name = table_info[0] sql = ( "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '" + table_name + "';" ) success, res = self.execute_sql(sql) # print("======== table columns", res) columns = res schema[table_name] = [] for col in columns: """compute the distinct value ratio of the column if transfer_field_type(col[1], self.args.dbtype) == DataType.VALUE.value: sql = 'SELECT count({}) FROM {};'.format(col[0], table_name) success, res = self.execute_sql(sql) print("======== column rows", res) num = res if num[0][0] != 0: schema[table_name].append(col[0]) """ # schema[table_name].append("column {} is of {} type".format(col[0], col[1])) schema[table_name].append("{}".format(col[0])) """ with open(tpath, 'w') as f: f.write(str(schema)) """ # print(schema) return schema else: logging.error("pgsql_cost_estimation Fails!") return 0 def simulate_index(self, index): # table_name = index.table() statement = "SELECT * FROM hypopg_create_index(E'{}');".format(index) result = self.execute_sql(statement) return result def drop_simulated_index(self, oid): statement = f"select * from hypopg_drop_index({oid})" result = self.execute_sql(statement) assert result[0] is True, f"Could not drop simulated index with oid = {oid}."