DeepLearning101 commited on
Commit
6c0ee22
1 Parent(s): fdc4786

Upload 2 files

Browse files
models/fewshot_learning/span_proto.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/21 5:30 下午
3
+ # @Author : JianingWang
4
+ # @File : span_proto.py
5
+
6
+ """
7
+ This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
8
+ """
9
+
10
+ import os
11
+ from typing import Optional
12
+ import torch
13
+ import numpy as np
14
+ import torch.nn as nn
15
+ from typing import Union
16
+ from dataclasses import dataclass
17
+ from torch.nn import BCEWithLogitsLoss
18
+ from transformers import MegatronBertModel, MegatronBertPreTrainedModel
19
+ from transformers.file_utils import ModelOutput
20
+ from transformers.models.bert import BertPreTrainedModel, BertModel
21
+
22
+ a = torch.nn.Embedding(10, 20)
23
+ a.parameters
24
+
25
+ class RawGlobalPointer(nn.Module):
26
+ def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
27
+ # encodr: RoBerta-Large as encoder
28
+ # inner_dim: 64
29
+ # ent_type_size: ent_cls_num
30
+ super().__init__()
31
+ self.encoder = encoder
32
+ self.ent_type_size = ent_type_size
33
+ self.inner_dim = inner_dim
34
+ self.hidden_size = encoder.config.hidden_size
35
+ self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
36
+
37
+ self.RoPE = RoPE
38
+
39
+ def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
40
+ position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
41
+
42
+ indices = torch.arange(0, output_dim // 2, dtype=torch.float)
43
+ indices = torch.pow(10000, -2 * indices / output_dim)
44
+ embeddings = position_ids * indices
45
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
46
+ embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
47
+ embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
48
+ embeddings = embeddings.to(self.device)
49
+ return embeddings
50
+
51
+ def forward(self, input_ids, attention_mask, token_type_ids):
52
+ self.device = input_ids.device
53
+
54
+ context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
55
+ # last_hidden_state:(batch_size, seq_len, hidden_size)
56
+ last_hidden_state = context_outputs[0]
57
+
58
+ batch_size = last_hidden_state.size()[0]
59
+ seq_len = last_hidden_state.size()[1]
60
+
61
+ outputs = self.dense(last_hidden_state)
62
+ outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
63
+ outputs = torch.stack(outputs, dim=-2)
64
+ qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
65
+ if self.RoPE:
66
+ # pos_emb:(batch_size, seq_len, inner_dim)
67
+ pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
68
+ cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
69
+ sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
70
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
71
+ qw2 = qw2.reshape(qw.shape)
72
+ qw = qw * cos_pos + qw2 * sin_pos
73
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
74
+ kw2 = kw2.reshape(kw.shape)
75
+ kw = kw * cos_pos + kw2 * sin_pos
76
+ # logits:(batch_size, ent_type_size, seq_len, seq_len)
77
+ logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw)
78
+
79
+ # padding mask
80
+ pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
81
+ logits = logits * pad_mask - (1 - pad_mask) * 1e12
82
+
83
+ # 排除下三角
84
+ mask = torch.tril(torch.ones_like(logits), -1)
85
+ logits = logits - mask * 1e12
86
+
87
+ return logits / self.inner_dim ** 0.5
88
+
89
+
90
+ class SinusoidalPositionEmbedding(nn.Module):
91
+ """定义Sin-Cos位置Embedding
92
+ """
93
+
94
+ def __init__(
95
+ self, output_dim, merge_mode="add", custom_position_ids=False):
96
+ super(SinusoidalPositionEmbedding, self).__init__()
97
+ self.output_dim = output_dim
98
+ self.merge_mode = merge_mode
99
+ self.custom_position_ids = custom_position_ids
100
+
101
+ def forward(self, inputs):
102
+ if self.custom_position_ids:
103
+ seq_len = inputs.shape[1]
104
+ inputs, position_ids = inputs
105
+ position_ids = position_ids.type(torch.float)
106
+ else:
107
+ input_shape = inputs.shape
108
+ batch_size, seq_len = input_shape[0], input_shape[1]
109
+ position_ids = torch.arange(seq_len).type(torch.float)[None]
110
+ indices = torch.arange(self.output_dim // 2).type(torch.float)
111
+ indices = torch.pow(10000.0, -2 * indices / self.output_dim)
112
+ embeddings = torch.einsum("bn,d->bnd", position_ids, indices)
113
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
114
+ embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
115
+ if self.merge_mode == "add":
116
+ return inputs + embeddings.to(inputs.device)
117
+ elif self.merge_mode == "mul":
118
+ return inputs * (embeddings + 1.0).to(inputs.device)
119
+ elif self.merge_mode == "zero":
120
+ return embeddings.to(inputs.device)
121
+
122
+
123
+ def multilabel_categorical_crossentropy(y_pred, y_true):
124
+ y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
125
+ y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes
126
+ y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes
127
+ zeros = torch.zeros_like(y_pred[..., :1])
128
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
129
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
130
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
131
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
132
+ # print(y_pred, y_true, pos_loss)
133
+ return (neg_loss + pos_loss).mean()
134
+
135
+
136
+ def multilabel_categorical_crossentropy2(y_pred, y_true):
137
+ y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
138
+ y_pred_neg = y_pred.clone()
139
+ y_pred_pos = y_pred.clone()
140
+ y_pred_neg[y_true>0] -= float("inf")
141
+ y_pred_pos[y_true<1] -= float("inf")
142
+ # y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes
143
+ # y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes
144
+ zeros = torch.zeros_like(y_pred[..., :1])
145
+ y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
146
+ y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
147
+ neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
148
+ pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
149
+ # print(y_pred, y_true, pos_loss)
150
+ return (neg_loss + pos_loss).mean()
151
+
152
+ @dataclass
153
+ class GlobalPointerOutput(ModelOutput):
154
+ loss: Optional[torch.FloatTensor] = None
155
+ topk_probs: torch.FloatTensor = None
156
+ topk_indices: torch.IntTensor = None
157
+ last_hidden_state: torch.FloatTensor = None
158
+
159
+
160
+ @dataclass
161
+ class SpanProtoOutput(ModelOutput):
162
+ loss: Optional[torch.FloatTensor] = None
163
+ query_spans: list = None
164
+ proto_logits: list = None
165
+ topk_probs: torch.FloatTensor = None
166
+ topk_indices: torch.IntTensor = None
167
+
168
+
169
+ class SpanDetector(BertPreTrainedModel):
170
+ def __init__(self, config):
171
+ # encodr: RoBerta-Large as encoder
172
+ # inner_dim: 64
173
+ # ent_type_size: ent_cls_num
174
+ super().__init__(config)
175
+ self.bert = BertModel(config)
176
+ # self.ent_type_size = config.ent_type_size
177
+ self.ent_type_size = 1
178
+ self.inner_dim = 64
179
+ self.hidden_size = config.hidden_size
180
+ self.RoPE = True
181
+
182
+ self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
183
+ self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
184
+
185
+
186
+ def sequence_masking(self, x, mask, value="-inf", axis=None):
187
+ if mask is None:
188
+ return x
189
+ else:
190
+ if value == "-inf":
191
+ value = -1e12
192
+ elif value == "inf":
193
+ value = 1e12
194
+ assert axis > 0, "axis must be greater than 0"
195
+ for _ in range(axis - 1):
196
+ mask = torch.unsqueeze(mask, 1)
197
+ for _ in range(x.ndim - mask.ndim):
198
+ mask = torch.unsqueeze(mask, mask.ndim)
199
+ return x * mask + value * (1 - mask)
200
+
201
+ def add_mask_tril(self, logits, mask):
202
+ if mask.dtype != logits.dtype:
203
+ mask = mask.type(logits.dtype)
204
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
205
+ logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
206
+ # 排除下三角
207
+ mask = torch.tril(torch.ones_like(logits), diagonal=-1)
208
+ logits = logits - mask * 1e12
209
+ return logits
210
+
211
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
212
+ # with torch.no_grad():
213
+ context_outputs = self.bert(input_ids, attention_mask, token_type_ids)
214
+ last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
215
+ del context_outputs
216
+ outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
217
+ qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个维度,从0开始,取奇数位置所有向量汇总
218
+ batch_size = input_ids.shape[0]
219
+ if self.RoPE: # 是否使用RoPE旋转位置编码
220
+ pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
221
+ cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
222
+ sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
223
+ qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
224
+ qw2 = torch.reshape(qw2, qw.shape)
225
+ qw = qw * cos_pos + qw2 * sin_pos
226
+ kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
227
+ kw2 = torch.reshape(kw2, kw.shape)
228
+ kw = kw * cos_pos + kw2 * sin_pos
229
+ logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
230
+ bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
231
+ logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
232
+ # logit_mask = self.add_mask_tril(logits, mask=attention_mask)
233
+ loss = None
234
+
235
+ mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
236
+ # mask = torch.where(mask > 0, 0.0, 1)
237
+ if labels is not None:
238
+ # y_pred = torch.zeros(input_ids.shape[0], self.ent_type_size, input_ids.shape[1], input_ids.shape[1], device=input_ids.device)
239
+ # for i in range(input_ids.shape[0]):
240
+ # for j in range(self.ent_type_size):
241
+ # y_pred[i, j, labels[i, j, 0], labels[i, j, 1]] = 1
242
+ # y_true = labels.reshape(input_ids.shape[0] * self.ent_type_size, -1)
243
+ # y_pred = logit_mask.reshape(input_ids.shape[0] * self.ent_type_size, -1)
244
+ # loss = multilabel_categorical_crossentropy(y_pred, y_true)
245
+ #
246
+
247
+ # weight = ((labels == 0).sum() / labels.sum())/5
248
+ # loss_fct = nn.BCEWithLogitsLoss(weight=weight)
249
+ # loss_fct = nn.BCEWithLogitsLoss(reduction="none")
250
+ # unmask_labels = labels.view(-1)[mask.view(-1) > 0]
251
+ # loss = loss_fct(logits.view(-1)[mask.view(-1) > 0], unmask_labels.float())
252
+ # if unmask_labels.sum() > 0:
253
+ # loss = (loss[unmask_labels > 0].mean()+loss[unmask_labels < 1].mean())/2
254
+ # else:
255
+ # loss = loss[unmask_labels < 1].mean()
256
+ # y_pred = logits.view(-1)[mask.view(-1) > 0]
257
+ # y_true = labels.view(-1)[mask.view(-1) > 0]
258
+ # loss = multilabel_categorical_crossentropy2(y_pred, y_true)
259
+ # y_pred = logits - torch.where(mask > 0, 0.0, float("inf")).unsqueeze(1)
260
+ y_pred = logits - (1-mask.unsqueeze(1))*1e12
261
+ y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
262
+ y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
263
+ loss = multilabel_categorical_crossentropy(y_pred, y_true)
264
+
265
+ with torch.no_grad():
266
+ prob = torch.sigmoid(logits) * mask.unsqueeze(1)
267
+ topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
268
+
269
+
270
+ return GlobalPointerOutput(
271
+ loss=loss,
272
+ topk_probs=topk.values,
273
+ topk_indices=topk.indices,
274
+ last_hidden_state=last_hidden_state
275
+ )
276
+
277
+
278
+ class SpanProto(nn.Module):
279
+ def __init__(self, config):
280
+ """
281
+ word_encoder: Sentence encoder
282
+
283
+ You need to set self.cost as your own loss function.
284
+ """
285
+ nn.Module.__init__(self)
286
+ self.config = config
287
+ self.output_dir = "./outputs"
288
+ # self.predict_dir = self.predict_result_path(self.output_dir)
289
+ self.drop = nn.Dropout()
290
+ self.global_span_detector = SpanDetector(config=self.config) # global span detector
291
+ self.projector = nn.Sequential( # projector
292
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
293
+ nn.Sigmoid(),
294
+ # nn.LayerNorm(2)
295
+ )
296
+ self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
297
+ # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
298
+ self.max_length = 64
299
+ self.margin_distance = 6.0
300
+ self.global_step = 0
301
+
302
+ def predict_result_path(self, path=None):
303
+ if path is None:
304
+ predict_dir = os.path.join(
305
+ self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
306
+ )
307
+ else:
308
+ predict_dir = os.path.join(
309
+ path, "predict"
310
+ )
311
+ # if os.path.exists(predict_dir):
312
+ # os.rmdir(predict_dir) # 删除历史记录
313
+ if not os.path.exists(predict_dir): # 重新创建一个新的目录
314
+ os.makedirs(predict_dir)
315
+ return predict_dir
316
+
317
+
318
+ @classmethod
319
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
320
+ config = kwargs.pop("config", None)
321
+ model = SpanProto(config=config)
322
+ # 将bert部分参数加载进去
323
+ model.global_span_detector = SpanDetector.from_pretrained(
324
+ pretrained_model_name_or_path,
325
+ *model_args,
326
+ **kwargs
327
+ )
328
+ # 将剩余的参数加载进来
329
+ return model
330
+
331
+ # @classmethod
332
+ # def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
333
+ # self.global_span_detector.resize_token_embeddings(new_num_tokens)
334
+
335
+ def __dist__(self, x, y, dim, use_dot=False):
336
+ # x: [1, class_num, hidden_dim], y: [span_num, 1, hidden_dim]
337
+ # x - y: [span_num, class_num, hidden_dim]
338
+ # (x - y)^2.sum(2): [span_num, class_num]
339
+ if use_dot:
340
+ return (x * y).sum(dim)
341
+ else:
342
+ return -(torch.pow(x - y, 2)).sum(dim)
343
+
344
+ def __get_proto__(self, support_emb: torch, support_span: list, support_span_type: list, use_tag=False):
345
+ """
346
+ support_emb: [n", seq_len, dim]
347
+ support_span: [n", m, 2] e.g. [[[3, 6], [12, 13]], [[1, 3]], ...]
348
+ support_span_type: [n", m] e.g. [[2, 1], [5], ...]
349
+ """
350
+ prototype = list() # 每个类的proto type
351
+ all_span_embs = list() # 保存每个span的embedding
352
+ all_span_tags = list()
353
+ # 遍历每个类
354
+ for tag in range(self.num_class):
355
+ # tag_id = torch.Tensor([1 if tag == self.num_class else 0]).long().cuda()
356
+ # tag_embeddings = self.tag_embeddings(tag_id).view(-1)
357
+ tag_prototype = list() # [k, dim]
358
+ # 遍历当前episode内的每个句子
359
+ for emb, span, type in zip(support_emb, support_span, support_span_type):
360
+ # emb: [seq_len, dim], span: [m, 2], type: [m]
361
+ span = torch.Tensor(span).long().cuda() # e.g. [[3, 4], [9, 11]]
362
+ type = torch.Tensor(type).long().cuda() # e.g. [1, 4]
363
+ # 获取当前句子中属于tag类的span
364
+ try:
365
+ tag_span = span[type == tag] # e.g. span==[[3, 4]], tag==1
366
+
367
+ # 遍历每个检索到的span,获得其span embedding
368
+ for (s, e) in tag_span:
369
+ # tag_emb = torch.cat([emb[s], emb[e - 1]]) # [2*dim]
370
+ tag_emb = emb[s] + emb[e] # [dim]
371
+ # if use_tag: # 添加是否为unlabeled的标记,0对应embedding表示当前的span是labeled span,否则为unlabeled span
372
+ # tag_emb = tag_emb + tag_embeddings
373
+ tag_prototype.append(tag_emb)
374
+ all_span_embs.append(tag_emb)
375
+ all_span_tags.append(tag)
376
+ except:
377
+ # 说明当前类不存在对应的span,则随机
378
+ tag_prototype.append(torch.randn(support_emb.shape[-1]).cuda())
379
+ # assert 1 > 2
380
+ try:
381
+ prototype.append(torch.mean(torch.stack(tag_prototype), dim=0))
382
+ except:
383
+ # print("the class {} has no span".format(tag))
384
+ prototype.append(torch.randn(support_emb.shape[-1]).cuda())
385
+ # assert 1 > 2
386
+ all_span_embs = torch.stack(all_span_embs).detach().cpu().numpy().tolist()
387
+
388
+ return torch.stack(prototype), all_span_embs, all_span_tags # [num_class + 1, dim]
389
+
390
+
391
+ def __batch_dist__(self, prototype: torch, query_emb: torch, query_spans: list, query_span_type: Union[list, None]):
392
+ """
393
+ 该函数用于获得query到各个prototype的分类
394
+ """
395
+ # 首先获得当前episode的每个句子的每个span的表征向量
396
+ # 遍历每个句子
397
+ all_logits = list() # 保存每个episode,每个句子所有span的预测概率
398
+ all_types = list()
399
+ visual_all_types, visual_all_embs = list(), list() # 用于展示可视化
400
+ # num = 0
401
+ for emb, span in zip(query_emb, query_spans): # 遍历每个句子
402
+ # assert len(span) == len(query_span_type[num]), "span={}\ntype{}".format(span, query_span_type[num])
403
+ # print("len(span)={}, len(type)= {}".format(len(span), len(query_span_type[num])))
404
+ span_emb = list() # 保存当前句子所有span的embedding [m", dim]
405
+ try:
406
+ for (s, e) in span: # 遍历每个span
407
+ tag_emb = emb[s] + emb[e] # [dim]
408
+ span_emb.append(tag_emb)
409
+ except:
410
+ span_emb = []
411
+ if len(span_emb) != 0:
412
+ span_emb = torch.stack(span_emb) # [span_num, dim]
413
+ # 每个span与prototype计算距离
414
+ logits = self.__dist__(prototype.unsqueeze(0), span_emb.unsqueeze(1), 2) # [span_num, num_class]
415
+ # pred_types = torch.argmax(logits, -1).detach().cpu().numpy().tolist()
416
+ with torch.no_grad():
417
+ pred_dist, pred_types = torch.max(logits, -1) # 获得每个query与所有prototype的距离的最近的类及其距离的平方
418
+ pred_dist = torch.pow(-1 * pred_dist, 0.5)
419
+ # print("pred_dist=", pred_dist)
420
+ # 如果最近的距离超过了margin distant,则该span视为unlabeled span,标注为特殊的类
421
+ pred_types[pred_dist > self.margin_distance] = self.num_class
422
+ pred_types = pred_types.detach().cpu().numpy().tolist()
423
+ # # 获得概率分布
424
+ # with torch.no_grad():
425
+ # prob = torch.softmax(logits, -1)
426
+ # pred_proba, pred_types = torch.max(logits, -1) # 获得每个span预测概率最大的类及其概率
427
+ # pred_types[pred_proba <= 0.6] = self.num_class # 如果当前预测的最大概率不满足,则说明其可能是一个其他实体
428
+ # pred_types = pred_types.detach().cpu().numpy().tolist()
429
+
430
+ all_logits.append(logits)
431
+ all_types.append(pred_types)
432
+ visual_all_types.extend(pred_types)
433
+ visual_all_embs.extend(span_emb.detach().cpu().numpy().tolist())
434
+ else:
435
+ all_logits.append([])
436
+ all_types.append([])
437
+ # num += 1
438
+
439
+ if query_span_type is not None:
440
+ # query_span_type: [n", m]
441
+ try:
442
+ all_type = torch.Tensor([type for types in query_span_type for type in types]).long().cuda() # [span_num]
443
+ loss = nn.CrossEntropyLoss()(torch.cat(all_logits, 0), all_type)
444
+ except:
445
+ all_logit, all_type = list(), list()
446
+ for logits, types in zip(all_logits, query_span_type):
447
+ if len(logits) != 0 and len(types) != 0 and len(logits) == len(types):
448
+ # print("len(logits)=", len(logits))
449
+ # print("len(types)=", len(types))
450
+ # print("logits=", logits)
451
+ all_logit.append(logits)
452
+ all_type.extend(types)
453
+ # print("all_logit=", all_logit)
454
+ if len(all_logit) != 0:
455
+ all_logit = torch.cat(all_logit, 0)
456
+ all_type = torch.Tensor(all_type).long().cuda()
457
+ # print("len(all_logits)=", len(all_logits))
458
+ # print("len(query_span_type)=", len(query_span_type))
459
+
460
+ # print("types.shape=", torch.Tensor(all_type).shape)
461
+
462
+ # min_len = min(len(all_type), len(all_type))
463
+ # all_logit, all_type = all_logit[: min_len], all_type[: min_len]
464
+ # print("logits.shape=", all_logit.shape)
465
+ # print("all_type=", all_type)
466
+ loss = nn.CrossEntropyLoss()(all_logit, all_type)
467
+ else:
468
+ loss = 0.
469
+
470
+
471
+ else:
472
+ loss = None
473
+ all_logits = [i.detach().cpu().numpy().tolist() for i in all_logits if len(i) != 0]
474
+ return loss, all_logits, all_types, visual_all_types, visual_all_embs
475
+
476
+
477
+ def __batch_margin__(self, prototype: torch, query_emb: torch, query_unlabeled_spans: list,
478
+ query_labeled_spans: list, query_span_type: list):
479
+ """
480
+ 该函数用于拉开unlabeled span与各个prototype的距离,拉近labeled span到对应类别的距离
481
+ """
482
+
483
+ # prototype: [num_class, dim], negative: [span_num, dim]
484
+ # 获得每个unlabeled span与每个prototype的距离的平方,目标是对于每个距离平方都要设置大于margin阈值
485
+ def distance(input1, input2, p=2, eps=1e-6):
486
+ # Compute the distance (p-norm)
487
+ norm = torch.pow(torch.abs((input1 - input2 + eps)), p)
488
+ pnorm = torch.pow(torch.sum(norm, -1), 1.0 / p)
489
+ return pnorm
490
+
491
+ unlabeled_span_emb, labeled_span_emb, labeled_span_type = list(), list(), list()
492
+ for emb, span in zip(query_emb, query_unlabeled_spans): # 遍历每个句子
493
+ # 保存当前句子所有span的embedding [m", dim]
494
+ for (s, e) in span: # 遍历每个span
495
+ tag_emb = emb[s] + emb[e] # [dim]
496
+ unlabeled_span_emb.append(tag_emb)
497
+
498
+ # for emb, span, type in zip(query_emb, query_labeled_spans, query_span_type): # 遍历每个句子
499
+ # # 保存当前句子所有span的embedding [m", dim]
500
+ # for (s, e) in span: # 遍历每个span
501
+ # tag_emb = emb[s] + emb[e] # [dim]
502
+ # labeled_span_emb.append(tag_emb)
503
+ # labeled_span_type.extend(type)
504
+
505
+ try:
506
+ unlabeled_span_emb = torch.stack(unlabeled_span_emb) # [span_num, dim]
507
+ # labeled_span_emb = torch.stack(labeled_span_emb) # [span_num, dim]
508
+ # labeled_span_type = torch.stack(labeled_span_type) # [span_num]
509
+ except:
510
+ return 0.
511
+
512
+ unlabeled_dist = distance(prototype.unsqueeze(0), unlabeled_span_emb.unsqueeze(1)) # [span_num, num_class]
513
+ # labeled_dist = distance(prototype.unsqueeze(0), labeled_span_emb.unsqueeze(1)) # [span_num, num_class]
514
+ # 获得每个span对应ground truth类别距离prototype的距离
515
+ # labeled_type_dist = torch.gather(labeled_dist, -1, labeled_span_type.unsqueeze(1)) # [span_num, 1]
516
+ # print(dist)
517
+ unlabeled_output = torch.maximum(torch.zeros_like(unlabeled_dist), self.margin_distance - unlabeled_dist)
518
+ # labeled_output = torch.maximum(torch.zeros_like(labeled_type_dist), labeled_type_dist)
519
+ # return torch.mean(unlabeled_output) + torch.mean(labeled_output)
520
+ return torch.mean(unlabeled_output)
521
+
522
+
523
+ def forward(
524
+ self,
525
+ episode_ids,
526
+ support, query,
527
+ num_class,
528
+ num_example,
529
+ mode=None,
530
+ short_labels=None,
531
+ stage:str ="train",
532
+ path: str=None
533
+ ):
534
+ """
535
+ episode_ids: Input of the idx of each episode data. (only list)
536
+ support: Inputs of the support set.
537
+ query: Inputs of the query set.
538
+ num_class: Num of classes
539
+ K: Num of instances for each class in the support set
540
+ Q: Num of instances for each class in the query set
541
+ return: logits, pred
542
+ """
543
+ if stage.startswith("train"):
544
+ self.global_step += 1
545
+ self.num_class = num_class # N-way K-shot里的N
546
+ self.num_example = num_example # N-way K-shot里的K
547
+ # print("num_class=", num_class)
548
+ self.mode = mode # FewNERD mode=inter/intra
549
+ self.max_length = support["input_ids"].shape[1]
550
+ support_inputs, support_attention_masks, support_type_ids = \
551
+ support["input_ids"], support["attention_mask"], support["token_type_ids"] # torch, [n, seq_len]
552
+ query_inputs, query_attention_masks, query_type_ids = \
553
+ query["input_ids"], query["attention_mask"], query["token_type_ids"] # torch, [n, seq_len]
554
+ support_labels = support["labels"] # torch,
555
+ query_labels = query["labels"] # torch,
556
+ # global span detector: obtain all mention span and loss
557
+ support_detector_outputs = self.global_span_detector(
558
+ support_inputs, support_attention_masks, support_type_ids, support_labels, short_labels=short_labels
559
+ )
560
+ query_detector_outputs = self.global_span_detector(
561
+ query_inputs, query_attention_masks, query_type_ids, query_labels, short_labels=short_labels
562
+ )
563
+ device_id = support_inputs.device.index
564
+
565
+ # if stage == "train_span":
566
+ if self.global_step <= 500 and stage == "train":
567
+ # only train span detector
568
+ return SpanProtoOutput(
569
+ loss=support_detector_outputs.loss,
570
+ topk_probs=query_detector_outputs.topk_probs,
571
+ topk_indices=query_detector_outputs.topk_indices,
572
+ )
573
+ # obtain labeled span from the support set
574
+ support_labeled_spans = support["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
575
+ support_labeled_types = support["labeled_types"] # all labeled ent type id, list, [n, m],
576
+ query_labeled_spans = query["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
577
+ query_labeled_types = query["labeled_types"] # all labeled ent type id, list, [n, m],
578
+
579
+ # for span, type in zip(query_labeled_spans, query_labeled_types): # 遍历每个句子
580
+ # assert len(span) == len(type), "span={}\ntype{}".format(span, type)
581
+
582
+ # obtain unlabeled span from the support set
583
+ # according to the detector, we can obtain multiple unlabeled span, which generated by the detector
584
+ # but not labeled in n-way k-shot episode
585
+ # support_predict_spans = self.get_topk_spans( #
586
+ # support_detector_outputs.topk_probs,
587
+ # support_detector_outputs.topk_indices,
588
+ # support["input_ids"]
589
+ # ) # [n, m, 2]
590
+ # print("predicted support span num={}".format([len(i) for i in support_predict_spans]))
591
+ # e.g. 打印一个所有句子,每个元素表示每个句子中的span个数,[5, 50, 4, 43, 5, 5, 1, 50, 2, 5, 6, 4, 50, 8, 12, 28, 17]
592
+
593
+ # we can also obtain all predicted span from the query set
594
+ query_predict_spans = self.get_topk_spans( #
595
+ query_detector_outputs.topk_probs,
596
+ query_detector_outputs.topk_indices,
597
+ query["input_ids"],
598
+ threshold=0.9 if stage.startswith("train") else 0.95,
599
+ is_query=True
600
+ ) # [n, m, 2]
601
+ # print("predicted query span num={}".format([len(i) for i in query_predict_spans]))
602
+
603
+
604
+ # merge predicted span and labeled span, and generate other class for unlabeled span set
605
+ # support_all_spans, support_span_types = self.merge_span(
606
+ # labeled_spans=support_labeled_spans,
607
+ # labeled_types=support_labeled_types,
608
+ # predict_spans=support_predict_spans,
609
+ # stage=stage
610
+ # ) # [n, m, 2] n 个句子,每个句子有若干个span
611
+ # print("merged support span num={}".format([len(i) for i in support_all_spans]))
612
+
613
+
614
+ if stage.startswith("train"):
615
+ # 在训练阶段,需要知道detector识别的所有区间中,哪些是labeled,哪些是unlabeled,将unlabeled span全部分离出来
616
+ query_unlabeled_spans = self.split_span( # 拆分出unlabeled span,用于后面的margin loss
617
+ labeled_spans=query_labeled_spans,
618
+ labeled_types=query_labeled_types,
619
+ predict_spans=query_predict_spans,
620
+ stage=stage
621
+ ) # [n, m, 2] n 个句子,每个句子有若干个span
622
+ # print("merged query span num={}".format([len(i) for i in query_all_spans]))
623
+ query_all_spans = query_labeled_spans
624
+ query_span_types = query_labeled_types
625
+
626
+ else:
627
+ # 在推理阶段,直接全部merge
628
+ query_unlabeled_spans = None
629
+ query_all_spans, _ = self.merge_span(
630
+ labeled_spans=query_labeled_spans,
631
+ labeled_types=query_labeled_types,
632
+ predict_spans=query_predict_spans,
633
+ stage=stage
634
+ ) # [n, m, 2] n 个句子,每个句子有若干个span
635
+ # 在dev和test时,此时query部分的span完全靠detector识别
636
+ # query_all_spans = query_predict_spans
637
+ query_span_types = None
638
+ # 用于查看推理阶段dev或test的query上detector的预测结果
639
+ # for query_label, query_pred in zip(query_labeled_spans, query_predict_spans):
640
+ # print(" ==== ")
641
+ # print("query_labeled_spans=", query_label)
642
+ # print("query_predict_spans=", query_pred)
643
+
644
+ # obtain representations of each token
645
+ support_emb, query_emb = support_detector_outputs.last_hidden_state, \
646
+ query_detector_outputs.last_hidden_state # [n, seq_len, dim]
647
+ support_emb, query_emb = self.projector(support_emb), self.projector(query_emb) # [n, seq_len, dim]
648
+
649
+ # all_query_spans = list() # 保存每个episode的所有句子所有的预测span
650
+ # all_proto_logits = list() # 保存每个episode的所有句子每个预测span对应的entity type
651
+ batch_result = dict()
652
+ proto_losses = list() # 保存每个episode的loss
653
+ # batch_visual = list() # 保存每个episode所有span的表征向量,用于可视化
654
+ current_support_num = 0
655
+ current_query_num = 0
656
+ typing_loss = None
657
+ # 遍历每个episode
658
+ for i, sent_support_num in enumerate(support["sentence_num"]):
659
+ sent_query_num = query["sentence_num"][i]
660
+ id_ = episode_ids[i] # 当前episode的编号
661
+
662
+ # 对于support,只对labeled span获得prototype
663
+ # locate one episode and obtain the span prototype
664
+ # [n", seq_len, dim] n" sentence in one episode
665
+ # support_proto [num_class + 1, dim]
666
+ support_proto, all_span_embs, all_span_tags = self.__get_proto__(
667
+ support_emb[current_support_num: current_support_num + sent_support_num], # [n", seq_len, dim]
668
+ support_labeled_spans[current_support_num: current_support_num + sent_support_num], # [n", m]
669
+ support_labeled_types[current_support_num: current_support_num + sent_support_num], # [n", m]
670
+ )
671
+
672
+
673
+ # 对于query set每个labeled span,使用标准的prototype learning
674
+ # for each query, we first obtain corresponding span, and then calculate distance between it and each prototype
675
+ # # [n", seq_len, dim] n" sentence in one episode
676
+ proto_loss, proto_logits, all_types, visual_all_types, visual_all_embs = self.__batch_dist__(
677
+ support_proto,
678
+ query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim]
679
+ query_all_spans[current_query_num: current_query_num + sent_query_num], # [n", m]
680
+ query_span_types[current_query_num: current_query_num + sent_query_num] if query_span_types else None, # [n", m]
681
+ )
682
+
683
+ visual_data = {
684
+ "data": all_span_embs + visual_all_embs,
685
+ "target": all_span_tags + visual_all_types,
686
+ }
687
+
688
+ # 对于query unlabeled span,遍历每个span,拉开与所有prototype的距离,选择margin loss
689
+ if stage.startswith("train"):
690
+
691
+ margin_loss = self.__batch_margin__(
692
+ support_proto,
693
+ query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim]
694
+ query_unlabeled_spans[current_query_num: current_query_num + sent_query_num], # [n", span_num]
695
+ query_all_spans[current_query_num: current_query_num + sent_query_num],
696
+ query_span_types[current_query_num: current_query_num + sent_query_num],
697
+ )
698
+
699
+ proto_losses.append(proto_loss + margin_loss)
700
+
701
+ batch_result[id_] = {
702
+ "spans": query_all_spans[current_query_num: current_query_num + sent_query_num],
703
+ "types": all_types,
704
+ "visualization": visual_data
705
+ }
706
+
707
+ current_query_num += sent_query_num
708
+ current_support_num += sent_support_num
709
+ # proto_logits = torch.stack(proto_logits)
710
+ if stage.startswith("train"):
711
+ typing_loss = torch.mean(torch.stack(proto_losses), dim=-1)
712
+
713
+
714
+ if not stage.startswith("train"):
715
+ self.__save_evaluate_predicted_result__(batch_result, device_id=device_id, stage=stage, path=path)
716
+
717
+ # return SpanProtoOutput(
718
+ # loss=((support_detector_outputs.loss + query_detector_outputs.loss) / 2.0 + typing_loss)
719
+ # if stage.startswith("train") else (support_detector_outputs.loss + query_detector_outputs.loss),
720
+ # ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
721
+ return SpanProtoOutput(
722
+ loss=(support_detector_outputs.loss + typing_loss)
723
+ if stage.startswith("train") else query_detector_outputs.loss,
724
+ ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
725
+
726
+ def __save_evaluate_predicted_result__(self, new_result: dict, device_id: int = 0, stage="dev", path=None):
727
+ """
728
+ 本函数用于在forward时保存每一个batch内的预测span以及span type
729
+ new_result / result: {
730
+ "(id)": { # id-th episode query
731
+ "spans": [[[1, 4], [6, 7], xxx], ... ] # [sent_num, span_num, 2]
732
+ "types": [[2, 0, xxx], ...] # [sent_num, span_num]
733
+ },
734
+ xxx
735
+ }
736
+ """
737
+ # 拉取当前任务中已经预测的结果
738
+ self.predict_dir = self.predict_result_path(path)
739
+ npy_file_name = os.path.join(self.predict_dir, "{}_predictions_{}.npy".format(stage, device_id))
740
+ result = dict()
741
+ if os.path.exists(npy_file_name):
742
+ result = np.load(npy_file_name, allow_pickle=True)[()]
743
+ # 合并
744
+ for episode_id, query_res in new_result.items():
745
+ result[episode_id] = query_res
746
+ # 保存
747
+ np.save(npy_file_name, result, allow_pickle=True)
748
+
749
+
750
+ def get_topk_spans(self, probs, indices, input_ids, threshold=0.60, low_threshold=0.1, is_query=False):
751
+ """
752
+ probs: [n, m]
753
+ indices: [n, m]
754
+ input_texts: [n, seq_len]
755
+ is_query: if true, each sentence must recall at least one span
756
+ """
757
+ probs = probs.squeeze(1).detach().cpu() # topk结果的概率 [n, m] # 返回的已经是按照概率进行降序排列的结果
758
+ indices = indices.squeeze(1).detach().cpu() # topk结果的索引 [n, m] # 返回的已经是按照概率进行降序排列的结果
759
+ input_ids = input_ids.detach().cpu()
760
+ # print("probs=", probs) # [n, m]
761
+ # print("indices=", indices) # [n, m]
762
+ predict_span = list()
763
+ if is_query:
764
+ low_threshold = 0.0
765
+ for prob, index, text in zip(probs, indices, input_ids): # 遍历每个句子,其对应若干预测的span及其概率
766
+ threshold_ = threshold
767
+ index_ids = torch.Tensor([i for i in range(len(index))]).long()
768
+ span = set()
769
+ # TODO 1. 调节阈值 2. 处理输出实体重叠问题
770
+ entity_index = index[prob >= low_threshold]
771
+ index_ids = index_ids[prob >= low_threshold]
772
+ while threshold_ >= low_threshold: # 动态控制阈值,以确保可以召回出span数量是尽可能均匀的(如果所有句子使用同一个阈值,那么每个句子被召回的span数量参差不齐)
773
+ for ei, entity in enumerate(entity_index):
774
+ p = prob[index_ids[ei]]
775
+ if p < threshold_: # 如果此时候选的span得分已经低于阈值,由于获得的结果已经是降序排列的,则后续的结果一定都低于阈值,则直接结束
776
+ break
777
+ # 1D index转2D index
778
+ start_end = np.unravel_index(entity, (self.max_length, self.max_length))
779
+ # print("self.max_length=", self.max_length)
780
+ s, e = start_end[0], start_end[1]
781
+ ans = text[s: e]
782
+ # if ans not in answer:
783
+ # answer.append(ans)
784
+ # topk_answer_dict[ans] = {"prob": float(prob[index_ids[ei]]), "pos": [(s, e)]}
785
+ span.add((s, e))
786
+ # 满足下列几个条件的,动态调低阈值,并重新筛选
787
+ if len(span) <= 3:
788
+ threshold_ -= 0.05
789
+ else:
790
+ break
791
+ if len(span) == 0:
792
+ # 如果当前没有召回出任何span,则直接选择[cls]作为结果(相当于MRC的unanswerable)
793
+ span = [[0, 0]]
794
+ span = [list(i) for i in list(span)]
795
+ # print("prob=", prob) e.g. [0.96, 0.85, 0.04, 0.00, ...]
796
+ # print("span=", span) e.g. [[20, 23], [11, 14]]
797
+ predict_span.append(span)
798
+ return predict_span
799
+
800
+
801
+ def split_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):
802
+ """
803
+ # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span
804
+
805
+ """
806
+ def check_similar_span(span1, span2):
807
+ """
808
+ 检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
809
+ """
810
+ # 考虑一个特殊情况,例如 [12, 12], [13, 13]
811
+ if len(span1) == 0 or len(span2) == 0:
812
+ return False
813
+ if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
814
+ return False
815
+ if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
816
+ return True
817
+ return False
818
+
819
+ all_spans, span_types = list(), list() # [n, m]
820
+ num = 0
821
+ unlabeled_spans = list()
822
+ for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
823
+ # 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span
824
+ unlabeled_span = list()
825
+ # if len(all_span) != len(span_type):
826
+ # length = min(len(all_span), len(span_type))
827
+ # all_span, span_type = all_span[: length], span_type[: length]
828
+ for span in predict_span: # 遍历每个预测的span
829
+ if span not in labeled_span: # 如果span没有存在,则说明当前的span是unlabeled的
830
+ # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
831
+ is_remove = False
832
+ for span_x in labeled_span: # 遍历所有已经被merge的span
833
+ is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span
834
+ if is_remove is True:
835
+ break
836
+ if is_remove is True:
837
+ continue
838
+ unlabeled_span.append(span)
839
+ # if self.global_step % 1000 == 0:
840
+ # print(" === ")
841
+ # print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
842
+ # print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
843
+ # if len(unlabeled_span) == 0 and stage.startswith("train"):
844
+ # # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
845
+ # # print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
846
+ # # all_span.append([0, 0])
847
+ # # span_type.append(self.num_class)
848
+ # while True:
849
+ # random_span = np.random.randint(0, 32, 2).tolist()
850
+ # if abs(random_span[0] - random_span[1]) > 10:
851
+ # continue
852
+ # random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
853
+ # if random_span in labeled_span or random_span in unlabeled_span:
854
+ # continue
855
+ # unlabeled_span.append(random_span)
856
+ # break
857
+ num += len(unlabeled_span)
858
+ unlabeled_spans.append(unlabeled_span)
859
+ # print("num=", num)
860
+ return unlabeled_spans
861
+
862
+
863
+ def merge_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):
864
+
865
+ def check_similar_span(span1, span2):
866
+ """
867
+ 检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
868
+ """
869
+ # 考虑一个特殊情况,例如 [12, 12], [13, 13]
870
+ if len(span1) == 0 or len(span2) == 0:
871
+ return False
872
+ if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
873
+ return False
874
+ if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
875
+ return True
876
+ return False
877
+
878
+ all_spans, span_types = list(), list() # [n, m]
879
+ for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
880
+ # 遍历每个句子,对它们的span进行合并
881
+ unlabeled_num = 0
882
+ all_span, span_type = labeled_span, labeled_type # 先加入所有labeled span
883
+ if len(all_span) != len(span_type):
884
+ length = min(len(all_span), len(span_type))
885
+ all_span, span_type = all_span[: length], span_type[: length]
886
+ for span in predict_span: # 遍历每个预测的span
887
+ if span not in all_span: # 如果span没有存在,则说明当前的span是unlabeled的
888
+ # 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
889
+ is_remove = False
890
+ for span_x in all_span: # 遍历所有已经被merge的span
891
+ is_remove = check_similar_span(span_x, span) # 如果已��在的span,和当前的span很接近,则排除当前的span
892
+ if is_remove is True:
893
+ break
894
+ if is_remove is True:
895
+ continue
896
+ all_span.append(span)
897
+ span_type.append(self.num_class) # e.g. 5-way问题,已有标签为0,1,2,3,4,因此新增一个标签为5
898
+ unlabeled_num += 1
899
+ # if self.global_step % 1000 == 0:
900
+ # print(" === ")
901
+ # print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
902
+ # print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
903
+ if unlabeled_num == 0 and stage.startswith("train"):
904
+ # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
905
+ # print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
906
+ # all_span.append([0, 0])
907
+ # span_type.append(self.num_class)
908
+ while True:
909
+ random_span = np.random.randint(0, 32, 2).tolist()
910
+ if abs(random_span[0] - random_span[1]) > 10:
911
+ continue
912
+ random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
913
+ if random_span in all_span:
914
+ continue
915
+ all_span.append(random_span)
916
+ span_type.append(self.num_class)
917
+ break
918
+
919
+ # if len(all_span) != len(span_type):
920
+ # all_span = [[0, 0]]
921
+ # span_type = [self.num_class]
922
+
923
+ all_spans.append(all_span)
924
+ span_types.append(span_type)
925
+
926
+ return all_spans, span_types
models/fewshot_learning/token_proto.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/21 5:30 下午
3
+ # @Author : JianingWang
4
+ # @File : span_proto.py
5
+
6
+ """
7
+ This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
8
+ """
9
+
10
+ import os
11
+ from typing import Optional
12
+ import torch
13
+ import numpy as np
14
+ import torch.nn as nn
15
+ from typing import Union
16
+ from dataclasses import dataclass
17
+ from torch.nn import BCEWithLogitsLoss
18
+ from transformers import MegatronBertModel, MegatronBertPreTrainedModel
19
+ from transformers.file_utils import ModelOutput
20
+ from transformers.models.bert import BertPreTrainedModel, BertModel
21
+
22
+
23
+ @dataclass
24
+ class TokenProtoOutput(ModelOutput):
25
+ loss: Optional[torch.FloatTensor] = None
26
+ logits: Optional[torch.FloatTensor] = None
27
+
28
+
29
+ class TokenProto(nn.Module):
30
+ def __init__(self, config):
31
+ """
32
+ word_encoder: Sentence encoder
33
+
34
+ You need to set self.cost as your own loss function.
35
+ """
36
+ nn.Module.__init__(self)
37
+ self.config = config
38
+ self.output_dir = "./outputs"
39
+ # self.predict_dir = self.predict_result_path(self.output_dir)
40
+ self.drop = nn.Dropout()
41
+ self.projector = nn.Sequential( # projector
42
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
43
+ nn.Sigmoid(),
44
+ # nn.LayerNorm(2)
45
+ )
46
+ self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
47
+ # self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
48
+ self.max_length = 64
49
+ self.margin_distance = 6.0
50
+ self.global_step = 0
51
+
52
+ def predict_result_path(self, path=None):
53
+ if path is None:
54
+ predict_dir = os.path.join(
55
+ self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
56
+ )
57
+ else:
58
+ predict_dir = os.path.join(
59
+ path, "predict"
60
+ )
61
+ # if os.path.exists(predict_dir):
62
+ # os.rmdir(predict_dir) # 删除历史记录
63
+ if not os.path.exists(predict_dir): # 重新创建一个新的目录
64
+ os.makedirs(predict_dir)
65
+ return predict_dir
66
+
67
+ @classmethod
68
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
69
+ config = kwargs.pop("config", None)
70
+ model = TokenProto(config=config)
71
+ return model
72
+
73
+ def __dist__(self, x, y, dim):
74
+ if self.dot:
75
+ return (x * y).sum(dim)
76
+ else:
77
+ return -(torch.pow(x - y, 2)).sum(dim)
78
+
79
+ def __batch_dist__(self, S, Q, q_mask):
80
+ # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
81
+ assert Q.size()[:2] == q_mask.size()
82
+ Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
83
+ return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
84
+
85
+ def __get_proto__(self, embedding, tag, mask):
86
+ proto = []
87
+ embedding = embedding[mask==1].view(-1, embedding.size(-1))
88
+ tag = torch.cat(tag, 0)
89
+ assert tag.size(0) == embedding.size(0)
90
+ for label in range(torch.max(tag)+1):
91
+ proto.append(torch.mean(embedding[tag==label], 0))
92
+ proto = torch.stack(proto)
93
+ return proto, embedding
94
+
95
+ def forward(self, support, query):
96
+ """
97
+ support: Inputs of the support set.
98
+ query: Inputs of the query set.
99
+ N: Num of classes
100
+ K: Num of instances for each class in the support set
101
+ Q: Num of instances in the query set
102
+
103
+ support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []}
104
+
105
+ """
106
+ # support set和query set分别喂入BERT中获得各个样本的表示
107
+ support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768]
108
+ query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768]
109
+ support_emb = self.drop(support_emb)
110
+ query_emb = self.drop(query_emb)
111
+
112
+ # Prototypical Networks
113
+ logits = []
114
+ current_support_num = 0
115
+ current_query_num = 0
116
+ assert support_emb.size()[:2] == support["mask"].size()
117
+ assert query_emb.size()[:2] == query["mask"].size()
118
+
119
+ for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据
120
+ sent_query_num = query["sentence_num"][i]
121
+ # Calculate prototype for each class
122
+ # 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
123
+ # 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
124
+ support_proto, embedding = self.__get_proto__(
125
+ support_emb[current_support_num:current_support_num+sent_support_num],
126
+ support["label"][current_support_num:current_support_num+sent_support_num],
127
+ support["text_mask"][current_support_num: current_support_num+sent_support_num])
128
+ # calculate distance to each prototype
129
+ logits.append(self.__batch_dist__(
130
+ support_proto,
131
+ query_emb[current_query_num:current_query_num+sent_query_num],
132
+ query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
133
+ current_query_num += sent_query_num
134
+ current_support_num += sent_support_num
135
+ logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
136
+ _, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果
137
+
138
+
139
+ # return logits, pred, embedding
140
+
141
+ return TokenProtoOutput(
142
+ logits=logits
143
+ ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错