File size: 3,642 Bytes
b5dbcf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689c24b
b5dbcf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from gradio import *

from run import *

szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv"))
tableqa_ = "数据表问答(编辑数据)"

default_val_dict = {
    tableqa_ :{
            "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?",
            "tqa_header": szse_summary_df.columns.tolist(),
            "tqa_rows": szse_summary_df.values.tolist(),
            "tqa_data_path": os.path.join(main_path ,"data/df1.csv"),
            "tqa_answer":   {
    "sql_query":  "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5",
    "cnt_num":  2,
    "conclusion":  [57.645]
        }
    }
}

def tableqa_layer(post_data):
    question = post_data["question"]
    table_rows = post_data["table_rows"]
    table_header = post_data["table_header"]
    assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header]))
    table_rows = json.loads(table_rows)
    table_header = json.loads(table_header)

    assert all(map(lambda x: type(x) == type([]), [table_rows, table_header]))
    if bool(table_rows) and bool(table_header):
        assert len(table_header) == len(table_rows[0])
    df = pd.DataFrame(table_rows, columns = table_header)
    conclusion = single_table_pred(question, df)
    return conclusion

def run_tableqa(*input):
    question, data = input
    header = data.columns.tolist()
    rows = data.values.tolist()

    rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows))

    assert all(map(lambda x: type(x) == type([]), [header, rows]))
    header = json.dumps(header)
    rows = json.dumps(rows)

    assert all(map(lambda x: type(x) == type(""), [question, header, rows]))

    resp = tableqa_layer(
        {
            "question": question,
            "table_header": header,
            "table_rows": rows
        }
    )
    if "cnt_num" in resp:
        if hasattr(resp["cnt_num"], "tolist"):
            resp["cnt_num"] = resp["cnt_num"].tolist()
    if "conclusion" in resp:
        if hasattr(resp["conclusion"], "tolist"):
            resp["conclusion"] = resp["conclusion"].tolist()
    '''
    import pickle as pkl
    with open("resp.pkl", "wb") as f:
        pkl.dump(resp, f)
    print(resp)
    '''
    resp = json.loads(json.dumps(resp))
    return resp

demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")

with demo:
    gr.Markdown("")
    gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)</h4></b>\n")
    with gr.Tabs():
        #### tableqa
        with gr.TabItem("数据表问答(TableQA)"):

            with gr.Tabs():
                with gr.TabItem(tableqa_):
                    tqa_question = gr.Textbox(
                    default_val_dict[tableqa_]["tqa_question"],
                    label = "问句:(输入)"
                    )

                    tqa_data = gr.Dataframe(
                    headers=default_val_dict[tableqa_]["tqa_header"],
                    value=default_val_dict[tableqa_]["tqa_rows"],
                    row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1
                    )

                    tqa_answer = JSON(
                    default_val_dict[tableqa_]["tqa_answer"],
                    label = "问句:(输出)"
                    )

                    tqa_button = gr.Button("得到答案")

                tqa_button.click(run_tableqa, inputs=[
                    tqa_question,
                    tqa_data
                ], outputs=tqa_answer)

demo.launch(server_name="0.0.0.0")