svjack commited on
Commit
6be3652
1 Parent(s): b5dbcf3

Delete Untitled.ipynb

Browse files
Files changed (1) hide show
  1. Untitled.ipynb +0 -761
Untitled.ipynb DELETED
@@ -1,761 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "404d2feb-8a5c-4a6e-9012-bd88edd0b5bb",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "data\t\t __pycache__\t run.py\t\t Untitled.ipynb\n",
14
- "JointBERT-master requirements.txt tableQA_single_table.py\n"
15
- ]
16
- }
17
- ],
18
- "source": [
19
- "!ls"
20
- ]
21
- },
22
- {
23
- "cell_type": "code",
24
- "execution_count": 1,
25
- "id": "b729ace0-35b0-4ba3-b44f-b907deabf4fd",
26
- "metadata": {},
27
- "outputs": [],
28
- "source": [
29
- "import gradio as gr\n",
30
- "from gradio import *"
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "execution_count": 2,
36
- "id": "b1c9b471-9078-4413-b8e5-ed21c0f858fc",
37
- "metadata": {},
38
- "outputs": [],
39
- "source": [
40
- "from run import *"
41
- ]
42
- },
43
- {
44
- "cell_type": "code",
45
- "execution_count": 3,
46
- "id": "151526b1-8ae0-4c5f-86e0-de3afd80cd73",
47
- "metadata": {},
48
- "outputs": [],
49
- "source": [
50
- "szse_summary_df = pd.read_csv(os.path.join(main_path ,\"data/df1.csv\"))\n",
51
- "tableqa_ = \"数据表问答(编辑数据)\""
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": 4,
57
- "id": "9a5982e8-fb75-4a90-bca9-80a1b394e691",
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "default_val_dict = {\n",
62
- " tableqa_ :{\n",
63
- " \"tqa_question\": \"EPS大于0且周涨跌大于5的平均市值是多少?\",\n",
64
- " \"tqa_header\": szse_summary_df.columns.tolist(),\n",
65
- " \"tqa_rows\": szse_summary_df.values.tolist(),\n",
66
- " \"tqa_data_path\": os.path.join(main_path ,\"data/df1.csv\"),\n",
67
- " \"tqa_answer\": {\n",
68
- " \"sql_query\": \"SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5\",\n",
69
- " \"cnt_num\": 2,\n",
70
- " \"conclusion\": [57.645]\n",
71
- " }\n",
72
- " }\n",
73
- "}"
74
- ]
75
- },
76
- {
77
- "cell_type": "code",
78
- "execution_count": 5,
79
- "id": "10098d86-3b56-4fff-a2fd-1eca344cc114",
80
- "metadata": {},
81
- "outputs": [],
82
- "source": [
83
- "###default_val_dict"
84
- ]
85
- },
86
- {
87
- "cell_type": "code",
88
- "execution_count": 6,
89
- "id": "c5eb6ca8-0317-4e63-8407-7f3033164dad",
90
- "metadata": {},
91
- "outputs": [],
92
- "source": [
93
- "def tableqa_layer(post_data):\n",
94
- " question = post_data[\"question\"]\n",
95
- " table_rows = post_data[\"table_rows\"]\n",
96
- " table_header = post_data[\"table_header\"]\n",
97
- " assert all(map(lambda x: type(x) == type(\"\"), [question, table_rows, table_header]))\n",
98
- " table_rows = json.loads(table_rows)\n",
99
- " table_header = json.loads(table_header)\n",
100
- "\n",
101
- " assert all(map(lambda x: type(x) == type([]), [table_rows, table_header]))\n",
102
- " if bool(table_rows) and bool(table_header):\n",
103
- " assert len(table_header) == len(table_rows[0])\n",
104
- " df = pd.DataFrame(table_rows, columns = table_header)\n",
105
- " conclusion = single_table_pred(question, df)\n",
106
- " return conclusion"
107
- ]
108
- },
109
- {
110
- "cell_type": "code",
111
- "execution_count": 7,
112
- "id": "dab2d0e0-3e09-465d-9a02-af17d79cc8ea",
113
- "metadata": {},
114
- "outputs": [],
115
- "source": [
116
- "def run_tableqa(*input):\n",
117
- " question, data = input\n",
118
- " header = data.columns.tolist()\n",
119
- " rows = data.values.tolist()\n",
120
- "\n",
121
- " rows = list(filter(lambda x: any(map(lambda xx: bool(xx), x)), rows))\n",
122
- "\n",
123
- " assert all(map(lambda x: type(x) == type([]), [header, rows]))\n",
124
- " header = json.dumps(header)\n",
125
- " rows = json.dumps(rows)\n",
126
- "\n",
127
- " assert all(map(lambda x: type(x) == type(\"\"), [question, header, rows]))\n",
128
- " \n",
129
- " resp = tableqa_layer(\n",
130
- " {\n",
131
- " \"question\": question,\n",
132
- " \"table_header\": header,\n",
133
- " \"table_rows\": rows\n",
134
- " }\n",
135
- " )\n",
136
- " if \"cnt_num\" in resp:\n",
137
- " if hasattr(resp[\"cnt_num\"], \"tolist\"):\n",
138
- " resp[\"cnt_num\"] = resp[\"cnt_num\"].tolist()\n",
139
- " if \"conclusion\" in resp:\n",
140
- " if hasattr(resp[\"conclusion\"], \"tolist\"):\n",
141
- " resp[\"conclusion\"] = resp[\"conclusion\"].tolist()\n",
142
- " '''\n",
143
- " import pickle as pkl\n",
144
- " with open(\"resp.pkl\", \"wb\") as f:\n",
145
- " pkl.dump(resp, f)\n",
146
- " print(resp)\n",
147
- " '''\n",
148
- " resp = json.loads(json.dumps(resp))\n",
149
- " return resp"
150
- ]
151
- },
152
- {
153
- "cell_type": "code",
154
- "execution_count": 8,
155
- "id": "62b96be0-d491-42f7-9655-0f65f38e1d75",
156
- "metadata": {},
157
- "outputs": [],
158
- "source": [
159
- "###\n",
160
- "###np.asarray(2).tolist()"
161
- ]
162
- },
163
- {
164
- "cell_type": "code",
165
- "execution_count": 9,
166
- "id": "b63dc5ae-49af-4fa7-abf0-681fdb37e2e2",
167
- "metadata": {},
168
- "outputs": [],
169
- "source": [
170
- "###json.loads(json.dumps(default_val_dict))"
171
- ]
172
- },
173
- {
174
- "cell_type": "code",
175
- "execution_count": 10,
176
- "id": "41abba16-830e-4755-a783-92fef780d3e5",
177
- "metadata": {},
178
- "outputs": [],
179
- "source": [
180
- "demo = gr.Blocks(css=\".container { max-width: 800px; margin: auto; }\")\n",
181
- "\n",
182
- "with demo:\n",
183
- " gr.Markdown(\"\")\n",
184
- " with gr.Tabs():\n",
185
- " #### tableqa\n",
186
- " with gr.TabItem(\"数据表问答(TableQA)\"):\n",
187
- "\n",
188
- " with gr.Tabs():\n",
189
- " with gr.TabItem(tableqa_):\n",
190
- " tqa_question = gr.Textbox(\n",
191
- " default_val_dict[tableqa_][\"tqa_question\"],\n",
192
- " label = \"问句:(输入)\"\n",
193
- " )\n",
194
- "\n",
195
- " tqa_data = gr.Dataframe(\n",
196
- " headers=default_val_dict[tableqa_][\"tqa_header\"],\n",
197
- " value=default_val_dict[tableqa_][\"tqa_rows\"],\n",
198
- " row_count = len(default_val_dict[tableqa_][\"tqa_rows\"]) + 1\n",
199
- " )\n",
200
- "\n",
201
- " tqa_answer = JSON(\n",
202
- " default_val_dict[tableqa_][\"tqa_answer\"],\n",
203
- " label = \"问句:(输出)\"\n",
204
- " )\n",
205
- "\n",
206
- " tqa_button = gr.Button(\"得到答案\")\n",
207
- "\n",
208
- " tqa_button.click(run_tableqa, inputs=[\n",
209
- " tqa_question,\n",
210
- " tqa_data\n",
211
- " ], outputs=tqa_answer)\n",
212
- "\n"
213
- ]
214
- },
215
- {
216
- "cell_type": "code",
217
- "execution_count": 11,
218
- "id": "68a25da4-a514-45c4-8d49-264210b14f87",
219
- "metadata": {},
220
- "outputs": [
221
- {
222
- "name": "stdout",
223
- "output_type": "stream",
224
- "text": [
225
- "Running on local URL: http://172.16.56.206:7860\n",
226
- "Running on public URL: https://c7bd449621083f2690.gradio.live\n",
227
- "\n",
228
- "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
229
- ]
230
- },
231
- {
232
- "data": {
233
- "text/html": [
234
- "<div><iframe src=\"https://c7bd449621083f2690.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
235
- ],
236
- "text/plain": [
237
- "<IPython.core.display.HTML object>"
238
- ]
239
- },
240
- "metadata": {},
241
- "output_type": "display_data"
242
- },
243
- {
244
- "data": {
245
- "text/plain": []
246
- },
247
- "execution_count": 11,
248
- "metadata": {},
249
- "output_type": "execute_result"
250
- },
251
- {
252
- "name": "stderr",
253
- "output_type": "stream",
254
- "text": [
255
- "Building prefix dict from the default dictionary ...\n",
256
- "Loading model from cache /tmp/jieba.cache\n",
257
- "Loading model cost 0.680 seconds.\n",
258
- "Prefix dict has been built successfully.\n"
259
- ]
260
- }
261
- ],
262
- "source": [
263
- "demo.launch(server_name=\"172.16.56.206\" ,share = True)"
264
- ]
265
- },
266
- {
267
- "cell_type": "code",
268
- "execution_count": null,
269
- "id": "2bcddf9e-b269-4d74-a4c7-c1ad294e2408",
270
- "metadata": {},
271
- "outputs": [],
272
- "source": []
273
- },
274
- {
275
- "cell_type": "code",
276
- "execution_count": null,
277
- "id": "3d7c4e4f-0abb-43a1-aa40-3fe05aeb8e2e",
278
- "metadata": {},
279
- "outputs": [],
280
- "source": []
281
- },
282
- {
283
- "cell_type": "code",
284
- "execution_count": null,
285
- "id": "35169744-0f5e-48be-b844-31eb35f0ce69",
286
- "metadata": {},
287
- "outputs": [],
288
- "source": []
289
- },
290
- {
291
- "cell_type": "code",
292
- "execution_count": null,
293
- "id": "66693900-ffeb-473e-b411-465fd33b2ea9",
294
- "metadata": {},
295
- "outputs": [],
296
- "source": []
297
- },
298
- {
299
- "cell_type": "code",
300
- "execution_count": 13,
301
- "id": "e1494b23-b335-441d-a572-928bf99f20d8",
302
- "metadata": {},
303
- "outputs": [],
304
- "source": [
305
- "import pickle as pkl"
306
- ]
307
- },
308
- {
309
- "cell_type": "code",
310
- "execution_count": 14,
311
- "id": "aed88b91-f662-4f2a-b62b-c8eb062d145d",
312
- "metadata": {},
313
- "outputs": [],
314
- "source": [
315
- "with open(\"resp.pkl\", \"rb\") as f:\n",
316
- " resp = pkl.load(f)"
317
- ]
318
- },
319
- {
320
- "cell_type": "code",
321
- "execution_count": 19,
322
- "id": "0329da8a-e16c-42f8-8efb-2bac80902285",
323
- "metadata": {},
324
- "outputs": [
325
- {
326
- "data": {
327
- "text/plain": [
328
- "{'sql_query': 'SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5',\n",
329
- " 'cnt_num': 2,\n",
330
- " 'conclusion': [57.645]}"
331
- ]
332
- },
333
- "execution_count": 19,
334
- "metadata": {},
335
- "output_type": "execute_result"
336
- }
337
- ],
338
- "source": [
339
- "resp"
340
- ]
341
- },
342
- {
343
- "cell_type": "code",
344
- "execution_count": 20,
345
- "id": "6241edc4-35da-427b-8fa7-356f7c20a547",
346
- "metadata": {},
347
- "outputs": [
348
- {
349
- "data": {
350
- "text/plain": [
351
- "numpy.int64"
352
- ]
353
- },
354
- "execution_count": 20,
355
- "metadata": {},
356
- "output_type": "execute_result"
357
- }
358
- ],
359
- "source": [
360
- "type(resp[\"cnt_num\"])"
361
- ]
362
- },
363
- {
364
- "cell_type": "code",
365
- "execution_count": 21,
366
- "id": "0fe9930d-fc27-4840-a8dd-ed91d44a922b",
367
- "metadata": {},
368
- "outputs": [],
369
- "source": [
370
- "if \"cnt_num\" in resp:\n",
371
- " if hasattr(resp[\"cnt_num\"], \"tolist\"):\n",
372
- " resp[\"cnt_num\"] = resp[\"cnt_num\"].tolist()\n",
373
- "if \"conclusion\" in resp:\n",
374
- " if hasattr(resp[\"conclusion\"], \"tolist\"):\n",
375
- " resp[\"conclusion\"] = resp[\"conclusion\"].tolist()"
376
- ]
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": 22,
381
- "id": "1e23ef44-acef-4aaa-9d36-0ccd97184d92",
382
- "metadata": {},
383
- "outputs": [
384
- {
385
- "data": {
386
- "text/plain": [
387
- "'{\"sql_query\": \"SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5\", \"cnt_num\": 2, \"conclusion\": [57.645]}'"
388
- ]
389
- },
390
- "execution_count": 22,
391
- "metadata": {},
392
- "output_type": "execute_result"
393
- }
394
- ],
395
- "source": [
396
- "json.dumps(resp)"
397
- ]
398
- },
399
- {
400
- "cell_type": "code",
401
- "execution_count": null,
402
- "id": "2523825d-2483-4149-ae16-b2b5f4177359",
403
- "metadata": {},
404
- "outputs": [],
405
- "source": []
406
- },
407
- {
408
- "cell_type": "code",
409
- "execution_count": null,
410
- "id": "574628ab-e413-4dfe-9d29-133d0a79c80e",
411
- "metadata": {},
412
- "outputs": [],
413
- "source": []
414
- },
415
- {
416
- "cell_type": "code",
417
- "execution_count": 1,
418
- "id": "921bd428-a667-429c-8142-6d9aef9c9fb6",
419
- "metadata": {},
420
- "outputs": [
421
- {
422
- "data": {
423
- "application/vnd.jupyter.widget-view+json": {
424
- "model_id": "791fee503c3f4aefb4802442d050b1ef",
425
- "version_major": 2,
426
- "version_minor": 0
427
- },
428
- "text/plain": [
429
- "Downloading: 0%| | 0.00/110k [00:00<?, ?B/s]"
430
- ]
431
- },
432
- "metadata": {},
433
- "output_type": "display_data"
434
- }
435
- ],
436
- "source": [
437
- "from tableQA_single_table import *"
438
- ]
439
- },
440
- {
441
- "cell_type": "code",
442
- "execution_count": null,
443
- "id": "a5890d29-fc99-4bf9-aa61-d55c1989aea9",
444
- "metadata": {},
445
- "outputs": [],
446
- "source": []
447
- },
448
- {
449
- "cell_type": "code",
450
- "execution_count": null,
451
- "id": "0ac20004-cfa9-4259-b324-b2ea4fd9f8e5",
452
- "metadata": {},
453
- "outputs": [],
454
- "source": []
455
- },
456
- {
457
- "cell_type": "code",
458
- "execution_count": null,
459
- "id": "5362b7f7-7d3d-4d2a-b3a3-ead266fad51e",
460
- "metadata": {},
461
- "outputs": [],
462
- "source": []
463
- },
464
- {
465
- "cell_type": "code",
466
- "execution_count": null,
467
- "id": "5cc6eb17-f50d-46ac-b516-311c24b62376",
468
- "metadata": {},
469
- "outputs": [],
470
- "source": []
471
- },
472
- {
473
- "cell_type": "code",
474
- "execution_count": 4,
475
- "id": "b89e0bde-662d-48c6-801f-60bedf873846",
476
- "metadata": {},
477
- "outputs": [],
478
- "source": [
479
- "import json\n",
480
- "import os\n",
481
- "import sys"
482
- ]
483
- },
484
- {
485
- "cell_type": "code",
486
- "execution_count": 5,
487
- "id": "60682115-d2f4-417f-9347-f5feb4261750",
488
- "metadata": {},
489
- "outputs": [],
490
- "source": [
491
- "def run_sql_query(s, df):\n",
492
- " conn = sqlite3.connect(\":memory:\")\n",
493
- "\n",
494
- " assert isinstance(df, pd.DataFrame)\n",
495
- " question_column = s.question_column\n",
496
- " if question_column is None:\n",
497
- " return {\n",
498
- " \"sql_query\": \"\",\n",
499
- " \"cnt_num\": 0,\n",
500
- " \"conclusion\": []\n",
501
- " }\n",
502
- " total_conds_filtered = s.total_conds_filtered\n",
503
- " agg_pred = s.agg_pred\n",
504
- " conn_pred = s.conn_pred\n",
505
- " sql_format = \"SELECT {} FROM {} {}\"\n",
506
- " header = df.columns.tolist()\n",
507
- " if len(header) > len(set(header)):\n",
508
- " req = []\n",
509
- " have_req = set([])\n",
510
- " idx = 0\n",
511
- " for h in header:\n",
512
- " if h in have_req:\n",
513
- " idx += 1\n",
514
- " req.append(\"{}_{}\".format(h, idx))\n",
515
- " else:\n",
516
- " req.append(h)\n",
517
- " have_req.add(h)\n",
518
- " header = req\n",
519
- " def format_right(val):\n",
520
- " val = str(val)\n",
521
- " is_string = True\n",
522
- " try:\n",
523
- " literal_eval(val)\n",
524
- " is_string = False\n",
525
- " except:\n",
526
- " pass\n",
527
- " if is_string:\n",
528
- " return \"'{}'\".format(val)\n",
529
- " else:\n",
530
- " return val\n",
531
- " #ic(question_column, header)\n",
532
- " assert question_column in header\n",
533
- " assert all(map(lambda t3: t3[0] in header, total_conds_filtered))\n",
534
- " assert len(header) == len(set(header))\n",
535
- " index_header_mapping = dict(enumerate(header))\n",
536
- " header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items()))\n",
537
- " assert len(index_header_mapping) == len(header_index_mapping)\n",
538
- " df_saved = df.copy()\n",
539
- " df_saved.columns = list(map(lambda idx: \"col_{}\".format(idx), range(len(header))))\n",
540
- " df_saved.to_sql(\"Mem_Table\", conn, if_exists = \"replace\", index = False)\n",
541
- " question_column_idx = header.index(question_column)\n",
542
- " sql_question_column = \"col_{}\".format(question_column_idx)\n",
543
- " sql_total_conds_filtered = list(map(lambda t3: (\"col_{}\".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered))\n",
544
- " sql_agg_pred = agg_pred\n",
545
- " if sql_agg_pred.strip():\n",
546
- " sql_agg_pred = \"{}()\".format(sql_agg_pred)\n",
547
- " else:\n",
548
- " sql_agg_pred = \"()\"\n",
549
- " sql_agg_pred = sql_agg_pred.replace(\"()\", \"({})\")\n",
550
- " sql_conn_pred = conn_pred\n",
551
- " if sql_conn_pred.strip():\n",
552
- " pass\n",
553
- " else:\n",
554
- " sql_conn_pred = \"\"\n",
555
- " #sql_where_string = \"\" if not (sql_total_conds_filtered and sql_conn_pred) else \"WHERE {}\".format(\" {} \".format(sql_conn_pred).join(map(lambda t3: \"{} {} {}\".format(t3[0],\"=\" if t3[1] == \"==\" else t3[1], t3[2]), sql_total_conds_filtered)))\n",
556
- " sql_where_string = \"\" if not (sql_total_conds_filtered) else \"WHERE {}\".format(\" {} \".format(sql_conn_pred if sql_conn_pred else \"and\").join(map(lambda t3: \"{} {} {}\".format(t3[0],\"=\" if t3[1] == \"==\" else t3[1], t3[2]), sql_total_conds_filtered)))\n",
557
- " #ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s)\n",
558
- " sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), \"Mem_Table\", sql_where_string)\n",
559
- " cnt_sql_query = sql_format.format(\"COUNT(*)\", \"Mem_Table\", sql_where_string).strip()\n",
560
- " #ic(cnt_sql_query)\n",
561
- " cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0]\n",
562
- " if cnt_num == 0:\n",
563
- " return {\n",
564
- " \"sql_query\": sql_query,\n",
565
- " \"cnt_num\": 0,\n",
566
- " \"conclusion\": []\n",
567
- " }\n",
568
- " query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist()\n",
569
- " return {\n",
570
- " \"sql_query\": sql_query,\n",
571
- " \"cnt_num\": cnt_num,\n",
572
- " \"conclusion\": query_conclusion_list\n",
573
- " }\n",
574
- "\n",
575
- "#save_conn = sqlite3.connect(\":memory:\")\n",
576
- "def single_table_pred(question, pd_df):\n",
577
- " assert type(question) == type(\"\")\n",
578
- " assert isinstance(pd_df, pd.DataFrame)\n",
579
- " qs_df = pd.DataFrame([[question]], columns = [\"question\"])\n",
580
- "\n",
581
- " #print(\"pd_df :\")\n",
582
- " #print(pd_df)\n",
583
- "\n",
584
- " tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False)\n",
585
- "\n",
586
- " #print(\"tableqa_df :\")\n",
587
- " #print(tableqa_df)\n",
588
- "\n",
589
- " assert tableqa_df.shape[0] == 1\n",
590
- " #sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn)\n",
591
- " sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df)\n",
592
- " return sql_query_dict\n"
593
- ]
594
- },
595
- {
596
- "cell_type": "code",
597
- "execution_count": 22,
598
- "id": "a85ca3ee-3d78-4605-aad1-17e29f557c77",
599
- "metadata": {},
600
- "outputs": [],
601
- "source": [
602
- "szse_summary_df = pd.read_csv(os.path.join(main_path ,\"data/df1.csv\"))"
603
- ]
604
- },
605
- {
606
- "cell_type": "code",
607
- "execution_count": 25,
608
- "id": "090a2750-5c6c-454f-bb0d-d9eabfb00721",
609
- "metadata": {},
610
- "outputs": [],
611
- "source": [
612
- "data = {\n",
613
- " \"tqa_question\": \"EPS大于0且周涨跌大于5的平均市值是多少?\",\n",
614
- " \"tqa_header\": szse_summary_df.columns.tolist(),\n",
615
- " \"tqa_rows\": szse_summary_df.values.tolist(),\n",
616
- " \"tqa_data_path\": os.path.join(main_path ,\"data/df1.csv\"),\n",
617
- " \"tqa_answer\": {\n",
618
- " \"sql_query\": \"SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5\",\n",
619
- " \"cnt_num\": 2,\n",
620
- " \"conclusion\": [57.645]\n",
621
- " }\n",
622
- "}"
623
- ]
624
- },
625
- {
626
- "cell_type": "code",
627
- "execution_count": 28,
628
- "id": "b9bbaf84-53f5-4fac-82f9-3cdb2a6a355a",
629
- "metadata": {},
630
- "outputs": [],
631
- "source": [
632
- "pd_df = pd.DataFrame(data[\"tqa_rows\"], columns = data[\"tqa_header\"])\n",
633
- "question = data[\"tqa_question\"]"
634
- ]
635
- },
636
- {
637
- "cell_type": "code",
638
- "execution_count": 29,
639
- "id": "92fd092e-58ec-4951-850d-dc1cba07adea",
640
- "metadata": {},
641
- "outputs": [
642
- {
643
- "data": {
644
- "text/plain": [
645
- "{'sql_query': 'SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5',\n",
646
- " 'cnt_num': 2,\n",
647
- " 'conclusion': [57.645]}"
648
- ]
649
- },
650
- "execution_count": 29,
651
- "metadata": {},
652
- "output_type": "execute_result"
653
- }
654
- ],
655
- "source": [
656
- "single_table_pred(question, pd_df)"
657
- ]
658
- },
659
- {
660
- "cell_type": "code",
661
- "execution_count": null,
662
- "id": "27cfe48f-49df-4e95-b896-6ad499851920",
663
- "metadata": {},
664
- "outputs": [],
665
- "source": []
666
- },
667
- {
668
- "cell_type": "code",
669
- "execution_count": null,
670
- "id": "89cddb81-0b82-4f30-b553-60d1c9a9b007",
671
- "metadata": {},
672
- "outputs": [],
673
- "source": []
674
- },
675
- {
676
- "cell_type": "code",
677
- "execution_count": 8,
678
- "id": "892fa614-7b9e-440e-bac1-7b4c68641fbb",
679
- "metadata": {},
680
- "outputs": [],
681
- "source": [
682
- "data = {\n",
683
- " \"question\": \"翔的出生地是什么?\",\n",
684
- " \"table_header\": json.dumps(\n",
685
- " [\"姓名\", \"年龄\", \"出生地\"]\n",
686
- " ),\n",
687
- " \"table_rows\": json.dumps(\n",
688
- " [\n",
689
- " [\"王翔\", 31, \"宁波\"],\n",
690
- " [\"王雨\", 12, \"翔雨宾馆\"]\n",
691
- " ]\n",
692
- " )\n",
693
- " }"
694
- ]
695
- },
696
- {
697
- "cell_type": "code",
698
- "execution_count": 20,
699
- "id": "261de835-fe0b-4c4f-ac0b-70d0f2026f4d",
700
- "metadata": {},
701
- "outputs": [],
702
- "source": [
703
- "question = data[\"question\"]\n",
704
- "question = \"年龄是王翔的是什么?\"\n",
705
- "pd_df = pd.DataFrame(json.loads(data[\"table_rows\"]), columns = json.loads(data[\"table_header\"]))"
706
- ]
707
- },
708
- {
709
- "cell_type": "code",
710
- "execution_count": 21,
711
- "id": "fe1cb949-049f-47f1-97cc-37f34a369bc6",
712
- "metadata": {},
713
- "outputs": [
714
- {
715
- "data": {
716
- "text/plain": [
717
- "{'sql_query': 'SELECT (col_1) FROM Mem_Table ',\n",
718
- " 'cnt_num': 2,\n",
719
- " 'conclusion': [31, 12]}"
720
- ]
721
- },
722
- "execution_count": 21,
723
- "metadata": {},
724
- "output_type": "execute_result"
725
- }
726
- ],
727
- "source": [
728
- "single_table_pred(question, pd_df)"
729
- ]
730
- },
731
- {
732
- "cell_type": "code",
733
- "execution_count": null,
734
- "id": "f4f42ab2-6ab7-405a-b4f0-2e925878ed77",
735
- "metadata": {},
736
- "outputs": [],
737
- "source": []
738
- }
739
- ],
740
- "metadata": {
741
- "kernelspec": {
742
- "display_name": "Python 3 (ipykernel)",
743
- "language": "python",
744
- "name": "python3"
745
- },
746
- "language_info": {
747
- "codemirror_mode": {
748
- "name": "ipython",
749
- "version": 3
750
- },
751
- "file_extension": ".py",
752
- "mimetype": "text/x-python",
753
- "name": "python",
754
- "nbconvert_exporter": "python",
755
- "pygments_lexer": "ipython3",
756
- "version": "3.7.10"
757
- }
758
- },
759
- "nbformat": 4,
760
- "nbformat_minor": 5
761
- }