|
import gradio as gr |
|
from gradio import * |
|
|
|
from run import * |
|
|
|
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) |
|
tableqa_ = "数据表问答(编辑数据)" |
|
|
|
default_val_dict = { |
|
tableqa_ :{ |
|
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", |
|
"tqa_header": szse_summary_df.columns.tolist(), |
|
"tqa_rows": szse_summary_df.values.tolist(), |
|
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"), |
|
"tqa_answer": { |
|
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", |
|
"cnt_num": 2, |
|
"conclusion": [57.645] |
|
} |
|
} |
|
} |
|
|
|
def tableqa_layer(post_data): |
|
question = post_data["question"] |
|
table_rows = post_data["table_rows"] |
|
table_header = post_data["table_header"] |
|
assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header])) |
|
table_rows = json.loads(table_rows) |
|
table_header = json.loads(table_header) |
|
|
|
assert all(map(lambda x: type(x) == type([]), [table_rows, table_header])) |
|
if bool(table_rows) and bool(table_header): |
|
assert len(table_header) == len(table_rows[0]) |
|
df = pd.DataFrame(table_rows, columns = table_header) |
|
conclusion = single_table_pred(question, df) |
|
return conclusion |
|
|
|
def run_tableqa(*input): |
|
question, data = input |
|
header = data.columns.tolist() |
|
rows = data.values.tolist() |
|
|
|
rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows)) |
|
|
|
assert all(map(lambda x: type(x) == type([]), [header, rows])) |
|
header = json.dumps(header) |
|
rows = json.dumps(rows) |
|
|
|
assert all(map(lambda x: type(x) == type(""), [question, header, rows])) |
|
|
|
resp = tableqa_layer( |
|
{ |
|
"question": question, |
|
"table_header": header, |
|
"table_rows": rows |
|
} |
|
) |
|
if "cnt_num" in resp: |
|
if hasattr(resp["cnt_num"], "tolist"): |
|
resp["cnt_num"] = resp["cnt_num"].tolist() |
|
if "conclusion" in resp: |
|
if hasattr(resp["conclusion"], "tolist"): |
|
resp["conclusion"] = resp["conclusion"].tolist() |
|
''' |
|
import pickle as pkl |
|
with open("resp.pkl", "wb") as f: |
|
pkl.dump(resp, f) |
|
print(resp) |
|
''' |
|
resp = json.loads(json.dumps(resp)) |
|
return resp |
|
|
|
demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") |
|
|
|
with demo: |
|
gr.Markdown("") |
|
gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)</h4></b>\n") |
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("数据表问答(TableQA)"): |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem(tableqa_): |
|
tqa_question = gr.Textbox( |
|
default_val_dict[tableqa_]["tqa_question"], |
|
label = "问句:(输入)" |
|
) |
|
|
|
tqa_data = gr.Dataframe( |
|
headers=default_val_dict[tableqa_]["tqa_header"], |
|
value=default_val_dict[tableqa_]["tqa_rows"], |
|
row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1 |
|
) |
|
|
|
tqa_answer = JSON( |
|
default_val_dict[tableqa_]["tqa_answer"], |
|
label = "问句:(输出)" |
|
) |
|
|
|
tqa_button = gr.Button("得到答案") |
|
|
|
tqa_button.click(run_tableqa, inputs=[ |
|
tqa_question, |
|
tqa_data |
|
], outputs=tqa_answer) |
|
|
|
demo.launch(server_name="0.0.0.0") |
|
|