from tableQA_single_table import * import json import os import sys def run_sql_query(s, df): conn = sqlite3.connect(":memory:") assert isinstance(df, pd.DataFrame) question_column = s.question_column if question_column is None: return { "sql_query": "", "cnt_num": 0, "conclusion": [] } total_conds_filtered = s.total_conds_filtered agg_pred = s.agg_pred conn_pred = s.conn_pred sql_format = "SELECT {} FROM {} {}" header = df.columns.tolist() if len(header) > len(set(header)): req = [] have_req = set([]) idx = 0 for h in header: if h in have_req: idx += 1 req.append("{}_{}".format(h, idx)) else: req.append(h) have_req.add(h) header = req def format_right(val): val = str(val) is_string = True try: literal_eval(val) is_string = False except: pass if is_string: return "'{}'".format(val) else: return val #ic(question_column, header) assert question_column in header assert all(map(lambda t3: t3[0] in header, total_conds_filtered)) assert len(header) == len(set(header)) index_header_mapping = dict(enumerate(header)) header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items())) assert len(index_header_mapping) == len(header_index_mapping) df_saved = df.copy() df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header)))) df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False) question_column_idx = header.index(question_column) sql_question_column = "col_{}".format(question_column_idx) sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered)) sql_agg_pred = agg_pred if sql_agg_pred.strip(): sql_agg_pred = "{}()".format(sql_agg_pred) else: sql_agg_pred = "()" sql_agg_pred = sql_agg_pred.replace("()", "({})") sql_conn_pred = conn_pred if sql_conn_pred.strip(): pass else: sql_conn_pred = "" #sql_where_string = "" if not (sql_total_conds_filtered and sql_conn_pred) else "WHERE {}".format(" {} ".format(sql_conn_pred).join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered))) sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered))) #ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s) sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string) cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip() #ic(cnt_sql_query) cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0] if cnt_num == 0: return { "sql_query": sql_query, "cnt_num": 0, "conclusion": [] } query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist() return { "sql_query": sql_query, "cnt_num": cnt_num, "conclusion": query_conclusion_list } #save_conn = sqlite3.connect(":memory:") def single_table_pred(question, pd_df): assert type(question) == type("") assert isinstance(pd_df, pd.DataFrame) qs_df = pd.DataFrame([[question]], columns = ["question"]) #print("pd_df :") #print(pd_df) tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False) #print("tableqa_df :") #print(tableqa_df) assert tableqa_df.shape[0] == 1 #sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn) sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df) return sql_query_dict if __name__ == "__main__": szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) data = { "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", "tqa_header": szse_summary_df.columns.tolist(), "tqa_rows": szse_summary_df.values.tolist(), "tqa_data_path": os.path.join(main_path ,"data/df1.csv"), "tqa_answer": { "sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", "cnt_num": 2, "conclusion": [57.645] } } pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"]) question = data["tqa_question"] single_table_pred(question, pd_df)