yangwang825 commited on
Commit
8b2a0e9
·
verified ·
1 Parent(s): a5babd8

Upload PureBertForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +3 -2
  2. model.safetensors +3 -0
  3. modeling_pure_bert.py +859 -0
config.json CHANGED
@@ -2,11 +2,12 @@
2
  "_name_or_path": "avsolatorio/GIST-large-Embedding-v0",
3
  "alpha": 1,
4
  "architectures": [
5
- "BertModel"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
- "AutoConfig": "configuration_pure_bert.PureBertConfig"
 
10
  },
11
  "center": false,
12
  "classifier_dropout": null,
 
2
  "_name_or_path": "avsolatorio/GIST-large-Embedding-v0",
3
  "alpha": 1,
4
  "architectures": [
5
+ "PureBertForSequenceClassification"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_bert.PureBertConfig",
10
+ "AutoModelForSequenceClassification": "modeling_pure_bert.PureBertForSequenceClassification"
11
  },
12
  "center": false,
13
  "classifier_dropout": null,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc15a7070af74fa98568b09e9b0fd8d4c5254cf7d2344545d197e123635d0002
3
+ size 1336420068
modeling_pure_bert.py ADDED
@@ -0,0 +1,859 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.autograd import Function
5
+ from transformers import (
6
+ BertModel,
7
+ PreTrainedModel,
8
+ )
9
+ from typing import Union, Tuple, Optional
10
+ from transformers.modeling_outputs import (
11
+ SequenceClassifierOutput,
12
+ MultipleChoiceModelOutput,
13
+ QuestionAnsweringModelOutput
14
+ )
15
+ from transformers.utils import ModelOutput
16
+
17
+ from .configuration_pure_bert import PureBertConfig
18
+
19
+ PureBertModel = BertModel
20
+
21
+
22
+ class CovarianceFunction(Function):
23
+
24
+ @staticmethod
25
+ def forward(ctx, inputs):
26
+ x = inputs
27
+ b, c, h, w = x.data.shape
28
+ m = h * w
29
+ x = x.view(b, c, m)
30
+ I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + (
31
+ 1.0 / m
32
+ ) * torch.eye(m, m, device=x.device)
33
+ I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype)
34
+ y = x @ I_hat @ x.transpose(-1, -2)
35
+ ctx.save_for_backward(inputs, I_hat)
36
+ return y
37
+
38
+ @staticmethod
39
+ def backward(ctx, grad_output):
40
+ inputs, I_hat = ctx.saved_tensors
41
+ x = inputs
42
+ b, c, h, w = x.data.shape
43
+ m = h * w
44
+ x = x.view(b, c, m)
45
+ grad_input = grad_output + grad_output.transpose(1, 2)
46
+ grad_input = grad_input @ x @ I_hat
47
+ grad_input = grad_input.reshape(b, c, h, w)
48
+ return grad_input
49
+
50
+
51
+ class Covariance(nn.Module):
52
+
53
+ def __init__(self):
54
+ super(Covariance, self).__init__()
55
+
56
+ def _covariance(self, x):
57
+ return CovarianceFunction.apply(x)
58
+
59
+ def forward(self, x):
60
+ # x should be [batch_size, seq_len, embed_dim]
61
+ if x.dim() == 2:
62
+ x = x.transpose(-1, -2)
63
+ C = self._covariance(x[None, :, :, None])
64
+ C = C.squeeze(dim=0)
65
+ return C
66
+
67
+
68
+ class PFSA(torch.nn.Module):
69
+ """
70
+ https://openreview.net/pdf?id=isodM5jTA7h
71
+ """
72
+ def __init__(self, input_dim, alpha=1):
73
+ super(PFSA, self).__init__()
74
+ self.input_dim = input_dim
75
+ self.alpha = alpha
76
+
77
+ def forward_one_sample(self, x):
78
+ x = x.transpose(1, 2)[..., None]
79
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
80
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
81
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
82
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
83
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
84
+ out = x * A
85
+ out = out.squeeze(dim=-1).transpose(1, 2)
86
+ return out
87
+
88
+ def forward(self, input_values, attention_mask=None):
89
+ """
90
+ x: [B, T, F]
91
+ """
92
+ out = []
93
+ b, t, f = input_values.shape
94
+ for x, mask in zip(input_values, attention_mask):
95
+ x = x.view(1, t, f)
96
+ # x_in = x[:, :sum(mask), :]
97
+ x_in = x[:, :int(mask.sum().item()), :]
98
+ x_out = self.forward_one_sample(x_in)
99
+ x_expanded = torch.zeros_like(x, device=x.device)
100
+ x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
101
+ out.append(x_expanded)
102
+ out = torch.vstack(out)
103
+ out = out.view(b, t, f)
104
+ return out
105
+
106
+
107
+ class PURE(torch.nn.Module):
108
+
109
+ def __init__(
110
+ self,
111
+ in_dim,
112
+ svd_rank=16,
113
+ num_pc_to_remove=1,
114
+ center=False,
115
+ num_iters=2,
116
+ alpha=1,
117
+ disable_pcr=False,
118
+ disable_pfsa=False,
119
+ disable_covariance=True,
120
+ *args, **kwargs
121
+ ):
122
+ super().__init__()
123
+ self.in_dim = in_dim
124
+ self.svd_rank = svd_rank
125
+ self.num_pc_to_remove = num_pc_to_remove
126
+ self.center = center
127
+ self.num_iters = num_iters
128
+ self.do_pcr = not disable_pcr
129
+ self.do_pfsa = not disable_pfsa
130
+ self.do_covariance = not disable_covariance
131
+ self.attention = PFSA(in_dim, alpha=alpha)
132
+
133
+ def _compute_pc(self, X, attention_mask):
134
+ """
135
+ x: (B, T, F)
136
+ """
137
+ pcs = []
138
+ bs, seqlen, dim = X.shape
139
+ for x, mask in zip(X, attention_mask):
140
+ rank = int(mask.sum().item())
141
+ x = x[:rank, :]
142
+ if self.do_covariance:
143
+ x = Covariance()(x)
144
+ q = self.svd_rank
145
+ else:
146
+ q = min(self.svd_rank, rank)
147
+ _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters)
148
+ # _, _, Vh = torch.linalg.svd(x_, full_matrices=False)
149
+ # V = Vh.mH
150
+ pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F]
151
+ pcs.append(pc)
152
+ # pcs = torch.vstack(pcs)
153
+ # pcs = pcs.view(bs, self.num_pc_to_remove, dim)
154
+ return pcs
155
+
156
+ def _remove_pc(self, X, pcs):
157
+ """
158
+ [B, T, F], [B, ..., F]
159
+ """
160
+ b, t, f = X.shape
161
+ out = []
162
+ for i, (x, pc) in enumerate(zip(X, pcs)):
163
+ # v = []
164
+ # for j, t in enumerate(x):
165
+ # t_ = t
166
+ # for c_ in c:
167
+ # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
168
+ # v.append(t_.transpose(-1, -2))
169
+ # v = torch.vstack(v)
170
+ v = x - x @ pc.transpose(0, 1) @ pc
171
+ out.append(v[None, ...])
172
+ out = torch.vstack(out)
173
+ return out
174
+
175
+ def forward(self, input_values, attention_mask=None, *args, **kwargs):
176
+ """
177
+ PCR -> Attention
178
+ x: (B, T, F)
179
+ """
180
+ x = input_values
181
+ if self.do_pcr:
182
+ pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
183
+ xx = self._remove_pc(x, pc)
184
+ # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
185
+ else:
186
+ xx = x
187
+ if self.do_pfsa:
188
+ xx = self.attention(xx, attention_mask)
189
+ return xx
190
+
191
+
192
+ class StatisticsPooling(torch.nn.Module):
193
+
194
+ def __init__(self, return_mean=True, return_std=True):
195
+ super().__init__()
196
+
197
+ # Small value for GaussNoise
198
+ self.eps = 1e-5
199
+ self.return_mean = return_mean
200
+ self.return_std = return_std
201
+ if not (self.return_mean or self.return_std):
202
+ raise ValueError(
203
+ "both of statistics are equal to False \n"
204
+ "consider enabling mean and/or std statistic pooling"
205
+ )
206
+
207
+ def forward(self, input_values, attention_mask=None):
208
+ """Calculates mean and std for a batch (input tensor).
209
+
210
+ Arguments
211
+ ---------
212
+ x : torch.Tensor
213
+ It represents a tensor for a mini-batch.
214
+ """
215
+ x = input_values
216
+ if attention_mask is None:
217
+ if self.return_mean:
218
+ mean = x.mean(dim=1)
219
+ if self.return_std:
220
+ std = x.std(dim=1)
221
+ else:
222
+ mean = []
223
+ std = []
224
+ for snt_id in range(x.shape[0]):
225
+ # Avoiding padded time steps
226
+ lengths = torch.sum(attention_mask, dim=1)
227
+ relative_lengths = lengths / torch.max(lengths)
228
+ actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
229
+ # actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
230
+
231
+ # computing statistics
232
+ if self.return_mean:
233
+ mean.append(
234
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
235
+ )
236
+ if self.return_std:
237
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
238
+ if self.return_mean:
239
+ mean = torch.stack(mean)
240
+ if self.return_std:
241
+ std = torch.stack(std)
242
+
243
+ if self.return_mean:
244
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
245
+ gnoise = gnoise
246
+ mean += gnoise
247
+ if self.return_std:
248
+ std = std + self.eps
249
+
250
+ # Append mean and std of the batch
251
+ if self.return_mean and self.return_std:
252
+ pooled_stats = torch.cat((mean, std), dim=1)
253
+ pooled_stats = pooled_stats.unsqueeze(1)
254
+ elif self.return_mean:
255
+ pooled_stats = mean.unsqueeze(1)
256
+ elif self.return_std:
257
+ pooled_stats = std.unsqueeze(1)
258
+
259
+ return pooled_stats
260
+
261
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
262
+ """Returns a tensor of epsilon Gaussian noise.
263
+
264
+ Arguments
265
+ ---------
266
+ shape_of_tensor : tensor
267
+ It represents the size of tensor for generating Gaussian noise.
268
+ """
269
+ gnoise = torch.randn(shape_of_tensor, device=device)
270
+ gnoise -= torch.min(gnoise)
271
+ gnoise /= torch.max(gnoise)
272
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
273
+
274
+ return gnoise
275
+
276
+
277
+ class PureBertPreTrainedModel(PreTrainedModel):
278
+ """
279
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
280
+ models.
281
+ """
282
+
283
+ config_class = PureBertConfig
284
+ base_model_prefix = "bert"
285
+ supports_gradient_checkpointing = True
286
+ _supports_sdpa = True
287
+
288
+ def _init_weights(self, module):
289
+ """Initialize the weights"""
290
+ if isinstance(module, nn.Linear):
291
+ # Slightly different from the TF version which uses truncated_normal for initialization
292
+ # cf https://github.com/pytorch/pytorch/pull/5617
293
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
294
+ if module.bias is not None:
295
+ module.bias.data.zero_()
296
+ elif isinstance(module, nn.Embedding):
297
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
298
+ if module.padding_idx is not None:
299
+ module.weight.data[module.padding_idx].zero_()
300
+ elif isinstance(module, nn.LayerNorm):
301
+ module.bias.data.zero_()
302
+ module.weight.data.fill_(1.0)
303
+
304
+
305
+ class BertClsForSequenceClassification(PureBertPreTrainedModel):
306
+
307
+ def __init__(self, config, add_pooling_layer=True):
308
+ super().__init__(config)
309
+ self.num_labels = config.num_labels
310
+ self.config = config
311
+
312
+ self.bert = PureBertModel(config, add_pooling_layer=add_pooling_layer)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ # Initialize weights and apply final processing
320
+ self.post_init()
321
+
322
+ def forward(
323
+ self,
324
+ input_ids: Optional[torch.Tensor] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ token_type_ids: Optional[torch.Tensor] = None,
327
+ position_ids: Optional[torch.Tensor] = None,
328
+ head_mask: Optional[torch.Tensor] = None,
329
+ inputs_embeds: Optional[torch.Tensor] = None,
330
+ labels: Optional[torch.Tensor] = None,
331
+ output_attentions: Optional[bool] = None,
332
+ output_hidden_states: Optional[bool] = None,
333
+ return_dict: Optional[bool] = None,
334
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
335
+ r"""
336
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
337
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
338
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
339
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
340
+ """
341
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
342
+
343
+ outputs = self.bert(
344
+ input_ids,
345
+ attention_mask=attention_mask,
346
+ token_type_ids=token_type_ids,
347
+ position_ids=position_ids,
348
+ head_mask=head_mask,
349
+ inputs_embeds=inputs_embeds,
350
+ output_attentions=output_attentions,
351
+ output_hidden_states=output_hidden_states,
352
+ return_dict=return_dict,
353
+ )
354
+
355
+ pooled_output = outputs.pooler_output
356
+ if pooled_output is None:
357
+ pooled_output = outputs.last_hidden_state[:, 0, :]
358
+
359
+ pooled_output = self.dropout(pooled_output)
360
+ logits = self.classifier(pooled_output)
361
+
362
+ loss = None
363
+ if labels is not None:
364
+ if self.config.problem_type is None:
365
+ if self.num_labels == 1:
366
+ self.config.problem_type = "regression"
367
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
368
+ self.config.problem_type = "single_label_classification"
369
+ else:
370
+ self.config.problem_type = "multi_label_classification"
371
+
372
+ if self.config.problem_type == "regression":
373
+ loss_fct = nn.MSELoss()
374
+ if self.num_labels == 1:
375
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
376
+ else:
377
+ loss = loss_fct(logits, labels)
378
+ elif self.config.problem_type == "single_label_classification":
379
+ loss_fct = nn.CrossEntropyLoss()
380
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
381
+ elif self.config.problem_type == "multi_label_classification":
382
+ loss_fct = nn.BCEWithLogitsLoss()
383
+ loss = loss_fct(logits, labels)
384
+ if not return_dict:
385
+ output = (logits,) + outputs[2:]
386
+ return ((loss,) + output) if loss is not None else output
387
+
388
+ return SequenceClassifierOutput(
389
+ loss=loss,
390
+ logits=logits,
391
+ hidden_states=outputs.hidden_states,
392
+ attentions=outputs.attentions,
393
+ )
394
+
395
+
396
+ class BertMixupForSequenceClassification(PureBertPreTrainedModel):
397
+
398
+ def __init__(self, config, alpha=1.0, label_smoothing=0.0):
399
+ super().__init__(config)
400
+ self.num_labels = config.num_labels
401
+ self.alpha = alpha
402
+ self.label_smoothing = label_smoothing
403
+ self.config = config
404
+
405
+ self.bert = PureBertModel(config)
406
+ classifier_dropout = (
407
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
408
+ )
409
+ self.dropout = nn.Dropout(classifier_dropout)
410
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
411
+
412
+ # Initialize weights and apply final processing
413
+ self.post_init()
414
+
415
+ def mixup_data(self, embeddings, labels, alpha=1.0):
416
+ """Compute the mixup data. Returns mixed inputs, pairs of targets, and lambda"""
417
+ if alpha > 0:
418
+ lam = np.random.beta(alpha, alpha)
419
+ else:
420
+ lam = 1
421
+
422
+ batch_size = embeddings.size()[0]
423
+ index = torch.randperm(batch_size).to(embeddings.device)
424
+
425
+ mixed_x = lam * embeddings + (1 - lam) * embeddings[index, :]
426
+ y_a, y_b = labels, labels[index]
427
+ return mixed_x, y_a, y_b, lam
428
+
429
+ def forward(
430
+ self,
431
+ input_ids: Optional[torch.Tensor] = None,
432
+ attention_mask: Optional[torch.Tensor] = None,
433
+ token_type_ids: Optional[torch.Tensor] = None,
434
+ position_ids: Optional[torch.Tensor] = None,
435
+ head_mask: Optional[torch.Tensor] = None,
436
+ inputs_embeds: Optional[torch.Tensor] = None,
437
+ labels: Optional[torch.Tensor] = None,
438
+ output_attentions: Optional[bool] = None,
439
+ output_hidden_states: Optional[bool] = None,
440
+ return_dict: Optional[bool] = None,
441
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
442
+ r"""
443
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
444
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
445
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
446
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
447
+ """
448
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
449
+
450
+ outputs = self.bert(
451
+ input_ids,
452
+ attention_mask=attention_mask,
453
+ token_type_ids=token_type_ids,
454
+ position_ids=position_ids,
455
+ head_mask=head_mask,
456
+ inputs_embeds=inputs_embeds,
457
+ output_attentions=output_attentions,
458
+ output_hidden_states=output_hidden_states,
459
+ return_dict=return_dict,
460
+ )
461
+
462
+ if self.training:
463
+ mixed_embeddings, targets_a, targets_b, lam = self.mixup_data(outputs.pooler_output, labels, self.alpha)
464
+ mixed_embeddings = self.dropout(mixed_embeddings)
465
+ logits = self.classifier(mixed_embeddings)
466
+ else:
467
+ pooler_output = self.dropout(outputs.pooler_output)
468
+ logits = self.classifier(pooler_output)
469
+
470
+ loss = None
471
+ if labels is not None:
472
+ if self.config.problem_type is None:
473
+ if self.num_labels == 1:
474
+ self.config.problem_type = "regression"
475
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
476
+ self.config.problem_type = "single_label_classification"
477
+ else:
478
+ self.config.problem_type = "multi_label_classification"
479
+
480
+ if self.config.problem_type == "regression":
481
+ loss_fct = nn.MSELoss()
482
+ if self.num_labels == 1:
483
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
484
+ else:
485
+ loss = loss_fct(logits, labels)
486
+ elif self.config.problem_type == "single_label_classification":
487
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
488
+ logits = logits.view(-1, self.num_labels)
489
+ if self.training:
490
+ targets_a = targets_a.view(-1)
491
+ targets_b = targets_b.view(-1)
492
+ loss = lam * loss_fct(logits, targets_a) + (1 - lam) * loss_fct(logits, targets_b)
493
+ else:
494
+ loss = loss_fct(logits, labels.view(-1))
495
+ elif self.config.problem_type == "multi_label_classification":
496
+ loss_fct = nn.BCEWithLogitsLoss()
497
+ loss = loss_fct(logits, labels)
498
+ if not return_dict:
499
+ output = (logits,) + outputs[2:]
500
+ return ((loss,) + output) if loss is not None else output
501
+
502
+ return SequenceClassifierOutput(
503
+ loss=loss,
504
+ logits=logits,
505
+ hidden_states=outputs.hidden_states,
506
+ attentions=outputs.attentions,
507
+ )
508
+
509
+
510
+ class PureBertForSequenceClassification(PureBertPreTrainedModel):
511
+
512
+ def __init__(
513
+ self,
514
+ config,
515
+ label_smoothing=0.0,
516
+ ):
517
+ super().__init__(config)
518
+ self.label_smoothing = label_smoothing
519
+ self.num_labels = config.num_labels
520
+ self.config = config
521
+
522
+ self.bert = PureBertModel(config, add_pooling_layer=False)
523
+ classifier_dropout = (
524
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
525
+ )
526
+ self.pure = PURE(
527
+ in_dim=config.hidden_size,
528
+ svd_rank=config.svd_rank,
529
+ num_pc_to_remove=config.num_pc_to_remove,
530
+ center=config.center,
531
+ num_iters=config.num_iters,
532
+ alpha=config.alpha,
533
+ disable_pcr=config.disable_pcr,
534
+ disable_pfsa=config.disable_pfsa,
535
+ disable_covariance=config.disable_covariance
536
+ )
537
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
538
+ self.dropout = nn.Dropout(classifier_dropout)
539
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
540
+
541
+ # Initialize weights and apply final processing
542
+ self.post_init()
543
+
544
+ def forward_pure_embeddings(
545
+ self,
546
+ input_ids: Optional[torch.Tensor] = None,
547
+ attention_mask: Optional[torch.Tensor] = None,
548
+ token_type_ids: Optional[torch.Tensor] = None,
549
+ position_ids: Optional[torch.Tensor] = None,
550
+ head_mask: Optional[torch.Tensor] = None,
551
+ inputs_embeds: Optional[torch.Tensor] = None,
552
+ labels: Optional[torch.Tensor] = None,
553
+ output_attentions: Optional[bool] = None,
554
+ output_hidden_states: Optional[bool] = None,
555
+ return_dict: Optional[bool] = None,
556
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
557
+ r"""
558
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
559
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
560
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
561
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
562
+ """
563
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
564
+
565
+ outputs = self.bert(
566
+ input_ids,
567
+ attention_mask=attention_mask,
568
+ token_type_ids=token_type_ids,
569
+ position_ids=position_ids,
570
+ head_mask=head_mask,
571
+ inputs_embeds=inputs_embeds,
572
+ output_attentions=output_attentions,
573
+ output_hidden_states=output_hidden_states,
574
+ return_dict=return_dict,
575
+ )
576
+
577
+ token_embeddings = outputs.last_hidden_state
578
+ token_embeddings = self.pure(token_embeddings, attention_mask)
579
+
580
+ return ModelOutput(
581
+ last_hidden_state=token_embeddings,
582
+ )
583
+
584
+ def forward(
585
+ self,
586
+ input_ids: Optional[torch.Tensor] = None,
587
+ attention_mask: Optional[torch.Tensor] = None,
588
+ token_type_ids: Optional[torch.Tensor] = None,
589
+ position_ids: Optional[torch.Tensor] = None,
590
+ head_mask: Optional[torch.Tensor] = None,
591
+ inputs_embeds: Optional[torch.Tensor] = None,
592
+ labels: Optional[torch.Tensor] = None,
593
+ output_attentions: Optional[bool] = None,
594
+ output_hidden_states: Optional[bool] = None,
595
+ return_dict: Optional[bool] = None,
596
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
597
+ r"""
598
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
599
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
600
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
601
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
602
+ """
603
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
604
+
605
+ outputs = self.bert(
606
+ input_ids,
607
+ attention_mask=attention_mask,
608
+ token_type_ids=token_type_ids,
609
+ position_ids=position_ids,
610
+ head_mask=head_mask,
611
+ inputs_embeds=inputs_embeds,
612
+ output_attentions=output_attentions,
613
+ output_hidden_states=output_hidden_states,
614
+ return_dict=return_dict,
615
+ )
616
+
617
+ token_embeddings = outputs.last_hidden_state
618
+ token_embeddings = self.pure(token_embeddings, attention_mask)
619
+ pooled_output = self.mean(token_embeddings).squeeze(1)
620
+ pooled_output = self.dropout(pooled_output)
621
+ logits = self.classifier(pooled_output)
622
+
623
+ loss = None
624
+ if labels is not None:
625
+ if self.config.problem_type is None:
626
+ if self.num_labels == 1:
627
+ self.config.problem_type = "regression"
628
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
629
+ self.config.problem_type = "single_label_classification"
630
+ else:
631
+ self.config.problem_type = "multi_label_classification"
632
+
633
+ if self.config.problem_type == "regression":
634
+ loss_fct = nn.MSELoss()
635
+ if self.num_labels == 1:
636
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
637
+ else:
638
+ loss = loss_fct(logits, labels)
639
+ elif self.config.problem_type == "single_label_classification":
640
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
641
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
642
+ elif self.config.problem_type == "multi_label_classification":
643
+ loss_fct = nn.BCEWithLogitsLoss()
644
+ loss = loss_fct(logits, labels)
645
+ if not return_dict:
646
+ output = (logits,) + outputs[2:]
647
+ return ((loss,) + output) if loss is not None else output
648
+
649
+ return SequenceClassifierOutput(
650
+ loss=loss,
651
+ logits=logits,
652
+ hidden_states=outputs.hidden_states,
653
+ attentions=outputs.attentions,
654
+ )
655
+
656
+
657
+ class PureBertForMultipleChoice(PureBertPreTrainedModel):
658
+
659
+ def __init__(
660
+ self,
661
+ config,
662
+ label_smoothing=0.0,
663
+ ):
664
+ super().__init__(config)
665
+ self.label_smoothing = label_smoothing
666
+
667
+ self.bert = PureBertModel(config)
668
+ self.pure = PURE(
669
+ in_dim=config.hidden_size,
670
+ svd_rank=config.svd_rank,
671
+ num_pc_to_remove=config.num_pc_to_remove,
672
+ center=config.center,
673
+ num_iters=config.num_iters,
674
+ alpha=config.alpha,
675
+ disable_pcr=config.disable_pcr,
676
+ disable_pfsa=config.disable_pfsa,
677
+ disable_covariance=config.disable_covariance
678
+ )
679
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
680
+ classifier_dropout = (
681
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
682
+ )
683
+ self.dropout = nn.Dropout(classifier_dropout)
684
+ self.classifier = nn.Linear(config.hidden_size, 1)
685
+
686
+ # Initialize weights and apply final processing
687
+ self.post_init()
688
+
689
+ def forward(
690
+ self,
691
+ input_ids: Optional[torch.Tensor] = None,
692
+ attention_mask: Optional[torch.Tensor] = None,
693
+ token_type_ids: Optional[torch.Tensor] = None,
694
+ position_ids: Optional[torch.Tensor] = None,
695
+ head_mask: Optional[torch.Tensor] = None,
696
+ inputs_embeds: Optional[torch.Tensor] = None,
697
+ labels: Optional[torch.Tensor] = None,
698
+ output_attentions: Optional[bool] = None,
699
+ output_hidden_states: Optional[bool] = None,
700
+ return_dict: Optional[bool] = None,
701
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
702
+ r"""
703
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
704
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
705
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
706
+ `input_ids` above)
707
+ """
708
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
709
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
710
+
711
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
712
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
713
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
714
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
715
+ inputs_embeds = (
716
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
717
+ if inputs_embeds is not None
718
+ else None
719
+ )
720
+
721
+ outputs = self.bert(
722
+ input_ids,
723
+ attention_mask=attention_mask,
724
+ token_type_ids=token_type_ids,
725
+ position_ids=position_ids,
726
+ head_mask=head_mask,
727
+ inputs_embeds=inputs_embeds,
728
+ output_attentions=output_attentions,
729
+ output_hidden_states=output_hidden_states,
730
+ return_dict=return_dict,
731
+ )
732
+
733
+ token_embeddings = outputs.last_hidden_state
734
+ token_embeddings = self.pure(token_embeddings, attention_mask)
735
+ pooled_output = self.mean(token_embeddings).squeeze(1)
736
+ pooled_output = self.dropout(pooled_output)
737
+
738
+ logits = self.classifier(pooled_output)
739
+ reshaped_logits = logits.view(-1, num_choices)
740
+
741
+ loss = None
742
+ if labels is not None:
743
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
744
+ loss = loss_fct(reshaped_logits, labels)
745
+
746
+ if not return_dict:
747
+ output = (reshaped_logits,) + outputs[2:]
748
+ return ((loss,) + output) if loss is not None else output
749
+
750
+ return MultipleChoiceModelOutput(
751
+ loss=loss,
752
+ logits=reshaped_logits,
753
+ hidden_states=outputs.hidden_states,
754
+ attentions=outputs.attentions,
755
+ )
756
+
757
+
758
+ class PureBertForQuestionAnswering(PureBertPreTrainedModel):
759
+
760
+ def __init__(
761
+ self,
762
+ config,
763
+ label_smoothing=0.0,
764
+ ):
765
+ super().__init__(config)
766
+ self.num_labels = config.num_labels
767
+ self.label_smoothing = label_smoothing
768
+
769
+ self.bert = PureBertModel(config, add_pooling_layer=False)
770
+ self.pure = PURE(
771
+ in_dim=config.hidden_size,
772
+ svd_rank=config.svd_rank,
773
+ num_pc_to_remove=config.num_pc_to_remove,
774
+ center=config.center,
775
+ num_iters=config.num_iters,
776
+ alpha=config.alpha,
777
+ disable_pcr=config.disable_pcr,
778
+ disable_pfsa=config.disable_pfsa,
779
+ disable_covariance=config.disable_covariance
780
+ )
781
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
782
+
783
+ # Initialize weights and apply final processing
784
+ self.post_init()
785
+
786
+ def forward(
787
+ self,
788
+ input_ids: Optional[torch.Tensor] = None,
789
+ attention_mask: Optional[torch.Tensor] = None,
790
+ token_type_ids: Optional[torch.Tensor] = None,
791
+ position_ids: Optional[torch.Tensor] = None,
792
+ head_mask: Optional[torch.Tensor] = None,
793
+ inputs_embeds: Optional[torch.Tensor] = None,
794
+ start_positions: Optional[torch.Tensor] = None,
795
+ end_positions: Optional[torch.Tensor] = None,
796
+ output_attentions: Optional[bool] = None,
797
+ output_hidden_states: Optional[bool] = None,
798
+ return_dict: Optional[bool] = None,
799
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
800
+ r"""
801
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
802
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
803
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
804
+ are not taken into account for computing the loss.
805
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
806
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
807
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
808
+ are not taken into account for computing the loss.
809
+ """
810
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
811
+
812
+ outputs = self.bert(
813
+ input_ids,
814
+ attention_mask=attention_mask,
815
+ token_type_ids=token_type_ids,
816
+ position_ids=position_ids,
817
+ head_mask=head_mask,
818
+ inputs_embeds=inputs_embeds,
819
+ output_attentions=output_attentions,
820
+ output_hidden_states=output_hidden_states,
821
+ return_dict=return_dict,
822
+ )
823
+
824
+ token_embeddings = outputs.last_hidden_state
825
+ sequence_output = self.pure(token_embeddings, attention_mask)
826
+
827
+ logits = self.qa_outputs(sequence_output)
828
+ start_logits, end_logits = logits.split(1, dim=-1)
829
+ start_logits = start_logits.squeeze(-1).contiguous()
830
+ end_logits = end_logits.squeeze(-1).contiguous()
831
+
832
+ total_loss = None
833
+ if start_positions is not None and end_positions is not None:
834
+ # If we are on multi-GPU, split add a dimension
835
+ if len(start_positions.size()) > 1:
836
+ start_positions = start_positions.squeeze(-1)
837
+ if len(end_positions.size()) > 1:
838
+ end_positions = end_positions.squeeze(-1)
839
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
840
+ ignored_index = start_logits.size(1)
841
+ start_positions = start_positions.clamp(0, ignored_index)
842
+ end_positions = end_positions.clamp(0, ignored_index)
843
+
844
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
845
+ start_loss = loss_fct(start_logits, start_positions)
846
+ end_loss = loss_fct(end_logits, end_positions)
847
+ total_loss = (start_loss + end_loss) / 2
848
+
849
+ if not return_dict:
850
+ output = (start_logits, end_logits) + outputs[2:]
851
+ return ((total_loss,) + output) if total_loss is not None else output
852
+
853
+ return QuestionAnsweringModelOutput(
854
+ loss=total_loss,
855
+ start_logits=start_logits,
856
+ end_logits=end_logits,
857
+ hidden_states=outputs.hidden_states,
858
+ attentions=outputs.attentions,
859
+ )