xusenlin commited on
Commit
6e00f90
1 Parent(s): 86a0650

Upload 9 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[UNK]": 39979
3
+ }
config.json CHANGED
@@ -1,19 +1,30 @@
1
  {
2
- "attention_probs_dropout_prob": 0.1,
3
- "hidden_act": "gelu",
4
- "hidden_dropout_prob": 0.1,
5
- "hidden_size": 768,
6
- "initializer_range": 0.02,
7
- "max_position_embeddings": 2048,
8
- "num_attention_heads": 12,
9
- "num_hidden_layers": 12,
10
- "task_type_vocab_size": 3,
11
- "type_vocab_size": 4,
12
- "use_task_id": true,
13
- "vocab_size": 40000,
14
- "architectures": [
15
- "UIE"
16
- ],
17
- "layer_norm_eps": 1e-12,
18
- "intermediate_size": 3072
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "_name_or_path": "uie_base_pytorch",
3
+ "architectures": [
4
+ "UIEModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModel": "modeling_uie.UIEModel"
9
+ },
10
+ "classifier_dropout": null,
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 768,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 3072,
16
+ "layer_norm_eps": 1e-12,
17
+ "max_position_embeddings": 2048,
18
+ "model_type": "ernie",
19
+ "num_attention_heads": 12,
20
+ "num_hidden_layers": 12,
21
+ "pad_token_id": 0,
22
+ "position_embedding_type": "absolute",
23
+ "task_type_vocab_size": 3,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.39.1",
26
+ "type_vocab_size": 4,
27
+ "use_cache": true,
28
+ "use_task_id": true,
29
+ "vocab_size": 40000
30
+ }
decode_utils.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from typing import (
5
+ List,
6
+ Union,
7
+ Any,
8
+ Optional,
9
+ )
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from tqdm import tqdm
15
+ from transformers import PreTrainedTokenizer
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_id_and_prob(spans, offset_map):
21
+ prompt_length = 0
22
+ for i in range(1, len(offset_map)):
23
+ if offset_map[i] != [0, 0]:
24
+ prompt_length += 1
25
+ else:
26
+ break
27
+
28
+ for i in range(1, prompt_length + 1):
29
+ offset_map[i][0] -= (prompt_length + 1)
30
+ offset_map[i][1] -= (prompt_length + 1)
31
+
32
+ sentence_id = []
33
+ prob = []
34
+ for start, end in spans:
35
+ prob.append(start[1] * end[1])
36
+ sentence_id.append(
37
+ (offset_map[start[0]][0], offset_map[end[0]][1]))
38
+ return sentence_id, prob
39
+
40
+
41
+ def get_span(start_ids, end_ids, with_prob=False):
42
+ """
43
+ Get span set from position start and end list.
44
+ Args:
45
+ start_ids (List[int]/List[tuple]): The start index list.
46
+ end_ids (List[int]/List[tuple]): The end index list.
47
+ with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
48
+ Returns:
49
+ set: The span set without overlapping, every id can only be used once.
50
+ """
51
+ if with_prob:
52
+ start_ids = sorted(start_ids, key=lambda x: x[0])
53
+ end_ids = sorted(end_ids, key=lambda x: x[0])
54
+ else:
55
+ start_ids = sorted(start_ids)
56
+ end_ids = sorted(end_ids)
57
+
58
+ start_pointer = 0
59
+ end_pointer = 0
60
+ len_start = len(start_ids)
61
+ len_end = len(end_ids)
62
+ couple_dict = {}
63
+
64
+ # 将每一个span的首/尾token的id进行配对(就近匹配,默认没有overlap的情况)
65
+ while start_pointer < len_start and end_pointer < len_end:
66
+ if with_prob:
67
+ start_id = start_ids[start_pointer][0]
68
+ end_id = end_ids[end_pointer][0]
69
+ else:
70
+ start_id = start_ids[start_pointer]
71
+ end_id = end_ids[end_pointer]
72
+
73
+ if start_id == end_id:
74
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
75
+ start_pointer += 1
76
+ end_pointer += 1
77
+ continue
78
+
79
+ if start_id < end_id:
80
+ couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
81
+ start_pointer += 1
82
+ continue
83
+
84
+ if start_id > end_id:
85
+ end_pointer += 1
86
+ continue
87
+
88
+ result = [(couple_dict[end], end) for end in couple_dict]
89
+ result = set(result)
90
+ return result
91
+
92
+
93
+ def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
94
+ """
95
+ Get idx of the last dimension in probability arrays, which is greater than a limitation.
96
+ Args:
97
+ probs (List[List[float]]): The input probability arrays.
98
+ limit (float): The limitation for probability.
99
+ return_prob (bool): Whether to return the probability
100
+ Returns:
101
+ List[List[int]]: The index of the last dimension meet the conditions.
102
+ """
103
+ probs = np.array(probs)
104
+ dim_len = len(probs.shape)
105
+ if dim_len > 1:
106
+ result = []
107
+ for p in probs:
108
+ result.append(get_bool_ids_greater_than(p, limit, return_prob))
109
+ return result
110
+ else:
111
+ result = []
112
+ for i, p in enumerate(probs):
113
+ if p > limit:
114
+ if return_prob:
115
+ result.append((i, p))
116
+ else:
117
+ result.append(i)
118
+ return result
119
+
120
+
121
+ def dbc2sbc(s):
122
+ rs = ""
123
+ for char in s:
124
+ code = ord(char)
125
+ if code == 0x3000:
126
+ code = 0x0020
127
+ else:
128
+ code -= 0xfee0
129
+ if not (0x0021 <= code <= 0x7e):
130
+ rs += char
131
+ continue
132
+ rs += chr(code)
133
+ return rs
134
+
135
+
136
+ def cut_chinese_sent(para):
137
+ """
138
+ Cut the Chinese sentences more precisely, reference to
139
+ "https://blog.csdn.net/blmoistawinde/article/details/82379256".
140
+ """
141
+ para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
142
+ para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
143
+ para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
144
+ para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
145
+ para = para.rstrip()
146
+ return para.split("\n")
147
+
148
+
149
+ def auto_splitter(input_texts, max_text_len, split_sentence=False):
150
+ """
151
+ Split the raw texts automatically for model inference.
152
+ Args:
153
+ input_texts (List[str]): input raw texts.
154
+ max_text_len (int): cutting length.
155
+ split_sentence (bool): If True, sentence-level split will be performed.
156
+ return:
157
+ short_input_texts (List[str]): the short input texts for model inference.
158
+ input_mapping (dict): mapping between raw text and short input texts.
159
+ """
160
+ input_mapping = {}
161
+ short_input_texts = []
162
+ cnt_short = 0
163
+ for cnt_org, text in enumerate(input_texts):
164
+ sens = cut_chinese_sent(text) if split_sentence else [text]
165
+ for sen in sens:
166
+ lens = len(sen)
167
+ if lens <= max_text_len:
168
+ short_input_texts.append(sen)
169
+ if cnt_org in input_mapping:
170
+ input_mapping[cnt_org].append(cnt_short)
171
+ else:
172
+ input_mapping[cnt_org] = [cnt_short]
173
+ cnt_short += 1
174
+ else:
175
+ temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
176
+
177
+ short_input_texts.extend(temp_text_list)
178
+ short_idx = cnt_short
179
+ cnt_short += math.ceil(lens / max_text_len)
180
+ temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
181
+ if cnt_org in input_mapping:
182
+ input_mapping[cnt_org].extend(temp_text_id)
183
+ else:
184
+ input_mapping[cnt_org] = temp_text_id
185
+ return short_input_texts, input_mapping
186
+
187
+
188
+ class UIEDecoder(nn.Module):
189
+
190
+ keys_to_ignore_on_gpu = ["offset_mapping", "texts"]
191
+
192
+ @torch.inference_mode()
193
+ def predict(
194
+ self,
195
+ tokenizer: PreTrainedTokenizer,
196
+ texts: Union[List[str], str],
197
+ schema: Optional[Any] = None,
198
+ batch_size: int = 64,
199
+ max_length: int = 512,
200
+ split_sentence: bool = False,
201
+ position_prob: float = 0.5,
202
+ is_english: bool = False,
203
+ disable_tqdm: bool = True,
204
+ ) -> List[Any]:
205
+ self.eval()
206
+ self.tokenizer = tokenizer
207
+ self.is_english = is_english
208
+ if schema is not None:
209
+ self.set_schema(schema)
210
+
211
+ texts = texts
212
+ if isinstance(texts, str):
213
+ texts = [texts]
214
+ return self._multi_stage_predict(
215
+ texts, batch_size, max_length, split_sentence, position_prob, disable_tqdm
216
+ )
217
+
218
+ def set_schema(self, schema):
219
+ if isinstance(schema, (dict, str)):
220
+ schema = [schema]
221
+ self._schema_tree = self._build_tree(schema)
222
+
223
+ def _multi_stage_predict(
224
+ self,
225
+ texts: List[str],
226
+ batch_size: int = 64,
227
+ max_length: int = 512,
228
+ split_sentence: bool = False,
229
+ position_prob: float = 0.5,
230
+ disable_tqdm: bool = True,
231
+ ) -> List[Any]:
232
+ """ Traversal the schema tree and do multi-stage prediction. """
233
+ results = [{} for _ in range(len(texts))]
234
+ if len(texts) < 1 or self._schema_tree is None:
235
+ return results
236
+
237
+ schema_list = self._schema_tree.children[:]
238
+ while len(schema_list) > 0:
239
+ node = schema_list.pop(0)
240
+ examples = []
241
+ input_map = {}
242
+ cnt = 0
243
+ idx = 0
244
+ if not node.prefix:
245
+ for data in texts:
246
+ examples.append({"text": data, "prompt": dbc2sbc(node.name)})
247
+ input_map[cnt] = [idx]
248
+ idx += 1
249
+ cnt += 1
250
+ else:
251
+ for pre, data in zip(node.prefix, texts):
252
+ if len(pre) == 0:
253
+ input_map[cnt] = []
254
+ else:
255
+ for p in pre:
256
+ if self.is_english:
257
+ if re.search(r'\[.*?\]$', node.name):
258
+ prompt_prefix = node.name[:node.name.find("[", 1)].strip()
259
+ cls_options = re.search(r'\[.*?\]$', node.name).group()
260
+ # Sentiment classification of xxx [positive, negative]
261
+ prompt = prompt_prefix + p + " " + cls_options
262
+ else:
263
+ prompt = node.name + p
264
+ else:
265
+ prompt = p + node.name
266
+ examples.append(
267
+ {
268
+ "text": data,
269
+ "prompt": dbc2sbc(prompt)
270
+ }
271
+ )
272
+ input_map[cnt] = [i + idx for i in range(len(pre))]
273
+ idx += len(pre)
274
+ cnt += 1
275
+
276
+ result_list = self._single_stage_predict(
277
+ examples, batch_size, max_length, split_sentence, position_prob, disable_tqdm
278
+ ) if examples else []
279
+ if not node.parent_relations:
280
+ relations = [[] for _ in range(len(texts))]
281
+ for k, v in input_map.items():
282
+ for idx in v:
283
+ if len(result_list[idx]) == 0:
284
+ continue
285
+ if node.name not in results[k].keys():
286
+ results[k][node.name] = result_list[idx]
287
+ else:
288
+ results[k][node.name].extend(result_list[idx])
289
+ if node.name in results[k].keys():
290
+ relations[k].extend(results[k][node.name])
291
+ else:
292
+ relations = node.parent_relations
293
+ for k, v in input_map.items():
294
+ for i in range(len(v)):
295
+ if len(result_list[v[i]]) == 0:
296
+ continue
297
+ if "relations" not in relations[k][i].keys():
298
+ relations[k][i]["relations"] = {node.name: result_list[v[i]]}
299
+ elif node.name not in relations[k][i]["relations"].keys():
300
+ relations[k][i]["relations"][node.name] = result_list[v[i]]
301
+ else:
302
+ relations[k][i]["relations"][node.name].extend(result_list[v[i]])
303
+
304
+ new_relations = [[] for _ in range(len(texts))]
305
+ for i in range(len(relations)):
306
+ for j in range(len(relations[i])):
307
+ if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys():
308
+ for k in range(len(relations[i][j]["relations"][node.name])):
309
+ new_relations[i].append(relations[i][j]["relations"][node.name][k])
310
+ relations = new_relations
311
+
312
+ prefix = [[] for _ in range(len(texts))]
313
+ for k, v in input_map.items():
314
+ for idx in v:
315
+ for i in range(len(result_list[idx])):
316
+ if self.is_english:
317
+ prefix[k].append(" of " + result_list[idx][i]["text"])
318
+ else:
319
+ prefix[k].append(result_list[idx][i]["text"] + "的")
320
+
321
+ for child in node.children:
322
+ child.prefix = prefix
323
+ child.parent_relations = relations
324
+ schema_list.append(child)
325
+
326
+ return results
327
+
328
+ def _convert_ids_to_results(self, examples, sentence_ids, probs):
329
+ """ Convert ids to raw text in a single stage. """
330
+ results = []
331
+ for example, sentence_id, prob in zip(examples, sentence_ids, probs):
332
+ if len(sentence_id) == 0:
333
+ results.append([])
334
+ continue
335
+ result_list = []
336
+ text = example["text"]
337
+ prompt = example["prompt"]
338
+ for i in range(len(sentence_id)):
339
+ start, end = sentence_id[i]
340
+ if start < 0 and end >= 0:
341
+ continue
342
+ if end < 0:
343
+ start += len(prompt) + 1
344
+ end += len(prompt) + 1
345
+ result = {"text": prompt[start: end], "probability": prob[i]}
346
+ else:
347
+ result = {"text": text[start: end], "start": start, "end": end, "probability": prob[i]}
348
+
349
+ result_list.append(result)
350
+ results.append(result_list)
351
+ return results
352
+
353
+ def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
354
+ """
355
+ Split the raw texts automatically for model inference.
356
+ Args:
357
+ input_texts (List[str]): input raw texts.
358
+ max_text_len (int): cutting length.
359
+ split_sentence (bool): If True, sentence-level split will be performed.
360
+ return:
361
+ short_input_texts (List[str]): the short input texts for model inference.
362
+ input_mapping (dict): mapping between raw text and short input texts.
363
+ """
364
+ input_mapping = {}
365
+ short_input_texts = []
366
+ cnt_short = 0
367
+ for cnt_org, text in enumerate(input_texts):
368
+ sens = cut_chinese_sent(text) if split_sentence else [text]
369
+ for sen in sens:
370
+ lens = len(sen)
371
+ if lens <= max_text_len:
372
+ short_input_texts.append(sen)
373
+ if cnt_org in input_mapping:
374
+ input_mapping[cnt_org].append(cnt_short)
375
+ else:
376
+ input_mapping[cnt_org] = [cnt_short]
377
+ cnt_short += 1
378
+ else:
379
+ temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)]
380
+
381
+ short_input_texts.extend(temp_text_list)
382
+ short_idx = cnt_short
383
+ cnt_short += math.ceil(lens / max_text_len)
384
+ temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
385
+ if cnt_org in input_mapping:
386
+ input_mapping[cnt_org].extend(temp_text_id)
387
+ else:
388
+ input_mapping[cnt_org] = temp_text_id
389
+ return short_input_texts, input_mapping
390
+
391
+ def _single_stage_predict(
392
+ self,
393
+ inputs: List[dict],
394
+ batch_size: int = 64,
395
+ max_length: int = 512,
396
+ split_sentence: bool = False,
397
+ position_prob: float = 0.5,
398
+ disable_tqdm: bool = True,
399
+ ):
400
+ input_texts = []
401
+ prompts = []
402
+ for i in range(len(inputs)):
403
+ input_texts.append(inputs[i]["text"])
404
+ prompts.append(inputs[i]["prompt"])
405
+ # max predict length should exclude the length of prompt and summary tokens
406
+ max_predict_len = max_length - len(max(prompts)) - 3
407
+
408
+ short_input_texts, input_mapping = self._auto_splitter(
409
+ input_texts, max_predict_len, split_sentence=split_sentence
410
+ )
411
+
412
+ short_texts_prompts = []
413
+ for k, v in input_mapping.items():
414
+ short_texts_prompts.extend([prompts[k] for _ in range(len(v))])
415
+ short_inputs = [
416
+ {
417
+ "text": short_input_texts[i],
418
+ "prompt": short_texts_prompts[i]
419
+ }
420
+ for i in range(len(short_input_texts))
421
+ ]
422
+
423
+ encoded_inputs = self.tokenizer(
424
+ text=short_texts_prompts,
425
+ text_pair=short_input_texts,
426
+ stride=2,
427
+ truncation=True,
428
+ max_length=max_length,
429
+ padding="longest",
430
+ add_special_tokens=True,
431
+ return_offsets_mapping=True,
432
+ return_tensors="np")
433
+ offset_maps = encoded_inputs["offset_mapping"]
434
+
435
+ start_prob_concat, end_prob_concat = [], []
436
+ if disable_tqdm:
437
+ batch_iterator = range(0, len(short_input_texts), batch_size)
438
+ else:
439
+ batch_iterator = tqdm(range(0, len(short_input_texts), batch_size), desc="Predicting", unit="batch")
440
+ for batch_start in batch_iterator:
441
+ batch = {
442
+ key:
443
+ np.array(value[batch_start: batch_start + batch_size], dtype="int64")
444
+ for key, value in encoded_inputs.items() if key not in self.keys_to_ignore_on_gpu
445
+ }
446
+
447
+ for k, v in batch.items():
448
+ batch[k] = torch.LongTensor(v, device=self.device)
449
+
450
+ outputs = self(**batch)
451
+ start_prob, end_prob = outputs[0], outputs[1]
452
+ if self.device != torch.device("cpu"):
453
+ start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
454
+ start_prob_concat.append(start_prob.detach().numpy())
455
+ end_prob_concat.append(end_prob.detach().numpy())
456
+
457
+ start_prob_concat = np.concatenate(start_prob_concat)
458
+ end_prob_concat = np.concatenate(end_prob_concat)
459
+
460
+ start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=position_prob, return_prob=True)
461
+ end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=position_prob, return_prob=True)
462
+
463
+ input_ids = encoded_inputs['input_ids'].tolist()
464
+ sentence_ids, probs = [], []
465
+ for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, input_ids, offset_maps):
466
+ span_list = get_span(start_ids, end_ids, with_prob=True)
467
+ sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist())
468
+ sentence_ids.append(sentence_id)
469
+ probs.append(prob)
470
+
471
+ results = self._convert_ids_to_results(short_inputs, sentence_ids, probs)
472
+ results = self._auto_joiner(results, short_input_texts, input_mapping)
473
+ return results
474
+
475
+ def _auto_joiner(self, short_results, short_inputs, input_mapping):
476
+ concat_results = []
477
+ is_cls_task = False
478
+ for short_result in short_results:
479
+ if not short_result:
480
+ continue
481
+ elif 'start' not in short_result[0].keys() and 'end' not in short_result[0].keys():
482
+ is_cls_task = True
483
+ break
484
+ else:
485
+ break
486
+ for k, vs in input_mapping.items():
487
+ single_results = []
488
+ if is_cls_task:
489
+ cls_options = {}
490
+ for v in vs:
491
+ if len(short_results[v]) == 0:
492
+ continue
493
+ if short_results[v][0]['text'] in cls_options:
494
+ cls_options[short_results[v][0]["text"]][0] += 1
495
+ cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"]
496
+
497
+ else:
498
+ cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]]
499
+
500
+ if cls_options:
501
+ cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1])
502
+ concat_results.append(
503
+ [
504
+ {"text": cls_res, "probability": cls_info[1] / cls_info[0]}
505
+ ]
506
+ )
507
+
508
+ else:
509
+ concat_results.append([])
510
+ else:
511
+ offset = 0
512
+ for v in vs:
513
+ if v == 0:
514
+ single_results = short_results[v]
515
+ offset += len(short_inputs[v])
516
+ else:
517
+ for i in range(len(short_results[v])):
518
+ if 'start' not in short_results[v][i] or 'end' not in short_results[v][i]:
519
+ continue
520
+ short_results[v][i]["start"] += offset
521
+ short_results[v][i]["end"] += offset
522
+ offset += len(short_inputs[v])
523
+ single_results.extend(short_results[v])
524
+ concat_results.append(single_results)
525
+ return concat_results
526
+
527
+ @classmethod
528
+ def _build_tree(cls, schema, name='root'):
529
+ """
530
+ Build the schema tree.
531
+ """
532
+ schema_tree = SchemaTree(name)
533
+ for s in schema:
534
+ if isinstance(s, str):
535
+ schema_tree.add_child(SchemaTree(s))
536
+ elif isinstance(s, dict):
537
+ for k, v in s.items():
538
+ if isinstance(v, str):
539
+ child = [v]
540
+ elif isinstance(v, list):
541
+ child = v
542
+ else:
543
+ raise TypeError(
544
+ f"Invalid schema, value for each key:value pairs should be list or string"
545
+ f"but {type(v)} received")
546
+ schema_tree.add_child(cls._build_tree(child, name=k))
547
+ else:
548
+ raise TypeError(f"Invalid schema, element should be string or dict, but {type(s)} received")
549
+
550
+ return schema_tree
551
+
552
+
553
+ class SchemaTree(object):
554
+ """
555
+ Implementation of SchemaTree
556
+ """
557
+
558
+ def __init__(self, name='root', children=None):
559
+ self.name = name
560
+ self.children = []
561
+ self.prefix = None
562
+ self.parent_relations = None
563
+ if children is not None:
564
+ for child in children:
565
+ self.add_child(child)
566
+
567
+ def __repr__(self):
568
+ return self.name
569
+
570
+ def add_child(self, node):
571
+ assert isinstance(
572
+ node, SchemaTree
573
+ ), "The children of a node should be an instance of SchemaTree."
574
+ self.children.append(node)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3479e5df1444559754f8d5369270d5a15cf40a9e54bbcb5d06ee800888b68fe7
3
+ size 471809912
modeling_uie.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import ErnieModel, ErniePreTrainedModel, PretrainedConfig
7
+ from transformers.file_utils import ModelOutput
8
+
9
+ from .decode_utils import UIEDecoder
10
+
11
+
12
+ @dataclass
13
+ class UIEModelOutput(ModelOutput):
14
+ """
15
+ Output class for outputs of UIE.
16
+ losses (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
17
+ Total spn extraction losses is the sum of a Cross-Entropy for the start and end positions.
18
+ start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
19
+ Span-start scores (after Sigmoid).
20
+ end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
21
+ Span-end scores (after Sigmoid).
22
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
23
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layers, +
24
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
25
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
26
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
27
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
28
+ sequence_length)`.
29
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
30
+ heads.
31
+ """
32
+ loss: Optional[torch.FloatTensor] = None
33
+ start_prob: torch.FloatTensor = None
34
+ end_prob: torch.FloatTensor = None
35
+ start_positions: torch.FloatTensor = None
36
+ end_positions: torch.FloatTensor = None
37
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
38
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
39
+
40
+
41
+ class UIEModel(ErniePreTrainedModel, UIEDecoder):
42
+ """
43
+ UIE model based on Bert model.
44
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
45
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
46
+ etc.)
47
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
48
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
49
+ and behavior.
50
+ Parameters:
51
+ config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
52
+ Initializing with a config file does not load the weights associated with the model, only the
53
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
54
+ """
55
+
56
+ def __init__(self, config: PretrainedConfig):
57
+ super(UIEModel, self).__init__(config)
58
+ self.encoder = ErnieModel(config)
59
+ self.config = config
60
+ hidden_size = self.config.hidden_size
61
+
62
+ self.linear_start = nn.Linear(hidden_size, 1)
63
+ self.linear_end = nn.Linear(hidden_size, 1)
64
+ self.sigmoid = nn.Sigmoid()
65
+
66
+ self.post_init()
67
+
68
+ def forward(
69
+ self,
70
+ input_ids: Optional[torch.Tensor] = None,
71
+ token_type_ids: Optional[torch.Tensor] = None,
72
+ position_ids: Optional[torch.Tensor] = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ head_mask: Optional[torch.Tensor] = None,
75
+ inputs_embeds: Optional[torch.Tensor] = None,
76
+ start_positions: Optional[torch.Tensor] = None,
77
+ end_positions: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ ) -> UIEModelOutput:
81
+ """
82
+ Args:
83
+ input_ids (`torch.LongTensor` of shape `({0})`):
84
+ Indices of input sequence tokens in the vocabulary.
85
+ Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
86
+ [`PreTrainedTokenizer.__call__`] for details.
87
+ [What are input IDs?](../glossary#input-ids)
88
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
89
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
90
+ - 1 for tokens that are **not masked**,
91
+ - 0 for tokens that are **masked**.
92
+ [What are attention masks?](../glossary#attention-mask)
93
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
94
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
95
+ 1]`:
96
+ - 0 corresponds to a *sentence A* token,
97
+ - 1 corresponds to a *sentence B* token.
98
+ [What are token type IDs?](../glossary#token-type-ids)
99
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
100
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
101
+ config.max_position_embeddings - 1]`.
102
+ [What are position IDs?](../glossary#position-ids)
103
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
104
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
105
+ - 1 indicates the head is **not masked**,
106
+ - 0 indicates the head is **masked**.
107
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
108
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
109
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
110
+ model's internal embedding lookup matrix.
111
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
112
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
113
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
114
+ are not taken into account for computing the loss.
115
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
116
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
117
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outsides of the sequence
118
+ are not taken into account for computing the loss.
119
+ output_attentions (`bool`, *optional*):
120
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
121
+ tensors for more detail.
122
+ output_hidden_states (`bool`, *optional*):
123
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
124
+ more detail.
125
+ return_dict (`bool`, *optional*):
126
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
127
+ """
128
+ outputs = self.encoder(
129
+ input_ids=input_ids,
130
+ token_type_ids=token_type_ids,
131
+ position_ids=position_ids,
132
+ attention_mask=attention_mask,
133
+ head_mask=head_mask,
134
+ inputs_embeds=inputs_embeds,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ )
138
+ sequence_output = outputs[0]
139
+
140
+ start_logits = self.linear_start(sequence_output)
141
+ start_logits = torch.squeeze(start_logits, -1)
142
+ start_prob = self.sigmoid(start_logits)
143
+
144
+ end_logits = self.linear_end(sequence_output)
145
+ end_logits = torch.squeeze(end_logits, -1)
146
+ end_prob = self.sigmoid(end_logits)
147
+
148
+ total_loss = None
149
+ if start_positions is not None and end_positions is not None:
150
+ loss_fct = nn.BCELoss()
151
+ start_loss = loss_fct(start_prob, start_positions)
152
+ end_loss = loss_fct(end_prob, end_positions)
153
+
154
+ total_loss = (start_loss + end_loss) / 2.0
155
+
156
+ return UIEModelOutput(
157
+ loss=total_loss,
158
+ start_prob=start_prob,
159
+ end_prob=end_prob,
160
+ hidden_states=outputs.hidden_states,
161
+ attentions=outputs.attentions,
162
+ )
special_tokens_map.json CHANGED
@@ -1 +1,7 @@
1
- {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json CHANGED
@@ -1 +1,57 @@
1
- {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenizer_class": "BertTokenizer"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[CLS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[MASK]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "39979": {
36
+ "content": "[UNK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 1000000000000000019884624838656,
50
+ "never_split": null,
51
+ "pad_token": "[PAD]",
52
+ "sep_token": "[SEP]",
53
+ "strip_accents": null,
54
+ "tokenize_chinese_chars": true,
55
+ "tokenizer_class": "BertTokenizer",
56
+ "unk_token": "[UNK]"
57
+ }
vocab.txt CHANGED
@@ -12082,7 +12082,6 @@ _
12082
 
12083
 
12084
 
12085
- $
12086
  {
12087
  }
12088
 
@@ -18003,7 +18002,7 @@ $
18003
  π
18004
 
18005
 
18006
- /$
18007
 
18008
 
18009
  °
 
12082
 
12083
 
12084
 
 
12085
  {
12086
  }
12087
 
 
18002
  π
18003
 
18004
 
18005
+ $
18006
 
18007
 
18008
  °