Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio import * | |
from run import * | |
szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv")) | |
tableqa_ = "数据表问答(编辑数据)" | |
default_val_dict = { | |
tableqa_ :{ | |
"tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?", | |
"tqa_header": szse_summary_df.columns.tolist(), | |
"tqa_rows": szse_summary_df.values.tolist(), | |
"tqa_data_path": os.path.join(main_path ,"data/df1.csv"), | |
"tqa_answer": { | |
"sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5", | |
"cnt_num": 2, | |
"conclusion": [57.645] | |
} | |
} | |
} | |
def tableqa_layer(post_data): | |
question = post_data["question"] | |
table_rows = post_data["table_rows"] | |
table_header = post_data["table_header"] | |
assert all(map(lambda x: type(x) == type(""), [question, table_rows, table_header])) | |
table_rows = json.loads(table_rows) | |
table_header = json.loads(table_header) | |
assert all(map(lambda x: type(x) == type([]), [table_rows, table_header])) | |
if bool(table_rows) and bool(table_header): | |
assert len(table_header) == len(table_rows[0]) | |
df = pd.DataFrame(table_rows, columns = table_header) | |
conclusion = single_table_pred(question, df) | |
return conclusion | |
def run_tableqa(*input): | |
question, data = input | |
header = data.columns.tolist() | |
rows = data.values.tolist() | |
rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows)) | |
assert all(map(lambda x: type(x) == type([]), [header, rows])) | |
header = json.dumps(header) | |
rows = json.dumps(rows) | |
assert all(map(lambda x: type(x) == type(""), [question, header, rows])) | |
resp = tableqa_layer( | |
{ | |
"question": question, | |
"table_header": header, | |
"table_rows": rows | |
} | |
) | |
if "cnt_num" in resp: | |
if hasattr(resp["cnt_num"], "tolist"): | |
resp["cnt_num"] = resp["cnt_num"].tolist() | |
if "conclusion" in resp: | |
if hasattr(resp["conclusion"], "tolist"): | |
resp["conclusion"] = resp["conclusion"].tolist() | |
''' | |
import pickle as pkl | |
with open("resp.pkl", "wb") as f: | |
pkl.dump(resp, f) | |
print(resp) | |
''' | |
resp = json.loads(json.dumps(resp)) | |
return resp | |
demo = gr.Blocks(css=".container { max-width: 800px; margin: auto; }") | |
with demo: | |
gr.Markdown("") | |
gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/tableQA-Chinese](https://github.com/svjack/tableQA-Chinese)</h4></b>\n") | |
with gr.Tabs(): | |
#### tableqa | |
with gr.TabItem("数据表问答(TableQA)"): | |
with gr.Tabs(): | |
with gr.TabItem(tableqa_): | |
tqa_question = gr.Textbox( | |
default_val_dict[tableqa_]["tqa_question"], | |
label = "问句:(输入)" | |
) | |
tqa_data = gr.Dataframe( | |
headers=default_val_dict[tableqa_]["tqa_header"], | |
value=default_val_dict[tableqa_]["tqa_rows"], | |
row_count = len(default_val_dict[tableqa_]["tqa_rows"]) + 1 | |
) | |
tqa_answer = JSON( | |
default_val_dict[tableqa_]["tqa_answer"], | |
label = "问句:(输出)" | |
) | |
tqa_button = gr.Button("得到答案") | |
tqa_button.click(run_tableqa, inputs=[ | |
tqa_question, | |
tqa_data | |
], outputs=tqa_answer) | |
demo.launch(server_name="0.0.0.0") | |