yangwang825 commited on
Commit
3550a02
1 Parent(s): 60b3791

Upload PureBertForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +4 -3
  2. model.safetensors +3 -0
  3. modeling_pure_bert.py +864 -0
config.json CHANGED
@@ -1,12 +1,13 @@
1
  {
2
- "_name_or_path": "/root/.cache/torch/sentence_transformers/BAAI_bge-large-en/",
3
  "alpha": 1,
4
  "architectures": [
5
- "BertModel"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
- "AutoConfig": "configuration_pure_bert.PureBertConfig"
 
10
  },
11
  "center": false,
12
  "classifier_dropout": null,
 
1
  {
2
+ "_name_or_path": "BAAI/bge-large-en-v1.5",
3
  "alpha": 1,
4
  "architectures": [
5
+ "PureBertForSequenceClassification"
6
  ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_bert.PureBertConfig",
10
+ "AutoModelForSequenceClassification": "modeling_pure_bert.PureBertForSequenceClassification"
11
  },
12
  "center": false,
13
  "classifier_dropout": null,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88d12b0f3b00bfde7a5ee7d3e054714e68ee6c358de80add011040ac7a009e26
3
+ size 1336420068
modeling_pure_bert.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from torch.autograd import Function
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.bert.modeling_bert import (
7
+ BertEmbeddings, BertEncoder, BertPooler
8
+ )
9
+ from typing import Union, Tuple, Optional, List
10
+ from transformers.modeling_outputs import (
11
+ SequenceClassifierOutput,
12
+ MultipleChoiceModelOutput,
13
+ QuestionAnsweringModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions
15
+ )
16
+ from transformers.modeling_attn_mask_utils import (
17
+ _prepare_4d_attention_mask_for_sdpa,
18
+ _prepare_4d_causal_attention_mask_for_sdpa,
19
+ )
20
+ from transformers.utils import ModelOutput
21
+
22
+ from .configuration_pure_bert import PureBertConfig
23
+
24
+
25
+ class CovarianceFunction(Function):
26
+
27
+ @staticmethod
28
+ def forward(ctx, inputs):
29
+ x = inputs
30
+ b, c, h, w = x.data.shape
31
+ m = h * w
32
+ x = x.view(b, c, m)
33
+ I_hat = (-1.0 / m / m) * torch.ones(m, m, device=x.device) + (
34
+ 1.0 / m
35
+ ) * torch.eye(m, m, device=x.device)
36
+ I_hat = I_hat.view(1, m, m).repeat(b, 1, 1).type(x.dtype)
37
+ y = x @ I_hat @ x.transpose(-1, -2)
38
+ ctx.save_for_backward(inputs, I_hat)
39
+ return y
40
+
41
+ @staticmethod
42
+ def backward(ctx, grad_output):
43
+ inputs, I_hat = ctx.saved_tensors
44
+ x = inputs
45
+ b, c, h, w = x.data.shape
46
+ m = h * w
47
+ x = x.view(b, c, m)
48
+ grad_input = grad_output + grad_output.transpose(1, 2)
49
+ grad_input = grad_input @ x @ I_hat
50
+ grad_input = grad_input.reshape(b, c, h, w)
51
+ return grad_input
52
+
53
+
54
+ class Covariance(nn.Module):
55
+
56
+ def __init__(self):
57
+ super(Covariance, self).__init__()
58
+
59
+ def _covariance(self, x):
60
+ return CovarianceFunction.apply(x)
61
+
62
+ def forward(self, x):
63
+ # x should be [batch_size, seq_len, embed_dim]
64
+ if x.dim() == 2:
65
+ x = x.transpose(-1, -2)
66
+ C = self._covariance(x[None, :, :, None])
67
+ C = C.squeeze(dim=0)
68
+ return C
69
+
70
+
71
+ class PFSA(torch.nn.Module):
72
+ """
73
+ https://openreview.net/pdf?id=isodM5jTA7h
74
+ """
75
+ def __init__(self, input_dim, alpha=1):
76
+ super(PFSA, self).__init__()
77
+ self.input_dim = input_dim
78
+ self.alpha = alpha
79
+
80
+ def forward_one_sample(self, x):
81
+ x = x.transpose(1, 2)[..., None]
82
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
83
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
84
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
85
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
86
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
87
+ out = x * A
88
+ out = out.squeeze(dim=-1).transpose(1, 2)
89
+ return out
90
+
91
+ def forward(self, input_values, attention_mask=None):
92
+ """
93
+ x: [B, T, F]
94
+ """
95
+ out = []
96
+ b, t, f = input_values.shape
97
+ for x, mask in zip(input_values, attention_mask):
98
+ x = x.view(1, t, f)
99
+ # x_in = x[:, :sum(mask), :]
100
+ x_in = x[:, :int(mask.sum().item()), :]
101
+ x_out = self.forward_one_sample(x_in)
102
+ x_expanded = torch.zeros_like(x, device=x.device)
103
+ x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
104
+ out.append(x_expanded)
105
+ out = torch.vstack(out)
106
+ out = out.view(b, t, f)
107
+ return out
108
+
109
+
110
+ class PURE(torch.nn.Module):
111
+
112
+ def __init__(
113
+ self,
114
+ in_dim,
115
+ svd_rank=16,
116
+ num_pc_to_remove=1,
117
+ center=False,
118
+ num_iters=2,
119
+ alpha=1,
120
+ disable_pcr=False,
121
+ disable_pfsa=False,
122
+ disable_covariance=True,
123
+ *args, **kwargs
124
+ ):
125
+ super().__init__()
126
+ self.in_dim = in_dim
127
+ self.svd_rank = svd_rank
128
+ self.num_pc_to_remove = num_pc_to_remove
129
+ self.center = center
130
+ self.num_iters = num_iters
131
+ self.do_pcr = not disable_pcr
132
+ self.do_pfsa = not disable_pfsa
133
+ self.do_covariance = not disable_covariance
134
+ self.attention = PFSA(in_dim, alpha=alpha)
135
+
136
+ def _compute_pc(self, X, attention_mask):
137
+ """
138
+ x: (B, T, F)
139
+ """
140
+ pcs = []
141
+ bs, seqlen, dim = X.shape
142
+ for x, mask in zip(X, attention_mask):
143
+ rank = int(mask.sum().item())
144
+ x = x[:rank, :]
145
+ if self.do_covariance:
146
+ x = Covariance()(x)
147
+ q = self.svd_rank
148
+ else:
149
+ q = min(self.svd_rank, rank)
150
+ _, _, V = torch.pca_lowrank(x, q=q, center=self.center, niter=self.num_iters)
151
+ # _, _, Vh = torch.linalg.svd(x_, full_matrices=False)
152
+ # V = Vh.mH
153
+ pc = V.transpose(0, 1)[:self.num_pc_to_remove, :] # pc: [K, F]
154
+ pcs.append(pc)
155
+ # pcs = torch.vstack(pcs)
156
+ # pcs = pcs.view(bs, self.num_pc_to_remove, dim)
157
+ return pcs
158
+
159
+ def _remove_pc(self, X, pcs):
160
+ """
161
+ [B, T, F], [B, ..., F]
162
+ """
163
+ b, t, f = X.shape
164
+ out = []
165
+ for i, (x, pc) in enumerate(zip(X, pcs)):
166
+ # v = []
167
+ # for j, t in enumerate(x):
168
+ # t_ = t
169
+ # for c_ in c:
170
+ # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
171
+ # v.append(t_.transpose(-1, -2))
172
+ # v = torch.vstack(v)
173
+ v = x - x @ pc.transpose(0, 1) @ pc
174
+ out.append(v[None, ...])
175
+ out = torch.vstack(out)
176
+ return out
177
+
178
+ def forward(self, input_values, attention_mask=None, *args, **kwargs):
179
+ """
180
+ PCR -> Attention
181
+ x: (B, T, F)
182
+ """
183
+ x = input_values
184
+ if self.do_pcr:
185
+ pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
186
+ xx = self._remove_pc(x, pc)
187
+ # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
188
+ else:
189
+ xx = x
190
+ if self.do_pfsa:
191
+ xx = self.attention(xx, attention_mask)
192
+ return xx
193
+
194
+
195
+ class StatisticsPooling(torch.nn.Module):
196
+
197
+ def __init__(self, return_mean=True, return_std=True):
198
+ super().__init__()
199
+
200
+ # Small value for GaussNoise
201
+ self.eps = 1e-5
202
+ self.return_mean = return_mean
203
+ self.return_std = return_std
204
+ if not (self.return_mean or self.return_std):
205
+ raise ValueError(
206
+ "both of statistics are equal to False \n"
207
+ "consider enabling mean and/or std statistic pooling"
208
+ )
209
+
210
+ def forward(self, input_values, attention_mask=None):
211
+ """Calculates mean and std for a batch (input tensor).
212
+
213
+ Arguments
214
+ ---------
215
+ x : torch.Tensor
216
+ It represents a tensor for a mini-batch.
217
+ """
218
+ x = input_values
219
+ if attention_mask is None:
220
+ if self.return_mean:
221
+ mean = x.mean(dim=1)
222
+ if self.return_std:
223
+ std = x.std(dim=1)
224
+ else:
225
+ mean = []
226
+ std = []
227
+ for snt_id in range(x.shape[0]):
228
+ # Avoiding padded time steps
229
+ lengths = torch.sum(attention_mask, dim=1)
230
+ relative_lengths = lengths / torch.max(lengths)
231
+ actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
232
+ # actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
233
+
234
+ # computing statistics
235
+ if self.return_mean:
236
+ mean.append(
237
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
238
+ )
239
+ if self.return_std:
240
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
241
+ if self.return_mean:
242
+ mean = torch.stack(mean)
243
+ if self.return_std:
244
+ std = torch.stack(std)
245
+
246
+ if self.return_mean:
247
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
248
+ gnoise = gnoise
249
+ mean += gnoise
250
+ if self.return_std:
251
+ std = std + self.eps
252
+
253
+ # Append mean and std of the batch
254
+ if self.return_mean and self.return_std:
255
+ pooled_stats = torch.cat((mean, std), dim=1)
256
+ pooled_stats = pooled_stats.unsqueeze(1)
257
+ elif self.return_mean:
258
+ pooled_stats = mean.unsqueeze(1)
259
+ elif self.return_std:
260
+ pooled_stats = std.unsqueeze(1)
261
+
262
+ return pooled_stats
263
+
264
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
265
+ """Returns a tensor of epsilon Gaussian noise.
266
+
267
+ Arguments
268
+ ---------
269
+ shape_of_tensor : tensor
270
+ It represents the size of tensor for generating Gaussian noise.
271
+ """
272
+ gnoise = torch.randn(shape_of_tensor, device=device)
273
+ gnoise -= torch.min(gnoise)
274
+ gnoise /= torch.max(gnoise)
275
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
276
+
277
+ return gnoise
278
+
279
+
280
+ class PureBertPreTrainedModel(PreTrainedModel):
281
+ """
282
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
283
+ models.
284
+ """
285
+
286
+ config_class = PureBertConfig
287
+ base_model_prefix = "bert"
288
+ supports_gradient_checkpointing = True
289
+ _supports_sdpa = True
290
+
291
+ def _init_weights(self, module):
292
+ """Initialize the weights"""
293
+ if isinstance(module, nn.Linear):
294
+ # Slightly different from the TF version which uses truncated_normal for initialization
295
+ # cf https://github.com/pytorch/pytorch/pull/5617
296
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
297
+ if module.bias is not None:
298
+ module.bias.data.zero_()
299
+ elif isinstance(module, nn.Embedding):
300
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
301
+ if module.padding_idx is not None:
302
+ module.weight.data[module.padding_idx].zero_()
303
+ elif isinstance(module, nn.LayerNorm):
304
+ module.bias.data.zero_()
305
+ module.weight.data.fill_(1.0)
306
+
307
+
308
+ class PureBertModel(PureBertPreTrainedModel):
309
+ """
310
+
311
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
312
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
313
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
314
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
315
+
316
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
317
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
318
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
319
+ """
320
+
321
+ _no_split_modules = ["BertEmbeddings", "BertLayer"]
322
+
323
+ def __init__(self, config, add_pooling_layer=True):
324
+ super().__init__(config)
325
+ self.config = config
326
+
327
+ self.embeddings = BertEmbeddings(config)
328
+ self.encoder = BertEncoder(config)
329
+
330
+ self.pooler = BertPooler(config) if add_pooling_layer else None
331
+
332
+ self.attn_implementation = config._attn_implementation
333
+ self.position_embedding_type = config.position_embedding_type
334
+
335
+ # Initialize weights and apply final processing
336
+ self.post_init()
337
+
338
+ def get_input_embeddings(self):
339
+ return self.embeddings.word_embeddings
340
+
341
+ def set_input_embeddings(self, value):
342
+ self.embeddings.word_embeddings = value
343
+
344
+ def _prune_heads(self, heads_to_prune):
345
+ """
346
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
347
+ class PreTrainedModel
348
+ """
349
+ for layer, heads in heads_to_prune.items():
350
+ self.encoder.layer[layer].attention.prune_heads(heads)
351
+
352
+ def forward(
353
+ self,
354
+ input_ids: Optional[torch.Tensor] = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ token_type_ids: Optional[torch.Tensor] = None,
357
+ position_ids: Optional[torch.Tensor] = None,
358
+ head_mask: Optional[torch.Tensor] = None,
359
+ inputs_embeds: Optional[torch.Tensor] = None,
360
+ encoder_hidden_states: Optional[torch.Tensor] = None,
361
+ encoder_attention_mask: Optional[torch.Tensor] = None,
362
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
363
+ use_cache: Optional[bool] = None,
364
+ output_attentions: Optional[bool] = None,
365
+ output_hidden_states: Optional[bool] = None,
366
+ return_dict: Optional[bool] = None,
367
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
368
+ r"""
369
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
370
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
371
+ the model is configured as a decoder.
372
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
373
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
374
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
375
+
376
+ - 1 for tokens that are **not masked**,
377
+ - 0 for tokens that are **masked**.
378
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
379
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
380
+
381
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
382
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
383
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
384
+ use_cache (`bool`, *optional*):
385
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
386
+ `past_key_values`).
387
+ """
388
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
389
+ output_hidden_states = (
390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
391
+ )
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+
394
+ if self.config.is_decoder:
395
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
396
+ else:
397
+ use_cache = False
398
+
399
+ if input_ids is not None and inputs_embeds is not None:
400
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
401
+ elif input_ids is not None:
402
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
403
+ input_shape = input_ids.size()
404
+ elif inputs_embeds is not None:
405
+ input_shape = inputs_embeds.size()[:-1]
406
+ else:
407
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
408
+
409
+ batch_size, seq_length = input_shape
410
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
411
+
412
+ # past_key_values_length
413
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
414
+
415
+ if token_type_ids is None:
416
+ if hasattr(self.embeddings, "token_type_ids"):
417
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
418
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
419
+ token_type_ids = buffered_token_type_ids_expanded
420
+ else:
421
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
422
+
423
+ embedding_output = self.embeddings(
424
+ input_ids=input_ids,
425
+ position_ids=position_ids,
426
+ token_type_ids=token_type_ids,
427
+ inputs_embeds=inputs_embeds,
428
+ past_key_values_length=past_key_values_length,
429
+ )
430
+
431
+ if attention_mask is None:
432
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
433
+
434
+ use_sdpa_attention_masks = (
435
+ self.attn_implementation == "sdpa"
436
+ and self.position_embedding_type == "absolute"
437
+ and head_mask is None
438
+ and not output_attentions
439
+ )
440
+
441
+ # Expand the attention mask
442
+ if use_sdpa_attention_masks and attention_mask.dim() == 2:
443
+ # Expand the attention mask for SDPA.
444
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
445
+ if self.config.is_decoder:
446
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
447
+ attention_mask,
448
+ input_shape,
449
+ embedding_output,
450
+ past_key_values_length,
451
+ )
452
+ else:
453
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
454
+ attention_mask, embedding_output.dtype, tgt_len=seq_length
455
+ )
456
+ else:
457
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
458
+ # ourselves in which case we just need to make it broadcastable to all heads.
459
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
460
+
461
+ # If a 2D or 3D attention mask is provided for the cross-attention
462
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
463
+ if self.config.is_decoder and encoder_hidden_states is not None:
464
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
465
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
466
+ if encoder_attention_mask is None:
467
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
468
+
469
+ if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
470
+ # Expand the attention mask for SDPA.
471
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
472
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
473
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
474
+ )
475
+ else:
476
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
477
+ else:
478
+ encoder_extended_attention_mask = None
479
+
480
+ # Prepare head mask if needed
481
+ # 1.0 in head_mask indicate we keep the head
482
+ # attention_probs has shape bsz x n_heads x N x N
483
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
484
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
485
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
486
+
487
+ encoder_outputs = self.encoder(
488
+ embedding_output,
489
+ attention_mask=extended_attention_mask,
490
+ head_mask=head_mask,
491
+ encoder_hidden_states=encoder_hidden_states,
492
+ encoder_attention_mask=encoder_extended_attention_mask,
493
+ past_key_values=past_key_values,
494
+ use_cache=use_cache,
495
+ output_attentions=output_attentions,
496
+ output_hidden_states=output_hidden_states,
497
+ return_dict=return_dict,
498
+ )
499
+ sequence_output = encoder_outputs[0]
500
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
501
+
502
+ if not return_dict:
503
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
504
+
505
+ return BaseModelOutputWithPoolingAndCrossAttentions(
506
+ last_hidden_state=sequence_output,
507
+ pooler_output=pooled_output,
508
+ past_key_values=encoder_outputs.past_key_values,
509
+ hidden_states=encoder_outputs.hidden_states,
510
+ attentions=encoder_outputs.attentions,
511
+ cross_attentions=encoder_outputs.cross_attentions,
512
+ )
513
+
514
+
515
+ class PureBertForSequenceClassification(PureBertPreTrainedModel):
516
+
517
+ def __init__(
518
+ self,
519
+ config,
520
+ label_smoothing=0.0,
521
+ ):
522
+ super().__init__(config)
523
+ self.label_smoothing = label_smoothing
524
+ self.num_labels = config.num_labels
525
+ self.config = config
526
+
527
+ self.bert = PureBertModel(config, add_pooling_layer=False)
528
+ classifier_dropout = (
529
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
530
+ )
531
+ self.pure = PURE(
532
+ in_dim=config.hidden_size,
533
+ svd_rank=config.svd_rank,
534
+ num_pc_to_remove=config.num_pc_to_remove,
535
+ center=config.center,
536
+ num_iters=config.num_iters,
537
+ alpha=config.alpha,
538
+ disable_pcr=config.disable_pcr,
539
+ disable_pfsa=config.disable_pfsa,
540
+ disable_covariance=config.disable_covariance
541
+ )
542
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
543
+ self.dropout = nn.Dropout(classifier_dropout)
544
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
545
+
546
+ # Initialize weights and apply final processing
547
+ self.post_init()
548
+
549
+ def forward_pure_embeddings(
550
+ self,
551
+ input_ids: Optional[torch.Tensor] = None,
552
+ attention_mask: Optional[torch.Tensor] = None,
553
+ token_type_ids: Optional[torch.Tensor] = None,
554
+ position_ids: Optional[torch.Tensor] = None,
555
+ head_mask: Optional[torch.Tensor] = None,
556
+ inputs_embeds: Optional[torch.Tensor] = None,
557
+ labels: Optional[torch.Tensor] = None,
558
+ output_attentions: Optional[bool] = None,
559
+ output_hidden_states: Optional[bool] = None,
560
+ return_dict: Optional[bool] = None,
561
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
562
+ r"""
563
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
564
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
565
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
566
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
567
+ """
568
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
569
+
570
+ outputs = self.bert(
571
+ input_ids,
572
+ attention_mask=attention_mask,
573
+ token_type_ids=token_type_ids,
574
+ position_ids=position_ids,
575
+ head_mask=head_mask,
576
+ inputs_embeds=inputs_embeds,
577
+ output_attentions=output_attentions,
578
+ output_hidden_states=output_hidden_states,
579
+ return_dict=return_dict,
580
+ )
581
+
582
+ token_embeddings = outputs.last_hidden_state
583
+ token_embeddings = self.pure(token_embeddings, attention_mask)
584
+
585
+ return ModelOutput(
586
+ last_hidden_state=token_embeddings,
587
+ )
588
+
589
+ def forward(
590
+ self,
591
+ input_ids: Optional[torch.Tensor] = None,
592
+ attention_mask: Optional[torch.Tensor] = None,
593
+ token_type_ids: Optional[torch.Tensor] = None,
594
+ position_ids: Optional[torch.Tensor] = None,
595
+ head_mask: Optional[torch.Tensor] = None,
596
+ inputs_embeds: Optional[torch.Tensor] = None,
597
+ labels: Optional[torch.Tensor] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
602
+ r"""
603
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
604
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
605
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
606
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
607
+ """
608
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
609
+
610
+ outputs = self.bert(
611
+ input_ids,
612
+ attention_mask=attention_mask,
613
+ token_type_ids=token_type_ids,
614
+ position_ids=position_ids,
615
+ head_mask=head_mask,
616
+ inputs_embeds=inputs_embeds,
617
+ output_attentions=output_attentions,
618
+ output_hidden_states=output_hidden_states,
619
+ return_dict=return_dict,
620
+ )
621
+
622
+ token_embeddings = outputs.last_hidden_state
623
+ token_embeddings = self.pure(token_embeddings, attention_mask)
624
+ pooled_output = self.mean(token_embeddings).squeeze(1)
625
+ pooled_output = self.dropout(pooled_output)
626
+ logits = self.classifier(pooled_output)
627
+
628
+ loss = None
629
+ if labels is not None:
630
+ if self.config.problem_type is None:
631
+ if self.num_labels == 1:
632
+ self.config.problem_type = "regression"
633
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
634
+ self.config.problem_type = "single_label_classification"
635
+ else:
636
+ self.config.problem_type = "multi_label_classification"
637
+
638
+ if self.config.problem_type == "regression":
639
+ loss_fct = nn.MSELoss()
640
+ if self.num_labels == 1:
641
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
642
+ else:
643
+ loss = loss_fct(logits, labels)
644
+ elif self.config.problem_type == "single_label_classification":
645
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
646
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
647
+ elif self.config.problem_type == "multi_label_classification":
648
+ loss_fct = nn.BCEWithLogitsLoss()
649
+ loss = loss_fct(logits, labels)
650
+ if not return_dict:
651
+ output = (logits,) + outputs[2:]
652
+ return ((loss,) + output) if loss is not None else output
653
+
654
+ return SequenceClassifierOutput(
655
+ loss=loss,
656
+ logits=logits,
657
+ hidden_states=outputs.hidden_states,
658
+ attentions=outputs.attentions,
659
+ )
660
+
661
+
662
+ class PureBertForMultipleChoice(PureBertPreTrainedModel):
663
+
664
+ def __init__(
665
+ self,
666
+ config,
667
+ label_smoothing=0.0,
668
+ ):
669
+ super().__init__(config)
670
+ self.label_smoothing = label_smoothing
671
+
672
+ self.bert = PureBertModel(config)
673
+ self.pure = PURE(
674
+ in_dim=config.hidden_size,
675
+ svd_rank=config.svd_rank,
676
+ num_pc_to_remove=config.num_pc_to_remove,
677
+ center=config.center,
678
+ num_iters=config.num_iters,
679
+ alpha=config.alpha,
680
+ disable_pcr=config.disable_pcr,
681
+ disable_pfsa=config.disable_pfsa,
682
+ disable_covariance=config.disable_covariance
683
+ )
684
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
685
+ classifier_dropout = (
686
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
687
+ )
688
+ self.dropout = nn.Dropout(classifier_dropout)
689
+ self.classifier = nn.Linear(config.hidden_size, 1)
690
+
691
+ # Initialize weights and apply final processing
692
+ self.post_init()
693
+
694
+ def forward(
695
+ self,
696
+ input_ids: Optional[torch.Tensor] = None,
697
+ attention_mask: Optional[torch.Tensor] = None,
698
+ token_type_ids: Optional[torch.Tensor] = None,
699
+ position_ids: Optional[torch.Tensor] = None,
700
+ head_mask: Optional[torch.Tensor] = None,
701
+ inputs_embeds: Optional[torch.Tensor] = None,
702
+ labels: Optional[torch.Tensor] = None,
703
+ output_attentions: Optional[bool] = None,
704
+ output_hidden_states: Optional[bool] = None,
705
+ return_dict: Optional[bool] = None,
706
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
707
+ r"""
708
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
709
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
710
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
711
+ `input_ids` above)
712
+ """
713
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
714
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
715
+
716
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
717
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
718
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
719
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
720
+ inputs_embeds = (
721
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
722
+ if inputs_embeds is not None
723
+ else None
724
+ )
725
+
726
+ outputs = self.bert(
727
+ input_ids,
728
+ attention_mask=attention_mask,
729
+ token_type_ids=token_type_ids,
730
+ position_ids=position_ids,
731
+ head_mask=head_mask,
732
+ inputs_embeds=inputs_embeds,
733
+ output_attentions=output_attentions,
734
+ output_hidden_states=output_hidden_states,
735
+ return_dict=return_dict,
736
+ )
737
+
738
+ token_embeddings = outputs.last_hidden_state
739
+ token_embeddings = self.pure(token_embeddings, attention_mask)
740
+ pooled_output = self.mean(token_embeddings).squeeze(1)
741
+ pooled_output = self.dropout(pooled_output)
742
+
743
+ logits = self.classifier(pooled_output)
744
+ reshaped_logits = logits.view(-1, num_choices)
745
+
746
+ loss = None
747
+ if labels is not None:
748
+ loss_fct = nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
749
+ loss = loss_fct(reshaped_logits, labels)
750
+
751
+ if not return_dict:
752
+ output = (reshaped_logits,) + outputs[2:]
753
+ return ((loss,) + output) if loss is not None else output
754
+
755
+ return MultipleChoiceModelOutput(
756
+ loss=loss,
757
+ logits=reshaped_logits,
758
+ hidden_states=outputs.hidden_states,
759
+ attentions=outputs.attentions,
760
+ )
761
+
762
+
763
+ class PureBertForQuestionAnswering(PureBertPreTrainedModel):
764
+
765
+ def __init__(
766
+ self,
767
+ config,
768
+ label_smoothing=0.0,
769
+ ):
770
+ super().__init__(config)
771
+ self.num_labels = config.num_labels
772
+ self.label_smoothing = label_smoothing
773
+
774
+ self.bert = PureBertModel(config, add_pooling_layer=False)
775
+ self.pure = PURE(
776
+ in_dim=config.hidden_size,
777
+ svd_rank=config.svd_rank,
778
+ num_pc_to_remove=config.num_pc_to_remove,
779
+ center=config.center,
780
+ num_iters=config.num_iters,
781
+ alpha=config.alpha,
782
+ disable_pcr=config.disable_pcr,
783
+ disable_pfsa=config.disable_pfsa,
784
+ disable_covariance=config.disable_covariance
785
+ )
786
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
787
+
788
+ # Initialize weights and apply final processing
789
+ self.post_init()
790
+
791
+ def forward(
792
+ self,
793
+ input_ids: Optional[torch.Tensor] = None,
794
+ attention_mask: Optional[torch.Tensor] = None,
795
+ token_type_ids: Optional[torch.Tensor] = None,
796
+ position_ids: Optional[torch.Tensor] = None,
797
+ head_mask: Optional[torch.Tensor] = None,
798
+ inputs_embeds: Optional[torch.Tensor] = None,
799
+ start_positions: Optional[torch.Tensor] = None,
800
+ end_positions: Optional[torch.Tensor] = None,
801
+ output_attentions: Optional[bool] = None,
802
+ output_hidden_states: Optional[bool] = None,
803
+ return_dict: Optional[bool] = None,
804
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
805
+ r"""
806
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
807
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
808
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
809
+ are not taken into account for computing the loss.
810
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
811
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
812
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
813
+ are not taken into account for computing the loss.
814
+ """
815
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
816
+
817
+ outputs = self.bert(
818
+ input_ids,
819
+ attention_mask=attention_mask,
820
+ token_type_ids=token_type_ids,
821
+ position_ids=position_ids,
822
+ head_mask=head_mask,
823
+ inputs_embeds=inputs_embeds,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ )
828
+
829
+ token_embeddings = outputs.last_hidden_state
830
+ sequence_output = self.pure(token_embeddings, attention_mask)
831
+
832
+ logits = self.qa_outputs(sequence_output)
833
+ start_logits, end_logits = logits.split(1, dim=-1)
834
+ start_logits = start_logits.squeeze(-1).contiguous()
835
+ end_logits = end_logits.squeeze(-1).contiguous()
836
+
837
+ total_loss = None
838
+ if start_positions is not None and end_positions is not None:
839
+ # If we are on multi-GPU, split add a dimension
840
+ if len(start_positions.size()) > 1:
841
+ start_positions = start_positions.squeeze(-1)
842
+ if len(end_positions.size()) > 1:
843
+ end_positions = end_positions.squeeze(-1)
844
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
845
+ ignored_index = start_logits.size(1)
846
+ start_positions = start_positions.clamp(0, ignored_index)
847
+ end_positions = end_positions.clamp(0, ignored_index)
848
+
849
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
850
+ start_loss = loss_fct(start_logits, start_positions)
851
+ end_loss = loss_fct(end_logits, end_positions)
852
+ total_loss = (start_loss + end_loss) / 2
853
+
854
+ if not return_dict:
855
+ output = (start_logits, end_logits) + outputs[2:]
856
+ return ((total_loss,) + output) if total_loss is not None else output
857
+
858
+ return QuestionAnsweringModelOutput(
859
+ loss=total_loss,
860
+ start_logits=start_logits,
861
+ end_logits=end_logits,
862
+ hidden_states=outputs.hidden_states,
863
+ attentions=outputs.attentions,
864
+ )