Binder / utils /sql /extraction_from_sql.py
Timothyxxx
Init
f6f97d8
import argparse
import json
from utils.sql.process_sql import (
tokenize, CLAUSE_KEYWORDS, WHERE_OPS, COND_OPS, UNIT_OPS, AGG_OPS,
JOIN_KEYWORDS, ORDER_OPS, skip_semicolon, SQL_OPS)
KEPT_WHERE_OP = ('not', 'in', 'exists')
def parse_table_unit(toks, start_idx, tables_with_alias):
idx = start_idx
len_ = len(toks)
key = toks[idx]
if idx + 1 < len_ and toks[idx + 1] == "as":
tables_with_alias[toks[idx + 2]] = toks[idx]
idx += 3
else:
idx += 1
return idx, key
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
# key = tables_with_alias[alias] + "." + col
table = tables_with_alias[alias]
"""
Add schema
"""
if table not in schema:
schema[table] = []
schema[table].append(col)
# We also want to normalize the column
toks[start_idx] = "{}.{}".format(table, col)
"""
END
"""
return start_idx + 1
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
# assert len(default_tables) == 1, "Default table should only have one time"
"""
Add schema
"""
# Find the best table here
def choose_best_table(default_tables, tok):
lower_tok = tok.lower()
candidate = process.extractOne(lower_tok, [table.lower() for table in default_tables])[0]
return candidate
if len(default_tables) != 1:
# print(default_tables)
table = choose_best_table(default_tables, tok)
# assert len(default_tables) == 1, "Default table should only have one time"
else:
table = default_tables[0]
if table not in schema:
schema[table] = []
schema[table].append(tok)
toks[start_idx] = "{}.{}".format(table, tok)
return start_idx + 1
# for alias in default_tables:
# table = tables_with_alias[alias]
# if tok in schema.schema[table]:
# key = table + "." + tok
# return start_idx + 1, schema.idMap[key]
# assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None, end_idx=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
if end_idx is not None:
len_ = len(toks[start_idx:end_idx])
else:
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx = parse_sql(toks, idx, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
# Replace with placeholder
toks[idx] = "_str_value_"
idx += 1
else:
try:
val = float(toks[idx])
toks[idx] = "_num_value_"
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')' \
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[
end_idx] not in JOIN_KEYWORDS:
end_idx += 1
# idx = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
idx = parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables, end_idx=end_idx)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
# conds = []
while idx < len_:
idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
# conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
# conds.append(toks[idx])
idx += 1 # skip and/or
return idx# , conds
def parse_from(toks, start_idx, schema):
assert 'from' in toks[start_idx:], "'from' not found"
tables_with_alias = {}
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
# print(idx, len_)
while idx < len_:
# print("idx", idx, toks[idx])
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx = parse_sql(toks, idx, schema)
# table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_name = parse_table_unit(toks, idx, tables_with_alias)
# print(table_name)
# table_units.append((TABLE_TYPE['table_unit'], table_unit))
default_tables.append(table_name)
"""
Add schema
"""
if table_name not in schema:
schema[table_name] = []
"""
END
"""
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
# if len(conds) > 0:
# conds.append('and')
# conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, default_tables, tables_with_alias
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
# val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx
idx += 1
idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
# col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx
idx += 1
idx = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
# val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
toks[idx - 1] = "_limit_value_"
# make limit value can work, cannot assume put 1 as a fake limit number
if type(toks[idx - 1]) != int:
return idx
return idx
return idx
def parse_sql(toks, start_idx, schema):
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
if toks[idx] == '(':
isBlock = True
idx += 1
from_end_idx, default_tables, tables_with_alias = parse_from(toks, start_idx, schema)
_ = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
idx = parse_where(toks, idx, tables_with_alias, schema, default_tables)
idx = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
idx = parse_having(toks, idx, tables_with_alias, schema, default_tables)
idx = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
idx = parse_limit(toks, idx)
#
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# for op in SQL_OPS: # initialize IUE
# sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx = parse_sql(toks, idx, schema)
# sql[sql_op] = IUE_sql
return idx
def extract_schema_from_sql(schema, sql):
toks = tokenize(sql)
parse_sql(toks=toks, start_idx=0, schema=schema)
return toks
def extract_template_from_sql(sql, schema={}):
try:
toks = tokenize(sql)
except:
print("Tokenization error for {}".format(sql))
toks = []
# print(toks)
template = []
# ignore_follow_up_and = False
len_ = len(toks)
idx = 0
while idx < len_:
tok = toks[idx]
if tok == "from":
template.append(tok)
if toks[idx+1] != "(":
template.append("[FROM_PART]")
idx += 1
while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"):
idx += 1
continue
elif tok in CLAUSE_KEYWORDS:
template.append(tok)
elif tok in AGG_OPS:
template.append(tok)
elif tok in [",", "*", "(", ")", "having", "by", "distinct"]:
template.append(tok)
elif tok in ["asc", "desc"]:
template.append("[ORDER_DIRECTION]")
elif tok in WHERE_OPS:
if tok in KEPT_WHERE_OP:
template.append(tok)
else:
template.append("[WHERE_OP]")
if tok == "between":
idx += 2
elif tok in COND_OPS:
template.append(tok)
elif template[-1] == "[WHERE_OP]":
template.append("[VALUE]")
elif template[-1] == "limit":
template.append("[LIMIT_VALUE]")
elif template[-1] != "[MASK]": # value, schema, join on as
template.append("[MASK]")
idx += 1
return template
def extract_partial_template_from_sql(sql, schema={}):
toks = tokenize(sql)
# print(toks)
template = []
# ignore_follow_up_and = False
len_ = len(toks)
idx = 0
while idx < len_:
tok = toks[idx]
if tok == "from":
template.append(tok)
if toks[idx+1] != "(":
# template.append("[FROM_PART]")
idx += 1
while idx < len_ and (toks[idx] not in CLAUSE_KEYWORDS and toks[idx] != ")"):
template.append(toks[idx])
idx += 1
continue
elif tok in CLAUSE_KEYWORDS:
template.append(tok)
elif tok in AGG_OPS:
template.append(tok)
elif tok in [",", "*", "(", ")", "having", "by", "distinct"]:
template.append(tok)
elif tok in ["asc", "desc"]:
template.append("[ORDER_DIRECTION]")
elif tok in WHERE_OPS:
if tok in KEPT_WHERE_OP:
template.append(tok)
else:
template.append("[WHERE_OP]")
if tok == "between":
idx += 2
elif tok in COND_OPS:
template.append(tok)
elif template[-1] == "[WHERE_OP]":
template.append("[VALUE]")
elif template[-1] == "limit":
template.append("[LIMIT_VALUE]")
else:
template.append(tok)
idx += 1
return template
def is_valid_schema(schema):
# There is no "." and " " in the column name
for table in schema:
if "." in table:
return False
if any([keyword == table for keyword in CLAUSE_KEYWORDS]):
return False
for column in schema[table]:
if "." in column or " " in column or '"' in column or "'" in column:
return False
return True
def clean_sql(sql):
while "JOIN JOIN" in sql:
sql = sql.replace("JOIN JOIN", "JOIN")
if "JOIN WHERE" in sql:
sql = sql.replace("JOIN WHERE", "WHERE")
if "JOIN GROUP BY" in sql:
sql = sql.replace("JOIN GROUP BY", "GROUP BY")
return sql
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file", type=str)
parser.add_argument("--output_file", type=str)
parser.add_argument("--mode", type=str, choices=["debug", "verbose", "silent"])
parser.add_argument("--task", type=str, choices=["template_extraction", "schema_extraction"])
args = parser.parse_args()
if args.task == "schema_extraction":
if args.mode == "debug":
sql = "SELECT count(*) FROM games"
sql = sql + " INTERSECT " + "SELECT sacks, year FROM players"
sql = sql + " EXCEPT " + 'SELECT T1.year, T1.sacks FROM players AS T1 JOIN tackles AS T2 ON T1.id = T2.player_id WHERE T2.manager = "A" and T2.season NOT IN (SELECT season FROM match WHERE match_name = "IVL" INTERSECT SELECT T1.year, T1.sacks FROM sack AS T1) GROUP BY T1.year, T1.sacks HAVING count(T1.coach) > 10 ORDER BY T2.score LIMIT 5'
sql = "SELECT T1.pld FROM pld AS T1 JOIN games AS T2 ON T1.crs_code = T2.crs_code JOIN GROUP BY T1.pld WHERE T2.gf = '8' AND T2.gf = '9'"
sql = 'select * from head where height = "6-0" or height = "6-0" order by height asc'
schema = {}
extract_schema_from_sql(schema, sql)
print(schema, is_valid_schema(schema))
elif args.mode == "verbose":
fout = open(args.output_file, "w")
with open(args.input_file) as fin:
for line in fin:
example = json.loads(line)
schema = {}
try:
sql = example["sql"] if "sql" in example else example["pred"]
sql = clean_sql(sql)
example["sql"] = sql
extract_schema_from_sql(schema, sql)
except:
# print(sql)
continue
for table in schema:
schema[table] = list(set(schema[table]))
if is_valid_schema(schema):
example["extracted_schema"] = schema
fout.write(json.dumps(example) + "\n")
elif args.mode == "verbose":
fout = open(args.output_file, "w")
with open(args.input_file) as fin:
for line in fin:
example = json.loads(line)
schema = {}
sql = example["sql"] if "sql" in example else example["pred"]
sql = clean_sql(sql)
example["sql"] = sql
extract_schema_from_sql(schema, sql)
for table in schema:
schema[table] = list(set(schema[table]))
example["extracted_schema"] = schema
fout.write(json.dumps(example) + "\n")
if is_valid_schema(schema):
example["extracted_schema"] = schema
fout.write(json.dumps(example) + "\n")
elif args.task == "template_extraction":
if args.mode == "debug":
sql = "SELECT avg(T1.Votes) FROM seats AS T1 JOIN votes AS T2 ON T1.Seat_ID = T2.Seat_ID WHERE T1.seats BETWEEN 1 AND 2 and T1.Seats = 1 AND T2.Votes = 10"
print(extract_template_from_sql(sql))
print(extract_partial_template_from_sql(sql))
elif args.mode == "verbose":
fout_json = open(args.output_file + ".json", "w")
fout_txt = open(args.output_file + ".txt", "w")
low_freq_txt = open(args.output_file + ".low_freq", "w")
high_freq_txt = open(args.output_file + ".high_freq", "w")
all_templates = set()
# for input_file in args.input_file.split(","):
templates = {}
with open(args.input_file) as fin:
for line in fin:
example = json.loads(line)
sql = example["sql"] if "sql" in example else example["pred"]
if isinstance(sql, list):
sql = sql[-1]
template = extract_template_from_sql(sql)
template_str = " ".join(template)
if template_str not in templates:
templates[template_str] = []
templates[template_str].append(sql)
print("{} has template {}".format(args.input_file, len(templates)))
json.dump(templates, fout_json)
for template in sorted(templates.keys()):
if len(templates[template]) > 1:
high_freq_txt.write(template + "\n")
else:
low_freq_txt.write(template + "\n")
fout_txt.write(template + "\n")