justinsiow commited on
Commit
1e712af
1 Parent(s): 1428afe

Uploaded Utils, Pycache and Python Files

Browse files
__pycache__/schema_filter.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
eval_mode.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from schema_filter import filter_func, SchemaItemClassifierInference
2
+
3
+ # 在eval模式下,sql不用提供
4
+ data = {
5
+ "text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
6
+ "sql": "",
7
+ "schema": {
8
+ "schema_items": [
9
+ {
10
+ "table_name": "lists",
11
+ "table_comment": "",
12
+ "column_names": [
13
+ "user_id",
14
+ "list_id",
15
+ "list_title",
16
+ "list_movie_number",
17
+ "list_update_timestamp_utc",
18
+ "list_creation_timestamp_utc",
19
+ "list_followers",
20
+ "list_url",
21
+ "list_comments",
22
+ "list_description",
23
+ "list_cover_image_url",
24
+ "list_first_image_url",
25
+ "list_second_image_url",
26
+ "list_third_image_url"
27
+ ],
28
+ "column_comments": [
29
+ "",
30
+ "",
31
+ "",
32
+ "",
33
+ "",
34
+ "",
35
+ "",
36
+ "",
37
+ "",
38
+ "",
39
+ "",
40
+ "",
41
+ "",
42
+ ""
43
+ ]
44
+ },
45
+ {
46
+ "table_name": "movies",
47
+ "table_comment": "",
48
+ "column_names": [
49
+ "movie_id",
50
+ "movie_title",
51
+ "movie_release_year",
52
+ "movie_url",
53
+ "movie_title_language",
54
+ "movie_popularity",
55
+ "movie_image_url",
56
+ "director_id",
57
+ "director_name",
58
+ "director_url"
59
+ ],
60
+ "column_comments": [
61
+ "",
62
+ "",
63
+ "",
64
+ "",
65
+ "",
66
+ "",
67
+ "",
68
+ "",
69
+ "",
70
+ ""
71
+ ]
72
+ },
73
+ {
74
+ "table_name": "ratings_users",
75
+ "table_comment": "",
76
+ "column_names": [
77
+ "user_id",
78
+ "rating_date_utc",
79
+ "user_trialist",
80
+ "user_subscriber",
81
+ "user_avatar_image_url",
82
+ "user_cover_image_url",
83
+ "user_eligible_for_trial",
84
+ "user_has_payment_method"
85
+ ],
86
+ "column_comments": [
87
+ "",
88
+ "",
89
+ "",
90
+ "",
91
+ "",
92
+ "",
93
+ "",
94
+ ""
95
+ ]
96
+ },
97
+ {
98
+ "table_name": "lists_users",
99
+ "table_comment": "",
100
+ "column_names": [
101
+ "user_id",
102
+ "list_id",
103
+ "list_update_date_utc",
104
+ "list_creation_date_utc",
105
+ "user_trialist",
106
+ "user_subscriber",
107
+ "user_avatar_image_url",
108
+ "user_cover_image_url",
109
+ "user_eligible_for_trial",
110
+ "user_has_payment_method"
111
+ ],
112
+ "column_comments": [
113
+ "",
114
+ "",
115
+ "",
116
+ "",
117
+ "",
118
+ "",
119
+ "",
120
+ "",
121
+ "",
122
+ ""
123
+ ]
124
+ },
125
+ {
126
+ "table_name": "ratings",
127
+ "table_comment": "",
128
+ "column_names": [
129
+ "movie_id",
130
+ "rating_id",
131
+ "rating_url",
132
+ "rating_score",
133
+ "rating_timestamp_utc",
134
+ "critic",
135
+ "critic_likes",
136
+ "critic_comments",
137
+ "user_id",
138
+ "user_trialist",
139
+ "user_subscriber",
140
+ "user_eligible_for_trial",
141
+ "user_has_payment_method"
142
+ ],
143
+ "column_comments": [
144
+ "",
145
+ "",
146
+ "",
147
+ "",
148
+ "",
149
+ "",
150
+ "",
151
+ "",
152
+ "",
153
+ "",
154
+ "",
155
+ "",
156
+ ""
157
+ ]
158
+ }
159
+ ]
160
+ }
161
+ }
162
+
163
+ dataset = [data]
164
+
165
+ # 最多保留数据库中的7张表
166
+ num_top_k_tables = 7
167
+ # 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列
168
+ num_top_k_columns = 10
169
+
170
+ # 加载分类器模型
171
+ sic = SchemaItemClassifierInference("sic_merged")
172
+
173
+ # 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分
174
+ dataset = filter_func(
175
+ dataset = dataset,
176
+ dataset_type = "eval",
177
+ sic = sic,
178
+ num_top_k_tables = num_top_k_tables,
179
+ num_top_k_columns = num_top_k_columns
180
+ )
181
+
182
+ print(dataset)
schema_filter.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer
7
+ from utils.classifier_model import SchemaItemClassifier
8
+ from transformers.trainer_utils import set_seed
9
+
10
+ def prepare_inputs_and_labels(sample, tokenizer):
11
+ table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
12
+ column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
13
+ column_num_in_each_table = [len(table["column_names"]) for table in sample["schema"]["schema_items"]]
14
+
15
+ # `column_name_word_indices` and `table_name_word_indices` record the word indices of each column and table in `input_words`, whose element is an integer
16
+ column_name_word_indices, table_name_word_indices = [], []
17
+
18
+ input_words = [sample["text"]]
19
+ for table_id, table_name in enumerate(table_names):
20
+ input_words.append("|")
21
+ input_words.append(table_name)
22
+ table_name_word_indices.append(len(input_words) - 1)
23
+ input_words.append(":")
24
+
25
+ for column_name in column_names[table_id]:
26
+ input_words.append(column_name)
27
+ column_name_word_indices.append(len(input_words) - 1)
28
+ input_words.append(",")
29
+
30
+ # remove the last ","
31
+ input_words = input_words[:-1]
32
+
33
+ tokenized_inputs = tokenizer(
34
+ input_words,
35
+ return_tensors="pt",
36
+ is_split_into_words = True,
37
+ padding = "max_length",
38
+ max_length = 512,
39
+ truncation = True
40
+ )
41
+
42
+ # after tokenizing, one table name or column name may be splitted into multiple tokens (i.e., sub-words)
43
+ # `column_name_token_indices` and `table_name_token_indices` records the token indices of each column and table in `input_ids`, whose element is a list of integer
44
+ column_name_token_indices, table_name_token_indices = [], []
45
+ word_indices = tokenized_inputs.word_ids(batch_index = 0)
46
+
47
+ # obtain token indices of each column in `input_ids`
48
+ for column_name_word_index in column_name_word_indices:
49
+ column_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if column_name_word_index == word_index])
50
+
51
+ # obtain token indices of each table in `input_ids`
52
+ for table_name_word_index in table_name_word_indices:
53
+ table_name_token_indices.append([token_id for token_id, word_index in enumerate(word_indices) if table_name_word_index == word_index])
54
+
55
+ encoder_input_ids = tokenized_inputs["input_ids"]
56
+ encoder_input_attention_mask = tokenized_inputs["attention_mask"]
57
+
58
+ # print("\n".join(tokenizer.batch_decode(encoder_input_ids, skip_special_tokens = True)))
59
+
60
+ if torch.cuda.is_available():
61
+ encoder_input_ids = encoder_input_ids.cuda()
62
+ encoder_input_attention_mask = encoder_input_attention_mask.cuda()
63
+
64
+ return encoder_input_ids, encoder_input_attention_mask, \
65
+ column_name_token_indices, table_name_token_indices, column_num_in_each_table
66
+
67
+ def get_schema(tables_and_columns):
68
+ schema_items = []
69
+ table_names = list(dict.fromkeys([t for t, c in tables_and_columns]))
70
+ for table_name in table_names:
71
+ schema_items.append(
72
+ {
73
+ "table_name": table_name,
74
+ "column_names": [c for t, c in tables_and_columns if t == table_name]
75
+ }
76
+ )
77
+
78
+ return {"schema_items": schema_items}
79
+
80
+ def get_sequence_length(text, tables_and_columns, tokenizer):
81
+ table_names = [t for t, c in tables_and_columns]
82
+ # duplicate `table_names` while preserving order
83
+ table_names = list(dict.fromkeys(table_names))
84
+
85
+ column_names = []
86
+ for table_name in table_names:
87
+ column_names.append([c for t, c in tables_and_columns if t == table_name])
88
+
89
+ input_words = [text]
90
+ for table_id, table_name in enumerate(table_names):
91
+ input_words.append("|")
92
+ input_words.append(table_name)
93
+ input_words.append(":")
94
+ for column_name in column_names[table_id]:
95
+ input_words.append(column_name)
96
+ input_words.append(",")
97
+ # remove the last ","
98
+ input_words = input_words[:-1]
99
+
100
+ tokenized_inputs = tokenizer(input_words, is_split_into_words = True)
101
+
102
+ return len(tokenized_inputs["input_ids"])
103
+
104
+ # handle extremely long schema sequences
105
+ def split_sample(sample, tokenizer):
106
+ text = sample["text"]
107
+
108
+ table_names = []
109
+ column_names = []
110
+ for table in sample["schema"]["schema_items"]:
111
+ table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
112
+ if table["table_comment"] != "" else table["table_name"])
113
+ column_names.append([column_name + " ( " + column_comment + " ) " \
114
+ if column_comment != "" else column_name \
115
+ for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
116
+
117
+ splitted_samples = []
118
+ recorded_tables_and_columns = []
119
+
120
+ for table_idx, table_name in enumerate(table_names):
121
+ for column_name in column_names[table_idx]:
122
+ if get_sequence_length(text, recorded_tables_and_columns + [[table_name, column_name]], tokenizer) < 500:
123
+ recorded_tables_and_columns.append([table_name, column_name])
124
+ else:
125
+ splitted_samples.append(
126
+ {
127
+ "text": text,
128
+ "schema": get_schema(recorded_tables_and_columns)
129
+ }
130
+ )
131
+ recorded_tables_and_columns = [[table_name, column_name]]
132
+
133
+ splitted_samples.append(
134
+ {
135
+ "text": text,
136
+ "schema": get_schema(recorded_tables_and_columns)
137
+ }
138
+ )
139
+
140
+ return splitted_samples
141
+
142
+ def merge_pred_results(sample, pred_results):
143
+ # table_names = [table["table_name"] for table in sample["schema"]["schema_items"]]
144
+ # column_names = [table["column_names"] for table in sample["schema"]["schema_items"]]
145
+ table_names = []
146
+ column_names = []
147
+ for table in sample["schema"]["schema_items"]:
148
+ table_names.append(table["table_name"] + " ( " + table["table_comment"] + " ) " \
149
+ if table["table_comment"] != "" else table["table_name"])
150
+ column_names.append([column_name + " ( " + column_comment + " ) " \
151
+ if column_comment != "" else column_name \
152
+ for column_name, column_comment in zip(table["column_names"], table["column_comments"])])
153
+
154
+ merged_results = []
155
+ for table_id, table_name in enumerate(table_names):
156
+ table_prob = 0
157
+ column_probs = []
158
+ for result_dict in pred_results:
159
+ if table_name in result_dict:
160
+ if table_prob < result_dict[table_name]["table_prob"]:
161
+ table_prob = result_dict[table_name]["table_prob"]
162
+ column_probs += result_dict[table_name]["column_probs"]
163
+
164
+ merged_results.append(
165
+ {
166
+ "table_name": table_name,
167
+ "table_prob": table_prob,
168
+ "column_names": column_names[table_id],
169
+ "column_probs": column_probs
170
+ }
171
+ )
172
+
173
+ return merged_results
174
+
175
+ def filter_func(dataset, dataset_type, sic, num_top_k_tables = 5, num_top_k_columns = 5):
176
+ for data in tqdm(dataset, desc = "filtering schema items for the dataset"):
177
+ filtered_schema = dict()
178
+ filtered_schema["schema_items"] = []
179
+
180
+ table_names = [table["table_name"] for table in data["schema"]["schema_items"]]
181
+ table_comments = [table["table_comment"] for table in data["schema"]["schema_items"]]
182
+ column_names = [table["column_names"] for table in data["schema"]["schema_items"]]
183
+ column_comments = [table["column_comments"] for table in data["schema"]["schema_items"]]
184
+
185
+ if dataset_type == "eval":
186
+ # predict scores for each tables and columns
187
+ pred_results = sic.predict(data)
188
+ # remain top_k1 tables for each database and top_k2 columns for each remained table
189
+ table_probs = [pred_result["table_prob"] for pred_result in pred_results]
190
+ table_indices = np.argsort(-np.array(table_probs), kind="stable")[:num_top_k_tables].tolist()
191
+ elif dataset_type == "train":
192
+ table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 1]
193
+ if len(table_indices) < num_top_k_tables:
194
+ unused_table_indices = [table_idx for table_idx, table_label in enumerate(data["table_labels"]) if table_label == 0]
195
+ table_indices += random.sample(unused_table_indices, min(len(unused_table_indices), num_top_k_tables - len(table_indices)))
196
+ random.shuffle(table_indices)
197
+
198
+ for table_idx in table_indices:
199
+ if dataset_type == "eval":
200
+ column_probs = pred_results[table_idx]["column_probs"]
201
+ column_indices = np.argsort(-np.array(column_probs), kind="stable")[:num_top_k_columns].tolist()
202
+ elif dataset_type == "train":
203
+ column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 1]
204
+ if len(column_indices) < num_top_k_columns:
205
+ unused_column_indices = [column_idx for column_idx, column_label in enumerate(data["column_labels"][table_idx]) if column_label == 0]
206
+ column_indices += random.sample(unused_column_indices, min(len(unused_column_indices), num_top_k_columns - len(column_indices)))
207
+ random.shuffle(column_indices)
208
+
209
+ filtered_schema["schema_items"].append(
210
+ {
211
+ "table_name": table_names[table_idx],
212
+ "table_comment": table_comments[table_idx],
213
+ "column_names": [column_names[table_idx][column_idx] for column_idx in column_indices],
214
+ "column_comments": [column_comments[table_idx][column_idx] for column_idx in column_indices]
215
+ }
216
+ )
217
+
218
+ # replace the old schema with the filtered schema
219
+ data["schema"] = filtered_schema
220
+
221
+ if dataset_type == "train":
222
+ del data["table_labels"]
223
+ del data["column_labels"]
224
+
225
+ return dataset
226
+
227
+ def lista_contains_listb(lista, listb):
228
+ for b in listb:
229
+ if b not in lista:
230
+ return 0
231
+
232
+ return 1
233
+
234
+ class SchemaItemClassifierInference():
235
+ def __init__(self, model_save_path):
236
+ set_seed(42)
237
+ # load tokenizer
238
+ self.tokenizer = AutoTokenizer.from_pretrained(model_save_path, add_prefix_space = True)
239
+ # initialize model
240
+ self.model = SchemaItemClassifier(model_save_path, "test")
241
+ # load fine-tuned params
242
+ self.model.load_state_dict(torch.load(model_save_path + "/dense_classifier.pt", map_location=torch.device('cpu')), strict=False)
243
+ if torch.cuda.is_available():
244
+ self.model = self.model.cuda()
245
+ self.model.eval()
246
+
247
+ def predict_one(self, sample):
248
+ encoder_input_ids, encoder_input_attention_mask, column_name_token_indices,\
249
+ table_name_token_indices, column_num_in_each_table = prepare_inputs_and_labels(sample, self.tokenizer)
250
+
251
+ with torch.no_grad():
252
+ model_outputs = self.model(
253
+ encoder_input_ids,
254
+ encoder_input_attention_mask,
255
+ [column_name_token_indices],
256
+ [table_name_token_indices],
257
+ [column_num_in_each_table]
258
+ )
259
+
260
+ table_logits = model_outputs["batch_table_name_cls_logits"][0]
261
+ table_pred_probs = torch.nn.functional.softmax(table_logits, dim = 1)[:, 1].cpu().tolist()
262
+
263
+ column_logits = model_outputs["batch_column_info_cls_logits"][0]
264
+ column_pred_probs = torch.nn.functional.softmax(column_logits, dim = 1)[:, 1].cpu().tolist()
265
+
266
+ splitted_column_pred_probs = []
267
+ # split predicted column probs into each table
268
+ for table_id, column_num in enumerate(column_num_in_each_table):
269
+ splitted_column_pred_probs.append(column_pred_probs[sum(column_num_in_each_table[:table_id]): sum(column_num_in_each_table[:table_id]) + column_num])
270
+ column_pred_probs = splitted_column_pred_probs
271
+
272
+ result_dict = dict()
273
+ for table_idx, table in enumerate(sample["schema"]["schema_items"]):
274
+ result_dict[table["table_name"]] = {
275
+ "table_name": table["table_name"],
276
+ "table_prob": table_pred_probs[table_idx],
277
+ "column_names": table["column_names"],
278
+ "column_probs": column_pred_probs[table_idx],
279
+ }
280
+
281
+ return result_dict
282
+
283
+ def predict(self, test_sample):
284
+ splitted_samples = split_sample(test_sample, self.tokenizer)
285
+ pred_results = []
286
+ for splitted_sample in splitted_samples:
287
+ pred_results.append(self.predict_one(splitted_sample))
288
+
289
+ return merge_pred_results(test_sample, pred_results)
290
+
291
+ def evaluate_coverage(self, dataset):
292
+ max_k = 100
293
+ total_num_for_table_coverage, total_num_for_column_coverage = 0, 0
294
+ table_coverage_results = [0]*max_k
295
+ column_coverage_results = [0]*max_k
296
+
297
+ for data in dataset:
298
+ indices_of_used_tables = [idx for idx, label in enumerate(data["table_labels"]) if label == 1]
299
+ pred_results = sic.predict(data)
300
+ # print(pred_results)
301
+ table_probs = [res["table_prob"] for res in pred_results]
302
+ for k in range(max_k):
303
+ indices_of_top_k_tables = np.argsort(-np.array(table_probs), kind="stable")[:k+1].tolist()
304
+ if lista_contains_listb(indices_of_top_k_tables, indices_of_used_tables):
305
+ table_coverage_results[k] += 1
306
+ total_num_for_table_coverage += 1
307
+
308
+ for table_idx in range(len(data["table_labels"])):
309
+ indices_of_used_columns = [idx for idx, label in enumerate(data["column_labels"][table_idx]) if label == 1]
310
+ if len(indices_of_used_columns) == 0:
311
+ continue
312
+ column_probs = pred_results[table_idx]["column_probs"]
313
+ for k in range(max_k):
314
+ indices_of_top_k_columns = np.argsort(-np.array(column_probs), kind="stable")[:k+1].tolist()
315
+ if lista_contains_listb(indices_of_top_k_columns, indices_of_used_columns):
316
+ column_coverage_results[k] += 1
317
+
318
+ total_num_for_column_coverage += 1
319
+
320
+ indices_of_top_10_columns = np.argsort(-np.array(column_probs), kind="stable")[:10].tolist()
321
+ if lista_contains_listb(indices_of_top_10_columns, indices_of_used_columns) == 0:
322
+ print(pred_results[table_idx])
323
+ print(data["column_labels"][table_idx])
324
+ print(data["question"])
325
+
326
+ print(total_num_for_table_coverage)
327
+ print(table_coverage_results)
328
+ print(total_num_for_column_coverage)
329
+ print(column_coverage_results)
330
+
331
+ if __name__ == "__main__":
332
+ dataset_name = "bird_with_evidence"
333
+ # dataset_name = "bird"
334
+ # dataset_name = "spider"
335
+ sic = SchemaItemClassifierInference("sic_ckpts/sic_{}".format(dataset_name))
336
+ import json
337
+ dataset = json.load(open("./data/sft_eval_{}_text2sql.json".format(dataset_name)))
338
+
339
+ sic.evaluate_coverage(dataset)
training_mode.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from schema_filter import filter_func
2
+
3
+ data = {
4
+ "text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
5
+ "sql": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1",
6
+ "schema": {
7
+ "schema_items": [
8
+ {
9
+ "table_name": "lists",
10
+ "table_comment": "",
11
+ "column_names": [
12
+ "user_id",
13
+ "list_id",
14
+ "list_title",
15
+ "list_movie_number",
16
+ "list_update_timestamp_utc",
17
+ "list_creation_timestamp_utc",
18
+ "list_followers",
19
+ "list_url",
20
+ "list_comments",
21
+ "list_description",
22
+ "list_cover_image_url",
23
+ "list_first_image_url",
24
+ "list_second_image_url",
25
+ "list_third_image_url"
26
+ ],
27
+ "column_comments": [
28
+ "",
29
+ "",
30
+ "",
31
+ "",
32
+ "",
33
+ "",
34
+ "",
35
+ "",
36
+ "",
37
+ "",
38
+ "",
39
+ "",
40
+ "",
41
+ ""
42
+ ]
43
+ },
44
+ {
45
+ "table_name": "movies",
46
+ "table_comment": "",
47
+ "column_names": [
48
+ "movie_id",
49
+ "movie_title",
50
+ "movie_release_year",
51
+ "movie_url",
52
+ "movie_title_language",
53
+ "movie_popularity",
54
+ "movie_image_url",
55
+ "director_id",
56
+ "director_name",
57
+ "director_url"
58
+ ],
59
+ "column_comments": [
60
+ "",
61
+ "",
62
+ "",
63
+ "",
64
+ "",
65
+ "",
66
+ "",
67
+ "",
68
+ "",
69
+ ""
70
+ ]
71
+ },
72
+ {
73
+ "table_name": "ratings_users",
74
+ "table_comment": "",
75
+ "column_names": [
76
+ "user_id",
77
+ "rating_date_utc",
78
+ "user_trialist",
79
+ "user_subscriber",
80
+ "user_avatar_image_url",
81
+ "user_cover_image_url",
82
+ "user_eligible_for_trial",
83
+ "user_has_payment_method"
84
+ ],
85
+ "column_comments": [
86
+ "",
87
+ "",
88
+ "",
89
+ "",
90
+ "",
91
+ "",
92
+ "",
93
+ ""
94
+ ]
95
+ },
96
+ {
97
+ "table_name": "lists_users",
98
+ "table_comment": "",
99
+ "column_names": [
100
+ "user_id",
101
+ "list_id",
102
+ "list_update_date_utc",
103
+ "list_creation_date_utc",
104
+ "user_trialist",
105
+ "user_subscriber",
106
+ "user_avatar_image_url",
107
+ "user_cover_image_url",
108
+ "user_eligible_for_trial",
109
+ "user_has_payment_method"
110
+ ],
111
+ "column_comments": [
112
+ "",
113
+ "",
114
+ "",
115
+ "",
116
+ "",
117
+ "",
118
+ "",
119
+ "",
120
+ "",
121
+ ""
122
+ ]
123
+ },
124
+ {
125
+ "table_name": "ratings",
126
+ "table_comment": "",
127
+ "column_names": [
128
+ "movie_id",
129
+ "rating_id",
130
+ "rating_url",
131
+ "rating_score",
132
+ "rating_timestamp_utc",
133
+ "critic",
134
+ "critic_likes",
135
+ "critic_comments",
136
+ "user_id",
137
+ "user_trialist",
138
+ "user_subscriber",
139
+ "user_eligible_for_trial",
140
+ "user_has_payment_method"
141
+ ],
142
+ "column_comments": [
143
+ "",
144
+ "",
145
+ "",
146
+ "",
147
+ "",
148
+ "",
149
+ "",
150
+ "",
151
+ "",
152
+ "",
153
+ "",
154
+ "",
155
+ ""
156
+ ]
157
+ }
158
+ ]
159
+ }
160
+ }
161
+
162
+ def find_used_tables_and_columns(dataset):
163
+ for data in dataset:
164
+ sql = data["sql"].lower()
165
+ data["table_labels"] = []
166
+ data["column_labels"] = []
167
+
168
+ for table_info in data["schema"]["schema_items"]:
169
+ table_name = table_info["table_name"]
170
+ data["table_labels"].append(1 if table_name.lower() in sql else 0)
171
+ data["column_labels"].append([1 if column_name.lower() in sql else 0 \
172
+ for column_name in table_info["column_names"]])
173
+ return dataset
174
+
175
+ dataset = [data]
176
+
177
+ # 根据sql找到用到的表和列
178
+ dataset = find_used_tables_and_columns(dataset)
179
+
180
+ # 最多保留数据库中的6张表
181
+ num_top_k_tables = 6
182
+ # 对于每张保留的表,最多保留其中6个列,所以输入的prompt中最多有6*6=36个列
183
+ num_top_k_columns = 6
184
+
185
+ # 对于训练数据,我们可以根据sql来模拟filter的过程,这时,sic(schema item classifier)是None就行,不需要用到模型
186
+ dataset = filter_func(
187
+ dataset = dataset,
188
+ dataset_type = "train",
189
+ sic = None,
190
+ num_top_k_tables = num_top_k_tables,
191
+ num_top_k_columns = num_top_k_columns
192
+ )
193
+
194
+ print(dataset)
utils/__pycache__/classifier_model.cpython-38.pyc ADDED
Binary file (4.01 kB). View file
 
