Spaces:
Sleeping
Sleeping
DeepLearning101
commited on
Commit
•
6c0ee22
1
Parent(s):
fdc4786
Upload 2 files
Browse files
models/fewshot_learning/span_proto.py
ADDED
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/4/21 5:30 下午
|
3 |
+
# @Author : JianingWang
|
4 |
+
# @File : span_proto.py
|
5 |
+
|
6 |
+
"""
|
7 |
+
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
from typing import Optional
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn as nn
|
15 |
+
from typing import Union
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from torch.nn import BCEWithLogitsLoss
|
18 |
+
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
|
19 |
+
from transformers.file_utils import ModelOutput
|
20 |
+
from transformers.models.bert import BertPreTrainedModel, BertModel
|
21 |
+
|
22 |
+
a = torch.nn.Embedding(10, 20)
|
23 |
+
a.parameters
|
24 |
+
|
25 |
+
class RawGlobalPointer(nn.Module):
|
26 |
+
def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
|
27 |
+
# encodr: RoBerta-Large as encoder
|
28 |
+
# inner_dim: 64
|
29 |
+
# ent_type_size: ent_cls_num
|
30 |
+
super().__init__()
|
31 |
+
self.encoder = encoder
|
32 |
+
self.ent_type_size = ent_type_size
|
33 |
+
self.inner_dim = inner_dim
|
34 |
+
self.hidden_size = encoder.config.hidden_size
|
35 |
+
self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
|
36 |
+
|
37 |
+
self.RoPE = RoPE
|
38 |
+
|
39 |
+
def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim):
|
40 |
+
position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1)
|
41 |
+
|
42 |
+
indices = torch.arange(0, output_dim // 2, dtype=torch.float)
|
43 |
+
indices = torch.pow(10000, -2 * indices / output_dim)
|
44 |
+
embeddings = position_ids * indices
|
45 |
+
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
|
46 |
+
embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape))))
|
47 |
+
embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim))
|
48 |
+
embeddings = embeddings.to(self.device)
|
49 |
+
return embeddings
|
50 |
+
|
51 |
+
def forward(self, input_ids, attention_mask, token_type_ids):
|
52 |
+
self.device = input_ids.device
|
53 |
+
|
54 |
+
context_outputs = self.encoder(input_ids, attention_mask, token_type_ids)
|
55 |
+
# last_hidden_state:(batch_size, seq_len, hidden_size)
|
56 |
+
last_hidden_state = context_outputs[0]
|
57 |
+
|
58 |
+
batch_size = last_hidden_state.size()[0]
|
59 |
+
seq_len = last_hidden_state.size()[1]
|
60 |
+
|
61 |
+
outputs = self.dense(last_hidden_state)
|
62 |
+
outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)
|
63 |
+
outputs = torch.stack(outputs, dim=-2)
|
64 |
+
qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:]
|
65 |
+
if self.RoPE:
|
66 |
+
# pos_emb:(batch_size, seq_len, inner_dim)
|
67 |
+
pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim)
|
68 |
+
cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1)
|
69 |
+
sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1)
|
70 |
+
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1)
|
71 |
+
qw2 = qw2.reshape(qw.shape)
|
72 |
+
qw = qw * cos_pos + qw2 * sin_pos
|
73 |
+
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1)
|
74 |
+
kw2 = kw2.reshape(kw.shape)
|
75 |
+
kw = kw * cos_pos + kw2 * sin_pos
|
76 |
+
# logits:(batch_size, ent_type_size, seq_len, seq_len)
|
77 |
+
logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw)
|
78 |
+
|
79 |
+
# padding mask
|
80 |
+
pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len)
|
81 |
+
logits = logits * pad_mask - (1 - pad_mask) * 1e12
|
82 |
+
|
83 |
+
# 排除下三角
|
84 |
+
mask = torch.tril(torch.ones_like(logits), -1)
|
85 |
+
logits = logits - mask * 1e12
|
86 |
+
|
87 |
+
return logits / self.inner_dim ** 0.5
|
88 |
+
|
89 |
+
|
90 |
+
class SinusoidalPositionEmbedding(nn.Module):
|
91 |
+
"""定义Sin-Cos位置Embedding
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(
|
95 |
+
self, output_dim, merge_mode="add", custom_position_ids=False):
|
96 |
+
super(SinusoidalPositionEmbedding, self).__init__()
|
97 |
+
self.output_dim = output_dim
|
98 |
+
self.merge_mode = merge_mode
|
99 |
+
self.custom_position_ids = custom_position_ids
|
100 |
+
|
101 |
+
def forward(self, inputs):
|
102 |
+
if self.custom_position_ids:
|
103 |
+
seq_len = inputs.shape[1]
|
104 |
+
inputs, position_ids = inputs
|
105 |
+
position_ids = position_ids.type(torch.float)
|
106 |
+
else:
|
107 |
+
input_shape = inputs.shape
|
108 |
+
batch_size, seq_len = input_shape[0], input_shape[1]
|
109 |
+
position_ids = torch.arange(seq_len).type(torch.float)[None]
|
110 |
+
indices = torch.arange(self.output_dim // 2).type(torch.float)
|
111 |
+
indices = torch.pow(10000.0, -2 * indices / self.output_dim)
|
112 |
+
embeddings = torch.einsum("bn,d->bnd", position_ids, indices)
|
113 |
+
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
|
114 |
+
embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim))
|
115 |
+
if self.merge_mode == "add":
|
116 |
+
return inputs + embeddings.to(inputs.device)
|
117 |
+
elif self.merge_mode == "mul":
|
118 |
+
return inputs * (embeddings + 1.0).to(inputs.device)
|
119 |
+
elif self.merge_mode == "zero":
|
120 |
+
return embeddings.to(inputs.device)
|
121 |
+
|
122 |
+
|
123 |
+
def multilabel_categorical_crossentropy(y_pred, y_true):
|
124 |
+
y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
|
125 |
+
y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes
|
126 |
+
y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes
|
127 |
+
zeros = torch.zeros_like(y_pred[..., :1])
|
128 |
+
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
|
129 |
+
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
|
130 |
+
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
|
131 |
+
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
|
132 |
+
# print(y_pred, y_true, pos_loss)
|
133 |
+
return (neg_loss + pos_loss).mean()
|
134 |
+
|
135 |
+
|
136 |
+
def multilabel_categorical_crossentropy2(y_pred, y_true):
|
137 |
+
y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes
|
138 |
+
y_pred_neg = y_pred.clone()
|
139 |
+
y_pred_pos = y_pred.clone()
|
140 |
+
y_pred_neg[y_true>0] -= float("inf")
|
141 |
+
y_pred_pos[y_true<1] -= float("inf")
|
142 |
+
# y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes
|
143 |
+
# y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes
|
144 |
+
zeros = torch.zeros_like(y_pred[..., :1])
|
145 |
+
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
|
146 |
+
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
|
147 |
+
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
|
148 |
+
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
|
149 |
+
# print(y_pred, y_true, pos_loss)
|
150 |
+
return (neg_loss + pos_loss).mean()
|
151 |
+
|
152 |
+
@dataclass
|
153 |
+
class GlobalPointerOutput(ModelOutput):
|
154 |
+
loss: Optional[torch.FloatTensor] = None
|
155 |
+
topk_probs: torch.FloatTensor = None
|
156 |
+
topk_indices: torch.IntTensor = None
|
157 |
+
last_hidden_state: torch.FloatTensor = None
|
158 |
+
|
159 |
+
|
160 |
+
@dataclass
|
161 |
+
class SpanProtoOutput(ModelOutput):
|
162 |
+
loss: Optional[torch.FloatTensor] = None
|
163 |
+
query_spans: list = None
|
164 |
+
proto_logits: list = None
|
165 |
+
topk_probs: torch.FloatTensor = None
|
166 |
+
topk_indices: torch.IntTensor = None
|
167 |
+
|
168 |
+
|
169 |
+
class SpanDetector(BertPreTrainedModel):
|
170 |
+
def __init__(self, config):
|
171 |
+
# encodr: RoBerta-Large as encoder
|
172 |
+
# inner_dim: 64
|
173 |
+
# ent_type_size: ent_cls_num
|
174 |
+
super().__init__(config)
|
175 |
+
self.bert = BertModel(config)
|
176 |
+
# self.ent_type_size = config.ent_type_size
|
177 |
+
self.ent_type_size = 1
|
178 |
+
self.inner_dim = 64
|
179 |
+
self.hidden_size = config.hidden_size
|
180 |
+
self.RoPE = True
|
181 |
+
|
182 |
+
self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2)
|
183 |
+
self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2)
|
184 |
+
|
185 |
+
|
186 |
+
def sequence_masking(self, x, mask, value="-inf", axis=None):
|
187 |
+
if mask is None:
|
188 |
+
return x
|
189 |
+
else:
|
190 |
+
if value == "-inf":
|
191 |
+
value = -1e12
|
192 |
+
elif value == "inf":
|
193 |
+
value = 1e12
|
194 |
+
assert axis > 0, "axis must be greater than 0"
|
195 |
+
for _ in range(axis - 1):
|
196 |
+
mask = torch.unsqueeze(mask, 1)
|
197 |
+
for _ in range(x.ndim - mask.ndim):
|
198 |
+
mask = torch.unsqueeze(mask, mask.ndim)
|
199 |
+
return x * mask + value * (1 - mask)
|
200 |
+
|
201 |
+
def add_mask_tril(self, logits, mask):
|
202 |
+
if mask.dtype != logits.dtype:
|
203 |
+
mask = mask.type(logits.dtype)
|
204 |
+
logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2)
|
205 |
+
logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1)
|
206 |
+
# 排除下三角
|
207 |
+
mask = torch.tril(torch.ones_like(logits), diagonal=-1)
|
208 |
+
logits = logits - mask * 1e12
|
209 |
+
return logits
|
210 |
+
|
211 |
+
def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None):
|
212 |
+
# with torch.no_grad():
|
213 |
+
context_outputs = self.bert(input_ids, attention_mask, token_type_ids)
|
214 |
+
last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim]
|
215 |
+
del context_outputs
|
216 |
+
outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim]
|
217 |
+
qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个维度,从0开始,取奇数位置所有向量汇总
|
218 |
+
batch_size = input_ids.shape[0]
|
219 |
+
if self.RoPE: # 是否使用RoPE旋转位置编码
|
220 |
+
pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs)
|
221 |
+
cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90]
|
222 |
+
sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1)
|
223 |
+
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3)
|
224 |
+
qw2 = torch.reshape(qw2, qw.shape)
|
225 |
+
qw = qw * cos_pos + qw2 * sin_pos
|
226 |
+
kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3)
|
227 |
+
kw2 = torch.reshape(kw2, kw.shape)
|
228 |
+
kw = kw * cos_pos + kw2 * sin_pos
|
229 |
+
logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5
|
230 |
+
bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2
|
231 |
+
logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度
|
232 |
+
# logit_mask = self.add_mask_tril(logits, mask=attention_mask)
|
233 |
+
loss = None
|
234 |
+
|
235 |
+
mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵
|
236 |
+
# mask = torch.where(mask > 0, 0.0, 1)
|
237 |
+
if labels is not None:
|
238 |
+
# y_pred = torch.zeros(input_ids.shape[0], self.ent_type_size, input_ids.shape[1], input_ids.shape[1], device=input_ids.device)
|
239 |
+
# for i in range(input_ids.shape[0]):
|
240 |
+
# for j in range(self.ent_type_size):
|
241 |
+
# y_pred[i, j, labels[i, j, 0], labels[i, j, 1]] = 1
|
242 |
+
# y_true = labels.reshape(input_ids.shape[0] * self.ent_type_size, -1)
|
243 |
+
# y_pred = logit_mask.reshape(input_ids.shape[0] * self.ent_type_size, -1)
|
244 |
+
# loss = multilabel_categorical_crossentropy(y_pred, y_true)
|
245 |
+
#
|
246 |
+
|
247 |
+
# weight = ((labels == 0).sum() / labels.sum())/5
|
248 |
+
# loss_fct = nn.BCEWithLogitsLoss(weight=weight)
|
249 |
+
# loss_fct = nn.BCEWithLogitsLoss(reduction="none")
|
250 |
+
# unmask_labels = labels.view(-1)[mask.view(-1) > 0]
|
251 |
+
# loss = loss_fct(logits.view(-1)[mask.view(-1) > 0], unmask_labels.float())
|
252 |
+
# if unmask_labels.sum() > 0:
|
253 |
+
# loss = (loss[unmask_labels > 0].mean()+loss[unmask_labels < 1].mean())/2
|
254 |
+
# else:
|
255 |
+
# loss = loss[unmask_labels < 1].mean()
|
256 |
+
# y_pred = logits.view(-1)[mask.view(-1) > 0]
|
257 |
+
# y_true = labels.view(-1)[mask.view(-1) > 0]
|
258 |
+
# loss = multilabel_categorical_crossentropy2(y_pred, y_true)
|
259 |
+
# y_pred = logits - torch.where(mask > 0, 0.0, float("inf")).unsqueeze(1)
|
260 |
+
y_pred = logits - (1-mask.unsqueeze(1))*1e12
|
261 |
+
y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1)
|
262 |
+
y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1)
|
263 |
+
loss = multilabel_categorical_crossentropy(y_pred, y_true)
|
264 |
+
|
265 |
+
with torch.no_grad():
|
266 |
+
prob = torch.sigmoid(logits) * mask.unsqueeze(1)
|
267 |
+
topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1)
|
268 |
+
|
269 |
+
|
270 |
+
return GlobalPointerOutput(
|
271 |
+
loss=loss,
|
272 |
+
topk_probs=topk.values,
|
273 |
+
topk_indices=topk.indices,
|
274 |
+
last_hidden_state=last_hidden_state
|
275 |
+
)
|
276 |
+
|
277 |
+
|
278 |
+
class SpanProto(nn.Module):
|
279 |
+
def __init__(self, config):
|
280 |
+
"""
|
281 |
+
word_encoder: Sentence encoder
|
282 |
+
|
283 |
+
You need to set self.cost as your own loss function.
|
284 |
+
"""
|
285 |
+
nn.Module.__init__(self)
|
286 |
+
self.config = config
|
287 |
+
self.output_dir = "./outputs"
|
288 |
+
# self.predict_dir = self.predict_result_path(self.output_dir)
|
289 |
+
self.drop = nn.Dropout()
|
290 |
+
self.global_span_detector = SpanDetector(config=self.config) # global span detector
|
291 |
+
self.projector = nn.Sequential( # projector
|
292 |
+
nn.Linear(self.config.hidden_size, self.config.hidden_size),
|
293 |
+
nn.Sigmoid(),
|
294 |
+
# nn.LayerNorm(2)
|
295 |
+
)
|
296 |
+
self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
|
297 |
+
# self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
|
298 |
+
self.max_length = 64
|
299 |
+
self.margin_distance = 6.0
|
300 |
+
self.global_step = 0
|
301 |
+
|
302 |
+
def predict_result_path(self, path=None):
|
303 |
+
if path is None:
|
304 |
+
predict_dir = os.path.join(
|
305 |
+
self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
|
306 |
+
)
|
307 |
+
else:
|
308 |
+
predict_dir = os.path.join(
|
309 |
+
path, "predict"
|
310 |
+
)
|
311 |
+
# if os.path.exists(predict_dir):
|
312 |
+
# os.rmdir(predict_dir) # 删除历史记录
|
313 |
+
if not os.path.exists(predict_dir): # 重新创建一个新的目录
|
314 |
+
os.makedirs(predict_dir)
|
315 |
+
return predict_dir
|
316 |
+
|
317 |
+
|
318 |
+
@classmethod
|
319 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
320 |
+
config = kwargs.pop("config", None)
|
321 |
+
model = SpanProto(config=config)
|
322 |
+
# 将bert部分参数加载进去
|
323 |
+
model.global_span_detector = SpanDetector.from_pretrained(
|
324 |
+
pretrained_model_name_or_path,
|
325 |
+
*model_args,
|
326 |
+
**kwargs
|
327 |
+
)
|
328 |
+
# 将剩余的参数加载进来
|
329 |
+
return model
|
330 |
+
|
331 |
+
# @classmethod
|
332 |
+
# def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
|
333 |
+
# self.global_span_detector.resize_token_embeddings(new_num_tokens)
|
334 |
+
|
335 |
+
def __dist__(self, x, y, dim, use_dot=False):
|
336 |
+
# x: [1, class_num, hidden_dim], y: [span_num, 1, hidden_dim]
|
337 |
+
# x - y: [span_num, class_num, hidden_dim]
|
338 |
+
# (x - y)^2.sum(2): [span_num, class_num]
|
339 |
+
if use_dot:
|
340 |
+
return (x * y).sum(dim)
|
341 |
+
else:
|
342 |
+
return -(torch.pow(x - y, 2)).sum(dim)
|
343 |
+
|
344 |
+
def __get_proto__(self, support_emb: torch, support_span: list, support_span_type: list, use_tag=False):
|
345 |
+
"""
|
346 |
+
support_emb: [n", seq_len, dim]
|
347 |
+
support_span: [n", m, 2] e.g. [[[3, 6], [12, 13]], [[1, 3]], ...]
|
348 |
+
support_span_type: [n", m] e.g. [[2, 1], [5], ...]
|
349 |
+
"""
|
350 |
+
prototype = list() # 每个类的proto type
|
351 |
+
all_span_embs = list() # 保存每个span的embedding
|
352 |
+
all_span_tags = list()
|
353 |
+
# 遍历每个类
|
354 |
+
for tag in range(self.num_class):
|
355 |
+
# tag_id = torch.Tensor([1 if tag == self.num_class else 0]).long().cuda()
|
356 |
+
# tag_embeddings = self.tag_embeddings(tag_id).view(-1)
|
357 |
+
tag_prototype = list() # [k, dim]
|
358 |
+
# 遍历当前episode内的每个句子
|
359 |
+
for emb, span, type in zip(support_emb, support_span, support_span_type):
|
360 |
+
# emb: [seq_len, dim], span: [m, 2], type: [m]
|
361 |
+
span = torch.Tensor(span).long().cuda() # e.g. [[3, 4], [9, 11]]
|
362 |
+
type = torch.Tensor(type).long().cuda() # e.g. [1, 4]
|
363 |
+
# 获取当前句子中属于tag类的span
|
364 |
+
try:
|
365 |
+
tag_span = span[type == tag] # e.g. span==[[3, 4]], tag==1
|
366 |
+
|
367 |
+
# 遍历每个检索到的span,获得其span embedding
|
368 |
+
for (s, e) in tag_span:
|
369 |
+
# tag_emb = torch.cat([emb[s], emb[e - 1]]) # [2*dim]
|
370 |
+
tag_emb = emb[s] + emb[e] # [dim]
|
371 |
+
# if use_tag: # 添加是否为unlabeled的标记,0对应embedding表示当前的span是labeled span,否则为unlabeled span
|
372 |
+
# tag_emb = tag_emb + tag_embeddings
|
373 |
+
tag_prototype.append(tag_emb)
|
374 |
+
all_span_embs.append(tag_emb)
|
375 |
+
all_span_tags.append(tag)
|
376 |
+
except:
|
377 |
+
# 说明当前类不存在对应的span,则随机
|
378 |
+
tag_prototype.append(torch.randn(support_emb.shape[-1]).cuda())
|
379 |
+
# assert 1 > 2
|
380 |
+
try:
|
381 |
+
prototype.append(torch.mean(torch.stack(tag_prototype), dim=0))
|
382 |
+
except:
|
383 |
+
# print("the class {} has no span".format(tag))
|
384 |
+
prototype.append(torch.randn(support_emb.shape[-1]).cuda())
|
385 |
+
# assert 1 > 2
|
386 |
+
all_span_embs = torch.stack(all_span_embs).detach().cpu().numpy().tolist()
|
387 |
+
|
388 |
+
return torch.stack(prototype), all_span_embs, all_span_tags # [num_class + 1, dim]
|
389 |
+
|
390 |
+
|
391 |
+
def __batch_dist__(self, prototype: torch, query_emb: torch, query_spans: list, query_span_type: Union[list, None]):
|
392 |
+
"""
|
393 |
+
该函数用于获得query到各个prototype的分类
|
394 |
+
"""
|
395 |
+
# 首先获得当前episode的每个句子的每个span的表征向量
|
396 |
+
# 遍历每个句子
|
397 |
+
all_logits = list() # 保存每个episode,每个句子所有span的预测概率
|
398 |
+
all_types = list()
|
399 |
+
visual_all_types, visual_all_embs = list(), list() # 用于展示可视化
|
400 |
+
# num = 0
|
401 |
+
for emb, span in zip(query_emb, query_spans): # 遍历每个句子
|
402 |
+
# assert len(span) == len(query_span_type[num]), "span={}\ntype{}".format(span, query_span_type[num])
|
403 |
+
# print("len(span)={}, len(type)= {}".format(len(span), len(query_span_type[num])))
|
404 |
+
span_emb = list() # 保存当前句子所有span的embedding [m", dim]
|
405 |
+
try:
|
406 |
+
for (s, e) in span: # 遍历每个span
|
407 |
+
tag_emb = emb[s] + emb[e] # [dim]
|
408 |
+
span_emb.append(tag_emb)
|
409 |
+
except:
|
410 |
+
span_emb = []
|
411 |
+
if len(span_emb) != 0:
|
412 |
+
span_emb = torch.stack(span_emb) # [span_num, dim]
|
413 |
+
# 每个span与prototype计算距离
|
414 |
+
logits = self.__dist__(prototype.unsqueeze(0), span_emb.unsqueeze(1), 2) # [span_num, num_class]
|
415 |
+
# pred_types = torch.argmax(logits, -1).detach().cpu().numpy().tolist()
|
416 |
+
with torch.no_grad():
|
417 |
+
pred_dist, pred_types = torch.max(logits, -1) # 获得每个query与所有prototype的距离的最近的类及其距离的平方
|
418 |
+
pred_dist = torch.pow(-1 * pred_dist, 0.5)
|
419 |
+
# print("pred_dist=", pred_dist)
|
420 |
+
# 如果最近的距离超过了margin distant,则该span视为unlabeled span,标注为特殊的类
|
421 |
+
pred_types[pred_dist > self.margin_distance] = self.num_class
|
422 |
+
pred_types = pred_types.detach().cpu().numpy().tolist()
|
423 |
+
# # 获得概率分布
|
424 |
+
# with torch.no_grad():
|
425 |
+
# prob = torch.softmax(logits, -1)
|
426 |
+
# pred_proba, pred_types = torch.max(logits, -1) # 获得每个span预测概率最大的类及其概率
|
427 |
+
# pred_types[pred_proba <= 0.6] = self.num_class # 如果当前预测的最大概率不满足,则说明其可能是一个其他实体
|
428 |
+
# pred_types = pred_types.detach().cpu().numpy().tolist()
|
429 |
+
|
430 |
+
all_logits.append(logits)
|
431 |
+
all_types.append(pred_types)
|
432 |
+
visual_all_types.extend(pred_types)
|
433 |
+
visual_all_embs.extend(span_emb.detach().cpu().numpy().tolist())
|
434 |
+
else:
|
435 |
+
all_logits.append([])
|
436 |
+
all_types.append([])
|
437 |
+
# num += 1
|
438 |
+
|
439 |
+
if query_span_type is not None:
|
440 |
+
# query_span_type: [n", m]
|
441 |
+
try:
|
442 |
+
all_type = torch.Tensor([type for types in query_span_type for type in types]).long().cuda() # [span_num]
|
443 |
+
loss = nn.CrossEntropyLoss()(torch.cat(all_logits, 0), all_type)
|
444 |
+
except:
|
445 |
+
all_logit, all_type = list(), list()
|
446 |
+
for logits, types in zip(all_logits, query_span_type):
|
447 |
+
if len(logits) != 0 and len(types) != 0 and len(logits) == len(types):
|
448 |
+
# print("len(logits)=", len(logits))
|
449 |
+
# print("len(types)=", len(types))
|
450 |
+
# print("logits=", logits)
|
451 |
+
all_logit.append(logits)
|
452 |
+
all_type.extend(types)
|
453 |
+
# print("all_logit=", all_logit)
|
454 |
+
if len(all_logit) != 0:
|
455 |
+
all_logit = torch.cat(all_logit, 0)
|
456 |
+
all_type = torch.Tensor(all_type).long().cuda()
|
457 |
+
# print("len(all_logits)=", len(all_logits))
|
458 |
+
# print("len(query_span_type)=", len(query_span_type))
|
459 |
+
|
460 |
+
# print("types.shape=", torch.Tensor(all_type).shape)
|
461 |
+
|
462 |
+
# min_len = min(len(all_type), len(all_type))
|
463 |
+
# all_logit, all_type = all_logit[: min_len], all_type[: min_len]
|
464 |
+
# print("logits.shape=", all_logit.shape)
|
465 |
+
# print("all_type=", all_type)
|
466 |
+
loss = nn.CrossEntropyLoss()(all_logit, all_type)
|
467 |
+
else:
|
468 |
+
loss = 0.
|
469 |
+
|
470 |
+
|
471 |
+
else:
|
472 |
+
loss = None
|
473 |
+
all_logits = [i.detach().cpu().numpy().tolist() for i in all_logits if len(i) != 0]
|
474 |
+
return loss, all_logits, all_types, visual_all_types, visual_all_embs
|
475 |
+
|
476 |
+
|
477 |
+
def __batch_margin__(self, prototype: torch, query_emb: torch, query_unlabeled_spans: list,
|
478 |
+
query_labeled_spans: list, query_span_type: list):
|
479 |
+
"""
|
480 |
+
该函数用于拉开unlabeled span与各个prototype的距离,拉近labeled span到对应类别的距离
|
481 |
+
"""
|
482 |
+
|
483 |
+
# prototype: [num_class, dim], negative: [span_num, dim]
|
484 |
+
# 获得每个unlabeled span与每个prototype的距离的平方,目标是对于每个距离平方都要设置大于margin阈值
|
485 |
+
def distance(input1, input2, p=2, eps=1e-6):
|
486 |
+
# Compute the distance (p-norm)
|
487 |
+
norm = torch.pow(torch.abs((input1 - input2 + eps)), p)
|
488 |
+
pnorm = torch.pow(torch.sum(norm, -1), 1.0 / p)
|
489 |
+
return pnorm
|
490 |
+
|
491 |
+
unlabeled_span_emb, labeled_span_emb, labeled_span_type = list(), list(), list()
|
492 |
+
for emb, span in zip(query_emb, query_unlabeled_spans): # 遍历每个句子
|
493 |
+
# 保存当前句子所有span的embedding [m", dim]
|
494 |
+
for (s, e) in span: # 遍历每个span
|
495 |
+
tag_emb = emb[s] + emb[e] # [dim]
|
496 |
+
unlabeled_span_emb.append(tag_emb)
|
497 |
+
|
498 |
+
# for emb, span, type in zip(query_emb, query_labeled_spans, query_span_type): # 遍历每个句子
|
499 |
+
# # 保存当前句子所有span的embedding [m", dim]
|
500 |
+
# for (s, e) in span: # 遍历每个span
|
501 |
+
# tag_emb = emb[s] + emb[e] # [dim]
|
502 |
+
# labeled_span_emb.append(tag_emb)
|
503 |
+
# labeled_span_type.extend(type)
|
504 |
+
|
505 |
+
try:
|
506 |
+
unlabeled_span_emb = torch.stack(unlabeled_span_emb) # [span_num, dim]
|
507 |
+
# labeled_span_emb = torch.stack(labeled_span_emb) # [span_num, dim]
|
508 |
+
# labeled_span_type = torch.stack(labeled_span_type) # [span_num]
|
509 |
+
except:
|
510 |
+
return 0.
|
511 |
+
|
512 |
+
unlabeled_dist = distance(prototype.unsqueeze(0), unlabeled_span_emb.unsqueeze(1)) # [span_num, num_class]
|
513 |
+
# labeled_dist = distance(prototype.unsqueeze(0), labeled_span_emb.unsqueeze(1)) # [span_num, num_class]
|
514 |
+
# 获得每个span对应ground truth类别距离prototype的距离
|
515 |
+
# labeled_type_dist = torch.gather(labeled_dist, -1, labeled_span_type.unsqueeze(1)) # [span_num, 1]
|
516 |
+
# print(dist)
|
517 |
+
unlabeled_output = torch.maximum(torch.zeros_like(unlabeled_dist), self.margin_distance - unlabeled_dist)
|
518 |
+
# labeled_output = torch.maximum(torch.zeros_like(labeled_type_dist), labeled_type_dist)
|
519 |
+
# return torch.mean(unlabeled_output) + torch.mean(labeled_output)
|
520 |
+
return torch.mean(unlabeled_output)
|
521 |
+
|
522 |
+
|
523 |
+
def forward(
|
524 |
+
self,
|
525 |
+
episode_ids,
|
526 |
+
support, query,
|
527 |
+
num_class,
|
528 |
+
num_example,
|
529 |
+
mode=None,
|
530 |
+
short_labels=None,
|
531 |
+
stage:str ="train",
|
532 |
+
path: str=None
|
533 |
+
):
|
534 |
+
"""
|
535 |
+
episode_ids: Input of the idx of each episode data. (only list)
|
536 |
+
support: Inputs of the support set.
|
537 |
+
query: Inputs of the query set.
|
538 |
+
num_class: Num of classes
|
539 |
+
K: Num of instances for each class in the support set
|
540 |
+
Q: Num of instances for each class in the query set
|
541 |
+
return: logits, pred
|
542 |
+
"""
|
543 |
+
if stage.startswith("train"):
|
544 |
+
self.global_step += 1
|
545 |
+
self.num_class = num_class # N-way K-shot里的N
|
546 |
+
self.num_example = num_example # N-way K-shot里的K
|
547 |
+
# print("num_class=", num_class)
|
548 |
+
self.mode = mode # FewNERD mode=inter/intra
|
549 |
+
self.max_length = support["input_ids"].shape[1]
|
550 |
+
support_inputs, support_attention_masks, support_type_ids = \
|
551 |
+
support["input_ids"], support["attention_mask"], support["token_type_ids"] # torch, [n, seq_len]
|
552 |
+
query_inputs, query_attention_masks, query_type_ids = \
|
553 |
+
query["input_ids"], query["attention_mask"], query["token_type_ids"] # torch, [n, seq_len]
|
554 |
+
support_labels = support["labels"] # torch,
|
555 |
+
query_labels = query["labels"] # torch,
|
556 |
+
# global span detector: obtain all mention span and loss
|
557 |
+
support_detector_outputs = self.global_span_detector(
|
558 |
+
support_inputs, support_attention_masks, support_type_ids, support_labels, short_labels=short_labels
|
559 |
+
)
|
560 |
+
query_detector_outputs = self.global_span_detector(
|
561 |
+
query_inputs, query_attention_masks, query_type_ids, query_labels, short_labels=short_labels
|
562 |
+
)
|
563 |
+
device_id = support_inputs.device.index
|
564 |
+
|
565 |
+
# if stage == "train_span":
|
566 |
+
if self.global_step <= 500 and stage == "train":
|
567 |
+
# only train span detector
|
568 |
+
return SpanProtoOutput(
|
569 |
+
loss=support_detector_outputs.loss,
|
570 |
+
topk_probs=query_detector_outputs.topk_probs,
|
571 |
+
topk_indices=query_detector_outputs.topk_indices,
|
572 |
+
)
|
573 |
+
# obtain labeled span from the support set
|
574 |
+
support_labeled_spans = support["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
|
575 |
+
support_labeled_types = support["labeled_types"] # all labeled ent type id, list, [n, m],
|
576 |
+
query_labeled_spans = query["labeled_spans"] # all labeled span, list, [n, m, 2], n sentence, m entity span, 2 (start / end)
|
577 |
+
query_labeled_types = query["labeled_types"] # all labeled ent type id, list, [n, m],
|
578 |
+
|
579 |
+
# for span, type in zip(query_labeled_spans, query_labeled_types): # 遍历每个句子
|
580 |
+
# assert len(span) == len(type), "span={}\ntype{}".format(span, type)
|
581 |
+
|
582 |
+
# obtain unlabeled span from the support set
|
583 |
+
# according to the detector, we can obtain multiple unlabeled span, which generated by the detector
|
584 |
+
# but not labeled in n-way k-shot episode
|
585 |
+
# support_predict_spans = self.get_topk_spans( #
|
586 |
+
# support_detector_outputs.topk_probs,
|
587 |
+
# support_detector_outputs.topk_indices,
|
588 |
+
# support["input_ids"]
|
589 |
+
# ) # [n, m, 2]
|
590 |
+
# print("predicted support span num={}".format([len(i) for i in support_predict_spans]))
|
591 |
+
# e.g. 打印一个所有句子,每个元素表示每个句子中的span个数,[5, 50, 4, 43, 5, 5, 1, 50, 2, 5, 6, 4, 50, 8, 12, 28, 17]
|
592 |
+
|
593 |
+
# we can also obtain all predicted span from the query set
|
594 |
+
query_predict_spans = self.get_topk_spans( #
|
595 |
+
query_detector_outputs.topk_probs,
|
596 |
+
query_detector_outputs.topk_indices,
|
597 |
+
query["input_ids"],
|
598 |
+
threshold=0.9 if stage.startswith("train") else 0.95,
|
599 |
+
is_query=True
|
600 |
+
) # [n, m, 2]
|
601 |
+
# print("predicted query span num={}".format([len(i) for i in query_predict_spans]))
|
602 |
+
|
603 |
+
|
604 |
+
# merge predicted span and labeled span, and generate other class for unlabeled span set
|
605 |
+
# support_all_spans, support_span_types = self.merge_span(
|
606 |
+
# labeled_spans=support_labeled_spans,
|
607 |
+
# labeled_types=support_labeled_types,
|
608 |
+
# predict_spans=support_predict_spans,
|
609 |
+
# stage=stage
|
610 |
+
# ) # [n, m, 2] n 个句子,每个句子有若干个span
|
611 |
+
# print("merged support span num={}".format([len(i) for i in support_all_spans]))
|
612 |
+
|
613 |
+
|
614 |
+
if stage.startswith("train"):
|
615 |
+
# 在训练阶段,需要知道detector识别的所有区间中,哪些是labeled,哪些是unlabeled,将unlabeled span全部分离出来
|
616 |
+
query_unlabeled_spans = self.split_span( # 拆分出unlabeled span,用于后面的margin loss
|
617 |
+
labeled_spans=query_labeled_spans,
|
618 |
+
labeled_types=query_labeled_types,
|
619 |
+
predict_spans=query_predict_spans,
|
620 |
+
stage=stage
|
621 |
+
) # [n, m, 2] n 个句子,每个句子有若干个span
|
622 |
+
# print("merged query span num={}".format([len(i) for i in query_all_spans]))
|
623 |
+
query_all_spans = query_labeled_spans
|
624 |
+
query_span_types = query_labeled_types
|
625 |
+
|
626 |
+
else:
|
627 |
+
# 在推理阶段,直接全部merge
|
628 |
+
query_unlabeled_spans = None
|
629 |
+
query_all_spans, _ = self.merge_span(
|
630 |
+
labeled_spans=query_labeled_spans,
|
631 |
+
labeled_types=query_labeled_types,
|
632 |
+
predict_spans=query_predict_spans,
|
633 |
+
stage=stage
|
634 |
+
) # [n, m, 2] n 个句子,每个句子有若干个span
|
635 |
+
# 在dev和test时,此时query部分的span完全靠detector识别
|
636 |
+
# query_all_spans = query_predict_spans
|
637 |
+
query_span_types = None
|
638 |
+
# 用于查看推理阶段dev或test的query上detector的预测结果
|
639 |
+
# for query_label, query_pred in zip(query_labeled_spans, query_predict_spans):
|
640 |
+
# print(" ==== ")
|
641 |
+
# print("query_labeled_spans=", query_label)
|
642 |
+
# print("query_predict_spans=", query_pred)
|
643 |
+
|
644 |
+
# obtain representations of each token
|
645 |
+
support_emb, query_emb = support_detector_outputs.last_hidden_state, \
|
646 |
+
query_detector_outputs.last_hidden_state # [n, seq_len, dim]
|
647 |
+
support_emb, query_emb = self.projector(support_emb), self.projector(query_emb) # [n, seq_len, dim]
|
648 |
+
|
649 |
+
# all_query_spans = list() # 保存每个episode的所有句子所有的预测span
|
650 |
+
# all_proto_logits = list() # 保存每个episode的所有句子每个预测span对应的entity type
|
651 |
+
batch_result = dict()
|
652 |
+
proto_losses = list() # 保存每个episode的loss
|
653 |
+
# batch_visual = list() # 保存每个episode所有span的表征向量,用于可视化
|
654 |
+
current_support_num = 0
|
655 |
+
current_query_num = 0
|
656 |
+
typing_loss = None
|
657 |
+
# 遍历每个episode
|
658 |
+
for i, sent_support_num in enumerate(support["sentence_num"]):
|
659 |
+
sent_query_num = query["sentence_num"][i]
|
660 |
+
id_ = episode_ids[i] # 当前episode的编号
|
661 |
+
|
662 |
+
# 对于support,只对labeled span获得prototype
|
663 |
+
# locate one episode and obtain the span prototype
|
664 |
+
# [n", seq_len, dim] n" sentence in one episode
|
665 |
+
# support_proto [num_class + 1, dim]
|
666 |
+
support_proto, all_span_embs, all_span_tags = self.__get_proto__(
|
667 |
+
support_emb[current_support_num: current_support_num + sent_support_num], # [n", seq_len, dim]
|
668 |
+
support_labeled_spans[current_support_num: current_support_num + sent_support_num], # [n", m]
|
669 |
+
support_labeled_types[current_support_num: current_support_num + sent_support_num], # [n", m]
|
670 |
+
)
|
671 |
+
|
672 |
+
|
673 |
+
# 对于query set每个labeled span,使用标准的prototype learning
|
674 |
+
# for each query, we first obtain corresponding span, and then calculate distance between it and each prototype
|
675 |
+
# # [n", seq_len, dim] n" sentence in one episode
|
676 |
+
proto_loss, proto_logits, all_types, visual_all_types, visual_all_embs = self.__batch_dist__(
|
677 |
+
support_proto,
|
678 |
+
query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim]
|
679 |
+
query_all_spans[current_query_num: current_query_num + sent_query_num], # [n", m]
|
680 |
+
query_span_types[current_query_num: current_query_num + sent_query_num] if query_span_types else None, # [n", m]
|
681 |
+
)
|
682 |
+
|
683 |
+
visual_data = {
|
684 |
+
"data": all_span_embs + visual_all_embs,
|
685 |
+
"target": all_span_tags + visual_all_types,
|
686 |
+
}
|
687 |
+
|
688 |
+
# 对于query unlabeled span,遍历每个span,拉开与所有prototype的距离,选择margin loss
|
689 |
+
if stage.startswith("train"):
|
690 |
+
|
691 |
+
margin_loss = self.__batch_margin__(
|
692 |
+
support_proto,
|
693 |
+
query_emb[current_query_num: current_query_num + sent_query_num], # [n", seq_len, dim]
|
694 |
+
query_unlabeled_spans[current_query_num: current_query_num + sent_query_num], # [n", span_num]
|
695 |
+
query_all_spans[current_query_num: current_query_num + sent_query_num],
|
696 |
+
query_span_types[current_query_num: current_query_num + sent_query_num],
|
697 |
+
)
|
698 |
+
|
699 |
+
proto_losses.append(proto_loss + margin_loss)
|
700 |
+
|
701 |
+
batch_result[id_] = {
|
702 |
+
"spans": query_all_spans[current_query_num: current_query_num + sent_query_num],
|
703 |
+
"types": all_types,
|
704 |
+
"visualization": visual_data
|
705 |
+
}
|
706 |
+
|
707 |
+
current_query_num += sent_query_num
|
708 |
+
current_support_num += sent_support_num
|
709 |
+
# proto_logits = torch.stack(proto_logits)
|
710 |
+
if stage.startswith("train"):
|
711 |
+
typing_loss = torch.mean(torch.stack(proto_losses), dim=-1)
|
712 |
+
|
713 |
+
|
714 |
+
if not stage.startswith("train"):
|
715 |
+
self.__save_evaluate_predicted_result__(batch_result, device_id=device_id, stage=stage, path=path)
|
716 |
+
|
717 |
+
# return SpanProtoOutput(
|
718 |
+
# loss=((support_detector_outputs.loss + query_detector_outputs.loss) / 2.0 + typing_loss)
|
719 |
+
# if stage.startswith("train") else (support_detector_outputs.loss + query_detector_outputs.loss),
|
720 |
+
# ) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
|
721 |
+
return SpanProtoOutput(
|
722 |
+
loss=(support_detector_outputs.loss + typing_loss)
|
723 |
+
if stage.startswith("train") else query_detector_outputs.loss,
|
724 |
+
) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
|
725 |
+
|
726 |
+
def __save_evaluate_predicted_result__(self, new_result: dict, device_id: int = 0, stage="dev", path=None):
|
727 |
+
"""
|
728 |
+
本函数用于在forward时保存每一个batch内的预测span以及span type
|
729 |
+
new_result / result: {
|
730 |
+
"(id)": { # id-th episode query
|
731 |
+
"spans": [[[1, 4], [6, 7], xxx], ... ] # [sent_num, span_num, 2]
|
732 |
+
"types": [[2, 0, xxx], ...] # [sent_num, span_num]
|
733 |
+
},
|
734 |
+
xxx
|
735 |
+
}
|
736 |
+
"""
|
737 |
+
# 拉取当前任务中已经预测的结果
|
738 |
+
self.predict_dir = self.predict_result_path(path)
|
739 |
+
npy_file_name = os.path.join(self.predict_dir, "{}_predictions_{}.npy".format(stage, device_id))
|
740 |
+
result = dict()
|
741 |
+
if os.path.exists(npy_file_name):
|
742 |
+
result = np.load(npy_file_name, allow_pickle=True)[()]
|
743 |
+
# 合并
|
744 |
+
for episode_id, query_res in new_result.items():
|
745 |
+
result[episode_id] = query_res
|
746 |
+
# 保存
|
747 |
+
np.save(npy_file_name, result, allow_pickle=True)
|
748 |
+
|
749 |
+
|
750 |
+
def get_topk_spans(self, probs, indices, input_ids, threshold=0.60, low_threshold=0.1, is_query=False):
|
751 |
+
"""
|
752 |
+
probs: [n, m]
|
753 |
+
indices: [n, m]
|
754 |
+
input_texts: [n, seq_len]
|
755 |
+
is_query: if true, each sentence must recall at least one span
|
756 |
+
"""
|
757 |
+
probs = probs.squeeze(1).detach().cpu() # topk结果的概率 [n, m] # 返回的已经是按照概率进行降序排列的结果
|
758 |
+
indices = indices.squeeze(1).detach().cpu() # topk结果的索引 [n, m] # 返回的已经是按照概率进行降序排列的结果
|
759 |
+
input_ids = input_ids.detach().cpu()
|
760 |
+
# print("probs=", probs) # [n, m]
|
761 |
+
# print("indices=", indices) # [n, m]
|
762 |
+
predict_span = list()
|
763 |
+
if is_query:
|
764 |
+
low_threshold = 0.0
|
765 |
+
for prob, index, text in zip(probs, indices, input_ids): # 遍历每个句子,其对应若干预测的span及其概率
|
766 |
+
threshold_ = threshold
|
767 |
+
index_ids = torch.Tensor([i for i in range(len(index))]).long()
|
768 |
+
span = set()
|
769 |
+
# TODO 1. 调节阈值 2. 处理输出实体重叠问题
|
770 |
+
entity_index = index[prob >= low_threshold]
|
771 |
+
index_ids = index_ids[prob >= low_threshold]
|
772 |
+
while threshold_ >= low_threshold: # 动态控制阈值,以确保可以召回出span数量是尽可能均匀的(如果所有句子使用同一个阈值,那么每个句子被召回的span数量参差不齐)
|
773 |
+
for ei, entity in enumerate(entity_index):
|
774 |
+
p = prob[index_ids[ei]]
|
775 |
+
if p < threshold_: # 如果此时候选的span得分已经低于阈值,由于获得的结果已经是降序排列的,则后续的结果一定都低于阈值,则直接结束
|
776 |
+
break
|
777 |
+
# 1D index转2D index
|
778 |
+
start_end = np.unravel_index(entity, (self.max_length, self.max_length))
|
779 |
+
# print("self.max_length=", self.max_length)
|
780 |
+
s, e = start_end[0], start_end[1]
|
781 |
+
ans = text[s: e]
|
782 |
+
# if ans not in answer:
|
783 |
+
# answer.append(ans)
|
784 |
+
# topk_answer_dict[ans] = {"prob": float(prob[index_ids[ei]]), "pos": [(s, e)]}
|
785 |
+
span.add((s, e))
|
786 |
+
# 满足下列几个条件的,动态调低阈值,并重新筛选
|
787 |
+
if len(span) <= 3:
|
788 |
+
threshold_ -= 0.05
|
789 |
+
else:
|
790 |
+
break
|
791 |
+
if len(span) == 0:
|
792 |
+
# 如果当前没有召回出任何span,则直接选择[cls]作为结果(相当于MRC的unanswerable)
|
793 |
+
span = [[0, 0]]
|
794 |
+
span = [list(i) for i in list(span)]
|
795 |
+
# print("prob=", prob) e.g. [0.96, 0.85, 0.04, 0.00, ...]
|
796 |
+
# print("span=", span) e.g. [[20, 23], [11, 14]]
|
797 |
+
predict_span.append(span)
|
798 |
+
return predict_span
|
799 |
+
|
800 |
+
|
801 |
+
def split_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):
|
802 |
+
"""
|
803 |
+
# 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span
|
804 |
+
|
805 |
+
"""
|
806 |
+
def check_similar_span(span1, span2):
|
807 |
+
"""
|
808 |
+
检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
|
809 |
+
"""
|
810 |
+
# 考虑一个特殊情况,例如 [12, 12], [13, 13]
|
811 |
+
if len(span1) == 0 or len(span2) == 0:
|
812 |
+
return False
|
813 |
+
if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
|
814 |
+
return False
|
815 |
+
if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
|
816 |
+
return True
|
817 |
+
return False
|
818 |
+
|
819 |
+
all_spans, span_types = list(), list() # [n, m]
|
820 |
+
num = 0
|
821 |
+
unlabeled_spans = list()
|
822 |
+
for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
|
823 |
+
# 对detector预测的所有span,划分出哪些是labeled span,哪些是unlabeled span
|
824 |
+
unlabeled_span = list()
|
825 |
+
# if len(all_span) != len(span_type):
|
826 |
+
# length = min(len(all_span), len(span_type))
|
827 |
+
# all_span, span_type = all_span[: length], span_type[: length]
|
828 |
+
for span in predict_span: # 遍历每个预测的span
|
829 |
+
if span not in labeled_span: # 如果span没有存在,则说明当前的span是unlabeled的
|
830 |
+
# 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
|
831 |
+
is_remove = False
|
832 |
+
for span_x in labeled_span: # 遍历所有已经被merge的span
|
833 |
+
is_remove = check_similar_span(span_x, span) # 如果已存在的span,和当前的span很接近,则排除当前的span
|
834 |
+
if is_remove is True:
|
835 |
+
break
|
836 |
+
if is_remove is True:
|
837 |
+
continue
|
838 |
+
unlabeled_span.append(span)
|
839 |
+
# if self.global_step % 1000 == 0:
|
840 |
+
# print(" === ")
|
841 |
+
# print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
|
842 |
+
# print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
|
843 |
+
# if len(unlabeled_span) == 0 and stage.startswith("train"):
|
844 |
+
# # 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
|
845 |
+
# # print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
|
846 |
+
# # all_span.append([0, 0])
|
847 |
+
# # span_type.append(self.num_class)
|
848 |
+
# while True:
|
849 |
+
# random_span = np.random.randint(0, 32, 2).tolist()
|
850 |
+
# if abs(random_span[0] - random_span[1]) > 10:
|
851 |
+
# continue
|
852 |
+
# random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
|
853 |
+
# if random_span in labeled_span or random_span in unlabeled_span:
|
854 |
+
# continue
|
855 |
+
# unlabeled_span.append(random_span)
|
856 |
+
# break
|
857 |
+
num += len(unlabeled_span)
|
858 |
+
unlabeled_spans.append(unlabeled_span)
|
859 |
+
# print("num=", num)
|
860 |
+
return unlabeled_spans
|
861 |
+
|
862 |
+
|
863 |
+
def merge_span(self, labeled_spans: list, labeled_types: list, predict_spans: list, stage: str = "train"):
|
864 |
+
|
865 |
+
def check_similar_span(span1, span2):
|
866 |
+
"""
|
867 |
+
检测两个span是否接近,例如[12, 16], [11, 16], [13, 15], [12, 17]是接近的
|
868 |
+
"""
|
869 |
+
# 考虑一个特殊情况,例如 [12, 12], [13, 13]
|
870 |
+
if len(span1) == 0 or len(span2) == 0:
|
871 |
+
return False
|
872 |
+
if span1[0] == span1[1] and span2[0] == span2[1] and abs(span1[0] - span2[0]) == 1:
|
873 |
+
return False
|
874 |
+
if abs(span1[0] - span2[0]) <= 1 and abs(span1[1] - span2[1]) <= 1: # 两个区间的起点和终点分别相差1以内
|
875 |
+
return True
|
876 |
+
return False
|
877 |
+
|
878 |
+
all_spans, span_types = list(), list() # [n, m]
|
879 |
+
for labeled_span, labeled_type, predict_span in zip(labeled_spans, labeled_types, predict_spans):
|
880 |
+
# 遍历每个句子,对它们的span进行合并
|
881 |
+
unlabeled_num = 0
|
882 |
+
all_span, span_type = labeled_span, labeled_type # 先加入所有labeled span
|
883 |
+
if len(all_span) != len(span_type):
|
884 |
+
length = min(len(all_span), len(span_type))
|
885 |
+
all_span, span_type = all_span[: length], span_type[: length]
|
886 |
+
for span in predict_span: # 遍历每个预测的span
|
887 |
+
if span not in all_span: # 如果span没有存在,则说明当前的span是unlabeled的
|
888 |
+
# 可能存在一些临界点非常接近的(global pointer预测的临界点有时候很模糊),对于临界点相近的予以排除
|
889 |
+
is_remove = False
|
890 |
+
for span_x in all_span: # 遍历所有已经被merge的span
|
891 |
+
is_remove = check_similar_span(span_x, span) # 如果已��在的span,和当前的span很接近,则排除当前的span
|
892 |
+
if is_remove is True:
|
893 |
+
break
|
894 |
+
if is_remove is True:
|
895 |
+
continue
|
896 |
+
all_span.append(span)
|
897 |
+
span_type.append(self.num_class) # e.g. 5-way问题,已有标签为0,1,2,3,4,因此新增一个标签为5
|
898 |
+
unlabeled_num += 1
|
899 |
+
# if self.global_step % 1000 == 0:
|
900 |
+
# print(" === ")
|
901 |
+
# print("labeled_span=", labeled_span) # [[1, 3], [12, 14], [25, 25], [7, 7]]
|
902 |
+
# print("predict_span=", predict_span) # [[25, 25], [1, 3], [12, 14], [7, 7]]
|
903 |
+
if unlabeled_num == 0 and stage.startswith("train"):
|
904 |
+
# 如果当前句子没有一个unlabeled span,则需要进行负采样,以确保unlabeled不为空
|
905 |
+
# print("unlabeled span is empty, so we randomly select one span as the unlabeled span")
|
906 |
+
# all_span.append([0, 0])
|
907 |
+
# span_type.append(self.num_class)
|
908 |
+
while True:
|
909 |
+
random_span = np.random.randint(0, 32, 2).tolist()
|
910 |
+
if abs(random_span[0] - random_span[1]) > 10:
|
911 |
+
continue
|
912 |
+
random_span = [random_span[1], random_span[0]] if random_span[0] > random_span[1] else random_span
|
913 |
+
if random_span in all_span:
|
914 |
+
continue
|
915 |
+
all_span.append(random_span)
|
916 |
+
span_type.append(self.num_class)
|
917 |
+
break
|
918 |
+
|
919 |
+
# if len(all_span) != len(span_type):
|
920 |
+
# all_span = [[0, 0]]
|
921 |
+
# span_type = [self.num_class]
|
922 |
+
|
923 |
+
all_spans.append(all_span)
|
924 |
+
span_types.append(span_type)
|
925 |
+
|
926 |
+
return all_spans, span_types
|
models/fewshot_learning/token_proto.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/4/21 5:30 下午
|
3 |
+
# @Author : JianingWang
|
4 |
+
# @File : span_proto.py
|
5 |
+
|
6 |
+
"""
|
7 |
+
This code is implemented for the paper ""SpanProto: A Two-stage Span-based Prototypical Network for Few-shot Named Entity Recognition""
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
from typing import Optional
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn as nn
|
15 |
+
from typing import Union
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from torch.nn import BCEWithLogitsLoss
|
18 |
+
from transformers import MegatronBertModel, MegatronBertPreTrainedModel
|
19 |
+
from transformers.file_utils import ModelOutput
|
20 |
+
from transformers.models.bert import BertPreTrainedModel, BertModel
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class TokenProtoOutput(ModelOutput):
|
25 |
+
loss: Optional[torch.FloatTensor] = None
|
26 |
+
logits: Optional[torch.FloatTensor] = None
|
27 |
+
|
28 |
+
|
29 |
+
class TokenProto(nn.Module):
|
30 |
+
def __init__(self, config):
|
31 |
+
"""
|
32 |
+
word_encoder: Sentence encoder
|
33 |
+
|
34 |
+
You need to set self.cost as your own loss function.
|
35 |
+
"""
|
36 |
+
nn.Module.__init__(self)
|
37 |
+
self.config = config
|
38 |
+
self.output_dir = "./outputs"
|
39 |
+
# self.predict_dir = self.predict_result_path(self.output_dir)
|
40 |
+
self.drop = nn.Dropout()
|
41 |
+
self.projector = nn.Sequential( # projector
|
42 |
+
nn.Linear(self.config.hidden_size, self.config.hidden_size),
|
43 |
+
nn.Sigmoid(),
|
44 |
+
# nn.LayerNorm(2)
|
45 |
+
)
|
46 |
+
self.tag_embeddings = nn.Embedding(2, self.config.hidden_size) # tag for labeled / unlabeled span set
|
47 |
+
# self.tag_mlp = nn.Linear(self.config.hidden_size, self.config.hidden_size)
|
48 |
+
self.max_length = 64
|
49 |
+
self.margin_distance = 6.0
|
50 |
+
self.global_step = 0
|
51 |
+
|
52 |
+
def predict_result_path(self, path=None):
|
53 |
+
if path is None:
|
54 |
+
predict_dir = os.path.join(
|
55 |
+
self.output_dir, "{}-{}-{}".format(self.mode, self.num_class, self.num_example), "predict"
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
predict_dir = os.path.join(
|
59 |
+
path, "predict"
|
60 |
+
)
|
61 |
+
# if os.path.exists(predict_dir):
|
62 |
+
# os.rmdir(predict_dir) # 删除历史记录
|
63 |
+
if not os.path.exists(predict_dir): # 重新创建一个新的目录
|
64 |
+
os.makedirs(predict_dir)
|
65 |
+
return predict_dir
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
69 |
+
config = kwargs.pop("config", None)
|
70 |
+
model = TokenProto(config=config)
|
71 |
+
return model
|
72 |
+
|
73 |
+
def __dist__(self, x, y, dim):
|
74 |
+
if self.dot:
|
75 |
+
return (x * y).sum(dim)
|
76 |
+
else:
|
77 |
+
return -(torch.pow(x - y, 2)).sum(dim)
|
78 |
+
|
79 |
+
def __batch_dist__(self, S, Q, q_mask):
|
80 |
+
# S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
|
81 |
+
assert Q.size()[:2] == q_mask.size()
|
82 |
+
Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
|
83 |
+
return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2)
|
84 |
+
|
85 |
+
def __get_proto__(self, embedding, tag, mask):
|
86 |
+
proto = []
|
87 |
+
embedding = embedding[mask==1].view(-1, embedding.size(-1))
|
88 |
+
tag = torch.cat(tag, 0)
|
89 |
+
assert tag.size(0) == embedding.size(0)
|
90 |
+
for label in range(torch.max(tag)+1):
|
91 |
+
proto.append(torch.mean(embedding[tag==label], 0))
|
92 |
+
proto = torch.stack(proto)
|
93 |
+
return proto, embedding
|
94 |
+
|
95 |
+
def forward(self, support, query):
|
96 |
+
"""
|
97 |
+
support: Inputs of the support set.
|
98 |
+
query: Inputs of the query set.
|
99 |
+
N: Num of classes
|
100 |
+
K: Num of instances for each class in the support set
|
101 |
+
Q: Num of instances in the query set
|
102 |
+
|
103 |
+
support/query = {"index": [], "word": [], "mask": [], "label": [], "sentence_num": [], "text_mask": []}
|
104 |
+
|
105 |
+
"""
|
106 |
+
# support set和query set分别喂入BERT中获得各个样本的表示
|
107 |
+
support_emb = self.word_encoder(support["word"], support["mask"]) # [num_sent, number_of_tokens, 768]
|
108 |
+
query_emb = self.word_encoder(query["word"], query["mask"]) # [num_sent, number_of_tokens, 768]
|
109 |
+
support_emb = self.drop(support_emb)
|
110 |
+
query_emb = self.drop(query_emb)
|
111 |
+
|
112 |
+
# Prototypical Networks
|
113 |
+
logits = []
|
114 |
+
current_support_num = 0
|
115 |
+
current_query_num = 0
|
116 |
+
assert support_emb.size()[:2] == support["mask"].size()
|
117 |
+
assert query_emb.size()[:2] == query["mask"].size()
|
118 |
+
|
119 |
+
for i, sent_support_num in enumerate(support["sentence_num"]): # 遍历每个采样得到的N-way K-shot任务数据
|
120 |
+
sent_query_num = query["sentence_num"][i]
|
121 |
+
# Calculate prototype for each class
|
122 |
+
# 因为一个batch里对应多个episode,因此 current_support_num:current_support_num+sent_support_num
|
123 |
+
# 用来表示当前输入的张量中,哪个范围内的句子属于当前N-way K-shot采样数据
|
124 |
+
support_proto, embedding = self.__get_proto__(
|
125 |
+
support_emb[current_support_num:current_support_num+sent_support_num],
|
126 |
+
support["label"][current_support_num:current_support_num+sent_support_num],
|
127 |
+
support["text_mask"][current_support_num: current_support_num+sent_support_num])
|
128 |
+
# calculate distance to each prototype
|
129 |
+
logits.append(self.__batch_dist__(
|
130 |
+
support_proto,
|
131 |
+
query_emb[current_query_num:current_query_num+sent_query_num],
|
132 |
+
query["text_mask"][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num]
|
133 |
+
current_query_num += sent_query_num
|
134 |
+
current_support_num += sent_support_num
|
135 |
+
logits = torch.cat(logits, 0) # 每个query的从属于support set对应各个类的概率
|
136 |
+
_, pred = torch.max(logits, 1) # 挑选最大概率对应的proto类作为预测结果
|
137 |
+
|
138 |
+
|
139 |
+
# return logits, pred, embedding
|
140 |
+
|
141 |
+
return TokenProtoOutput(
|
142 |
+
logits=logits
|
143 |
+
) # 返回部分的所有logits不论最外层是list还是tuple,最里层一定要包含一个张量,否则huggingface里的nested_detach函数会报错
|