TableQA-Chinese / run.py
svjack's picture
Upload . with huggingface_hub
b5dbcf3
raw
history blame
No virus
4.83 kB
from tableQA_single_table import *
import json
import os
import sys
def run_sql_query(s, df):
conn = sqlite3.connect(":memory:")
assert isinstance(df, pd.DataFrame)
question_column = s.question_column
if question_column is None:
return {
"sql_query": "",
"cnt_num": 0,
"conclusion": []
}
total_conds_filtered = s.total_conds_filtered
agg_pred = s.agg_pred
conn_pred = s.conn_pred
sql_format = "SELECT {} FROM {} {}"
header = df.columns.tolist()
if len(header) > len(set(header)):
req = []
have_req = set([])
idx = 0
for h in header:
if h in have_req:
idx += 1
req.append("{}_{}".format(h, idx))
else:
req.append(h)
have_req.add(h)
header = req
def format_right(val):
val = str(val)
is_string = True
try:
literal_eval(val)
is_string = False
except:
pass
if is_string:
return "'{}'".format(val)
else:
return val
#ic(question_column, header)
assert question_column in header
assert all(map(lambda t3: t3[0] in header, total_conds_filtered))
assert len(header) == len(set(header))
index_header_mapping = dict(enumerate(header))
header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items()))
assert len(index_header_mapping) == len(header_index_mapping)
df_saved = df.copy()
df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header))))
df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False)
question_column_idx = header.index(question_column)
sql_question_column = "col_{}".format(question_column_idx)
sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered))
sql_agg_pred = agg_pred
if sql_agg_pred.strip():
sql_agg_pred = "{}()".format(sql_agg_pred)
else:
sql_agg_pred = "()"
sql_agg_pred = sql_agg_pred.replace("()", "({})")
sql_conn_pred = conn_pred
if sql_conn_pred.strip():
pass
else:
sql_conn_pred = ""
#sql_where_string = "" if not (sql_total_conds_filtered and sql_conn_pred) else "WHERE {}".format(" {} ".format(sql_conn_pred).join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
#ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s)
sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string)
cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip()
#ic(cnt_sql_query)
cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0]
if cnt_num == 0:
return {
"sql_query": sql_query,
"cnt_num": 0,
"conclusion": []
}
query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist()
return {
"sql_query": sql_query,
"cnt_num": cnt_num,
"conclusion": query_conclusion_list
}
#save_conn = sqlite3.connect(":memory:")
def single_table_pred(question, pd_df):
assert type(question) == type("")
assert isinstance(pd_df, pd.DataFrame)
qs_df = pd.DataFrame([[question]], columns = ["question"])
#print("pd_df :")
#print(pd_df)
tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False)
#print("tableqa_df :")
#print(tableqa_df)
assert tableqa_df.shape[0] == 1
#sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn)
sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df)
return sql_query_dict
if __name__ == "__main__":
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv"))
data = {
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?",
"tqa_header": szse_summary_df.columns.tolist(),
"tqa_rows": szse_summary_df.values.tolist(),
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"),
"tqa_answer": {
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5",
"cnt_num": 2,
"conclusion": [57.645]
}
}
pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"])
question = data["tqa_question"]
single_table_pred(question, pd_df)