|
from tableQA_single_table import * |
|
|
|
import json |
|
import os |
|
import sys |
|
|
|
def run_sql_query(s, df): |
|
conn = sqlite3.connect(":memory:") |
|
|
|
assert isinstance(df, pd.DataFrame) |
|
question_column = s.question_column |
|
if question_column is None: |
|
return { |
|
"sql_query": "", |
|
"cnt_num": 0, |
|
"conclusion": [] |
|
} |
|
total_conds_filtered = s.total_conds_filtered |
|
agg_pred = s.agg_pred |
|
conn_pred = s.conn_pred |
|
sql_format = "SELECT {} FROM {} {}" |
|
header = df.columns.tolist() |
|
if len(header) > len(set(header)): |
|
req = [] |
|
have_req = set([]) |
|
idx = 0 |
|
for h in header: |
|
if h in have_req: |
|
idx += 1 |
|
req.append("{}_{}".format(h, idx)) |
|
else: |
|
req.append(h) |
|
have_req.add(h) |
|
header = req |
|
def format_right(val): |
|
val = str(val) |
|
is_string = True |
|
try: |
|
literal_eval(val) |
|
is_string = False |
|
except: |
|
pass |
|
if is_string: |
|
return "'{}'".format(val) |
|
else: |
|
return val |
|
|
|
assert question_column in header |
|
assert all(map(lambda t3: t3[0] in header, total_conds_filtered)) |
|
assert len(header) == len(set(header)) |
|
index_header_mapping = dict(enumerate(header)) |
|
header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items())) |
|
assert len(index_header_mapping) == len(header_index_mapping) |
|
df_saved = df.copy() |
|
df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header)))) |
|
df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False) |
|
question_column_idx = header.index(question_column) |
|
sql_question_column = "col_{}".format(question_column_idx) |
|
sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered)) |
|
sql_agg_pred = agg_pred |
|
if sql_agg_pred.strip(): |
|
sql_agg_pred = "{}()".format(sql_agg_pred) |
|
else: |
|
sql_agg_pred = "()" |
|
sql_agg_pred = sql_agg_pred.replace("()", "({})") |
|
sql_conn_pred = conn_pred |
|
if sql_conn_pred.strip(): |
|
pass |
|
else: |
|
sql_conn_pred = "" |
|
|
|
sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered))) |
|
|
|
sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string) |
|
cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip() |
|
|
|
cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0] |
|
if cnt_num == 0: |
|
return { |
|
"sql_query": sql_query, |
|
"cnt_num": 0, |
|
"conclusion": [] |
|
} |
|
query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist() |
|
return { |
|
"sql_query": sql_query, |
|
"cnt_num": cnt_num, |
|
"conclusion": query_conclusion_list |
|
} |
|
|
|
|
|
def single_table_pred(question, pd_df): |
|
assert type(question) == type("") |
|
assert isinstance(pd_df, pd.DataFrame) |
|
qs_df = pd.DataFrame([[question]], columns = ["question"]) |
|
|
|
|
|
|
|
|
|
tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False) |
|
|
|
|
|
|
|
|
|
assert tableqa_df.shape[0] == 1 |
|
|
|
sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df) |
|
return sql_query_dict |
|
|
|
|
|
if __name__ == "__main__": |
|
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) |
|
data = { |
|
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", |
|
"tqa_header": szse_summary_df.columns.tolist(), |
|
"tqa_rows": szse_summary_df.values.tolist(), |
|
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"), |
|
"tqa_answer": { |
|
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", |
|
"cnt_num": 2, |
|
"conclusion": [57.645] |
|
} |
|
} |
|
|
|
pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"]) |
|
question = data["tqa_question"] |
|
single_table_pred(question, pd_df) |
|
|