utils/classifier_model.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import AutoConfig, XLMRobertaXLModel
5
+
6
+ class SchemaItemClassifier(nn.Module):
7
+ def __init__(self, model_name_or_path, mode):
8
+ super(SchemaItemClassifier, self).__init__()
9
+ if mode in ["eval", "test"]:
10
+ # load config
11
+ config = AutoConfig.from_pretrained(model_name_or_path)
12
+ # randomly initialize model's parameters according to the config
13
+ self.plm_encoder = XLMRobertaXLModel(config)
14
+ elif mode == "train":
15
+ self.plm_encoder = XLMRobertaXLModel.from_pretrained(model_name_or_path)
16
+ else:
17
+ raise ValueError()
18
+
19
+ self.plm_hidden_size = self.plm_encoder.config.hidden_size
20
+
21
+ # column cls head
22
+ self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
23
+ self.column_info_cls_head_linear2 = nn.Linear(256, 2)
24
+
25
+ # column bi-lstm layer
26
+ self.column_info_bilstm = nn.LSTM(
27
+ input_size = self.plm_hidden_size,
28
+ hidden_size = int(self.plm_hidden_size/2),
29
+ num_layers = 2,
30
+ dropout = 0,
31
+ bidirectional = True
32
+ )
33
+
34
+ # linear layer after column bi-lstm layer
35
+ self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
36
+
37
+ # table cls head
38
+ self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256)
39
+ self.table_name_cls_head_linear2 = nn.Linear(256, 2)
40
+
41
+ # table bi-lstm pooling layer
42
+ self.table_name_bilstm = nn.LSTM(
43
+ input_size = self.plm_hidden_size,
44
+ hidden_size = int(self.plm_hidden_size/2),
45
+ num_layers = 2,
46
+ dropout = 0,
47
+ bidirectional = True
48
+ )
49
+ # linear layer after table bi-lstm layer
50
+ self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size)
51
+
52
+ # activation function
53
+ self.leakyrelu = nn.LeakyReLU()
54
+ self.tanh = nn.Tanh()
55
+
56
+ # table-column cross-attention layer
57
+ self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8)
58
+
59
+ # dropout function, p=0.2 means randomly set 20% neurons to 0
60
+ self.dropout = nn.Dropout(p = 0.2)
61
+
62
+ def table_column_cross_attention(
63
+ self,
64
+ table_name_embeddings_in_one_db,
65
+ column_info_embeddings_in_one_db,
66
+ column_number_in_each_table
67
+ ):
68
+ table_num = table_name_embeddings_in_one_db.shape[0]
69
+ table_name_embedding_attn_list = []
70
+ for table_id in range(table_num):
71
+ table_name_embedding = table_name_embeddings_in_one_db[[table_id], :]
72
+ column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[
73
+ sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :]
74
+
75
+ table_name_embedding_attn, _ = self.table_column_cross_attention_layer(
76
+ table_name_embedding,
77
+ column_info_embeddings_in_one_table,
78
+ column_info_embeddings_in_one_table
79
+ )
80
+
81
+ table_name_embedding_attn_list.append(table_name_embedding_attn)
82
+
83
+ # residual connection
84
+ table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0)
85
+ # row-wise L2 norm
86
+ table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1)
87
+
88
+ return table_name_embeddings_in_one_db
89
+
90
+ def table_column_cls(
91
+ self,
92
+ encoder_input_ids,
93
+ encoder_input_attention_mask,
94
+ batch_aligned_column_info_ids,
95
+ batch_aligned_table_name_ids,
96
+ batch_column_number_in_each_table
97
+ ):
98
+ batch_size = encoder_input_ids.shape[0]
99
+
100
+ encoder_output = self.plm_encoder(
101
+ input_ids = encoder_input_ids,
102
+ attention_mask = encoder_input_attention_mask,
103
+ return_dict = True
104
+ ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size)
105
+
106
+ batch_table_name_cls_logits, batch_column_info_cls_logits = [], []
107
+
108
+ # handle each data in current batch
109
+ for batch_id in range(batch_size):
110
+ column_number_in_each_table = batch_column_number_in_each_table[batch_id]
111
+ sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size)
112
+
113
+ # obtain table ids for each table
114
+ aligned_table_name_ids = batch_aligned_table_name_ids[batch_id]
115
+ # obtain column ids for each column
116
+ aligned_column_info_ids = batch_aligned_column_info_ids[batch_id]
117
+
118
+ table_name_embedding_list, column_info_embedding_list = [], []
119
+
120
+ # obtain table embedding via bi-lstm pooling + a non-linear layer
121
+ for table_name_ids in aligned_table_name_ids:
122
+ table_name_embeddings = sequence_embeddings[table_name_ids, :]
123
+
124
+ # BiLSTM pooling
125
+ output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings)
126
+ table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size)
127
+ table_name_embedding_list.append(table_name_embedding)
128
+ table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0)
129
+ # non-linear mlp layer
130
+ table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db))
131
+
132
+ # obtain column embedding via bi-lstm pooling + a non-linear layer
133
+ for column_info_ids in aligned_column_info_ids:
134
+ column_info_embeddings = sequence_embeddings[column_info_ids, :]
135
+
136
+ # BiLSTM pooling
137
+ output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings)
138
+ column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size)
139
+ column_info_embedding_list.append(column_info_embedding)
140
+ column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0)
141
+ # non-linear mlp layer
142
+ column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db))
143
+
144
+ # table-column (tc) cross-attention
145
+ table_name_embeddings_in_one_db = self.table_column_cross_attention(
146
+ table_name_embeddings_in_one_db,
147
+ column_info_embeddings_in_one_db,
148
+ column_number_in_each_table
149
+ )
150
+
151
+ # calculate table 0-1 logits
152
+ table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db)
153
+ table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db))
154
+ table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db)
155
+
156
+ # calculate column 0-1 logits
157
+ column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db)
158
+ column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db))
159
+ column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db)
160
+
161
+ batch_table_name_cls_logits.append(table_name_cls_logits)
162
+ batch_column_info_cls_logits.append(column_info_cls_logits)
163
+
164
+ return batch_table_name_cls_logits, batch_column_info_cls_logits
165
+
166
+ def forward(
167
+ self,
168
+ encoder_input_ids,
169
+ encoder_attention_mask,
170
+ batch_aligned_column_info_ids,
171
+ batch_aligned_table_name_ids,
172
+ batch_column_number_in_each_table,
173
+ ):
174
+ batch_table_name_cls_logits, batch_column_info_cls_logits \
175
+ = self.table_column_cls(
176
+ encoder_input_ids,
177
+ encoder_attention_mask,
178
+ batch_aligned_column_info_ids,
179
+ batch_aligned_table_name_ids,
180
+ batch_column_number_in_each_table
181
+ )
182
+
183
+ return {
184
+ "batch_table_name_cls_logits" : batch_table_name_cls_logits,
185
+ "batch_column_info_cls_logits": batch_column_info_cls_logits
186
+ }