svjack commited on
Commit
b5dbcf3
•
1 Parent(s): 22834b4

Upload . with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  2. .ipynb_checkpoints/requirements-checkpoint.txt +20 -0
  3. .ipynb_checkpoints/run-checkpoint.py +127 -0
  4. .ipynb_checkpoints/tableQA_single_table-checkpoint.py +1295 -0
  5. JointBERT-master/.gitignore +114 -0
  6. JointBERT-master/LICENSE +201 -0
  7. JointBERT-master/README.md +96 -0
  8. JointBERT-master/__pycache__/data_loader.cpython-37.pyc +0 -0
  9. JointBERT-master/__pycache__/data_loader.cpython-38.pyc +0 -0
  10. JointBERT-master/__pycache__/main.cpython-37.pyc +0 -0
  11. JointBERT-master/__pycache__/main.cpython-38.pyc +0 -0
  12. JointBERT-master/__pycache__/predict.cpython-37.pyc +0 -0
  13. JointBERT-master/__pycache__/predict.cpython-38.pyc +0 -0
  14. JointBERT-master/__pycache__/trainer.cpython-37.pyc +0 -0
  15. JointBERT-master/__pycache__/trainer.cpython-38.pyc +0 -0
  16. JointBERT-master/__pycache__/utils.cpython-37.pyc +0 -0
  17. JointBERT-master/__pycache__/utils.cpython-38.pyc +0 -0
  18. JointBERT-master/data/atis/dev/label +500 -0
  19. JointBERT-master/data/atis/dev/seq.in +500 -0
  20. JointBERT-master/data/atis/dev/seq.out +500 -0
  21. JointBERT-master/data/atis/intent_label.txt +22 -0
  22. JointBERT-master/data/atis/slot_label.txt +122 -0
  23. JointBERT-master/data/atis/test/label +893 -0
  24. JointBERT-master/data/atis/test/seq.in +893 -0
  25. JointBERT-master/data/atis/test/seq.out +893 -0
  26. JointBERT-master/data/atis/train/label +4478 -0
  27. JointBERT-master/data/atis/train/seq.in +0 -0
  28. JointBERT-master/data/atis/train/seq.out +0 -0
  29. JointBERT-master/data/snips/dev/label +700 -0
  30. JointBERT-master/data/snips/dev/seq.in +700 -0
  31. JointBERT-master/data/snips/dev/seq.out +700 -0
  32. JointBERT-master/data/snips/intent_label.txt +8 -0
  33. JointBERT-master/data/snips/slot_label.txt +74 -0
  34. JointBERT-master/data/snips/test/label +700 -0
  35. JointBERT-master/data/snips/test/seq.in +700 -0
  36. JointBERT-master/data/snips/test/seq.out +700 -0
  37. JointBERT-master/data/snips/train/label +0 -0
  38. JointBERT-master/data/snips/train/seq.in +0 -0
  39. JointBERT-master/data/snips/train/seq.out +0 -0
  40. JointBERT-master/data/vocab_process.py +48 -0
  41. JointBERT-master/data_loader.py +255 -0
  42. JointBERT-master/main.py +72 -0
  43. JointBERT-master/model/__init__.py +3 -0
  44. JointBERT-master/model/__pycache__/__init__.cpython-37.pyc +0 -0
  45. JointBERT-master/model/__pycache__/__init__.cpython-38.pyc +0 -0
  46. JointBERT-master/model/__pycache__/modeling_jointalbert.cpython-37.pyc +0 -0
  47. JointBERT-master/model/__pycache__/modeling_jointalbert.cpython-38.pyc +0 -0
  48. JointBERT-master/model/__pycache__/modeling_jointbert.cpython-37.pyc +0 -0
  49. JointBERT-master/model/__pycache__/modeling_jointbert.cpython-38.pyc +0 -0
  50. JointBERT-master/model/__pycache__/modeling_jointdistilbert.cpython-37.pyc +0 -0
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.6.0
2
+ #dask[dataframe]
3
+ #dask[distributed]
4
+ #keybert
5
+ #bertopic
6
+ jieba
7
+ seaborn
8
+ sqlite_utils
9
+ sqlitefts
10
+ icecream
11
+ protobuf
12
+ #snorkel
13
+ pyarrow
14
+ transformers==3.0.2
15
+ seqeval==0.0.12
16
+ pytorch-crf==0.7.2
17
+ rank_bm25
18
+ nltk
19
+ gradio
20
+
.ipynb_checkpoints/run-checkpoint.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tableQA_single_table import *
2
+
3
+ import json
4
+ import os
5
+ import sys
6
+
7
+ def run_sql_query(s, df):
8
+ conn = sqlite3.connect(":memory:")
9
+
10
+ assert isinstance(df, pd.DataFrame)
11
+ question_column = s.question_column
12
+ if question_column is None:
13
+ return {
14
+ "sql_query": "",
15
+ "cnt_num": 0,
16
+ "conclusion": []
17
+ }
18
+ total_conds_filtered = s.total_conds_filtered
19
+ agg_pred = s.agg_pred
20
+ conn_pred = s.conn_pred
21
+ sql_format = "SELECT {} FROM {} {}"
22
+ header = df.columns.tolist()
23
+ if len(header) > len(set(header)):
24
+ req = []
25
+ have_req = set([])
26
+ idx = 0
27
+ for h in header:
28
+ if h in have_req:
29
+ idx += 1
30
+ req.append("{}_{}".format(h, idx))
31
+ else:
32
+ req.append(h)
33
+ have_req.add(h)
34
+ header = req
35
+ def format_right(val):
36
+ val = str(val)
37
+ is_string = True
38
+ try:
39
+ literal_eval(val)
40
+ is_string = False
41
+ except:
42
+ pass
43
+ if is_string:
44
+ return "'{}'".format(val)
45
+ else:
46
+ return val
47
+ #ic(question_column, header)
48
+ assert question_column in header
49
+ assert all(map(lambda t3: t3[0] in header, total_conds_filtered))
50
+ assert len(header) == len(set(header))
51
+ index_header_mapping = dict(enumerate(header))
52
+ header_index_mapping = dict(map(lambda t2: (t2[1], t2[0]) ,index_header_mapping.items()))
53
+ assert len(index_header_mapping) == len(header_index_mapping)
54
+ df_saved = df.copy()
55
+ df_saved.columns = list(map(lambda idx: "col_{}".format(idx), range(len(header))))
56
+ df_saved.to_sql("Mem_Table", conn, if_exists = "replace", index = False)
57
+ question_column_idx = header.index(question_column)
58
+ sql_question_column = "col_{}".format(question_column_idx)
59
+ sql_total_conds_filtered = list(map(lambda t3: ("col_{}".format(header.index(t3[0])), t3[1], format_right(t3[2])), total_conds_filtered))
60
+ sql_agg_pred = agg_pred
61
+ if sql_agg_pred.strip():
62
+ sql_agg_pred = "{}()".format(sql_agg_pred)
63
+ else:
64
+ sql_agg_pred = "()"
65
+ sql_agg_pred = sql_agg_pred.replace("()", "({})")
66
+ sql_conn_pred = conn_pred
67
+ if sql_conn_pred.strip():
68
+ pass
69
+ else:
70
+ sql_conn_pred = ""
71
+ #sql_where_string = "" if not (sql_total_conds_filtered and sql_conn_pred) else "WHERE {}".format(" {} ".format(sql_conn_pred).join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
72
+ sql_where_string = "" if not (sql_total_conds_filtered) else "WHERE {}".format(" {} ".format(sql_conn_pred if sql_conn_pred else "and").join(map(lambda t3: "{} {} {}".format(t3[0],"=" if t3[1] == "==" else t3[1], t3[2]), sql_total_conds_filtered)))
73
+ #ic(sql_total_conds_filtered, sql_conn_pred, sql_where_string, s)
74
+ sql_query = sql_format.format(sql_agg_pred.format(sql_question_column), "Mem_Table", sql_where_string)
75
+ cnt_sql_query = sql_format.format("COUNT(*)", "Mem_Table", sql_where_string).strip()
76
+ #ic(cnt_sql_query)
77
+ cnt_num = pd.read_sql(cnt_sql_query, conn).values.reshape((-1,))[0]
78
+ if cnt_num == 0:
79
+ return {
80
+ "sql_query": sql_query,
81
+ "cnt_num": 0,
82
+ "conclusion": []
83
+ }
84
+ query_conclusion_list = pd.read_sql(sql_query, conn).values.reshape((-1,)).tolist()
85
+ return {
86
+ "sql_query": sql_query,
87
+ "cnt_num": cnt_num,
88
+ "conclusion": query_conclusion_list
89
+ }
90
+
91
+ #save_conn = sqlite3.connect(":memory:")
92
+ def single_table_pred(question, pd_df):
93
+ assert type(question) == type("")
94
+ assert isinstance(pd_df, pd.DataFrame)
95
+ qs_df = pd.DataFrame([[question]], columns = ["question"])
96
+
97
+ #print("pd_df :")
98
+ #print(pd_df)
99
+
100
+ tableqa_df = full_before_cat_decomp(pd_df, qs_df, only_req_columns=False)
101
+
102
+ #print("tableqa_df :")
103
+ #print(tableqa_df)
104
+
105
+ assert tableqa_df.shape[0] == 1
106
+ #sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df, save_conn)
107
+ sql_query_dict = run_sql_query(tableqa_df.iloc[0], pd_df)
108
+ return sql_query_dict
109
+
110
+
111
+ if __name__ == "__main__":
112
+ szse_summary_df = pd.read_csv(os.path.join(main_path ,"data/df1.csv"))
113
+ data = {
114
+ "tqa_question": "EPS大于0且周涨跌大于5的平均市值是多少?",
115
+ "tqa_header": szse_summary_df.columns.tolist(),
116
+ "tqa_rows": szse_summary_df.values.tolist(),
117
+ "tqa_data_path": os.path.join(main_path ,"data/df1.csv"),
118
+ "tqa_answer": {
119
+ "sql_query": "SELECT AVG(col_4) FROM Mem_Table WHERE col_5 > 0 and col_3 > 5",
120
+ "cnt_num": 2,
121
+ "conclusion": [57.645]
122
+ }
123
+ }
124
+
125
+ pd_df = pd.DataFrame(data["tqa_rows"], columns = data["tqa_header"])
126
+ question = data["tqa_question"]
127
+ single_table_pred(question, pd_df)
.ipynb_checkpoints/tableQA_single_table-checkpoint.py ADDED
@@ -0,0 +1,1295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ #### env base_cp
4
+
5
+ #main_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main"
6
+ #main_path = "/User/tableQA-Chinese-main"
7
+ #main_path = "/temp/tableQA-Chinese-main"
8
+ main_path = "."
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+ import os
13
+ import ast
14
+ import re
15
+ import json
16
+ from icecream import ic
17
+ from copy import deepcopy
18
+ from itertools import product, combinations
19
+
20
+
21
+ import pandas as pd
22
+ import os
23
+ import sys
24
+ from pyarrow.filesystem import LocalFileSystem
25
+ from functools import reduce
26
+ import nltk
27
+ from nltk import pos_tag, word_tokenize
28
+ from collections import namedtuple
29
+ from ast import literal_eval
30
+
31
+ from torch.nn import functional
32
+ import numpy as np
33
+ import torch
34
+ from torch import nn
35
+ from torch.nn import init
36
+ from torch.nn.utils import rnn as rnn_utils
37
+ import math
38
+
39
+ from icecream import ic
40
+ import seaborn as sns
41
+
42
+ import matplotlib.pyplot as plt
43
+
44
+ import shutil
45
+
46
+ #from keybert import KeyBERT
47
+ #from bertopic import BERTopic
48
+
49
+
50
+ import sqlite3
51
+ import sqlite_utils
52
+ from icecream import ic
53
+ import jieba
54
+ import pandas as pd
55
+ import urllib.request
56
+ from urllib.parse import quote
57
+ from time import sleep
58
+ import json
59
+ import os
60
+ from collections import defaultdict
61
+ import re
62
+ from functools import reduce, partial
63
+
64
+ #### used in this condition extract in training.
65
+ op_sql_dict = {0:">", 1:"<", 2:"==", 3:"!="}
66
+ #### used by clf for intension inference
67
+ agg_sql_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
68
+ #### final to combine them (one for 0, and multi for 1 2)
69
+ conn_sql_dict = {0:"", 1:"and", 2:"or"}
70
+
71
+ #### kws and time pattern defination
72
+ and_kws = ("且", "而且", "并且", "和", "当中", "同时")
73
+ or_kws = ("或", "或者",)
74
+ conn_kws = and_kws + or_kws
75
+
76
+ pattern_list = [u"[年月\.\-\d]+", u"[年月\d]+", u"[年个月\d]+", u"[年月日\d]+"]
77
+
78
+ time_kws = ("什么时候", "时间", "时候")
79
+
80
+ sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",)
81
+ mean_kws = ('平均数', '均值', '平均值', '平均')
82
+ max_kws = ('最大', '最多', '最大值', '最高')
83
+ min_kws = ('最少', '最小值', '最小', '最低')
84
+ sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",)
85
+ max_special_kws = ("以上", "大于")
86
+ min_special_kws = ("以下", "小于")
87
+
88
+ qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个")
89
+
90
+ only_kws_columns = {"城市": "=="}
91
+
92
+ ##### jointbert predict model init start
93
+ #jointbert_path = "../../featurize/JointBERT"
94
+ #jointbert_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/JointBERT-master"
95
+ jointbert_path = os.path.join(main_path, "JointBERT-master")
96
+ sys.path.append(jointbert_path)
97
+
98
+
99
+ from model.modeling_jointbert import JointBERT
100
+ from model.modeling_jointbert import *
101
+ from trainer import *
102
+ from main import *
103
+ from data_loader import *
104
+
105
+
106
+ pred_parser = argparse.ArgumentParser()
107
+
108
+ pred_parser.add_argument("--input_file", default="conds_pred/seq.in", type=str, help="Input file for prediction")
109
+ pred_parser.add_argument("--output_file", default="conds_pred/sample_pred_out.txt", type=str, help="Output file for prediction")
110
+ #pred_parser.add_argument("--model_dir", default="bert", type=str, help="Path to save, load model")
111
+ pred_parser.add_argument("--model_dir", default= os.path.join(main_path ,"data/bert"), type=str, help="Path to save, load model")
112
+ #pred_parser.add_argument("--model_dir", default= os.path.join(main_path ,"JBert_Zh_Condition_Extractor"), type=str, help="Path to save, load model")
113
+
114
+
115
+ pred_parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction")
116
+ pred_parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
117
+
118
+
119
+ pred_parser_config_dict = dict(map(lambda item:(item.option_strings[0].replace("--", ""), item.default) ,pred_parser.__dict__["_actions"]))
120
+ pred_parser_config_dict = dict(filter(lambda t2: t2[0] != "-h", pred_parser_config_dict.items()))
121
+
122
+ pred_parser_namedtuple = namedtuple("pred_parser_config", pred_parser_config_dict.keys())
123
+ for k, v in pred_parser_config_dict.items():
124
+ if type(v) == type(""):
125
+ exec("pred_parser_namedtuple.{}='{}'".format(k, v))
126
+ else:
127
+ exec("pred_parser_namedtuple.{}={}".format(k, v))
128
+
129
+
130
+ from predict import *
131
+
132
+
133
+ pred_config = pred_parser_namedtuple
134
+ args = get_args(pred_config)
135
+ device = get_device(pred_config)
136
+
137
+ args_parser_namedtuple = namedtuple("args_config", args.keys())
138
+ for k, v in args.items():
139
+ if type(v) == type(""):
140
+ exec("args_parser_namedtuple.{}='{}'".format(k, v))
141
+ else:
142
+ exec("args_parser_namedtuple.{}={}".format(k, v))
143
+
144
+
145
+ args = args_parser_namedtuple
146
+
147
+ #args.data_dir = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/data"
148
+ args.data_dir = os.path.join(main_path, "data")
149
+
150
+ '''
151
+ pred_model = MODEL_CLASSES["bert"][1].from_pretrained(args.model_dir,
152
+ args=args,
153
+ intent_label_lst=get_intent_labels(args),
154
+ slot_label_lst=get_slot_labels(args))
155
+ '''
156
+ pred_model = MODEL_CLASSES["bert"][1].from_pretrained(
157
+ os.path.join(main_path, "data/bert")
158
+ ,
159
+ args=args,
160
+ intent_label_lst=get_intent_labels(args),
161
+ slot_label_lst=get_slot_labels(args))
162
+
163
+ pred_model.to(device)
164
+ pred_model.eval()
165
+
166
+ intent_label_lst = get_intent_labels(args)
167
+ slot_label_lst = get_slot_labels(args)
168
+ pad_token_label_id = args.ignore_index
169
+ tokenizer = load_tokenizer(args)
170
+ ## jointbert predict model init end
171
+
172
+
173
+ ###### one sent conds decomp start
174
+ def predict_single_sent(question):
175
+ text = " ".join(list(question))
176
+ batch = convert_input_file_to_tensor_dataset([text.split(" ")], pred_config, args, tokenizer, pad_token_label_id).tensors
177
+ batch = tuple(t.to(device) for t in batch)
178
+ inputs = {"input_ids": batch[0],
179
+ "attention_mask": batch[1],
180
+ "intent_label_ids": None,
181
+ "slot_labels_ids": None}
182
+ inputs["token_type_ids"] = batch[2]
183
+ outputs = pred_model(**inputs)
184
+ _, (intent_logits, slot_logits) = outputs[:2]
185
+ intent_preds = intent_logits.detach().cpu().numpy()
186
+ slot_preds = slot_logits.detach().cpu().numpy()
187
+ intent_preds = np.argmax(intent_preds, axis=1)
188
+ slot_preds = np.argmax(slot_preds, axis=2)
189
+ all_slot_label_mask = batch[3].detach().cpu().numpy()
190
+ slot_label_map = {i: label for i, label in enumerate(slot_label_lst)}
191
+ slot_preds_list = [[] for _ in range(slot_preds.shape[0])]
192
+ for i in range(slot_preds.shape[0]):
193
+ for j in range(slot_preds.shape[1]):
194
+ if all_slot_label_mask[i, j] != pad_token_label_id:
195
+ slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
196
+ pred_l = []
197
+ for words, slot_preds, intent_pred in zip([text.split(" ")], slot_preds_list, intent_preds):
198
+ line = ""
199
+ for word, pred in zip(words, slot_preds):
200
+ if pred == 'O':
201
+ line = line + word + " "
202
+ else:
203
+ line = line + "[{}:{}] ".format(word, pred)
204
+ pred_l.append((line, intent_label_lst[intent_pred]))
205
+ return pred_l[0]
206
+
207
+
208
+ ###@@ conn_kws = ["且", "或", "或者", "和"]
209
+ '''
210
+ and_kws = ("且", "而且", "并且", "和", "当中", "同时")
211
+ or_kws = ("或", "或者",)
212
+ conn_kws = and_kws + or_kws
213
+ '''
214
+ #conn_kws = ("且", "或", "或者", "和") + ("而且", "并且", "当中")
215
+ #### some algorithm use in it.
216
+ def recurrent_extract(question):
217
+ def filter_relation(text):
218
+ #kws = ["且", "或", "或者", "和"]
219
+ kws = conn_kws
220
+ req = text
221
+ for kw in sorted(kws, key= lambda x: len(x))[::-1]:
222
+ req = req.replace(kw, "")
223
+ return req
224
+ def produce_plain_text(text):
225
+ ##### replace tag string from text
226
+ kws = ["[", "]", " ", ":B-HEADER", ":I-HEADER", ":B-VALUE", ":I-VALUE"]
227
+ plain_text = text
228
+ for kw in kws:
229
+ plain_text = plain_text.replace(kw, "")
230
+ return plain_text
231
+ def find_min_commmon_strings(c):
232
+ ##### {"jack", "ja", "ss", "sss", "ps", ""} -> {"ja", "ss", "ps"}
233
+ common_strings = list(filter(lambda x: type(x) == type("") ,
234
+ map(lambda t2: t2[0]
235
+ if t2[0] in t2[1]
236
+ else (t2[1]
237
+ if t2[1] in t2[0]
238
+ else (t2[0], t2[1])),combinations(c, 2))))
239
+ req = set([])
240
+ while c:
241
+ ele = c.pop()
242
+ if all(map(lambda cc: cc not in ele, common_strings)):
243
+ req.add(ele)
244
+ req = req.union(set(common_strings))
245
+ return set(filter(lambda x: x, req))
246
+ def extract_scope(scope_text):
247
+ def find_max_in(plain_text ,b_chars, i_chars):
248
+ chars = "".join(b_chars + i_chars)
249
+ while chars and chars not in plain_text:
250
+ chars = chars[:-1]
251
+ return chars
252
+ b_header_chars = re.findall(r"([\w\W]):B-HEADER", scope_text)
253
+ i_header_chars = re.findall(r"([\w\W]):I-HEADER", scope_text)
254
+ b_value_chars = re.findall(r"([\w\W]):B-VALUE", scope_text)
255
+ i_value_chars = re.findall(r"([\w\W]):I-VALUE", scope_text)
256
+ if len(b_header_chars) != 1 or len(b_value_chars) != 1:
257
+ return None
258
+ plain_text = produce_plain_text(scope_text)
259
+ header = find_max_in(plain_text, b_header_chars, i_header_chars)
260
+ value = find_max_in(plain_text, b_value_chars, i_value_chars)
261
+ if (not header) or (not value):
262
+ return None
263
+ return (header, value)
264
+ def find_scope(text):
265
+ start_index = text.find("[")
266
+ end_index = text.rfind("]")
267
+ if start_index == -1 or end_index == -1:
268
+ return text
269
+ scope_text = text[start_index: end_index + 1]
270
+ res_text = filter_relation(text.replace(scope_text, "")).replace(" ", "").strip()
271
+ return (scope_text, res_text)
272
+ def produce_all_attribute_remove(req):
273
+ if not req:
274
+ return None
275
+ string_or_t2 = find_scope(req[-1][0])
276
+ assert type(string_or_t2) in [type(""), type((1,))]
277
+ if type(string_or_t2) == type(""):
278
+ return string_or_t2
279
+ else:
280
+ return string_or_t2[-1]
281
+ def extract_all_attribute(req):
282
+ extract_list = list(map(lambda t2: (t2[0][0], t2[1], t2[0][1]) ,
283
+ filter(lambda x: x[0] ,
284
+ map(lambda tt2_t2: (extract_scope(tt2_t2[0][0]), tt2_t2[1]) ,
285
+ filter(lambda t2_t2: "HEADER" in t2_t2[0][0] and "VALUE" in t2_t2[0][0] ,
286
+ filter(lambda string_or_t2_t2: type(string_or_t2_t2[0]) == type((1,)),
287
+ map(lambda tttt2: (find_scope(tttt2[0]), tttt2[1]),
288
+ req)))))))
289
+ return extract_list
290
+ def extract_attributes_relation_string(plain_text, all_attributes, res):
291
+ if not all_attributes:
292
+ return plain_text.replace(res if res else "", "")
293
+ def replace_by_one_l_r(text ,t3):
294
+ l, _, r = t3
295
+ ##### produce multi l, r to satisfy string contrain problem
296
+ l0, l1 = l, l
297
+ r0, r1 = r, r
298
+ while l0 and l0 not in text:
299
+ l0 = l0[:-1]
300
+ while l1 and l1 not in text:
301
+ l1 = l1[1:]
302
+ while r0 and r0 not in text:
303
+ r0 = r0[:-1]
304
+ while r1 and r1 not in text:
305
+ r1 = r1[1:]
306
+ if not l or not r:
307
+ return text
308
+
309
+ conclusion = set([])
310
+ for l_, r_ in product([l0, l1], [r0, r1]):
311
+ l_r_conclusion = re.findall("({}.*?{})".format(l_, r_), text)
312
+ r_l_conclusion = re.findall("({}.*?{})".format(r_, l_), text)
313
+ conclusion = conclusion.union(set(l_r_conclusion + r_l_conclusion))
314
+
315
+ ##### because use produce multi must choose the shortest elements from them
316
+ ## to prevent "relation word" also be replaced.
317
+ conclusion_filtered = find_min_commmon_strings(conclusion)
318
+
319
+ conclusion = conclusion_filtered
320
+ req_text = text
321
+ for c in conclusion:
322
+ req_text = req_text.replace(c, "")
323
+ return req_text
324
+ req_text_ = plain_text
325
+ for t3 in all_attributes:
326
+ req_text_ = replace_by_one_l_r(req_text_, t3)
327
+ return req_text_.replace(res, "")
328
+ req = []
329
+ t2 = predict_single_sent(question)
330
+ req.append(t2)
331
+ while "[" in t2[0]:
332
+ scope = find_scope(t2[0])
333
+ if type(scope) == type(""):
334
+ break
335
+ else:
336
+ assert type(scope) == type((1,))
337
+ scope_text, res_text = scope
338
+ #ic(req)
339
+ t2 = predict_single_sent(res_text)
340
+ req.append(t2)
341
+ req = list(filter(lambda tt2: "HEADER" in tt2[0] and "VALUE" in tt2[0] , req))
342
+ res = produce_all_attribute_remove(req)
343
+ #ic(req)
344
+ all_attributes = extract_all_attribute(req)
345
+ # plain_text = produce_plain_text(scope_text)
346
+
347
+ return all_attributes, res, extract_attributes_relation_string(produce_plain_text(req[0][0] if req else ""), all_attributes, res)
348
+
349
+
350
+ def rec_more_time(decomp):
351
+ assert type(decomp) == type((1,)) and len(decomp) == 3
352
+ assert not decomp[0]
353
+ res, relation_string = decomp[1:]
354
+ new_decomp = recurrent_extract(relation_string)
355
+ #### stop if rec not help by new_decomp[1] != decomp[1]
356
+ if not new_decomp[0] and new_decomp[1] != decomp[1]:
357
+ return rec_more_time(new_decomp)
358
+ return (new_decomp[0], res, new_decomp[1])
359
+ ### one sent conds decomp end
360
+
361
+
362
+ ##### data source start
363
+ #train_path = "../TableQA/TableQA/train"
364
+ #train_path = "/Users/svjack/temp/gradio_prj/tableQA-Chinese-main/data/TableQA-master/train"
365
+ train_path = os.path.join(main_path, "data/TableQA-master/train")
366
+ def data_loader(table_json_path = os.path.join(train_path ,"train.tables.json"),
367
+ json_path = os.path.join(train_path ,"train.json"),
368
+ req_table_num = 1):
369
+ assert os.path.exists(table_json_path)
370
+ assert os.path.exists(json_path)
371
+ json_df = pd.read_json(json_path, lines = True)
372
+ all_tables = pd.read_json(table_json_path, lines = True)
373
+ if req_table_num is not None:
374
+ assert type(req_table_num) == type(0) and req_table_num > 0 and req_table_num <= all_tables.shape[0]
375
+ else:
376
+ req_table_num = all_tables.shape[0]
377
+ for i in range(req_table_num):
378
+ #one_table = all_tables.iloc[i]["table"]
379
+ #one_table_df = pd.read_sql("select * from `{}`".format(one_table), train_tables_dump_engine)
380
+ one_table_s = all_tables.iloc[i]
381
+ one_table_df = pd.DataFrame(one_table_s["rows"], columns = one_table_s["header"])
382
+ yield one_table_df, json_df[json_df["table_id"] == one_table_s["id"]]
383
+ ## data source end
384
+
385
+
386
+ ###### string toolkit start
387
+ def findMaxSubString(str1, str2):
388
+ """
389
+ """
390
+ maxSub = 0
391
+ maxSubString = ""
392
+
393
+ str1_len = len(str1)
394
+ str2_len = len(str2)
395
+
396
+ for i in range(str1_len):
397
+ str1_pos = i
398
+ for j in range(str2_len):
399
+ str2_pos = j
400
+ str1_pos = i
401
+ if str1[str1_pos] != str2[str2_pos]:
402
+ continue
403
+ else:
404
+ while (str1_pos < str1_len) and (str2_pos < str2_len):
405
+ if str1[str1_pos] == str2[str2_pos]:
406
+ str1_pos = str1_pos + 1
407
+ str2_pos = str2_pos + 1
408
+ else:
409
+ break
410
+
411
+ sub_len = str2_pos - j
412
+ if maxSub < sub_len:
413
+ maxSub = sub_len
414
+ maxSubString = str2[j:str2_pos]
415
+ return maxSubString
416
+
417
+
418
+ def find_min_commmon_strings(c):
419
+ ##### {"jack", "ja", "ss", "sss", "ps", ""} -> {"ja", "ss", "ps"}
420
+ common_strings = list(filter(lambda x: type(x) == type("") ,
421
+ map(lambda t2: t2[0]
422
+ if t2[0] in t2[1]
423
+ else (t2[1]
424
+ if t2[1] in t2[0]
425
+ else (t2[0], t2[1])),combinations(c, 2))))
426
+ req = set([])
427
+ while c:
428
+ ele = c.pop()
429
+ if all(map(lambda cc: cc not in ele, common_strings)):
430
+ req.add(ele)
431
+ req = req.union(set(common_strings))
432
+ return set(filter(lambda x: x, req))
433
+ ## string toolkit end
434
+
435
+
436
+
437
+ ###### datetime column match start
438
+ #### only use object dtype to extract
439
+ def time_template_extractor(rows_filtered, pattern = u"[年月\.\-\d]+"):
440
+ #re_words = re.compile(u"[年月\.\-\d]+")
441
+ re_words = re.compile(pattern)
442
+ nest_collection = pd.DataFrame(rows_filtered).applymap(lambda x: tuple(sorted(list(re.findall(re_words, x))))).values.tolist()
443
+ def flatten_collection(c):
444
+ if not c:
445
+ return c
446
+ if type(c[0]) == type(""):
447
+ return c
448
+ else:
449
+ c = list(c)
450
+ return flatten_collection(reduce(lambda a, b: a + b, map(list ,c)))
451
+ return flatten_collection(nest_collection)
452
+
453
+ ###@@ pattern_list
454
+ #pattern_list = [u"[年月\.\-\d]+", u"[年月\d]+", u"[年个月\d]+", u"[年月日\d]+"]
455
+
456
+ def justify_column_as_datetime(df, threshold = 0.8, time_template_extractor = lambda x: x):
457
+ object_columns = list(map(lambda tt2: tt2[0] ,filter(lambda t2: t2[1].name == "object" ,dict(df.dtypes).items())))
458
+ time_columns = []
459
+ for col in object_columns:
460
+ input_ = df[[col]].applymap(lambda x: "~" if type(x) != type("") else x)
461
+ output_ = time_template_extractor(input_.values.tolist())
462
+ input_ = input_.iloc[:, 0].values.tolist()
463
+ time_evidence_cnt = sum(map(lambda t2: t2[0].strip() == t2[1].strip() and t2[0] and t2[1] and t2[0] != "~" and t2[1] != "~",zip(input_, output_)))
464
+ if time_evidence_cnt > 0 and time_evidence_cnt / df.shape[0] >= threshold:
465
+ #### use evidence ratio because may have some noise in data
466
+ time_columns.append(col)
467
+ return time_columns
468
+
469
+ def justify_column_as_datetime_reduce(df, threshold = 0.8, time_template_extractor_list = list(map(lambda p: partial(time_template_extractor, pattern = p), pattern_list))):
470
+ return sorted(reduce(lambda a, b: a.union(b) ,map(lambda func: set(justify_column_as_datetime(df, threshold, func)), time_template_extractor_list)))
471
+ ## datetime column match end
472
+
473
+ ##### choose question column have a reduce function call below (choose_res_by_kws)
474
+ ##### this is a tiny first version
475
+ ###@@ time_kws = ("什么时候", "时间", "时候")
476
+ #time_kws = ("什么时候", "时间", "时候")
477
+ #####
478
+ def choose_question_column(decomp, header, df):
479
+ assert type(decomp) == type((1,)) and type(header) == type([])
480
+
481
+ time_columns = justify_column_as_datetime_reduce(df)
482
+ _, res, _ = decomp
483
+
484
+ if type(res) != type(""):
485
+ return None
486
+
487
+ #ic(res)
488
+ ##### should add time kws to it.
489
+ #time_kws = ("什么时候", "时间", "时候")
490
+ if any(map(lambda t_kw: t_kw in res, time_kws)):
491
+ if len(time_columns) == 1:
492
+ return time_columns[0]
493
+ else:
494
+ '''
495
+ return sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns),
496
+ key = lambda t2: t2[1])[::-1][0][0]
497
+ '''
498
+ sort_list = sorted(map(lambda t_col: (t_col ,len(findMaxSubString(t_col, res)) / len(t_col)), time_columns),
499
+ key = lambda t2: t2[1])[::-1]
500
+ if sort_list:
501
+ if sort_list[0]:
502
+ return sort_list[0][0]
503
+ return None
504
+
505
+ c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header)))
506
+ common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items()))
507
+ common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items()))
508
+ #ic(decomp)
509
+ #ic(common_ratio_c_dict)
510
+ #ic(common_ratio_res_dict)
511
+
512
+ if not common_ratio_c_dict or not common_ratio_res_dict:
513
+ return None
514
+
515
+ dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0]
516
+ dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0]
517
+ return dict_0_max_key if dict_0_max_key == dict_1_max_key else None
518
+
519
+
520
+ ##### agg-classifier start
521
+ '''
522
+ sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",)
523
+ mean_kws = ('平均数', '均值', '平均值', '平均')
524
+ max_kws = ('最大', '最多', '最大值', '最高')
525
+ min_kws = ('最少', '最小值', '最小', '最低')
526
+ sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",)
527
+ max_special_kws = ("以上", "大于")
528
+ min_special_kws = ("以下", "小于")
529
+ '''
530
+
531
+ ###@@ sum_count_high_kws = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",)
532
+ ###@@ mean_kws = ('平均数', '均值', '平均值', '平均')
533
+ ###@@ max_kws = ('最大', '最多', '最大值', '最高')
534
+ ###@@ min_kws = ('最少', '最小值', '最小', '最低')
535
+ ###@@ sum_count_low_kws = ('个', '总共') + ('总和','加','总','一共','和',) + ("哪些", "查", "数量", "数") + ("几",) + ('多少', "多大") + ("总数",)
536
+ ###@@ max_special_kws = ("以上", "大于")
537
+ ###@@ min_special_kws = ("以下", "小于")
538
+
539
+ def simple_label_func(s, drop_header = True):
540
+ text_tokens =s.question_cut
541
+ header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(",")))
542
+
543
+ #### not contain samples may not match in fuzzy-match, special column mapping in finance,
544
+ ### or "3" to "三"
545
+ '''
546
+ fit_collection = ('多少个', '有几个', '总共') + ('总和','一共',) + ('平均数', '均值', '平均值', '平均') + ('最大', '最多', '最大值', '最高') + ('最少', '最小值', '最小', '最低')
547
+
548
+ '''
549
+ fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws
550
+ fit_header = []
551
+ for c in header:
552
+ for kw in fit_collection:
553
+ if kw in c:
554
+ start_idx = c.find(kw)
555
+ end_idx = start_idx + len(kw)
556
+ fit_header.append(c[start_idx: end_idx])
557
+
558
+ if not drop_header:
559
+ header = []
560
+ fit_header = []
561
+
562
+ input_ = "".join(text_tokens)
563
+ for c in header + fit_header:
564
+ if c in fit_collection:
565
+ continue
566
+ input_ = input_.replace(c, "")
567
+ c0, c1 = c, c
568
+ while c0 and c0 not in fit_collection and len(c0) >= 4:
569
+ c0 = c0[1:]
570
+ if c0 in fit_collection:
571
+ break
572
+ input_ = input_.replace(c0, "")
573
+ while c1 and c1 not in fit_collection and len(c1) >= 4:
574
+ c1 = c1[:-1]
575
+ if c1 in fit_collection:
576
+ break
577
+ input_ = input_.replace(c1, "")
578
+
579
+ #ic(input_)
580
+ text_tokens = list(jieba.cut(input_))
581
+
582
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("哪些", "查", "数量")
583
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',)
584
+ ##### 高置信度部分 (作为是否构成使用特殊规则的判断标准)
585
+ #### case 2 部分 (高置信度有效匹配)
586
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',)
587
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("总数",)
588
+ cat_6_collection_high_level = sum_count_high_kws
589
+ if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)):
590
+ return 6
591
+
592
+ #### 够深 够宽 规则部分, change order by header, if header have kws in , lower order
593
+ if any(map(lambda kw: kw in text_tokens, mean_kws)):
594
+ return 1
595
+ if any(map(lambda kw: kw in text_tokens, max_kws)):
596
+ return 2
597
+ if any(map(lambda kw: kw in text_tokens, min_kws)):
598
+ return 3
599
+
600
+ ##### 低置信度部分
601
+ #### case 2 部分 (低置信度尾部匹配)
602
+ cat_6_collection = sum_count_low_kws
603
+ if any(map(lambda kw: kw in text_tokens, cat_6_collection)):
604
+ return 6
605
+ if any(map(lambda token: "几" in token, text_tokens)):
606
+ return 6
607
+
608
+ #### special case 部分
609
+ if any(map(lambda kw: kw in text_tokens, max_special_kws)):
610
+ return 2
611
+ if any(map(lambda kw: kw in text_tokens, min_special_kws)):
612
+ return 3
613
+
614
+ #### 无效匹配
615
+ return 0
616
+
617
+
618
+ def simple_special_func(s, drop_header = True):
619
+ text_tokens =s.question_cut
620
+ header = list(map(lambda x: x[:x.find("(")] if (not x.startswith("(") and x.endswith(")")) else x ,s.header.split(",")))
621
+
622
+ #### not contain samples may not match in fuzzy-match, special column mapping in finance,
623
+ ### or "3" to "三"
624
+ fit_collection = sum_count_high_kws + mean_kws + max_kws + min_kws
625
+ fit_header = []
626
+ for c in header:
627
+ for kw in fit_collection:
628
+ if kw in c:
629
+ start_idx = c.find(kw)
630
+ end_idx = start_idx + len(kw)
631
+ fit_header.append(c[start_idx: end_idx])
632
+
633
+ input_ = "".join(text_tokens)
634
+ if not drop_header:
635
+ header = []
636
+ fit_header = []
637
+
638
+ for c in header + fit_header:
639
+ if c in fit_collection:
640
+ continue
641
+ input_ = input_.replace(c, "")
642
+ c0, c1 = c, c
643
+ while c0 and c0 not in fit_collection and len(c0) >= 4:
644
+ c0 = c0[1:]
645
+ if c0 in fit_collection:
646
+ break
647
+ input_ = input_.replace(c0, "")
648
+ while c1 and c1 not in fit_collection and len(c1) >= 4:
649
+ c1 = c1[:-1]
650
+ if c1 in fit_collection:
651
+ break
652
+ input_ = input_.replace(c1, "")
653
+
654
+ #ic(input_)
655
+ text_tokens = list(jieba.cut(input_))
656
+ #ic(text_tokens)
657
+
658
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',) + ("哪些", "查", "数量")
659
+ #cat_6_collection_high_level = ('多少个', '有几个', '总共') + ('总和','一共',)
660
+ #### case 2 部分 (高置信度有效匹配)
661
+ cat_6_collection_high_level = sum_count_high_kws
662
+ if any(map(lambda high_level_token: high_level_token in "".join(text_tokens), cat_6_collection_high_level)):
663
+ return 6
664
+
665
+ #### 够深 够宽 规则部分, change order by header, if header have kws in , lower order
666
+ if any(map(lambda kw: kw in text_tokens, mean_kws)):
667
+ return 1
668
+ if any(map(lambda kw: kw in text_tokens, max_kws)):
669
+ return 2
670
+ if any(map(lambda kw: kw in text_tokens, min_kws)):
671
+ return 3
672
+
673
+ return 0
674
+
675
+
676
+ def simple_total_label_func(s):
677
+ is_special = False if simple_special_func(s) == 0 else True
678
+ if not is_special:
679
+ return 0
680
+ return simple_label_func(s)
681
+ ## agg-classifier end
682
+
683
+
684
+ ##### main block of process start
685
+ def split_by_cond(s, extract_return = True):
686
+ def recurrent_extract_cond(text):
687
+ #return np.asarray(recurrent_extract(text)[0])
688
+ #return recurrent_extract(text)[0]
689
+ return np.asarray(list(recurrent_extract(text)[0]))
690
+
691
+ question = s["question"]
692
+ res = s["rec_decomp"][1]
693
+ if question is None:
694
+ question = ""
695
+ if res is None:
696
+ res = ""
697
+
698
+ common_res = findMaxSubString(question, res)
699
+ #cond_kws = ("或", "而且", "并且", "当中")
700
+ #cond_kws = ("或", "而且" "并且" "当中")
701
+ cond_kws = conn_kws
702
+ condition_part = question.replace(common_res, "")
703
+ fit_kws = set([])
704
+ for kw in cond_kws:
705
+ if kw in condition_part and not condition_part.startswith(kw) and not condition_part.endswith(kw):
706
+ fit_kws.add(kw)
707
+ if not fit_kws:
708
+ will_return = ([condition_part.replace(" ", "") + " " + common_res], "")
709
+ if extract_return:
710
+ #return (list(map(recurrent_extract_cond, will_return[0])), will_return[1])
711
+ will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))) , will_return[1])
712
+ #will_return = (np.concatenate(list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist()))), axis = 0), will_return[1])
713
+ #will_return = (np.concatenate(list(map(np.asarray ,will_return[0].tolist())), axis = 0), will_return[1])
714
+ input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist())))
715
+ if input_:
716
+ will_return = (np.concatenate(input_, axis = 0), will_return[1])
717
+ else:
718
+ will_return = (np.empty((0, 3)), will_return[1])
719
+
720
+ will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3))
721
+ return will_return
722
+
723
+ fit_kw = sorted(fit_kws, key = len)[::-1][0]
724
+ condition_part_splits = condition_part.split(fit_kw)
725
+ #if fit_kw in ("或",):
726
+ if fit_kw in or_kws:
727
+ fit_kw = "or"
728
+ #elif fit_kw in ("而且", "并且", "当中",):
729
+ elif fit_kw in and_kws:
730
+ fit_kw = "and"
731
+ else:
732
+ fit_kw = ""
733
+
734
+ will_return = (list(map(lambda cond_: cond_.replace(" ", "") + " " + common_res, condition_part_splits)), fit_kw)
735
+ if extract_return:
736
+ #return (list(map(recurrent_extract_cond, will_return[0])), will_return[1])
737
+ will_return = (np.asarray(list(map(recurrent_extract_cond, will_return[0]))), will_return[1])
738
+ #ic(will_return[0])
739
+ #will_return = (np.concatenate(list(map(np.asarray ,will_return[0].tolist())), axis = 0), will_return[1])
740
+ input_ = list(filter(lambda x: x.size ,map(np.asarray ,will_return[0].tolist())))
741
+ if input_:
742
+ will_return = (np.concatenate(input_, axis = 0), will_return[1])
743
+ else:
744
+ will_return = (np.empty((0, 3)), will_return[1])
745
+ #ic(will_return[0])
746
+ will_return = will_return[0].reshape((-1, 3)) if will_return[0].size else np.empty((0, 3))
747
+
748
+ return will_return
749
+
750
+
751
+
752
+ def filter_total_conds(s, df, condition_fit_num = 0):
753
+ assert condition_fit_num >= 0 and type(condition_fit_num) == type(0)
754
+ df = df.copy()
755
+
756
+ #### some col not as float with only "None" as cell, also transform them into float
757
+ df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x)
758
+ def justify_column_as_float(s):
759
+ if "float" in str(s.dtype):
760
+ return True
761
+ if all(s.map(type).map(lambda tx: "float" in str(tx))):
762
+ return True
763
+ return False
764
+
765
+ float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items())))
766
+ for f_col in float_cols:
767
+ df[f_col] = df[f_col].astype(np.float64)
768
+ ###
769
+
770
+ header = df.columns.tolist()
771
+ units_cols = filter(lambda c: "(" in c and c.endswith(")"), df.columns.tolist())
772
+ if not float_cols:
773
+ float_discribe_df = pd.DataFrame()
774
+ else:
775
+ float_discribe_df = df[float_cols].describe()
776
+
777
+ def call_eval(val):
778
+ try:
779
+ return literal_eval(val)
780
+ except:
781
+ return val
782
+
783
+ #### find condition column same as question_column
784
+ def find_cond_col(res, header):
785
+ #ic(res, header)
786
+ c_res_common_dict = dict(filter(lambda t2: t2[1] ,map(lambda c: (c ,findMaxSubString(c, res)), header)))
787
+ #ic(c_res_common_dict)
788
+ common_ratio_c_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(t2[0])), c_res_common_dict.items()))
789
+ common_ratio_res_dict = dict(map(lambda t2: (t2[0], len(t2[1]) / len(res)), c_res_common_dict.items()))
790
+
791
+ if not common_ratio_c_dict or not common_ratio_res_dict:
792
+ return None
793
+
794
+ dict_0_max_key = sorted(common_ratio_c_dict.items(), key = lambda t2: t2[1])[::-1][0][0]
795
+ dict_1_max_key = sorted(common_ratio_res_dict.items(), key = lambda t2: t2[1])[::-1][0][0]
796
+ return dict_0_max_key if dict_0_max_key == dict_1_max_key else None
797
+ ###
798
+
799
+ #### type comptatible in column type and value type, and fit_num match
800
+ def filter_cond_col(cond_t3):
801
+ assert type(cond_t3) == type((1,)) and len(cond_t3) == 3
802
+ col, _, value = cond_t3
803
+
804
+ if type(value) == type(""):
805
+ value = call_eval(value)
806
+
807
+ if col not in df.columns.tolist():
808
+ return False
809
+
810
+ #### type key value comp
811
+ if col in float_cols and type(value) not in (type(0), type(0.0)):
812
+ return False
813
+
814
+ if col not in float_cols and type(value) in (type(0), type(0.0)):
815
+ return False
816
+
817
+ #### string value not in corr column
818
+ if col not in float_cols and type(value) not in (type(0), type(0.0)):
819
+ if type(value) == type(""):
820
+ if value not in df[col].tolist():
821
+ return False
822
+
823
+ if type(value) in (type(0), type(0.0)):
824
+ if col in float_discribe_df.columns.tolist():
825
+ if condition_fit_num > 0:
826
+ if value >= float_discribe_df[col].loc["min"] and value <= float_discribe_df[col].loc["max"]:
827
+ return True
828
+ else:
829
+ return False
830
+ else:
831
+ assert condition_fit_num == 0
832
+ return True
833
+
834
+ if condition_fit_num > 0:
835
+ if value in df[col].tolist():
836
+ return True
837
+ else:
838
+ return False
839
+ else:
840
+ assert condition_fit_num == 0
841
+ return True
842
+
843
+ return True
844
+ ###
845
+
846
+ #### condtions with same column may have conflict, choose the nearest one by stats in float or
847
+ ### common_string as find_cond_col do.
848
+ def same_column_cond_filter(cond_list, sort_stats = "mean"):
849
+ #ic(cond_list)
850
+ if len(cond_list) == len(set(map(lambda t3: t3[0] ,cond_list))):
851
+ return cond_list
852
+
853
+ req = defaultdict(list)
854
+ for t3 in cond_list:
855
+ req[t3[0]].append(t3[1:])
856
+
857
+ def t2_list_sort(col_name ,t2_list):
858
+ if not t2_list:
859
+ return None
860
+ t2 = None
861
+ if col_name in float_cols:
862
+ t2 = sorted(t2_list, key = lambda t2: np.abs(t2[1] - float_discribe_df[col_name].loc[sort_stats]))[0]
863
+ else:
864
+ if all(map(lambda t2: type(t2[1]) == type("") ,t2_list)):
865
+ col_val_cnt_df = df[col_name].value_counts().reset_index()
866
+ col_val_cnt_df.columns = ["val", "cnt"]
867
+ #col_val_cnt_df["val"].map(lambda x: sorted(filter(lambda tt2: tt2[-1] ,map(lambda t2: (t2 ,len(findMaxSubString(x, t2[1]))), t2_list)),
868
+ # key = lambda ttt2: -1 * ttt2[-1])[0])
869
+
870
+ t2_list_map_to_column_val = list(filter(lambda x: x[1] ,map(lambda t2: (t2[0] ,find_cond_col(t2[1], list(set(col_val_cnt_df["val"].values.tolist())))), t2_list)))
871
+ if t2_list_map_to_column_val:
872
+ #### return max length fit val in column
873
+ t2 = sorted(t2_list_map_to_column_val, key = lambda t2: -1 * len(t2[1]))[0]
874
+ if t2 is None and t2_list:
875
+ t2 = t2_list[0]
876
+ return t2
877
+
878
+ cond_list_filtered = list(map(lambda ttt2: (ttt2[0], ttt2[1][0], ttt2[1][1]) ,
879
+ filter(lambda tt2: tt2[1] ,map(lambda t2: (t2[0] ,t2_list_sort(t2[0], t2[1])), req.items()))))
880
+
881
+ return cond_list_filtered
882
+ ###
883
+
884
+ total_conds_map_to_column = list(map(lambda t3: (find_cond_col(t3[0], header), t3[1], t3[2]), s["total_conds"]))
885
+ total_conds_map_to_column_filtered = list(filter(filter_cond_col, total_conds_map_to_column))
886
+ total_conds_map_to_column_filtered = sorted(set(map(lambda t3: (t3[0], t3[1], call_eval(t3[2]) if type(t3[2]) == type("") else t3[2]), total_conds_map_to_column_filtered)))
887
+ #ic(total_conds_map_to_column_filtered)
888
+
889
+ cp_cond_list = list(filter(lambda t3: t3[1] in (">", "<"), total_conds_map_to_column_filtered))
890
+ eq_cond_list = list(filter(lambda t3: t3[1] in ("==", "!="), total_conds_map_to_column_filtered))
891
+
892
+ cp_cond_list_filtered = same_column_cond_filter(cp_cond_list)
893
+
894
+ #total_conds_map_to_column_filtered = same_column_cond_filter(total_conds_map_to_column_filtered)
895
+ return cp_cond_list_filtered + eq_cond_list
896
+ #return total_conds_map_to_column_filtered
897
+
898
+ ###@@ only_kws_columns = {"城市": "=="}
899
+
900
+ #### this function only use to cond can not extract by JointBert,
901
+ ### may because not contain column string in question such as "城市" or difficult to extract kw
902
+ ### define kw column as all cells in series are string type.
903
+ ### this function support config relation to column and if future
904
+ ### want to auto extract relation, this may can be done by head string or tail string by edit pattern "\w?{}\w?"
905
+ ### "是" or "不是" can be extract in this manner.
906
+ def augment_kw_in_question(question_df, df, only_kws_in_string = []):
907
+ #### keep only_kws_in_string empty to maintain all condition
908
+ question_df = question_df.copy()
909
+ #df = df.copy()
910
+
911
+ def call_eval(val):
912
+ try:
913
+ return literal_eval(val)
914
+ except:
915
+ return val
916
+
917
+ df = df.copy()
918
+
919
+ df = df.applymap(call_eval)
920
+
921
+ #### some col not as float with only "None" as cell, also transform them into float
922
+ df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x)
923
+ def justify_column_as_float(s):
924
+ if "float" in str(s.dtype):
925
+ return True
926
+ if all(s.map(type).map(lambda tx: "float" in str(tx))):
927
+ return True
928
+ return False
929
+
930
+ float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items())))
931
+ #obj_cols = set(df.columns.tolist()).difference(set(float_cols))
932
+
933
+ def justify_column_as_kw(s):
934
+ if all(s.map(type).map(lambda tx: "str" in str(tx))):
935
+ return True
936
+ return False
937
+
938
+ obj_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_kw, axis = 0).to_dict().items())))
939
+ obj_cols = list(set(obj_cols).difference(set(float_cols)))
940
+ if only_kws_columns:
941
+ obj_cols = list(set(obj_cols).intersection(set(only_kws_columns.keys())))
942
+
943
+ #replace_format = "{}是{}"
944
+ #kw_augmented_df = pd.DataFrame(df[obj_cols].apply(lambda s: list(map(lambda val :(val,replace_format.format(s.name, val)), s.tolist())), axis = 0).values.tolist())
945
+ #kw_augmented_df.columns = obj_cols
946
+ kw_augmented_df = df[obj_cols].copy()
947
+ #ic(kw_augmented_df)
948
+
949
+ def extract_question_kws(question):
950
+ if not kw_augmented_df.size:
951
+ return []
952
+ req = defaultdict(set)
953
+ for ridx, r in kw_augmented_df.iterrows():
954
+ for k, v in dict(r).items():
955
+ if v in question:
956
+ pattern = "\w?{}\w?".format(v)
957
+ all_match = re.findall(pattern, question)
958
+ #req = req.union(set(all_match))
959
+ #req[v] = req[v].union(set(all_match))
960
+ key = "{}~{}".format(k, v)
961
+ req[key] = req[key].union(set(all_match))
962
+ #ic(k, v)
963
+ #question = question.replace(v[0], v[1])
964
+ #return question.replace(replace_format.format("","") * 2, replace_format.format("",""))
965
+ #req = list(req)
966
+ if only_kws_in_string:
967
+ req = list(map(lambda tt2: tt2[0] ,filter(lambda t2: sum(map(lambda kw: sum(map(lambda t: kw in t ,t2[1])), only_kws_in_string)), req.items())))
968
+ else:
969
+ req = list(set(req.keys()))
970
+
971
+ def req_to_t3(req_string, relation = "=="):
972
+ assert "~" in req_string
973
+ left, right = req_string.split("~")
974
+ if left in only_kws_columns:
975
+ relation = only_kws_columns[left]
976
+ return (left, relation, right)
977
+
978
+ if not req:
979
+ return []
980
+
981
+ return list(map(req_to_t3, req))
982
+
983
+ #return req
984
+
985
+ question_df["question_kw_conds"] = question_df["question"].map(extract_question_kws)
986
+ return question_df
987
+
988
+
989
+ def choose_question_column_by_rm_conds(s, df):
990
+ question = s.question
991
+ total_conds_filtered = s.total_conds_filtered
992
+ #cond_kws = ("或", "而且", "并且", "当中")
993
+ cond_kws = conn_kws
994
+ stopwords = ("是", )
995
+ #ic(total_conds_filtered)
996
+ def construct_res(question):
997
+ for k, _, v in total_conds_filtered:
998
+ if "(" in k:
999
+ k = k[:k.find("(")]
1000
+ #ic(k)
1001
+ question = question.replace(str(k), "")
1002
+ question = question.replace(str(v), "")
1003
+ for w in cond_kws + stopwords:
1004
+ question = question.replace(w, "")
1005
+ return question
1006
+
1007
+ res = construct_res(question)
1008
+ decomp = (None, res, None)
1009
+ return choose_question_column(decomp, df.columns.tolist(), df)
1010
+
1011
+
1012
+ def split_qst_by_kw(question, kw = "的"):
1013
+ return question.split(kw)
1014
+
1015
+ #qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个")
1016
+ ###@@ qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个")
1017
+ def choose_res_by_kws(question):
1018
+ #kws = ["的","多少", '是']
1019
+ question = question.replace(" ", "")
1020
+ #kws = ["的","或者","或", "且","并且","同时"]
1021
+ kws = ("的",) + conn_kws
1022
+ kws = list(kws)
1023
+ def qst_kw_filter(text):
1024
+ #qst_kws = ("多少", "什么", "多大", "哪些", "怎么", "情况", "那些", "哪个")
1025
+ if any(map(lambda kw: kw in text, qst_kws)):
1026
+ return True
1027
+ return False
1028
+
1029
+ kws_cp = deepcopy(kws)
1030
+ qst_c = set(question.split(","))
1031
+ while kws:
1032
+ kw = kws.pop()
1033
+ qst_c = qst_c.union(set(filter(qst_kw_filter ,reduce(lambda a, b: a + b,map(lambda q: split_qst_by_kw(q, kw), qst_c))))
1034
+ )
1035
+ #print("-" * 10)
1036
+ #print(sorted(filter(lambda x: x and (x not in kws_cp) ,qst_c), key = len))
1037
+ #print(sorted(filter(lambda x: x and (x not in kws_cp) and qst_kw_filter(x) ,qst_c), key = len))
1038
+ #### final choose if or not
1039
+ return sorted(filter(lambda x: x and (x not in kws_cp) and qst_kw_filter(x) ,qst_c), key = len)
1040
+ #return sorted(filter(lambda x: x and (x not in kws_cp) and True ,qst_c), key = len)
1041
+
1042
+
1043
+ def cat6_to_45_by_column_type(s, df):
1044
+ agg_pred = s.agg_pred
1045
+ if agg_pred != 6:
1046
+ return agg_pred
1047
+ question_column = s.question_column
1048
+
1049
+ def call_eval(val):
1050
+ try:
1051
+ return literal_eval(val)
1052
+ except:
1053
+ return val
1054
+
1055
+ df = df.copy()
1056
+
1057
+ df = df.applymap(call_eval)
1058
+
1059
+ #### some col not as float with only "None" as cell, also transform them into float
1060
+ df = df.applymap(lambda x: np.nan if x in ["None", None, "/"] else x)
1061
+ def justify_column_as_float(s):
1062
+ if "float" in str(s.dtype):
1063
+ return True
1064
+ if all(s.map(type).map(lambda tx: "float" in str(tx))):
1065
+ return True
1066
+ return False
1067
+
1068
+ float_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_float, axis = 0).to_dict().items())))
1069
+ #obj_cols = set(df.columns.tolist()).difference(set(float_cols))
1070
+
1071
+ def justify_column_as_kw(s):
1072
+ if all(s.map(type).map(lambda tx: "str" in str(tx))):
1073
+ return True
1074
+ return False
1075
+
1076
+ #obj_cols = list(map(lambda tt2: tt2[0],filter(lambda t2: t2[1] ,df.apply(justify_column_as_kw, axis = 0).to_dict().items())))
1077
+ obj_cols = df.columns.tolist()
1078
+ obj_cols = list(set(obj_cols).difference(set(float_cols)))
1079
+
1080
+ #ic(obj_cols, float_cols, df.columns.tolist())
1081
+ assert len(obj_cols) + len(float_cols) == df.shape[1]
1082
+
1083
+ if question_column in obj_cols:
1084
+ return 4
1085
+ elif question_column in float_cols:
1086
+ return 5
1087
+ else:
1088
+ return 0
1089
+
1090
+
1091
+ def full_before_cat_decomp(df, question_df, only_req_columns = False):
1092
+ df, question_df = df.copy(), question_df.copy()
1093
+ first_train_question_extract_df = pd.DataFrame(question_df["question"].map(lambda question: (question, recurrent_extract(question))).tolist())
1094
+ first_train_question_extract_df.columns = ["question", "decomp"]
1095
+
1096
+ first_train_question_extract_df = augment_kw_in_question(first_train_question_extract_df, df)
1097
+
1098
+ first_train_question_extract_df["rec_decomp"] = first_train_question_extract_df["decomp"].map(lambda decomp: decomp if decomp[0] else rec_more_time(decomp))
1099
+ #return first_train_question_extract_df.copy()
1100
+ first_train_question_extract_df["question_cut"] = first_train_question_extract_df["rec_decomp"].map(lambda t3: jieba.cut(t3[1]) if t3[1] is not None else []).map(list)
1101
+ first_train_question_extract_df["header"] = ",".join(df.columns.tolist())
1102
+ first_train_question_extract_df["question_column_res"] = first_train_question_extract_df["rec_decomp"].map(lambda decomp: choose_question_column(decomp, df.columns.tolist(), df))
1103
+
1104
+ #### agg
1105
+ first_train_question_extract_df["agg_res_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1)
1106
+ first_train_question_extract_df["question_cut"] = first_train_question_extract_df["question"].map(jieba.cut).map(list)
1107
+ first_train_question_extract_df["agg_qst_pred"] = first_train_question_extract_df.apply(simple_total_label_func, axis = 1)
1108
+ ### if agg_res_pred and agg_qst_pred have conflict use max, to prevent from empty agg with errorous question column
1109
+ ### but this "max" can also be replaced by measure the performance of decomp part, and choose the best one
1110
+ ### or we can use agg_qst_pred with high balanced_score as 0.86+ in imbalanced dataset.
1111
+ ### which operation to use need some discussion.
1112
+ ### (balanced_accuracy_score(lookup_df["sql"], lookup_df["agg_pred"]),
1113
+ ### balanced_accuracy_score(lookup_df["sql"], lookup_df["agg_res_pred"]),
1114
+ ### balanced_accuracy_score(lookup_df