File size: 4,833 Bytes
b5dbcf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from tableQA_single_table import *

import json
import os
import sys

def run_sql_query(s, df):
    conn = sqlite3.connect(":memory:")

    assert isinstance(df, pd.DataFrame)
    question_column = s.question_column
    if question_column is None:
        return {
            "sql_query": "",
            "cnt_num": 0,
            "conclusion": []
        }
    total_conds_filtered = s.total_conds_filtered
    agg_pred = s.agg_pred
    conn_pred = s.conn_pred
    sql_format = "SELECT {} FROM {} {}"
    header = df.columns.tolist()
    if len(header) > len(set(header)):
        req = []
        have_req = set([])
        idx = 0
        for h in header:
            if h in have_req:
                idx += 1
                req.append("{}_{}".format(h, idx))
            else:
                req.append(h)
            have_req.add(h)
        header = req
    def format_right(val):
        val = str(val)
        is_string = True
        try:
            literal_eval(val)
            is_string = False
        except:
            pass
        if is_string:
            return "'{}'".format(val)
        else:
            return val
    #ic(question_column, header)
    assert question_column in header
    assert all(map(lambda t3: t3[0] in header, total_conds_filtered))
    assert len(header) == len(set(header))
    index_header_mapping = dict(enumerate(header))
    header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items()))
    assert len(index_header_mapping) == len(header_index_mapping)
    df_saved = df.copy()
    df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header))))
    df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False)
    question_column_idx = header.index(question_column)
    sql_question_column = "col_{}".format(question_column_idx)
    sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered))
    sql_agg_pred = agg_pred
    if sql_agg_pred.strip():
        sql_agg_pred = "{}()".format(sql_agg_pred)
    else:
        sql_agg_pred = "()"
    sql_agg_pred = sql_agg_pred.replace("()", "({})")
    sql_conn_pred = conn_pred
    if sql_conn_pred.strip():
        pass
    else:
        sql_conn_pred = ""
    #sql_where_string = "" if not (sql_total_conds_filtered and sql_conn_pred) else "WHERE {}".format(" {} ".format(sql_conn_pred).join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
    sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
    #ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s)
    sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string)
    cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip()
    #ic(cnt_sql_query)
    cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0]
    if cnt_num == 0:
        return {
            "sql_query": sql_query,
            "cnt_num": 0,
            "conclusion": []
        }
    query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist()
    return {
            "sql_query": sql_query,
            "cnt_num": cnt_num,
            "conclusion": query_conclusion_list
        }

#save_conn = sqlite3.connect(":memory:")
def single_table_pred(question, pd_df):
    assert type(question) == type("")
    assert isinstance(pd_df, pd.DataFrame)
    qs_df = pd.DataFrame([[question]], columns = ["question"])

    #print("pd_df :")
    #print(pd_df)

    tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False)

    #print("tableqa_df :")
    #print(tableqa_df)

    assert tableqa_df.shape[0] == 1
    #sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn)
    sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df)
    return sql_query_dict


if __name__ == "__main__":
    szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv"))
    data = {
            "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?",
            "tqa_header": szse_summary_df.columns.tolist(),
            "tqa_rows": szse_summary_df.values.tolist(),
            "tqa_data_path": os.path.join(main_path ,"data/df1.csv"),
            "tqa_answer":   {
    "sql_query":  "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5",
    "cnt_num":  2,
    "conclusion":  [57.645]
        }
    }

    pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"])
    question = data["tqa_question"]
    single_table_pred(question, pd_df)