han liu commited on
Commit
ff78ef7
1 Parent(s): ca2a245
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ ## @Author: liuhan(liuhan@idea.edu.cn)
4
+ ## @Created: 2022/12/28 11:24:43
5
+ # coding=utf-8
6
+ # Copyright 2021 The IDEA Authors. All rights reserved.
7
+
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ from typing import List, Dict
20
+ from logging import basicConfig
21
+ import json
22
+ import os
23
+ import numpy as np
24
+ from transformers import AutoTokenizer
25
+ import argparse
26
+ import copy
27
+ import streamlit as st
28
+ import time
29
+
30
+
31
+
32
+ from models import BagualuIEModel, BagualuIEExtractModel
33
+
34
+
35
+ class BagualuIEPipelines:
36
+ def __init__(self, args: argparse.Namespace) -> None:
37
+ self.args = args
38
+ # load model
39
+ self.model = BagualuIEModel.from_pretrained(args.pretrained_model_root)
40
+
41
+
42
+ # get tokenizer
43
+ added_token = [f"[unused{i + 1}]" for i in range(99)]
44
+ self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_root,
45
+ additional_special_tokens=added_token)
46
+
47
+ def predict(self, test_data: List[dict], cuda: bool = True) -> List[dict]:
48
+ """ predict
49
+
50
+ Args:
51
+ test_data (List[dict]): test data
52
+ cuda (bool, optional): cuda. Defaults to True.
53
+
54
+ Returns:
55
+ List[dict]: result
56
+ """
57
+ result = []
58
+ if cuda:
59
+ self.model = self.model.cuda()
60
+ self.model.eval()
61
+
62
+ batch_size = self.args.batch_size
63
+ extract_model = BagualuIEExtractModel(self.tokenizer, self.args)
64
+
65
+ for i in range(0, len(test_data), batch_size):
66
+ batch_data = test_data[i: i + batch_size]
67
+ batch_result = extract_model.extract(batch_data, self.model, cuda)
68
+ result.extend(batch_result)
69
+ return result
70
+
71
+
72
+ @st.experimental_memo()
73
+ def load_model(model_path):
74
+ parser = argparse.ArgumentParser()
75
+
76
+ # pipeline arguments
77
+ group_parser = parser.add_argument_group("piplines args")
78
+ group_parser.add_argument("--pretrained_model_root", default="", type=str)
79
+ group_parser.add_argument("--load_checkpoints_path", default="", type=str)
80
+
81
+ group_parser.add_argument("--threshold_ent", default=0.3, type=float)
82
+ group_parser.add_argument("--threshold_rel", default=0.3, type=float)
83
+ group_parser.add_argument("--entity_multi_label", action="store_true", default=True)
84
+ group_parser.add_argument("--relation_multi_label", action="store_true", default=True)
85
+
86
+
87
+ # data model arguments
88
+ group_parser = parser.add_argument_group("data_model")
89
+ group_parser.add_argument("--batch_size", default=4, type=int)
90
+ group_parser.add_argument("--max_length", default=512, type=int)
91
+ # pytorch_lightning.Trainer参数
92
+ args = parser.parse_args()
93
+ args.pretrained_model_root = model_path
94
+
95
+ model = BagualuIEPipelines(args)
96
+ return model
97
+
98
+ def main():
99
+
100
+ # model = load_model('/cognitive_comp/liuhan/pretrained/uniex_macbert_base_v7.1/')
101
+ model = load_model('IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese')
102
+
103
+ #
104
+
105
+ st.subheader("Erlangshen-BERT-120M-IE-Chinese Zero-shot 体验")
106
+
107
+
108
+
109
+ st.markdown("""
110
+ Erlangshen-BERT-120M-IE-Chinese是以110M参数的base模型为底座,基于大规模信息抽取数据进行预训练后的模型,
111
+ 通过统一的抽取架构设计,可支持few-shot、zero-shot场景下的实体识别、关系三元组抽取任务。
112
+ 更多信息见https://github.com/IDEA-CCNL/GTS-Engine/tree/main
113
+ 模型效果见https://huggingface.co/IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese
114
+ """)
115
+
116
+ st.info("Please input the following information to experiencing Bagualu-IE「请输入以下信息开始体验 Bagualu-IE...」")
117
+ model_type = st.selectbox('Select task type「选择任务类型」',['Named Entity Recognition「命名实体识别」','Relation Extraction「关系抽取」'])
118
+ if '命名实体识别' in model_type:
119
+ example = st.selectbox('Example', ['Example: 人物信息', 'Example: 财经新闻'])
120
+ else:
121
+ example = st.selectbox('Example', ['Example: 雇佣关系', 'Example: 影视关系'])
122
+ form = st.form("参数设置")
123
+ if '命名实体识别' in model_type:
124
+ if '人物信息' in example:
125
+ sentences = form.text_area(
126
+ "Please input the context「请输入句子」",
127
+ "姚明,男,汉族,无党派人士,前中国职业篮球运动员。")
128
+ choice = form.text_input("Please input the choice「请输入抽取���体名称,用中文;分割」", "姓名;性别;民族;运动项目;政治面貌")
129
+ else:
130
+ sentences = form.text_area(
131
+ "Please input the context「请输入句子」",
132
+ "寒流吹响华尔街,摩根士丹利、高盛、瑞信三大银行裁员合计超过8千人")
133
+ choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "裁员单位;裁员人数")
134
+
135
+ else:
136
+ if '雇佣关系' in example:
137
+ sentences = form.text_area(
138
+ "Please input the context「请输入句子」",
139
+ "东阳市企业家协会六届一次会员大会上,横店集团董事长、总裁徐永安当选为东阳市企业家协会会长。")
140
+ choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "企业|董事长|人物")
141
+ else:
142
+ sentences = form.text_area(
143
+ "Please input the context「请输入句子」",
144
+ "《傲骨贤妻第六季》是一套美国法律剧情电视连续剧,2014年9月29日在CBS上首播。")
145
+ choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "影视作品|上映时间|时间")
146
+
147
+ form.form_submit_button("Submit「点击一下,开始预测!」")
148
+
149
+
150
+ if '命名实体识别' in model_type:
151
+ data = [{"task": '实体识别',
152
+ "text": sentences,
153
+ "entity_list": [],
154
+ "choice": choice.split(';'),
155
+ }]
156
+ else:
157
+ choice = [one.split('|') for one in choice.split(';')]
158
+ data = [{"task": '关系抽取',
159
+ "text": sentences,
160
+ "entity_list": [],
161
+ "choice": choice,
162
+ }]
163
+
164
+
165
+ start = time.time()
166
+ # is_cuda= True if torch.cuda.is_available() else False
167
+ # result = model.predict(data, cuda=is_cuda)
168
+
169
+ # st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
170
+ # st.json(result[0])
171
+
172
+ rs = model.predict(data, False)
173
+ st.success(f"Prediction is successful, consumes {str(time.time() - start)} seconds")
174
+ st.json(rs[0])
175
+
176
+
177
+
178
+
179
+
180
+ if __name__ == "__main__":
181
+ main()
182
+
dataloaders/__init__.py ADDED
File without changes
dataloaders/dataset_utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import collections
17
+ from typing import List, Dict, Tuple
18
+
19
+
20
+ def get_choice(spo_choice: list) -> tuple:
21
+ """ 把关系schema中的关系、实体获取出来
22
+
23
+ Args:
24
+ spo_choice (list): 关系schema
25
+
26
+ Returns:
27
+ tuple:
28
+ choice_ent (list)
29
+ choice_rel (list)
30
+ choice_head (list)
31
+ choice_tail (list)
32
+ entity2rel (dict)
33
+ """
34
+ choice_head = []
35
+ choice_tail = []
36
+ choice_ent = []
37
+ choice_rel = []
38
+ entity2rel = collections.defaultdict(list) # "subject|object" -> [relation]
39
+
40
+ for head, rel, tail in spo_choice:
41
+
42
+ if head not in choice_head:
43
+ choice_head.append(head)
44
+ if tail not in choice_tail:
45
+ choice_tail.append(tail)
46
+
47
+ if head not in choice_ent:
48
+ choice_ent.append(head)
49
+ if tail not in choice_ent:
50
+ choice_ent.append(tail)
51
+
52
+ if rel not in choice_rel:
53
+ choice_rel.append(rel)
54
+
55
+ entity2rel[head, tail].append(rel)
56
+
57
+ return choice_ent, choice_rel, choice_head, choice_tail, entity2rel
dataloaders/item_decoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # from collections import defaultdict
17
+ from typing import List, Tuple, Dict
18
+ import argparse
19
+ import numpy as np
20
+ from transformers import PreTrainedTokenizer
21
+
22
+ from .item_encoder import entity_based_tokenize, get_entity_indices
23
+ from .dataset_utils import get_choice
24
+
25
+
26
+ class ItemDecoder(object):
27
+ """ Decoder
28
+
29
+ Args:
30
+ tokenizer (PreTrainedTokenizer): tokenizer
31
+ args (TrainingArgumentsIEStd): arguments
32
+ """
33
+ def __init__(self,
34
+ tokenizer: PreTrainedTokenizer,
35
+ args: argparse.Namespace) -> None:
36
+ self.tokenizer = tokenizer
37
+ self.max_length = args.max_length
38
+ self.threshold_entity = args.threshold_ent
39
+ self.threshold_rel = args.threshold_rel
40
+ self.entity_multi_label = args.entity_multi_label
41
+ self.relation_multi_label = args.relation_multi_label
42
+
43
+ def extract_entity_index(self,
44
+ entity_logits: np.ndarray,
45
+ ) -> List[Tuple[int, int]]:
46
+ """ extract entity index
47
+
48
+ Args:
49
+ entity_logits (np.ndarray): entity_logits
50
+
51
+ Returns:
52
+ List[Tuple[int, int]]: result
53
+ """
54
+
55
+ l, _, d = entity_logits.shape
56
+ result = []
57
+ for i in range(l):
58
+ for j in range(i, l):
59
+ if self.entity_multi_label:
60
+ for k in range(d):
61
+ entity_score = float(entity_logits[i, j, k])
62
+ if entity_score > self.threshold_entity:
63
+ result.append((i, j, k, entity_score))
64
+
65
+ else:
66
+ k = np.argmax(entity_logits[i, j])
67
+ entity_score = float(entity_logits[i, j, k])
68
+ if entity_score > self.threshold_entity:
69
+ result.append((i, j, k, entity_score))
70
+
71
+ return result
72
+
73
+ @staticmethod
74
+ def extract_entity(text: str,
75
+ entity_idx: List[int],
76
+ entity_type: str,
77
+ entity_score: float,
78
+ text_start_id: int,
79
+ offset_mapping: List[List[int]]) -> dict:
80
+ """ extract entity
81
+
82
+ Args:
83
+ text (str): text
84
+ entity_idx (List[int]): entity indices
85
+ entity_type (str): entity type
86
+ entity_score (float): entity score
87
+ text_start_id (int): text_start_id
88
+ offset_mapping (List[List[int]]): offset mapping
89
+
90
+ Returns:
91
+ dict: entity
92
+ """
93
+ entity_start, entity_end = entity_idx[0] - text_start_id, entity_idx[1] - text_start_id
94
+
95
+ start_split = offset_mapping[entity_start] if 0 <= entity_start < len(offset_mapping) else []
96
+ end_split = offset_mapping[entity_end] if 0 <= entity_end < len(offset_mapping) else []
97
+
98
+ if not start_split or not end_split:
99
+ return None
100
+
101
+ start_idx, end_idx = start_split[0], end_split[-1]
102
+ entity_text = text[start_idx: end_idx]
103
+
104
+ if not entity_text:
105
+ return None
106
+
107
+ entity = {
108
+ "entity_text": entity_text,
109
+ "entity_type": entity_type,
110
+ "score": entity_score,
111
+ "entity_index": [start_idx, end_idx]
112
+ }
113
+
114
+ return entity
115
+
116
+ def decode_ner(self,
117
+ text: str,
118
+ choice: List[str],
119
+ sample_span_logits: np.ndarray,
120
+ offset_mapping: List[List[int]]
121
+ ) -> List[dict]:
122
+ """ NER decode
123
+
124
+ Args:
125
+ text (str): text
126
+ choice (List[str]): choice
127
+ sample_span_logits (np.ndarray): sample span_logits
128
+ offset_mapping (List[List[int]]): offset mapping
129
+
130
+
131
+ Returns:
132
+ List[dict]: decoded entity list
133
+ """
134
+ entity_list = []
135
+
136
+ entity_idx_list = self.extract_entity_index(sample_span_logits)
137
+
138
+ for entity_start, entity_end, entity_type_idx, entity_score in entity_idx_list:
139
+
140
+ entity = self.extract_entity(text,
141
+ [entity_start, entity_end],
142
+ choice[entity_type_idx],
143
+ entity_score,
144
+ text_start_id=1,
145
+ offset_mapping=offset_mapping)
146
+
147
+ if entity is None:
148
+ continue
149
+
150
+ if entity not in entity_list:
151
+ entity_list.append(entity)
152
+
153
+ return entity_list
154
+
155
+ def decode_spo(self,
156
+ text: str,
157
+ choice: List[List[str]],
158
+ sample_span_logits: np.ndarray,
159
+ offset_mapping: List[List[int]]) -> tuple:
160
+ """ SPO decode
161
+
162
+ Args:
163
+ text (str): text
164
+ choice (List[List[str]]): choice
165
+ sample_span_logits (np.ndarray): sample span_logits
166
+ offset_mapping (List[List[int]): offset mapping
167
+
168
+ Returns:
169
+ List[dict]: decoded spo list
170
+ List[dict]: decoded entity list
171
+ """
172
+ spo_list = []
173
+ entity_list = []
174
+
175
+ choice_ent, choice_rel, choice_head, choice_tail, entity2rel = get_choice(choice)
176
+
177
+ entity_logits = sample_span_logits[:, :, : len(choice_ent)] # (seq_len, seq_len, num_entity)
178
+ relation_logits = sample_span_logits[:, :, len(choice_ent): ] # (seq_len, seq_len, num_relation)
179
+
180
+ entity_idx_list = self.extract_entity_index(entity_logits)
181
+
182
+ head_list = []
183
+ tail_list = []
184
+ for entity_start, entity_end, entity_type_idx, entity_score in entity_idx_list:
185
+
186
+ entity_type = choice_ent[entity_type_idx]
187
+
188
+ entity = self.extract_entity(text,
189
+ [entity_start, entity_end],
190
+ entity_type,
191
+ entity_score,
192
+ text_start_id=1,
193
+ offset_mapping=offset_mapping)
194
+
195
+ if entity is None:
196
+ continue
197
+
198
+ if entity_type in choice_head:
199
+ head_list.append((entity_start, entity_end, entity_type, entity))
200
+ if entity_type in choice_tail:
201
+ tail_list.append((entity_start, entity_end, entity_type, entity))
202
+
203
+ for head_start, head_end, subject_type, subject_dict in head_list:
204
+ for tail_start, tail_end, object_type, object_dict in tail_list:
205
+
206
+ if subject_dict == object_dict:
207
+ continue
208
+
209
+ if (subject_type, object_type) not in entity2rel.keys():
210
+ continue
211
+
212
+ relation_candidates = entity2rel[subject_type, object_type]
213
+ rel_idx = [choice_rel.index(r) for r in relation_candidates]
214
+
215
+ so_rel_logits = relation_logits[:, :, rel_idx]
216
+
217
+ if self.relation_multi_label:
218
+ for idx, predicate in enumerate(relation_candidates):
219
+ rel_score = so_rel_logits[head_start, tail_start, idx] + \
220
+ so_rel_logits[head_end, tail_end, idx]
221
+ predicate_score = float(rel_score / 2)
222
+
223
+ if predicate_score <= self.threshold_rel:
224
+ continue
225
+
226
+ if subject_dict not in entity_list:
227
+ entity_list.append(subject_dict)
228
+ if object_dict not in entity_list:
229
+ entity_list.append(object_dict)
230
+
231
+ spo = {
232
+ "predicate": predicate,
233
+ "score": predicate_score,
234
+ "subject": subject_dict,
235
+ "object": object_dict,
236
+ }
237
+
238
+ if spo not in spo_list:
239
+ spo_list.append(spo)
240
+
241
+ else:
242
+
243
+ hh_idx = np.argmax(so_rel_logits[head_start, head_end])
244
+ tt_idx = np.argmax(so_rel_logits[tail_start, tail_end])
245
+ hh_score = so_rel_logits[head_start, tail_start, hh_idx] + so_rel_logits[head_end, tail_end, hh_idx]
246
+ tt_score = so_rel_logits[head_start, tail_start, tt_idx] + so_rel_logits[head_end, tail_end, tt_idx]
247
+
248
+ predicate = relation_candidates[hh_idx] if hh_score > tt_score else relation_candidates[tt_idx]
249
+
250
+ predicate_score = float(max(hh_score, tt_score) / 2)
251
+
252
+ if predicate_score <= self.threshold_rel:
253
+ continue
254
+
255
+ if subject_dict not in entity_list:
256
+ entity_list.append(subject_dict)
257
+ if object_dict not in entity_list:
258
+ entity_list.append(object_dict)
259
+
260
+ spo = {
261
+ "predicate": predicate,
262
+ "score": predicate_score,
263
+ "subject": subject_dict,
264
+ "object": object_dict,
265
+ }
266
+
267
+ if spo not in spo_list:
268
+ spo_list.append(spo)
269
+
270
+ return spo_list, entity_list
271
+
272
+ def decode(self,
273
+ item: Dict,
274
+ span_logits: np.ndarray,
275
+ label_mask: np.ndarray,
276
+ ):
277
+ """ decode
278
+
279
+ Args:
280
+ task (str): task name
281
+ choice (list): choice
282
+ text (str): text
283
+ span_logits (np.ndarray): sample span_logits
284
+ label_mask (np.ndarray): label_mask
285
+
286
+ Raises:
287
+ NotImplementedError: raised if task name is not supported
288
+
289
+ Returns:
290
+ List[dict]: decoded entity list
291
+ List[dict]: decoded spo list
292
+ """
293
+ task, choice, text = item["task"], item["choice"], item["text"]
294
+ entity_indices = get_entity_indices(item.get("entity_list", []), item.get("spo_list", []))
295
+ _, offset_mapping = entity_based_tokenize(text, self.tokenizer, entity_indices,
296
+ return_offsets_mapping=True)
297
+
298
+ assert span_logits.shape == label_mask.shape
299
+
300
+ span_logits = span_logits + (label_mask - 1) * 100000
301
+
302
+ spo_list = []
303
+ entity_list = []
304
+
305
+ if task in {"实体识别", "抽取任务"}:
306
+ entity_list = self.decode_ner(text,
307
+ choice,
308
+ span_logits,
309
+ offset_mapping)
310
+
311
+ elif task in {"关系抽取"}:
312
+ spo_list, entity_list = self.decode_spo(text,
313
+ choice,
314
+ span_logits,
315
+ offset_mapping)
316
+
317
+ else:
318
+ raise NotImplementedError
319
+
320
+ return entity_list, spo_list
dataloaders/item_encoder.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=no-member
17
+
18
+ from typing import List, Tuple, Dict, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ from transformers import PreTrainedTokenizer
24
+
25
+ from .dataset_utils import get_choice
26
+
27
+
28
+ def get_entity_indices(entity_list: List[dict], spo_list: List[dict]) -> List[List[int]]:
29
+ """ 获取样本中包含的实体位置信息
30
+
31
+ Args:
32
+ entity_list (List[dict]): 实体列表
33
+ spo_list (List[dict]): 三元组列表
34
+
35
+ Returns:
36
+ List[List[int]]: 实体位置信息
37
+ """
38
+ entity_indices = []
39
+
40
+ # 实体中的实体位置
41
+ for entity in entity_list:
42
+ entity_index = entity["entity_index"]
43
+ entity_indices.append(entity_index)
44
+
45
+ # 三元组中的实体位置
46
+ for spo in spo_list:
47
+ sub_idx = spo["subject"]["entity_index"]
48
+ obj_idx = spo["object"]["entity_index"]
49
+ entity_indices.append(sub_idx)
50
+ entity_indices.append(obj_idx)
51
+
52
+ return entity_indices
53
+
54
+
55
+ def entity_based_tokenize(text: str,
56
+ tokenizer: PreTrainedTokenizer,
57
+ enitity_indices: List[Tuple[int, int]],
58
+ max_len: int = -1,
59
+ return_offsets_mapping: bool = False) \
60
+ -> Union[List[int], Tuple[List[int], List[Tuple[int, int]]]]:
61
+ """ 基于实体位置信息的编码,确保实体为连续1到多个token的合并,同时利用预训练模型词根信息
62
+
63
+ Args:
64
+ text (str): 文本
65
+ tokenizer (PreTrainedTokenizer): tokenizer
66
+ enitity_indices (List[Tuple[int, int]]): 实体位置信息
67
+ max_len (int, optional): 长度限制. Defaults to -1.
68
+ return_offsets_mapping (bool, optional): 是否返回offsets_mapping. Defaults to False.
69
+
70
+ Returns:
71
+ Union[List[int], Tuple[List[int], List[Tuple[int, int]]]]: 编码id
72
+ """
73
+ # 根据实体位置遍历出需要对文本进行切割的点
74
+ split_points = sorted(list({i for idx in enitity_indices for i in idx} | {0, len(text)}))
75
+ # 对文本进行切割
76
+ text_parts = []
77
+ for i in range(0, len(split_points) - 1):
78
+ text_parts.append(text[split_points[i]: split_points[i + 1]])
79
+
80
+ # 对切割后的文本进行编码
81
+ bias = 0
82
+ text_ids = []
83
+ offset_mapping = []
84
+ for part in text_parts:
85
+
86
+ part_encoded = tokenizer(part, add_special_tokens=False, return_offsets_mapping=True)
87
+ part_ids, part_mapping = part_encoded["input_ids"], part_encoded["offset_mapping"]
88
+
89
+ text_ids.extend(part_ids)
90
+ for start, end in part_mapping:
91
+ offset_mapping.append((start + bias, end + bias))
92
+
93
+ bias += len(part)
94
+
95
+ if max_len > 0:
96
+ text_ids = text_ids[: max_len]
97
+
98
+ # 是否返回offsets_mapping
99
+ if return_offsets_mapping:
100
+ return text_ids, offset_mapping
101
+ return text_ids
102
+
103
+
104
+ class ItemEncoder(object):
105
+ """ Item Encoder
106
+
107
+ Args:
108
+ tokenizer (PreTrainedTokenizer): tokenizer
109
+ max_length (int): max length
110
+ """
111
+ def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int) -> None:
112
+ self.tokenizer = tokenizer
113
+ self.max_length = max_length
114
+
115
+ def search_index(self,
116
+ entity_idx: List[int],
117
+ offset_mapping: List[Tuple[int, int]],
118
+ bias: int = 0) -> Tuple[int, int]:
119
+ """ 查找实体在tokens中的索引
120
+
121
+ Args:
122
+ entity_idx (List[int]): entity index
123
+ offset_mapping (List[Tuple[int, int]]): text
124
+ bias (int): bias
125
+
126
+ Returns:
127
+ Tuple[int]: (start_idx, end_idx)
128
+ """
129
+ entity_start, entity_end = entity_idx
130
+ start_idx, end_idx = -1, -1
131
+
132
+ for token_idx, (start, end) in enumerate(offset_mapping):
133
+ if start == entity_start:
134
+ start_idx = token_idx
135
+ if end == entity_end:
136
+ end_idx = token_idx
137
+ assert start_idx >= 0 and end_idx >= 0
138
+
139
+ return start_idx + bias, end_idx + bias
140
+
141
+ @staticmethod
142
+ def get_position_ids(text_len: int,
143
+ ent_ranges: List,
144
+ rel_ranges: List) -> np.ndarray:
145
+ """ 获取position_ids
146
+
147
+ Args:
148
+ text_len (int): input length
149
+ ent_ranges (List[List[int, int]]): each entity ranges idx
150
+ rel_ranges (List[List[int, int]]): each relation ranges idx.
151
+
152
+ Returns:
153
+ np.ndarray: position_ids
154
+ """
155
+ # 一切从0开始算position,@liuhan
156
+ text_pos_ids = list(range(text_len))
157
+
158
+ ent_pos_ids, rel_pos_ids = [], []
159
+ for s, e in ent_ranges:
160
+ ent_pos_ids.extend(list(range(e - s)))
161
+ for s, e in rel_ranges:
162
+ rel_pos_ids.extend(list(range(e - s)))
163
+ position_ids = text_pos_ids + ent_pos_ids + rel_pos_ids
164
+
165
+ return position_ids
166
+
167
+ @staticmethod
168
+ def get_att_mask(input_len: int,
169
+ ent_ranges: List,
170
+ rel_ranges: List= None,
171
+ choice_ent: List[str] = None,
172
+ choice_rel: List[str] = None,
173
+ entity2rel: dict = None,
174
+ full_attent: bool = False) -> np.ndarray:
175
+ """ 获取att_mask,不同choice之间的attention_mask置零
176
+
177
+ Args:
178
+ input_len (int): input length
179
+ ent_ranges (List[List[int, int]]): each entity ranges idx
180
+ rel_ranges (List[List[int, int]]): each relation ranges idx. Defaults to None.
181
+ choice_ent (List[str], optional): choice entity. Defaults to None.
182
+ choice_rel (List[str], optional): choice relation. Defaults to None.
183
+ entity2rel (dict, optional): entity to relations. Defaults to None.
184
+ full_attent (bool, optional): is full attention or not. Defaults to None.
185
+ Returns:
186
+ np.ndarray: attention mask
187
+ """
188
+
189
+ # attention_mask.shape = (input_len, input_len)
190
+ attention_mask = np.ones((input_len, input_len))
191
+ if full_attent and not rel_ranges: # full-attention且没有关系情况下,返回全1
192
+ return attention_mask
193
+
194
+ # input_ids: [CLS] text [SEP] [unused1] ent1 [unused2] rel1 [unused3] event1
195
+ text_len = ent_ranges[0][0] # text长度
196
+ # 将text-实体之间的attention置零,text看不到实体,不受传入的entity个数、顺序影响 @liuhan
197
+ attention_mask[:text_len, text_len:] = 0
198
+
199
+ # 将实体-实体、实体关系之间的attention_mask置零
200
+ attention_mask[text_len:, text_len: ] = 0
201
+
202
+ # 将每个实体与自己的attention_mask置一
203
+ for s, e in ent_ranges:
204
+ attention_mask[s: e, s: e] = 1
205
+
206
+ # 没有关系的话,直接返回
207
+ if not rel_ranges:
208
+ return attention_mask
209
+
210
+ # 处理有关系情况
211
+
212
+ # 关系自身attention_mask置1
213
+ for s, e in rel_ranges:
214
+ attention_mask[s: e, s: e] = 1
215
+
216
+ # 将有关联的实体-关系置一
217
+ for head_tail, relations in entity2rel.items():
218
+ for entity_type in head_tail:
219
+ ent_idx = choice_ent.index(entity_type)
220
+ ent_s, _ = ent_ranges[ent_idx] # ent_s, ent_e
221
+ for relation_type in relations:
222
+ rel_idx = choice_rel.index(relation_type)
223
+ rel_s, rel_e = rel_ranges[rel_idx]
224
+ attention_mask[rel_s: rel_e, ent_s] = 1 # 关系只看实体第一个的[unused1]
225
+
226
+ if full_attent: # full-attention且有关系情况下,让文本能看见关系
227
+ for s, e in rel_ranges:
228
+ attention_mask[: text_len, s: e] = 1
229
+
230
+ return attention_mask
231
+
232
+ def encode(self,
233
+ text: str,
234
+ task_name: str,
235
+ choice: List[str],
236
+ entity_list: List[dict],
237
+ spo_list: List[dict],
238
+ full_attent: bool = False,
239
+ with_label: bool = True) -> Dict[str, torch.Tensor]:
240
+ """ encode
241
+
242
+ Args:
243
+ text (str): text
244
+ task_name (str): task name
245
+ choice (List[str]): choice
246
+ entity_list (List[dict]): entity list
247
+ spo_list (List[dict]): spo list
248
+ full_attent (bool): full attention
249
+ with_label (bool): encoded with label. Defaults to True.
250
+
251
+ Returns:
252
+ Dict[str, torch.Tensor]: encoded
253
+ """
254
+ choice_ent, choice_rel, entity2rel = choice, [], {}
255
+ if isinstance(choice, list):
256
+ if isinstance(choice[0], list): # 关系抽取 & 实体识别
257
+ choice_ent, choice_rel, _, _, entity2rel = get_choice(choice)
258
+ elif isinstance(choice, dict):
259
+ # 事件类型
260
+ raise ValueError('event extract not supported now!')
261
+ else:
262
+ raise NotImplementedError
263
+
264
+ input_ids = []
265
+ text_ids = [] # text部分id
266
+ ent_ids = [] # entity部分id
267
+ rel_ids = [] # relation部分id
268
+ entity_labels_idx = []
269
+ relation_labels_idx = []
270
+
271
+ sep_ids = self.tokenizer.encode("[SEP]", add_special_tokens=False) # [SEP]的编码
272
+ cls_ids = self.tokenizer.encode("[CLS]", add_special_tokens=False) # [CLS]的编码
273
+ entity_op_ids = self.tokenizer.encode("[unused1]", add_special_tokens=False) # [unused1]的编码
274
+ relation_op_ids = self.tokenizer.encode("[unused2]", add_special_tokens=False) # [unused2]的编码
275
+
276
+ # 任务名称的编码
277
+ task_ids = self.tokenizer.encode(task_name, add_special_tokens=False)
278
+
279
+ # 实体标签的编码
280
+ for c in choice_ent:
281
+ c_ids = self.tokenizer.encode(c, add_special_tokens=False)[: self.max_length]
282
+ ent_ids += entity_op_ids + c_ids
283
+
284
+ # 关系标签的编码
285
+ for c in choice_rel:
286
+ c_ids = self.tokenizer.encode(c, add_special_tokens=False)[: self.max_length]
287
+ rel_ids += relation_op_ids + c_ids
288
+
289
+ # text的编码
290
+ entity_indices = get_entity_indices(entity_list, spo_list)
291
+ text_max_len = self.max_length - len(task_ids) - 3
292
+ text_ids, offset_mapping = entity_based_tokenize(text, self.tokenizer, entity_indices,
293
+ max_len=text_max_len,
294
+ return_offsets_mapping=True)
295
+ text_ids = cls_ids + text_ids + sep_ids
296
+
297
+ input_ids = text_ids + task_ids + sep_ids + ent_ids + rel_ids
298
+
299
+ token_type_ids = [0] * len(text_ids) + [0] * (len(task_ids) + 1) + \
300
+ [1] * len(ent_ids) + [1] * len(rel_ids)
301
+
302
+ entity_labels_idx = [i for i, id_ in enumerate(input_ids) if id_ == entity_op_ids[0]]
303
+ relation_labels_idx = [i for i, id_ in enumerate(input_ids) if id_ == relation_op_ids[0]]
304
+
305
+ ent_ranges = [] # 每个实体的起始范围
306
+ for i in range(len(entity_labels_idx) - 1):
307
+ ent_ranges.append([entity_labels_idx[i], entity_labels_idx[i + 1]])
308
+ if not relation_labels_idx:
309
+ ent_ranges.append([entity_labels_idx[-1], len(input_ids)])
310
+ else:
311
+ ent_ranges.append([entity_labels_idx[-1], relation_labels_idx[0]])
312
+ assert len(ent_ranges) == len(choice_ent)
313
+
314
+ rel_ranges = [] # 每个关系的起始范围
315
+ for i in range(len(relation_labels_idx) - 1):
316
+ rel_ranges.append([relation_labels_idx[i], relation_labels_idx[i + 1]])
317
+ if relation_labels_idx:
318
+ rel_ranges.append([relation_labels_idx[-1], len(input_ids)])
319
+ assert len(rel_ranges) == len(choice_rel)
320
+
321
+ # 所有unused的位置
322
+ label_token_idx = entity_labels_idx + relation_labels_idx
323
+ task_num_labels = len(label_token_idx)
324
+ input_len = len(input_ids)
325
+ text_len = len(text_ids)
326
+
327
+ # 计算mask
328
+ attention_mask = self.get_att_mask(input_len,
329
+ ent_ranges,
330
+ rel_ranges,
331
+ choice_ent,
332
+ choice_rel,
333
+ entity2rel,
334
+ full_attent)
335
+ # 计算label-mask
336
+ label_mask = np.ones((text_len, text_len, task_num_labels))
337
+ for i in range(text_len):
338
+ for j in range(text_len):
339
+ if j < i:
340
+ for l in range(len(entity_labels_idx)):
341
+ # entity部分的下三角可mask
342
+ label_mask[i, j, l] = 0
343
+
344
+ # 计算position_ids
345
+ position_ids = self.get_position_ids(len(text_ids) + len(task_ids) + 1,
346
+ ent_ranges,
347
+ rel_ranges)
348
+
349
+ assert len(input_ids) == len(position_ids) == len(token_type_ids)
350
+
351
+ if not with_label:
352
+ return {
353
+ "input_ids": torch.tensor(input_ids).long(),
354
+ "attention_mask": torch.tensor(attention_mask).float(),
355
+ "position_ids": torch.tensor(position_ids).long(),
356
+ "token_type_ids": torch.tensor(token_type_ids).long(),
357
+ "label_token_idx": torch.tensor(label_token_idx).long(),
358
+ "label_mask": torch.tensor(label_mask).float(),
359
+ "text_len": torch.tensor(text_len).long(),
360
+ "ent_ranges": ent_ranges,
361
+ "rel_ranges": rel_ranges,
362
+ }
363
+
364
+ # 输入的span_labels,只保留text部分
365
+ span_labels = np.zeros((text_len, text_len, task_num_labels))
366
+
367
+ # 将实体转成span
368
+ for entity in entity_list:
369
+
370
+ entity_type = entity["entity_type"]
371
+ entity_index = entity["entity_index"]
372
+
373
+ start_idx, end_idx = self.search_index(entity_index, offset_mapping, 1)
374
+
375
+ if start_idx < text_len and end_idx < text_len:
376
+ ent_label = choice_ent.index(entity_type)
377
+ span_labels[start_idx, end_idx, ent_label] = 1
378
+
379
+ # 将三元组转成span
380
+ for spo in spo_list:
381
+
382
+ sub_idx = spo["subject"]["entity_index"]
383
+ obj_idx = spo["object"]["entity_index"]
384
+
385
+ # 获取头实体、尾实体的开始、结束index
386
+ sub_start_idx, sub_end_idx = self.search_index(sub_idx, offset_mapping, 1)
387
+ obj_start_idx, obj_end_idx = self.search_index(obj_idx, offset_mapping, 1)
388
+ # 实体label置1
389
+ if sub_start_idx < text_len and sub_end_idx < text_len:
390
+ sub_label = choice_ent.index(spo["subject"]["entity_type"])
391
+ span_labels[sub_start_idx, sub_end_idx, sub_label] = 1
392
+
393
+ if obj_start_idx < text_len and obj_end_idx < text_len:
394
+ obj_label = choice_ent.index(spo["object"]["entity_type"])
395
+ span_labels[obj_start_idx, obj_end_idx, obj_label] = 1
396
+
397
+ # 有关系的sub/obj实体的start/end在realtion对应的label置1
398
+ if spo["predicate"] in choice_rel:
399
+ pre_label = choice_rel.index(spo["predicate"]) + len(choice_ent)
400
+ if sub_start_idx < text_len and obj_start_idx < text_len:
401
+ span_labels[sub_start_idx, obj_start_idx, pre_label] = 1
402
+ if sub_end_idx < text_len and obj_end_idx < text_len:
403
+ span_labels[sub_end_idx, obj_end_idx, pre_label] = 1
404
+
405
+ return {
406
+ "input_ids": torch.tensor(input_ids).long(),
407
+ "attention_mask": torch.tensor(attention_mask).float(),
408
+ "position_ids": torch.tensor(position_ids).long(),
409
+ "token_type_ids": torch.tensor(token_type_ids).long(),
410
+ "label_token_idx": torch.tensor(label_token_idx).long(),
411
+ "span_labels": torch.tensor(span_labels).float(),
412
+ "label_mask": torch.tensor(label_mask).float(),
413
+ "text_len": torch.tensor(text_len).long(),
414
+ "ent_ranges": ent_ranges,
415
+ "rel_ranges": rel_ranges,
416
+ }
417
+
418
+ def encode_item(self, item: dict, with_label: bool = True) -> Dict[str, torch.Tensor]: # pylint: disable=unused-argument
419
+ """ encode
420
+
421
+ Args:
422
+ item (dict): item
423
+ with_label (bool): encoded with label. Defaults to True.
424
+
425
+ Returns:
426
+ Dict[str, torch.Tensor]: encoded
427
+ """
428
+ return self.encode(text=item["text"],
429
+ task_name=item["task"],
430
+ choice=item["choice"],
431
+ entity_list=item.get("entity_list", []),
432
+ spo_list=item.get("spo_list", []),
433
+ full_attent=item.get('full_attent', False),
434
+ with_label=with_label)
435
+
436
+ @staticmethod
437
+ def collate(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
438
+ """
439
+ Aggregate a batch data.
440
+ batch = [ins1_dict, ins2_dict, ..., insN_dict]
441
+ batch_data = {"sentence":[ins1_sentence, ins2_sentence...],
442
+ "input_ids":[ins1_input_ids, ins2_input_ids...], ...}
443
+ """
444
+ input_ids = nn.utils.rnn.pad_sequence(
445
+ sequences=[encoded["input_ids"] for encoded in batch],
446
+ batch_first=True,
447
+ padding_value=0)
448
+
449
+ label_token_idx = nn.utils.rnn.pad_sequence(
450
+ sequences=[encoded["label_token_idx"] for encoded in batch],
451
+ batch_first=True,
452
+ padding_value=0)
453
+
454
+ token_type_ids = nn.utils.rnn.pad_sequence(
455
+ sequences=[encoded["token_type_ids"] for encoded in batch],
456
+ batch_first=True,
457
+ padding_value=0)
458
+
459
+ position_ids = nn.utils.rnn.pad_sequence(
460
+ sequences=[encoded["position_ids"] for encoded in batch],
461
+ batch_first=True,
462
+ padding_value=0)
463
+
464
+ text_len = torch.tensor([encoded["text_len"] for encoded in batch]).long()
465
+ max_text_len = text_len.max()
466
+
467
+ batch_size, batch_max_length = input_ids.shape
468
+ _, batch_max_labels = label_token_idx.shape
469
+
470
+ attention_mask = torch.zeros((batch_size, batch_max_length, batch_max_length))
471
+ label_mask = torch.zeros((batch_size,
472
+ batch_max_length,
473
+ batch_max_length,
474
+ batch_max_labels))
475
+ for i, encoded in enumerate(batch):
476
+ input_len = encoded["attention_mask"].shape[0]
477
+ attention_mask[i, :input_len, :input_len] = encoded["attention_mask"]
478
+ _, cur_text_len, label_len = encoded['label_mask'].shape
479
+ label_mask[i, :cur_text_len, :cur_text_len, :label_len] = encoded['label_mask']
480
+ label_mask = label_mask[:, :max_text_len, :max_text_len, :]
481
+
482
+ batch_data = {
483
+ "input_ids": input_ids,
484
+ "attention_mask": attention_mask,
485
+ "position_ids": position_ids,
486
+ "token_type_ids": token_type_ids,
487
+ "label_token_idx": label_token_idx,
488
+ "label_mask": label_mask,
489
+ 'text_len': text_len
490
+ }
491
+
492
+ if "span_labels" in batch[0].keys():
493
+ span_labels = torch.zeros((batch_size,
494
+ batch_max_length,
495
+ batch_max_length,
496
+ batch_max_labels))
497
+ for i, encoded in enumerate(batch):
498
+ input_len, _, sample_num_labels = encoded["span_labels"].shape
499
+ span_labels[i, :input_len, :input_len, :sample_num_labels] = encoded["span_labels"]
500
+ batch_data["span_labels"] = span_labels[:, :max_text_len, :max_text_len, :]
501
+
502
+ return batch_data
503
+
504
+ @staticmethod
505
+ def collate_expand(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
506
+ """
507
+ Aggregate a batch data and expand to full attention
508
+ batch = [ins1_dict, ins2_dict, ..., insN_dict]
509
+ batch_data = {"sentence":[ins1_sentence, ins2_sentence...],
510
+ "input_ids":[ins1_input_ids, ins2_input_ids...], ...}
511
+ """
512
+ mask_atten_batch = ItemEncoder.collate(batch)
513
+ full_atten_batch = ItemEncoder.collate(batch)
514
+ # 对full_atten_batch进行改造
515
+ atten_mask = full_atten_batch['attention_mask']
516
+ b, _, _ = atten_mask.size()
517
+ for i in range(b):
518
+ ent_ranges, rel_ranges = batch[i]['ent_ranges'], batch[i]['rel_ranges']
519
+ text_len = ent_ranges[0][0] # text长度
520
+
521
+ if not rel_ranges:
522
+ assert len(ent_ranges) == 1, 'ent_ranges:%s' % ent_ranges
523
+ s, e = ent_ranges[0]
524
+ atten_mask[i, : text_len, s: e] = 1
525
+ else:
526
+ assert len(rel_ranges) == 1 and len(ent_ranges) <= 2, \
527
+ 'ent_ranges:%s, rel_ranges:%s' % (ent_ranges, rel_ranges)
528
+ s, e = rel_ranges[0]
529
+ atten_mask[i, : text_len, s: e] = 1
530
+ full_atten_batch['attention_mask'] = atten_mask
531
+ collate_batch = {}
532
+ for key, value in mask_atten_batch.items():
533
+ collate_batch[key] = torch.cat((value, full_atten_batch[key]), 0)
534
+ return collate_batch
models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import BagualuIEModel
2
+ from .extract_model import BagualuIEExtractModel
models/extract_model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List
17
+ import copy
18
+
19
+ from transformers import PreTrainedTokenizer
20
+ import argparse
21
+ from dataloaders.item_encoder import ItemEncoder
22
+ from dataloaders.item_decoder import ItemDecoder
23
+ from .model import BagualuIEModel
24
+
25
+
26
+ class BagualuIEExtractModel(object):
27
+ """ BagualuIEExtractModel
28
+
29
+ Args:
30
+ tokenizer (PreTrainedTokenizer): tokenizer
31
+ args (TrainingArgumentsIEStd): arguments
32
+ """
33
+ def __init__(self,
34
+ tokenizer: PreTrainedTokenizer,
35
+ args: argparse.Namespace) -> None:
36
+ self.encoder = ItemEncoder(tokenizer, args.max_length)
37
+ self.decoder = ItemDecoder(tokenizer, args)
38
+
39
+ def extract(self, batch_data: List[dict], model: BagualuIEModel, use_cuda: bool) -> List[dict]:
40
+ """ extract
41
+
42
+ Args:
43
+ batch_data (List[dict]): batch of data
44
+ model (BagualuIEModel): model
45
+
46
+ Returns:
47
+ List[dict]: batch of data
48
+ """
49
+ if use_cuda:
50
+ model = model.cuda()
51
+ model.eval()
52
+
53
+ batch_data = copy.deepcopy(batch_data)
54
+ batch = [self.encoder.encode_item(item, with_label=False) for item in batch_data]
55
+ batch = self.encoder.collate(batch)
56
+ if use_cuda:
57
+ batch = {k: v.cuda() for k, v in batch.items()}
58
+
59
+ span_logits = model(**batch).cpu().detach().numpy()
60
+ label_mask = batch["label_mask"].cpu().detach().numpy()
61
+
62
+ for i, item in enumerate(batch_data):
63
+
64
+ entity_list, spo_list = self.decoder.decode(item,
65
+ span_logits[i],
66
+ label_mask[i])
67
+
68
+ item["spo_list"] = spo_list
69
+ item["entity_list"] = entity_list
70
+
71
+ return batch_data
models/model.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: disable=no-member
17
+
18
+ import torch
19
+ from torch import nn, Tensor
20
+ from transformers import BertPreTrainedModel, BertModel, BertConfig
21
+
22
+
23
+ class Triaffine(nn.Module):
24
+ """ Triaffine module
25
+
26
+ Args:
27
+ triaffine_hidden_size (int): Triaffine module hidden size
28
+ """
29
+ def __init__(self, triaffine_hidden_size: int) -> None:
30
+ super().__init__()
31
+
32
+ self.triaffine_hidden_size = triaffine_hidden_size
33
+
34
+ self.weight_start_end = nn.Parameter(
35
+ torch.zeros(triaffine_hidden_size,
36
+ triaffine_hidden_size,
37
+ triaffine_hidden_size))
38
+
39
+ nn.init.normal_(self.weight_start_end, mean=0, std=0.1)
40
+
41
+ def forward(self,
42
+ start_logits: Tensor,
43
+ end_logits: Tensor,
44
+ cls_logits: Tensor) -> Tensor:
45
+ """forward
46
+
47
+ Args:
48
+ start_logits (Tensor): start logits
49
+ end_logits (Tensor): end logits
50
+ cls_logits (Tensor): cls logits
51
+
52
+ Returns:
53
+ Tensor: span_logits
54
+ """
55
+ start_end_logits = torch.einsum("bxi,ioj,byj->bxyo",
56
+ start_logits,
57
+ self.weight_start_end,
58
+ end_logits)
59
+
60
+ span_logits = torch.einsum("bxyo,bzo->bxyz",
61
+ start_end_logits,
62
+ cls_logits)
63
+
64
+ return span_logits
65
+
66
+
67
+ class MLPLayer(nn.Module):
68
+ """MLP layer
69
+
70
+ Args:
71
+ input_size (int): input size
72
+ output_size (int): output size
73
+ """
74
+ def __init__(self, input_size: int, output_size: int) -> None:
75
+ super().__init__()
76
+ self.linear = nn.Linear(in_features=input_size, out_features=output_size)
77
+ self.act = nn.GELU()
78
+
79
+ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
80
+ """ forward
81
+
82
+ Args:
83
+ x (Tensor): input
84
+
85
+ Returns:
86
+ Tensor: output
87
+ """
88
+ x = self.linear(x)
89
+ x = self.act(x)
90
+ return x
91
+
92
+
93
+ class BagualuIEModel(BertPreTrainedModel):
94
+ """ BagualuIEModel
95
+
96
+ Args:
97
+ config (BertConfig): config
98
+ """
99
+ def __init__(self, config: BertConfig) -> None:
100
+ super().__init__(config)
101
+ self.bert = BertModel(config)
102
+ self.config = config
103
+
104
+ self.triaffine_hidden_size = 128
105
+
106
+ self.mlp_start = MLPLayer(self.config.hidden_size,
107
+ self.triaffine_hidden_size)
108
+ self.mlp_end = MLPLayer(self.config.hidden_size,
109
+ self.triaffine_hidden_size)
110
+ self.mlp_cls = MLPLayer(self.config.hidden_size,
111
+ self.triaffine_hidden_size)
112
+
113
+ self.triaffine = Triaffine(self.triaffine_hidden_size)
114
+
115
+ def forward(self, # pylint: disable=unused-argument
116
+ input_ids: Tensor,
117
+ attention_mask: Tensor,
118
+ position_ids: Tensor,
119
+ token_type_ids: Tensor,
120
+ text_len: Tensor,
121
+ label_token_idx: Tensor,
122
+ **kwargs) -> Tensor:
123
+ """ forward
124
+
125
+ Args:
126
+ input_ids (Tensor): input_ids
127
+ attention_mask (Tensor): attention_mask
128
+ position_ids (Tensor): position_ids
129
+ token_type_ids (Tensor): token_type_ids
130
+ text_len (Tensor): query length
131
+ label_token_idx (Tensor, optional): label_token_idx
132
+
133
+ Returns:
134
+ Tensor: span logits
135
+ """
136
+
137
+ # bert forward
138
+ hidden_states = self.bert(input_ids=input_ids,
139
+ attention_mask=attention_mask,
140
+ position_ids=position_ids,
141
+ token_type_ids=token_type_ids,
142
+ output_hidden_states=True)[0] # (bsz, seq, dim)
143
+
144
+ max_text_len = text_len.max()
145
+
146
+ # 获取start、end、cls的hidden_states
147
+ hidden_start_end = hidden_states[:, :max_text_len, :] # text部分表示
148
+ hidden_cls = hidden_states.gather(1, label_token_idx.unsqueeze(-1)\
149
+ .repeat(1, 1, self.config.hidden_size)) # (bsz, task, dim)
150
+
151
+ # Triaffine
152
+ span_logits = self.triaffine(self.mlp_start(hidden_start_end),
153
+ self.mlp_end(hidden_start_end),
154
+ self.mlp_cls(hidden_cls)).sigmoid()
155
+
156
+ return span_logits