DeepLearning101 commited on
Commit
d131d1a
1 Parent(s): 1f544f7

Upload 2 files

Browse files
models/span_extraction/global_pointer.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/21 5:30 下午
3
+ # @Author : JianingWang
4
+ # @File : global_pointer.py
5
+ from typing import Optional
6
+ import torch
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ from dataclasses import dataclass
10
+ from torch.nn import BCEWithLogitsLoss
11
+ from transformers import MegatronBertModel, MegatronBertPreTrainedModel
12
+ from transformers.file_utils import ModelOutput
13
+ from transformers.models.bert import BertPreTrainedModel, BertModel
14
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
15
+ from roformer import RoFormerPreTrainedModel, RoFormerModel, RoFormerModel
16
+
17
+
18
+ class RawGlobalPointer(nn.Module):
19
+ def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
20
+ # encodr: RoBerta-Large as encoder
21
+ # inner_dim: 64
22
+ # ent_type_size: ent_cls_num
23
+ super().__init__()
24
+ self.encoder = encoder
25
+ self.ent_type_size = ent_type_size
26
+ self.inner_dim = inner_dim
27
+ self.hidden_size = encoder.config.hidden_size
28
+ self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
29
+
30
+ self.RoPE = RoPE
31
+
32
+ def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
33
+ position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
34
+
35
+ indices = torch.arange(0, output_dim // 2, dtype=torch.float)
36
+ indices = torch.pow(10000, -2 * indices / output_dim)
37
+ embeddings = position_ids * indices
38
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
39
+ embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
40
+ embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
41
+ embeddings = embeddings.to(self.device)
42
+ return embeddings
43
+
44
+ def forward(self, input_ids, attention_mask, token_type_ids):
45
+ self.device = input_ids.device
46
+
47
+ context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
48
+ # last_hidden_state:(batch_size, seq_len, hidden_size)
49
+ last_hidden_state = context_outputs[0]
50
+
51
+ batch_size = last_hidden_state.size()[0]
52
+ seq_len = last_hidden_state.size()[1]
53
+
54
+ outputs = self.dense(last_hidden_state)
55
+ outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
56
+ outputs = torch.stack(outputs, dim=-2)
57
+ qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
58
+ if self.RoPE:
59
+ # pos_emb:(batch_size, seq_len, inner_dim)
60
+ pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
61
+ cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
62
+ sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
63
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
64
+ qw2 = qw2.reshape(qw.shape)
65
+ qw = qw * cos_pos + qw2 * sin_pos
66
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
67
+ kw2 = kw2.reshape(kw.shape)
68
+ kw = kw * cos_pos + kw2 * sin_pos
69
+ # logits:(batch_size, ent_type_size, seq_len, seq_len)
70
+ logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw)
71
+
72
+ # padding mask
73
+ pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
74
+ logits = logits * pad_mask - (1 - pad_mask) * 1e12
75
+
76
+ # 排除下三角
77
+ mask = torch.tril(torch.ones_like(logits), -1)
78
+ logits = logits - mask * 1e12
79
+
80
+ return logits / self.inner_dim ** 0.5
81
+
82
+
83
+ class SinusoidalPositionEmbedding(nn.Module):
84
+ """定义Sin-Cos位置Embedding
85
+ """
86
+
87
+ def __init__(
88
+ self, output_dim, merge_mode="add", custom_position_ids=False):
89
+ super(SinusoidalPositionEmbedding, self).__init__()
90
+ self.output_dim = output_dim
91
+ self.merge_mode = merge_mode
92
+ self.custom_position_ids = custom_position_ids
93
+
94
+ def forward(self, inputs):
95
+ if self.custom_position_ids:
96
+ seq_len = inputs.shape[1]
97
+ inputs, position_ids = inputs
98
+ position_ids = position_ids.type(torch.float)
99
+ else:
100
+ input_shape = inputs.shape
101
+ batch_size, seq_len = input_shape[0], input_shape[1]
102
+ position_ids = torch.arange(seq_len).type(torch.float)[None]
103
+ indices = torch.arange(self.output_dim // 2).type(torch.float)
104
+ indices = torch.pow(10000.0, -2 * indices / self.output_dim)
105
+ embeddings = torch.einsum("bn,d->bnd", position_ids, indices)
106
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
107
+ embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
108
+ if self.merge_mode == "add":
109
+ return inputs + embeddings.to(inputs.device)
110
+ elif self.merge_mode == "mul":
111
+ return inputs * (embeddings + 1.0).to(inputs.device)
112
+ elif self.merge_mode == "zero":
113
+ return embeddings.to(inputs.device)
114
+
115
+
116
+ def multilabel_categorical_crossentropy(y_pred, y_true):
117
+ y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
118
+ y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes
119
+ y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes
120
+ zeros = torch.zeros_like(y_pred[..., :1])
121
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
122
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
123
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
124
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
125
+ # print(y_pred, y_true, pos_loss)
126
+ return (neg_loss + pos_loss).mean()
127
+
128
+
129
+ def multilabel_categorical_crossentropy2(y_pred, y_true):
130
+ y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
131
+ y_pred_neg = y_pred.clone()
132
+ y_pred_pos = y_pred.clone()
133
+ y_pred_neg[y_true>0] -= float("inf")
134
+ y_pred_pos[y_true<1] -= float("inf")
135
+ # y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes
136
+ # y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes
137
+ zeros = torch.zeros_like(y_pred[..., :1])
138
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
139
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
140
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
141
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
142
+ # print(y_pred, y_true, pos_loss)
143
+ return (neg_loss + pos_loss).mean()
144
+
145
+ @dataclass
146
+ class GlobalPointerOutput(ModelOutput):
147
+ loss: Optional[torch.FloatTensor] = None
148
+ topk_probs: torch.FloatTensor = None
149
+ topk_indices: torch.IntTensor = None
150
+
151
+
152
+
153
+ class BertForEffiGlobalPointer(BertPreTrainedModel):
154
+ def __init__(self, config):
155
+ # encodr: RoBerta-Large as encoder
156
+ # inner_dim: 64
157
+ # ent_type_size: ent_cls_num
158
+ super().__init__(config)
159
+ self.bert = BertModel(config)
160
+ self.ent_type_size = config.ent_type_size
161
+ self.inner_dim = config.inner_dim
162
+ self.hidden_size = config.hidden_size
163
+ self.RoPE = config.RoPE
164
+
165
+ self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
166
+ self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
167
+
168
+ def sequence_masking(self, x, mask, value="-inf", axis=None):
169
+ if mask is None:
170
+ return x
171
+ else:
172
+ if value == "-inf":
173
+ value = -1e12
174
+ elif value == "inf":
175
+ value = 1e12
176
+ assert axis > 0, "axis must be greater than 0"
177
+ for _ in range(axis - 1):
178
+ mask = torch.unsqueeze(mask, 1)
179
+ for _ in range(x.ndim - mask.ndim):
180
+ mask = torch.unsqueeze(mask, mask.ndim)
181
+ return x * mask + value * (1 - mask)
182
+
183
+ def add_mask_tril(self, logits, mask):
184
+ if mask.dtype != logits.dtype:
185
+ mask = mask.type(logits.dtype)
186
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
187
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
188
+ # 排除下三角
189
+ mask = torch.tril(torch.ones_like(logits), diagonal=-1)
190
+ logits = logits - mask * 1e12
191
+ return logits
192
+
193
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
194
+ # with torch.no_grad():
195
+ context_outputs = self.bert(input_ids, attention_mask, token_type_ids)
196
+ last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
197
+ outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
198
+ qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总
199
+ batch_size = input_ids.shape[0]
200
+ if self.RoPE:
201
+ pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
202
+ cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
203
+ sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
204
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
205
+ qw2 = torch.reshape(qw2, qw.shape)
206
+ qw = qw * cos_pos + qw2 * sin_pos
207
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
208
+ kw2 = torch.reshape(kw2, kw.shape)
209
+ kw = kw * cos_pos + kw2 * sin_pos
210
+ logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
211
+ bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
212
+ logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
213
+ # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
214
+ loss = None
215
+
216
+ mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
217
+ # mask = torch.where(mask > 0, 0.0, 1)
218
+ if labels is not None:
219
+ y_pred = logits - (1-mask.unsqueeze(1))*1e12
220
+ y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
221
+ y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
222
+ loss = multilabel_categorical_crossentropy(y_pred, y_true)
223
+
224
+ with torch.no_grad():
225
+ prob = torch.sigmoid(logits) * mask.unsqueeze(1)
226
+ topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
227
+
228
+
229
+ return GlobalPointerOutput(
230
+ loss=loss,
231
+ topk_probs=topk.values,
232
+ topk_indices=topk.indices
233
+ )
234
+
235
+
236
+
237
+ class RobertaForEffiGlobalPointer(RobertaPreTrainedModel):
238
+ def __init__(self, config):
239
+ # encodr: RoBerta-Large as encoder
240
+ # inner_dim: 64
241
+ # ent_type_size: ent_cls_num
242
+ super().__init__(config)
243
+ self.roberta = RobertaModel(config)
244
+ self.ent_type_size = config.ent_type_size
245
+ self.inner_dim = config.inner_dim
246
+ self.hidden_size = config.hidden_size
247
+ self.RoPE = config.RoPE
248
+
249
+ self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
250
+ self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
251
+
252
+ def sequence_masking(self, x, mask, value="-inf", axis=None):
253
+ if mask is None:
254
+ return x
255
+ else:
256
+ if value == "-inf":
257
+ value = -1e12
258
+ elif value == "inf":
259
+ value = 1e12
260
+ assert axis > 0, "axis must be greater than 0"
261
+ for _ in range(axis - 1):
262
+ mask = torch.unsqueeze(mask, 1)
263
+ for _ in range(x.ndim - mask.ndim):
264
+ mask = torch.unsqueeze(mask, mask.ndim)
265
+ return x * mask + value * (1 - mask)
266
+
267
+ def add_mask_tril(self, logits, mask):
268
+ if mask.dtype != logits.dtype:
269
+ mask = mask.type(logits.dtype)
270
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
271
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
272
+ # 排除下三角
273
+ mask = torch.tril(torch.ones_like(logits), diagonal=-1)
274
+ logits = logits - mask * 1e12
275
+ return logits
276
+
277
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
278
+ # with torch.no_grad():
279
+ context_outputs = self.roberta(input_ids, attention_mask, token_type_ids)
280
+ last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
281
+ outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
282
+ qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总
283
+ batch_size = input_ids.shape[0]
284
+ if self.RoPE:
285
+ pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
286
+ cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
287
+ sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
288
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
289
+ qw2 = torch.reshape(qw2, qw.shape)
290
+ qw = qw * cos_pos + qw2 * sin_pos
291
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
292
+ kw2 = torch.reshape(kw2, kw.shape)
293
+ kw = kw * cos_pos + kw2 * sin_pos
294
+ logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
295
+ bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
296
+ logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
297
+ # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
298
+ loss = None
299
+
300
+ mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
301
+ # mask = torch.where(mask > 0, 0.0, 1)
302
+ if labels is not None:
303
+ y_pred = logits - (1-mask.unsqueeze(1))*1e12
304
+ y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
305
+ y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
306
+ loss = multilabel_categorical_crossentropy(y_pred, y_true)
307
+
308
+ with torch.no_grad():
309
+ prob = torch.sigmoid(logits) * mask.unsqueeze(1)
310
+ topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
311
+
312
+
313
+ return GlobalPointerOutput(
314
+ loss=loss,
315
+ topk_probs=topk.values,
316
+ topk_indices=topk.indices
317
+ )
318
+
319
+
320
+ class RoformerForEffiGlobalPointer(RoFormerPreTrainedModel):
321
+ def __init__(self, config):
322
+ # encodr: RoBerta-Large as encoder
323
+ # inner_dim: 64
324
+ # ent_type_size: ent_cls_num
325
+ super().__init__(config)
326
+ self.roformer = RoFormerModel(config)
327
+ self.ent_type_size = config.ent_type_size
328
+ self.inner_dim = config.inner_dim
329
+ self.hidden_size = config.hidden_size
330
+ self.RoPE = config.RoPE
331
+
332
+ self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
333
+ self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
334
+
335
+ def sequence_masking(self, x, mask, value="-inf", axis=None):
336
+ if mask is None:
337
+ return x
338
+ else:
339
+ if value == "-inf":
340
+ value = -1e12
341
+ elif value == "inf":
342
+ value = 1e12
343
+ assert axis > 0, "axis must be greater than 0"
344
+ for _ in range(axis - 1):
345
+ mask = torch.unsqueeze(mask, 1)
346
+ for _ in range(x.ndim - mask.ndim):
347
+ mask = torch.unsqueeze(mask, mask.ndim)
348
+ return x * mask + value * (1 - mask)
349
+
350
+ def add_mask_tril(self, logits, mask):
351
+ if mask.dtype != logits.dtype:
352
+ mask = mask.type(logits.dtype)
353
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
354
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
355
+ # 排除下三角
356
+ mask = torch.tril(torch.ones_like(logits), diagonal=-1)
357
+ logits = logits - mask * 1e12
358
+ return logits
359
+
360
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
361
+ # with torch.no_grad():
362
+ context_outputs = self.roformer(input_ids, attention_mask, token_type_ids)
363
+ last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
364
+ outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
365
+ qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总
366
+ batch_size = input_ids.shape[0]
367
+ if self.RoPE:
368
+ pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
369
+ cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
370
+ sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
371
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
372
+ qw2 = torch.reshape(qw2, qw.shape)
373
+ qw = qw * cos_pos + qw2 * sin_pos
374
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
375
+ kw2 = torch.reshape(kw2, kw.shape)
376
+ kw = kw * cos_pos + kw2 * sin_pos
377
+ logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
378
+ bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
379
+ logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
380
+ # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
381
+ loss = None
382
+
383
+ mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
384
+ # mask = torch.where(mask > 0, 0.0, 1)
385
+ if labels is not None:
386
+ y_pred = logits - (1-mask.unsqueeze(1))*1e12
387
+ y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
388
+ y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
389
+ loss = multilabel_categorical_crossentropy(y_pred, y_true)
390
+
391
+ with torch.no_grad():
392
+ prob = torch.sigmoid(logits) * mask.unsqueeze(1)
393
+ topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
394
+
395
+
396
+ return GlobalPointerOutput(
397
+ loss=loss,
398
+ topk_probs=topk.values,
399
+ topk_indices=topk.indices
400
+ )
401
+
402
+ class MegatronForEffiGlobalPointer(MegatronBertPreTrainedModel):
403
+ def __init__(self, config):
404
+ # encodr: RoBerta-Large as encoder
405
+ # inner_dim: 64
406
+ # ent_type_size: ent_cls_num
407
+ super().__init__(config)
408
+ self.bert = MegatronBertModel(config)
409
+ self.ent_type_size = config.ent_type_size
410
+ self.inner_dim = config.inner_dim
411
+ self.hidden_size = config.hidden_size
412
+ self.RoPE = config.RoPE
413
+
414
+ self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
415
+ self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
416
+
417
+ def sequence_masking(self, x, mask, value="-inf", axis=None):
418
+ if mask is None:
419
+ return x
420
+ else:
421
+ if value == "-inf":
422
+ value = -1e12
423
+ elif value == "inf":
424
+ value = 1e12
425
+ assert axis > 0, "axis must be greater than 0"
426
+ for _ in range(axis - 1):
427
+ mask = torch.unsqueeze(mask, 1)
428
+ for _ in range(x.ndim - mask.ndim):
429
+ mask = torch.unsqueeze(mask, mask.ndim)
430
+ return x * mask + value * (1 - mask)
431
+
432
+ def add_mask_tril(self, logits, mask):
433
+ if mask.dtype != logits.dtype:
434
+ mask = mask.type(logits.dtype)
435
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
436
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
437
+ # 排除下三角
438
+ mask = torch.tril(torch.ones_like(logits), diagonal=-1)
439
+ logits = logits - mask * 1e12
440
+ return logits
441
+
442
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
443
+ # with torch.no_grad():
444
+ context_outputs = self.bert(input_ids, attention_mask, token_type_ids)
445
+ last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
446
+ outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
447
+ qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总
448
+ batch_size = input_ids.shape[0]
449
+ if self.RoPE:
450
+ pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
451
+ cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
452
+ sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
453
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
454
+ qw2 = torch.reshape(qw2, qw.shape)
455
+ qw = qw * cos_pos + qw2 * sin_pos
456
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
457
+ kw2 = torch.reshape(kw2, kw.shape)
458
+ kw = kw * cos_pos + kw2 * sin_pos
459
+ logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
460
+ bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
461
+ logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
462
+ # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
463
+ loss = None
464
+
465
+ mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
466
+ # mask = torch.where(mask > 0, 0.0, 1)
467
+ if labels is not None:
468
+ y_pred = logits - (1-mask.unsqueeze(1))*1e12
469
+ y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
470
+ y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
471
+ loss = multilabel_categorical_crossentropy(y_pred, y_true)
472
+
473
+ with torch.no_grad():
474
+ prob = torch.sigmoid(logits) * mask.unsqueeze(1)
475
+ topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
476
+
477
+
478
+ return GlobalPointerOutput(
479
+ loss=loss,
480
+ topk_probs=topk.values,
481
+ topk_indices=topk.indices
482
+ )
models/span_extraction/span_for_ner.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
5
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
6
+ from transformers.models.albert.modeling_albert import AlbertPreTrainedModel, AlbertModel
7
+ from transformers.models.megatron_bert.modeling_megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
8
+ from models.basic_modules.linears import PoolerEndLogits, PoolerStartLogits
9
+ from torch.nn import CrossEntropyLoss
10
+ from loss.focal_loss import FocalLoss
11
+ from loss.label_smoothing import LabelSmoothingCrossEntropy
12
+
13
+ class BertSpanForNer(BertPreTrainedModel):
14
+ def __init__(self, config,):
15
+ super(BertSpanForNer, self).__init__(config)
16
+ self.soft_label = config.soft_label
17
+ self.num_labels = config.num_labels
18
+ self.loss_type = config.loss_type
19
+ self.bert = BertModel(config)
20
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
21
+ self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
22
+ if self.soft_label:
23
+ self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
24
+ else:
25
+ self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
26
+ self.init_weights()
27
+
28
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
29
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
30
+ sequence_output = outputs[0]
31
+ sequence_output = self.dropout(sequence_output)
32
+ start_logits = self.start_fc(sequence_output)
33
+ if start_positions is not None and self.training:
34
+ if self.soft_label:
35
+ batch_size = input_ids.size(0)
36
+ seq_len = input_ids.size(1)
37
+ label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
38
+ label_logits.zero_()
39
+ label_logits = label_logits.to(input_ids.device)
40
+ label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
41
+ else:
42
+ label_logits = start_positions.unsqueeze(2).float()
43
+ else:
44
+ label_logits = F.softmax(start_logits, -1)
45
+ if not self.soft_label:
46
+ label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
47
+ end_logits = self.end_fc(sequence_output, label_logits)
48
+ outputs = (start_logits, end_logits,) + outputs[2:]
49
+
50
+ if start_positions is not None and end_positions is not None:
51
+ assert self.loss_type in ["lsr", "focal", "ce"]
52
+ if self.loss_type =="lsr":
53
+ loss_fct = LabelSmoothingCrossEntropy()
54
+ elif self.loss_type == "focal":
55
+ loss_fct = FocalLoss()
56
+ else:
57
+ loss_fct = CrossEntropyLoss()
58
+ start_logits = start_logits.view(-1, self.num_labels)
59
+ end_logits = end_logits.view(-1, self.num_labels)
60
+ active_loss = attention_mask.view(-1) == 1
61
+ active_start_logits = start_logits[active_loss]
62
+ active_end_logits = end_logits[active_loss]
63
+
64
+ active_start_labels = start_positions.view(-1)[active_loss]
65
+ active_end_labels = end_positions.view(-1)[active_loss]
66
+
67
+ start_loss = loss_fct(active_start_logits, active_start_labels)
68
+ end_loss = loss_fct(active_end_logits, active_end_labels)
69
+ total_loss = (start_loss + end_loss) / 2
70
+ outputs = (total_loss,) + outputs
71
+ return outputs
72
+
73
+ class RobertaSpanForNer(RobertaPreTrainedModel):
74
+ def __init__(self, config,):
75
+ super(RobertaSpanForNer, self).__init__(config)
76
+ self.soft_label = config.soft_label
77
+ self.num_labels = config.num_labels
78
+ self.loss_type = config.loss_type
79
+ self.roberta = RobertaModel(config)
80
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
81
+ self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
82
+ if self.soft_label:
83
+ self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
84
+ else:
85
+ self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
86
+ self.init_weights()
87
+
88
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
89
+ outputs = self.roberta(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
90
+ sequence_output = outputs[0]
91
+ sequence_output = self.dropout(sequence_output)
92
+ start_logits = self.start_fc(sequence_output)
93
+ if start_positions is not None and self.training:
94
+ if self.soft_label:
95
+ batch_size = input_ids.size(0)
96
+ seq_len = input_ids.size(1)
97
+ label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
98
+ label_logits.zero_()
99
+ label_logits = label_logits.to(input_ids.device)
100
+ label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
101
+ else:
102
+ label_logits = start_positions.unsqueeze(2).float()
103
+ else:
104
+ label_logits = F.softmax(start_logits, -1)
105
+ if not self.soft_label:
106
+ label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
107
+ end_logits = self.end_fc(sequence_output, label_logits)
108
+ outputs = (start_logits, end_logits,) + outputs[2:]
109
+
110
+ if start_positions is not None and end_positions is not None:
111
+ assert self.loss_type in ["lsr", "focal", "ce"]
112
+ if self.loss_type =="lsr":
113
+ loss_fct = LabelSmoothingCrossEntropy()
114
+ elif self.loss_type == "focal":
115
+ loss_fct = FocalLoss()
116
+ else:
117
+ loss_fct = CrossEntropyLoss()
118
+ start_logits = start_logits.view(-1, self.num_labels)
119
+ end_logits = end_logits.view(-1, self.num_labels)
120
+ active_loss = attention_mask.view(-1) == 1
121
+ active_start_logits = start_logits[active_loss]
122
+ active_end_logits = end_logits[active_loss]
123
+
124
+ active_start_labels = start_positions.view(-1)[active_loss]
125
+ active_end_labels = end_positions.view(-1)[active_loss]
126
+
127
+ start_loss = loss_fct(active_start_logits, active_start_labels)
128
+ end_loss = loss_fct(active_end_logits, active_end_labels)
129
+ total_loss = (start_loss + end_loss) / 2
130
+ outputs = (total_loss,) + outputs
131
+ return outputs
132
+
133
+ class AlbertSpanForNer(AlbertPreTrainedModel):
134
+ def __init__(self, config,):
135
+ super(AlbertSpanForNer, self).__init__(config)
136
+ self.soft_label = config.soft_label
137
+ self.num_labels = config.num_labels
138
+ self.loss_type = config.loss_type
139
+ self.bert = AlbertModel(config)
140
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
141
+ self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
142
+ if self.soft_label:
143
+ self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
144
+ else:
145
+ self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
146
+ self.init_weights()
147
+
148
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
149
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
150
+ sequence_output = outputs[0]
151
+ sequence_output = self.dropout(sequence_output)
152
+ start_logits = self.start_fc(sequence_output)
153
+ if start_positions is not None and self.training:
154
+ if self.soft_label:
155
+ batch_size = input_ids.size(0)
156
+ seq_len = input_ids.size(1)
157
+ label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
158
+ label_logits.zero_()
159
+ label_logits = label_logits.to(input_ids.device)
160
+ label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
161
+ else:
162
+ label_logits = start_positions.unsqueeze(2).float()
163
+ else:
164
+ label_logits = F.softmax(start_logits, -1)
165
+ if not self.soft_label:
166
+ label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
167
+ end_logits = self.end_fc(sequence_output, label_logits)
168
+ outputs = (start_logits, end_logits,) + outputs[2:]
169
+
170
+ if start_positions is not None and end_positions is not None:
171
+ assert self.loss_type in ["lsr","focal","ce"]
172
+ if self.loss_type =="lsr":
173
+ loss_fct = LabelSmoothingCrossEntropy()
174
+ elif self.loss_type == "focal":
175
+ loss_fct = FocalLoss()
176
+ else:
177
+ loss_fct = CrossEntropyLoss()
178
+ start_logits = start_logits.view(-1, self.num_labels)
179
+ end_logits = end_logits.view(-1, self.num_labels)
180
+ active_loss = attention_mask.view(-1) == 1
181
+ active_start_logits = start_logits[active_loss]
182
+ active_start_labels = start_positions.view(-1)[active_loss]
183
+ active_end_logits = end_logits[active_loss]
184
+ active_end_labels = end_positions.view(-1)[active_loss]
185
+
186
+ start_loss = loss_fct(active_start_logits, active_start_labels)
187
+ end_loss = loss_fct(active_end_logits, active_end_labels)
188
+ total_loss = (start_loss + end_loss) / 2
189
+ outputs = (total_loss,) + outputs
190
+ return outputs
191
+
192
+ class MegatronBertSpanForNer(MegatronBertPreTrainedModel):
193
+ def __init__(self, config,):
194
+ super(BertSpanForNer, self).__init__(config)
195
+ # self.soft_label = config.soft_label
196
+ self.soft_label = True
197
+ self.num_labels = config.num_labels
198
+ # self.loss_type = config.loss_type
199
+ self.loss_type = "ce"
200
+ self.bert = MegatronBertModel(config)
201
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
202
+ self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
203
+ if self.soft_label:
204
+ self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
205
+ else:
206
+ self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
207
+ self.init_weights()
208
+
209
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
210
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
211
+ sequence_output = outputs[0]
212
+ sequence_output = self.dropout(sequence_output)
213
+ start_logits = self.start_fc(sequence_output)
214
+ if start_positions is not None and self.training:
215
+ if self.soft_label:
216
+ batch_size = input_ids.size(0)
217
+ seq_len = input_ids.size(1)
218
+ label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
219
+ label_logits.zero_()
220
+ label_logits = label_logits.to(input_ids.device)
221
+ label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
222
+ else:
223
+ label_logits = start_positions.unsqueeze(2).float()
224
+ else:
225
+ label_logits = F.softmax(start_logits, -1)
226
+ if not self.soft_label:
227
+ label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
228
+ end_logits = self.end_fc(sequence_output, label_logits)
229
+ outputs = (start_logits, end_logits,) + outputs[2:]
230
+
231
+ if start_positions is not None and end_positions is not None:
232
+ assert self.loss_type in ["lsr", "focal", "ce"]
233
+ if self.loss_type =="lsr":
234
+ loss_fct = LabelSmoothingCrossEntropy()
235
+ elif self.loss_type == "focal":
236
+ loss_fct = FocalLoss()
237
+ else:
238
+ loss_fct = CrossEntropyLoss()
239
+ start_logits = start_logits.view(-1, self.num_labels)
240
+ end_logits = end_logits.view(-1, self.num_labels)
241
+ active_loss = attention_mask.view(-1) == 1
242
+ active_start_logits = start_logits[active_loss]
243
+ active_end_logits = end_logits[active_loss]
244
+
245
+ active_start_labels = start_positions.view(-1)[active_loss]
246
+ active_end_labels = end_positions.view(-1)[active_loss]
247
+
248
+ start_loss = loss_fct(active_start_logits, active_start_labels)
249
+ end_loss = loss_fct(active_end_logits, active_end_labels)
250
+ total_loss = (start_loss + end_loss) / 2
251
+ outputs = (total_loss,) + outputs
252
+ return outputs