jaandoui commited on
Commit
212e84a
1 Parent(s): e1d8da4

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +912 -912
bert_layers.py CHANGED
@@ -1,912 +1,912 @@
1
- # Copyright 2022 MosaicML Examples authors
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
- # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
- # Copyright (c) 2022, Tri Dao.
7
-
8
- import copy
9
- import logging
10
- import math
11
- import warnings
12
- from typing import List, Optional, Tuple, Union
13
-
14
- import torch
15
- import torch.nn as nn
16
- from einops import rearrange
17
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
- from transformers.activations import ACT2FN
19
- from transformers.modeling_outputs import (MaskedLMOutput,
20
- SequenceClassifierOutput)
21
- from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
- from transformers.modeling_utils import PreTrainedModel
23
-
24
- from .bert_padding import (index_first_axis,
25
- index_put_first_axis, pad_input,
26
- unpad_input, unpad_input_only)
27
-
28
- try:
29
- from .flash_attn_triton import flash_attn_qkvpacked_func
30
- except ImportError as e:
31
- flash_attn_qkvpacked_func = None
32
-
33
- logger = logging.getLogger(__name__)
34
-
35
-
36
- class BertEmbeddings(nn.Module):
37
-
38
- def __init__(self, config):
39
- super().__init__()
40
- self.word_embeddings = nn.Embedding(config.vocab_size,
41
- config.hidden_size,
42
- padding_idx=config.pad_token_id)
43
- # ALiBi doesn't use position embeddings
44
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
45
- config.hidden_size)
46
-
47
- # self.LayerNorm is not snake-cased to stick with TensorFlow model
48
- # variable name and be able to load any TensorFlow checkpoint file
49
- self.LayerNorm = nn.LayerNorm(config.hidden_size,
50
- eps=config.layer_norm_eps)
51
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
- self.register_buffer('token_type_ids',
53
- torch.zeros(config.max_position_embeddings,
54
- dtype=torch.long),
55
- persistent=False)
56
-
57
- def forward(
58
- self,
59
- input_ids: Optional[torch.LongTensor] = None,
60
- token_type_ids: Optional[torch.LongTensor] = None,
61
- position_ids: Optional[torch.LongTensor] = None,
62
- inputs_embeds: Optional[torch.FloatTensor] = None,
63
- past_key_values_length: int = 0,
64
- ) -> torch.Tensor:
65
- if (input_ids is not None) == (inputs_embeds is not None):
66
- raise ValueError('Must specify either input_ids or input_embeds!')
67
- if input_ids is not None:
68
- input_shape = input_ids.size()
69
- else:
70
- assert inputs_embeds is not None # just for type checking
71
- input_shape = inputs_embeds.size()[:-1]
72
-
73
- seq_length = input_shape[1]
74
-
75
- if position_ids is None:
76
- # great! ALiBi
77
- pass
78
-
79
- # Setting the token_type_ids to the registered buffer in constructor
80
- # where it is all zeros, which usually occurs when it's auto-generated;
81
- # registered buffer helps users when tracing the model without passing
82
- # token_type_ids, solves issue #5664
83
- if token_type_ids is None:
84
- if hasattr(self, 'token_type_ids'):
85
- assert isinstance(self.token_type_ids, torch.LongTensor)
86
- buffered_token_type_ids = self.token_type_ids[:, :seq_length]
87
- buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
88
- input_shape[0], seq_length)
89
- token_type_ids = buffered_token_type_ids_expanded # type: ignore
90
- else:
91
- token_type_ids = torch.zeros(input_shape, # type: ignore
92
- dtype=torch.long,
93
- device=self.word_embeddings.device) # type: ignore # yapf: disable
94
-
95
- if inputs_embeds is None:
96
- inputs_embeds = self.word_embeddings(input_ids)
97
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
98
-
99
- embeddings = inputs_embeds + token_type_embeddings
100
- # no position embeddings! ALiBi
101
- embeddings = self.LayerNorm(embeddings)
102
- embeddings = self.dropout(embeddings)
103
- return embeddings
104
-
105
-
106
- class BertUnpadSelfAttention(nn.Module):
107
-
108
- def __init__(self, config):
109
- super().__init__()
110
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
111
- config, 'embedding_size'):
112
- raise ValueError(
113
- f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
114
- f'heads ({config.num_attention_heads})')
115
-
116
- self.num_attention_heads = config.num_attention_heads
117
- self.attention_head_size = int(config.hidden_size /
118
- config.num_attention_heads)
119
- self.all_head_size = self.num_attention_heads * self.attention_head_size
120
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121
- self.p_dropout = config.attention_probs_dropout_prob
122
- self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
123
-
124
- # Warn if defaulting to pytorch because of import issues
125
- if flash_attn_qkvpacked_func is None:
126
- warnings.warn(
127
- 'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
128
- )
129
-
130
- def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
131
- max_seqlen_in_batch: int, indices: torch.Tensor,
132
- attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
133
- """Perform self-attention.
134
-
135
- If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
136
- implementation of self-attention.
137
-
138
- The arguments are unpadded, and our implementations of attention require padded arguments,
139
- so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
140
- The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
141
- It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
142
-
143
- Args:
144
- hidden_states: (total_nnz, dim)
145
- cu_seqlens: (batch + 1,)
146
- max_seqlen_in_batch: int
147
- indices: (total_nnz,)
148
- attn_mask: (batch, max_seqlen_in_batch)
149
- bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
150
-
151
- Returns:
152
- attention: (total_nnz, dim)
153
- """
154
- qkv = self.Wqkv(hidden_states)
155
- qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
156
- max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
157
- qkv = rearrange(qkv,
158
- 'b s (t h d) -> b s t h d',
159
- t=3,
160
- h=self.num_attention_heads)
161
- if self.p_dropout or flash_attn_qkvpacked_func is None:
162
- # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
163
- q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
164
- k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
165
- v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
166
- attention_scores = torch.matmul(q, k) / math.sqrt(
167
- self.attention_head_size)
168
- attention_scores = attention_scores + bias
169
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
170
- attention_probs = self.dropout(attention_probs)
171
- attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
172
- 3) # b s h d
173
- else:
174
- # Triton implementation only supports 0 attention dropout
175
- convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
176
- if convert_dtype:
177
- # Triton implementation only supports fp16 and bf16
178
- orig_dtype = qkv.dtype
179
- qkv = qkv.to(torch.float16)
180
- bias_dtype = bias.dtype
181
- bias = bias.to(torch.float16)
182
- attention = flash_attn_qkvpacked_func(qkv, bias)
183
- attention = attention.to(orig_dtype)
184
- bias = bias.to(bias_dtype)
185
- else:
186
- attention = flash_attn_qkvpacked_func(qkv, bias)
187
-
188
- # attn_mask is 1 for attend and 0 for don't
189
- attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
190
- return rearrange(attention, 'nnz h d -> nnz (h d)')
191
-
192
-
193
- # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
194
- class BertSelfOutput(nn.Module):
195
-
196
- def __init__(self, config):
197
- super().__init__()
198
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
199
- self.LayerNorm = nn.LayerNorm(config.hidden_size,
200
- eps=config.layer_norm_eps)
201
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
202
-
203
- def forward(self, hidden_states: torch.Tensor,
204
- input_tensor: torch.Tensor) -> torch.Tensor:
205
- hidden_states = self.dense(hidden_states)
206
- hidden_states = self.dropout(hidden_states)
207
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
208
- return hidden_states
209
-
210
-
211
- class BertUnpadAttention(nn.Module):
212
- """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
213
-
214
- def __init__(self, config):
215
- super().__init__()
216
- self.self = BertUnpadSelfAttention(config)
217
- self.output = BertSelfOutput(config)
218
-
219
- def forward(
220
- self,
221
- input_tensor: torch.Tensor,
222
- cu_seqlens: torch.Tensor,
223
- max_s: int,
224
- subset_idx: Optional[torch.Tensor] = None,
225
- indices: Optional[torch.Tensor] = None,
226
- attn_mask: Optional[torch.Tensor] = None,
227
- bias: Optional[torch.Tensor] = None,
228
- ) -> torch.Tensor:
229
- """Forward pass for scaled self-attention without padding.
230
-
231
- Arguments:
232
- input_tensor: (total_nnz, dim)
233
- cu_seqlens: (batch + 1,)
234
- max_s: int
235
- subset_idx: () set of indices whose values we care about at the end of the layer
236
- (e.g., the masked tokens, if this is the final layer).
237
- indices: None or (total_nnz,)
238
- attn_mask: None or (batch, max_seqlen_in_batch)
239
- bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
240
- """
241
- self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
242
- attn_mask, bias)
243
- if subset_idx is not None:
244
- return self.output(index_first_axis(self_output, subset_idx),
245
- index_first_axis(input_tensor, subset_idx))
246
- else:
247
- return self.output(self_output, input_tensor)
248
-
249
-
250
- class BertGatedLinearUnitMLP(nn.Module):
251
- """Applies the FFN at the end of each Mosaic BERT layer.
252
-
253
- Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
254
- and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
255
- introduces Gated Linear Units.
256
-
257
- Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
258
- standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
259
- `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
260
- with the `config.intermediate_size=3072`.
261
- However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
262
- parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
263
- """
264
-
265
- def __init__(self, config):
266
- super().__init__()
267
- self.config = config
268
- self.gated_layers = nn.Linear(config.hidden_size,
269
- config.intermediate_size * 2,
270
- bias=False)
271
- self.act = nn.GELU(approximate='none')
272
- self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
273
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
274
- self.layernorm = nn.LayerNorm(config.hidden_size,
275
- eps=config.layer_norm_eps)
276
-
277
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
278
- """Compute new hidden states from current hidden states.
279
-
280
- Args:
281
- hidden_states (torch.Tensor): The (unpadded) hidden states from
282
- the attention layer [nnz, dim].
283
- """
284
- residual_connection = hidden_states
285
- # compute the activation
286
- hidden_states = self.gated_layers(hidden_states)
287
- gated = hidden_states[:, :self.config.intermediate_size]
288
- non_gated = hidden_states[:, self.config.intermediate_size:]
289
- hidden_states = self.act(gated) * non_gated
290
- hidden_states = self.dropout(hidden_states)
291
- # multiply by the second matrix
292
- hidden_states = self.wo(hidden_states)
293
- # add the residual connection and post-LN
294
- hidden_states = self.layernorm(hidden_states + residual_connection)
295
- return hidden_states
296
-
297
-
298
- class BertLayer(nn.Module):
299
- """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
300
-
301
- def __init__(self, config):
302
- super(BertLayer, self).__init__()
303
- self.attention = BertUnpadAttention(config)
304
- self.mlp = BertGatedLinearUnitMLP(config)
305
-
306
- def forward(
307
- self,
308
- hidden_states: torch.Tensor,
309
- cu_seqlens: torch.Tensor,
310
- seqlen: int,
311
- subset_idx: Optional[torch.Tensor] = None,
312
- indices: Optional[torch.Tensor] = None,
313
- attn_mask: Optional[torch.Tensor] = None,
314
- bias: Optional[torch.Tensor] = None,
315
- ) -> torch.Tensor:
316
- """Forward pass for a BERT layer, including both attention and MLP.
317
-
318
- Args:
319
- hidden_states: (total_nnz, dim)
320
- cu_seqlens: (batch + 1,)
321
- seqlen: int
322
- subset_idx: () set of indices whose values we care about at the end of the layer
323
- (e.g., the masked tokens, if this is the final layer).
324
- indices: None or (total_nnz,)
325
- attn_mask: None or (batch, max_seqlen_in_batch)
326
- bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
327
- """
328
- attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329
- subset_idx, indices, attn_mask, bias)
330
- layer_output = self.mlp(attention_output)
331
- return layer_output
332
-
333
-
334
- class BertEncoder(nn.Module):
335
- """A stack of BERT layers providing the backbone of Mosaic BERT.
336
-
337
- This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
338
- but with substantial modifications to implement unpadding and ALiBi.
339
-
340
- Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
341
- at padded tokens, and pre-computes attention biases to implement ALiBi.
342
- """
343
-
344
- def __init__(self, config):
345
- super().__init__()
346
- layer = BertLayer(config)
347
- self.layer = nn.ModuleList(
348
- [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
349
-
350
- self.num_attention_heads = config.num_attention_heads
351
-
352
- # The alibi mask will be dynamically expanded if it is too small for
353
- # the input the model receives. But it generally helps to initialize it
354
- # to a reasonably large size to help pre-allocate CUDA memory.
355
- # The default `alibi_starting_size` is 512.
356
- self._current_alibi_size = int(config.alibi_starting_size)
357
- self.alibi = torch.zeros(
358
- (1, self.num_attention_heads, self._current_alibi_size,
359
- self._current_alibi_size))
360
- self.rebuild_alibi_tensor(size=config.alibi_starting_size)
361
-
362
- def rebuild_alibi_tensor(self,
363
- size: int,
364
- device: Optional[Union[torch.device, str]] = None):
365
- # Alibi
366
- # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
367
- # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
368
- # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
369
- # will be applied, it is necessary to construct the diagonal mask.
370
- n_heads = self.num_attention_heads
371
-
372
- def _get_alibi_head_slopes(n_heads: int) -> List[float]:
373
-
374
- def get_slopes_power_of_2(n_heads: int) -> List[float]:
375
- start = (2**(-2**-(math.log2(n_heads) - 3)))
376
- ratio = start
377
- return [start * ratio**i for i in range(n_heads)]
378
-
379
- # In the paper, they only train models that have 2^a heads for some a. This function
380
- # has some good properties that only occur when the input is a power of 2. To
381
- # maintain that even when the number of heads is not a power of 2, we use a
382
- # workaround.
383
- if math.log2(n_heads).is_integer():
384
- return get_slopes_power_of_2(n_heads)
385
-
386
- closest_power_of_2 = 2**math.floor(math.log2(n_heads))
387
- slopes_a = get_slopes_power_of_2(closest_power_of_2)
388
- slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
389
- slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
390
- return slopes_a + slopes_b
391
-
392
- context_position = torch.arange(size, device=device)[:, None]
393
- memory_position = torch.arange(size, device=device)[None, :]
394
- relative_position = torch.abs(memory_position - context_position)
395
- # [n_heads, max_token_length, max_token_length]
396
- relative_position = relative_position.unsqueeze(0).expand(
397
- n_heads, -1, -1)
398
- slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
399
- alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
400
- # [1, n_heads, max_token_length, max_token_length]
401
- alibi = alibi.unsqueeze(0)
402
- assert alibi.shape == torch.Size([1, n_heads, size, size])
403
-
404
- self._current_alibi_size = size
405
- self.alibi = alibi
406
-
407
- def forward(
408
- self,
409
- hidden_states: torch.Tensor,
410
- attention_mask: torch.Tensor,
411
- output_all_encoded_layers: Optional[bool] = True,
412
- subset_mask: Optional[torch.Tensor] = None,
413
- ) -> List[torch.Tensor]:
414
-
415
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
- extended_attention_mask = extended_attention_mask.to(
417
- dtype=torch.float32) # fp16 compatibility
418
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
419
-
420
- attention_mask_bool = attention_mask.bool()
421
- batch, seqlen = hidden_states.shape[:2]
422
- # Unpad inputs and mask. It will remove tokens that are padded.
423
- # Assume ntokens is total number of tokens (padded and non-padded)
424
- # and ntokens_unpad is total number of non-padded tokens.
425
- # Then unpadding performs the following compression of the inputs:
426
- # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
427
- hidden_states, indices, cu_seqlens, _ = unpad_input(
428
- hidden_states, attention_mask_bool)
429
-
430
- # Add alibi matrix to extended_attention_mask
431
- if self._current_alibi_size < seqlen:
432
- # Rebuild the alibi tensor when needed
433
- warnings.warn(
434
- f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
435
- )
436
- self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
437
- elif self.alibi.device != hidden_states.device:
438
- # Device catch-up
439
- self.alibi = self.alibi.to(hidden_states.device)
440
- alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
441
- attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
442
- alibi_attn_mask = attn_bias + alibi_bias
443
-
444
- all_encoder_layers = []
445
- if subset_mask is None:
446
- for layer_module in self.layer:
447
- hidden_states = layer_module(hidden_states,
448
- cu_seqlens,
449
- seqlen,
450
- None,
451
- indices,
452
- attn_mask=attention_mask,
453
- bias=alibi_attn_mask)
454
- if output_all_encoded_layers:
455
- all_encoder_layers.append(hidden_states)
456
- # Pad inputs and mask. It will insert back zero-padded tokens.
457
- # Assume ntokens is total number of tokens (padded and non-padded)
458
- # and ntokens_unpad is total number of non-padded tokens.
459
- # Then padding performs the following de-compression:
460
- # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
461
- hidden_states = pad_input(hidden_states, indices, batch, seqlen)
462
- else:
463
- for i in range(len(self.layer) - 1):
464
- layer_module = self.layer[i]
465
- hidden_states = layer_module(hidden_states,
466
- cu_seqlens,
467
- seqlen,
468
- None,
469
- indices,
470
- attn_mask=attention_mask,
471
- bias=alibi_attn_mask)
472
- if output_all_encoded_layers:
473
- all_encoder_layers.append(hidden_states)
474
- subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
475
- as_tuple=False).flatten()
476
- hidden_states = self.layer[-1](hidden_states,
477
- cu_seqlens,
478
- seqlen,
479
- subset_idx=subset_idx,
480
- indices=indices,
481
- attn_mask=attention_mask,
482
- bias=alibi_attn_mask)
483
-
484
- if not output_all_encoded_layers:
485
- all_encoder_layers.append(hidden_states)
486
- return all_encoder_layers
487
-
488
-
489
- class BertPooler(nn.Module):
490
-
491
- def __init__(self, config):
492
- super(BertPooler, self).__init__()
493
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
494
- self.activation = nn.Tanh()
495
-
496
- def forward(self,
497
- hidden_states: torch.Tensor,
498
- pool: Optional[bool] = True) -> torch.Tensor:
499
- # We "pool" the model by simply taking the hidden state corresponding
500
- # to the first token.
501
- first_token_tensor = hidden_states[:, 0] if pool else hidden_states
502
- pooled_output = self.dense(first_token_tensor)
503
- pooled_output = self.activation(pooled_output)
504
- return pooled_output
505
-
506
-
507
- class BertPredictionHeadTransform(nn.Module):
508
-
509
- def __init__(self, config):
510
- super().__init__()
511
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
512
- if isinstance(config.hidden_act, str):
513
- self.transform_act_fn = ACT2FN[config.hidden_act]
514
- else:
515
- self.transform_act_fn = config.hidden_act
516
- self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
517
-
518
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
519
- hidden_states = self.dense(hidden_states)
520
- hidden_states = self.transform_act_fn(hidden_states)
521
- hidden_states = self.LayerNorm(hidden_states)
522
- return hidden_states
523
-
524
-
525
- class BertModel(BertPreTrainedModel):
526
- """Overall BERT model.
527
-
528
- Args:
529
- config: a BertConfig class instance with the configuration to build a new model
530
-
531
- Inputs:
532
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
533
- with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
534
- `extract_features.py`, `run_classifier.py` and `run_squad.py`)
535
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
536
- types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
537
- a `sentence B` token (see BERT paper for more details).
538
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
539
- selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
540
- input sequence length in the current batch. It's the mask that we typically use for attention when
541
- a batch has varying length sentences.
542
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
543
-
544
- Outputs: Tuple of (encoded_layers, pooled_output)
545
- `encoded_layers`: controlled by `output_all_encoded_layers` argument:
546
- - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
547
- of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
548
- encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
549
- - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
550
- to the last attention block of shape [batch_size, sequence_length, hidden_size],
551
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
552
- classifier pretrained on top of the hidden state associated to the first character of the
553
- input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
554
-
555
- Example usage:
556
- ```python
557
- # Already been converted into WordPiece token ids
558
- input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
559
- input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
560
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
561
- config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
562
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
563
- model = BertModel(config=config)
564
- all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
- ```
566
- """
567
-
568
- def __init__(self, config, add_pooling_layer=True):
569
- super(BertModel, self).__init__(config)
570
- self.embeddings = BertEmbeddings(config)
571
- self.encoder = BertEncoder(config)
572
- self.pooler = BertPooler(config) if add_pooling_layer else None
573
- self.post_init()
574
-
575
- def get_input_embeddings(self):
576
- return self.embeddings.word_embeddings
577
-
578
- def set_input_embeddings(self, value):
579
- self.embeddings.word_embeddings = value
580
-
581
- def forward(
582
- self,
583
- input_ids: torch.Tensor,
584
- token_type_ids: Optional[torch.Tensor] = None,
585
- attention_mask: Optional[torch.Tensor] = None,
586
- position_ids: Optional[torch.Tensor] = None,
587
- output_all_encoded_layers: Optional[bool] = False,
588
- masked_tokens_mask: Optional[torch.Tensor] = None,
589
- **kwargs
590
- ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
591
- if attention_mask is None:
592
- attention_mask = torch.ones_like(input_ids)
593
- if token_type_ids is None:
594
- token_type_ids = torch.zeros_like(input_ids)
595
-
596
- embedding_output = self.embeddings(input_ids, token_type_ids,
597
- position_ids)
598
-
599
- subset_mask = []
600
- first_col_mask = []
601
-
602
- if masked_tokens_mask is None:
603
- subset_mask = None
604
- else:
605
- first_col_mask = torch.zeros_like(masked_tokens_mask)
606
- first_col_mask[:, 0] = True
607
- subset_mask = masked_tokens_mask | first_col_mask
608
-
609
- encoder_outputs = self.encoder(
610
- embedding_output,
611
- attention_mask,
612
- output_all_encoded_layers=output_all_encoded_layers,
613
- subset_mask=subset_mask)
614
-
615
- if masked_tokens_mask is None:
616
- sequence_output = encoder_outputs[-1]
617
- pooled_output = self.pooler(
618
- sequence_output) if self.pooler is not None else None
619
- else:
620
- # TD [2022-03-01]: the indexing here is very tricky.
621
- attention_mask_bool = attention_mask.bool()
622
- subset_idx = subset_mask[attention_mask_bool] # type: ignore
623
- sequence_output = encoder_outputs[-1][
624
- masked_tokens_mask[attention_mask_bool][subset_idx]]
625
- if self.pooler is not None:
626
- pool_input = encoder_outputs[-1][
627
- first_col_mask[attention_mask_bool][subset_idx]]
628
- pooled_output = self.pooler(pool_input, pool=False)
629
- else:
630
- pooled_output = None
631
-
632
- if not output_all_encoded_layers:
633
- encoder_outputs = sequence_output
634
-
635
- if self.pooler is not None:
636
- return encoder_outputs, pooled_output
637
-
638
- return encoder_outputs, None
639
-
640
-
641
- ###################
642
- # Bert Heads
643
- ###################
644
- class BertLMPredictionHead(nn.Module):
645
-
646
- def __init__(self, config, bert_model_embedding_weights):
647
- super().__init__()
648
- self.transform = BertPredictionHeadTransform(config)
649
- # The output weights are the same as the input embeddings, but there is
650
- # an output-only bias for each token.
651
- self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
652
- bert_model_embedding_weights.size(0))
653
- self.decoder.weight = bert_model_embedding_weights
654
-
655
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
656
- hidden_states = self.transform(hidden_states)
657
- hidden_states = self.decoder(hidden_states)
658
- return hidden_states
659
-
660
-
661
- class BertOnlyMLMHead(nn.Module):
662
-
663
- def __init__(self, config, bert_model_embedding_weights):
664
- super().__init__()
665
- self.predictions = BertLMPredictionHead(config,
666
- bert_model_embedding_weights)
667
-
668
- def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
669
- prediction_scores = self.predictions(sequence_output)
670
- return prediction_scores
671
-
672
-
673
- class BertOnlyNSPHead(nn.Module):
674
-
675
- def __init__(self, config):
676
- super().__init__()
677
- self.seq_relationship = nn.Linear(config.hidden_size, 2)
678
-
679
- def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
680
- seq_relationship_score = self.seq_relationship(pooled_output)
681
- return seq_relationship_score
682
-
683
-
684
-
685
- class BertForMaskedLM(BertPreTrainedModel):
686
-
687
- def __init__(self, config):
688
- super().__init__(config)
689
-
690
- if config.is_decoder:
691
- warnings.warn(
692
- 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
693
- 'bi-directional self-attention.')
694
-
695
- self.bert = BertModel(config, add_pooling_layer=False)
696
- self.cls = BertOnlyMLMHead(config,
697
- self.bert.embeddings.word_embeddings.weight)
698
-
699
- # Initialize weights and apply final processing
700
- self.post_init()
701
-
702
- def get_output_embeddings(self):
703
- return self.cls.predictions.decoder
704
-
705
- def set_output_embeddings(self, new_embeddings):
706
- self.cls.predictions.decoder = new_embeddings
707
-
708
- def forward(
709
- self,
710
- input_ids: Optional[torch.Tensor] = None,
711
- attention_mask: Optional[torch.Tensor] = None,
712
- token_type_ids: Optional[torch.Tensor] = None,
713
- position_ids: Optional[torch.Tensor] = None,
714
- head_mask: Optional[torch.Tensor] = None,
715
- inputs_embeds: Optional[torch.Tensor] = None,
716
- encoder_hidden_states: Optional[torch.Tensor] = None,
717
- encoder_attention_mask: Optional[torch.Tensor] = None,
718
- labels: Optional[torch.Tensor] = None,
719
- output_attentions: Optional[bool] = None,
720
- output_hidden_states: Optional[bool] = None,
721
- return_dict: Optional[bool] = None,
722
- ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
723
- # labels should be a `torch.LongTensor` of shape
724
- # `(batch_size, sequence_length)`. These are used for computing the
725
- # masked language modeling loss.
726
- #
727
- # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
728
- # `input_ids` docstring) Tokens with indices set to `-100` are ignored
729
- # (masked), the loss is only computed for the tokens with labels in `[0,
730
- # ..., config.vocab_size]`
731
- #
732
- # Prediction scores are only computed for masked tokens and the (bs,
733
- # seqlen) dimensions are flattened
734
- if (input_ids is not None) == (inputs_embeds is not None):
735
- raise ValueError('Must specify either input_ids or input_embeds!')
736
-
737
- if labels is None:
738
- masked_tokens_mask = None
739
- else:
740
- masked_tokens_mask = labels > 0
741
-
742
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
743
-
744
- outputs = self.bert(
745
- input_ids,
746
- attention_mask=attention_mask,
747
- token_type_ids=token_type_ids,
748
- position_ids=position_ids,
749
- head_mask=head_mask,
750
- inputs_embeds=inputs_embeds,
751
- encoder_hidden_states=encoder_hidden_states,
752
- encoder_attention_mask=encoder_attention_mask,
753
- output_attentions=output_attentions,
754
- output_hidden_states=output_hidden_states,
755
- return_dict=return_dict,
756
- masked_tokens_mask=masked_tokens_mask,
757
- )
758
-
759
- sequence_output = outputs[0]
760
- prediction_scores = self.cls(sequence_output)
761
-
762
- loss = None
763
- if labels is not None:
764
- # Compute loss
765
- loss_fct = nn.CrossEntropyLoss()
766
- masked_token_idx = torch.nonzero(labels.flatten() > 0,
767
- as_tuple=False).flatten()
768
- loss = loss_fct(prediction_scores,
769
- labels.flatten()[masked_token_idx])
770
-
771
- assert input_ids is not None, 'Coding error; please open an issue'
772
- batch, seqlen = input_ids.shape[:2]
773
- prediction_scores = rearrange(index_put_first_axis(
774
- prediction_scores, masked_token_idx, batch * seqlen),
775
- '(b s) d -> b s d',
776
- b=batch)
777
-
778
- if not return_dict:
779
- output = (prediction_scores,) + outputs[2:]
780
- return ((loss,) + output) if loss is not None else output
781
-
782
- return MaskedLMOutput(
783
- loss=loss,
784
- logits=prediction_scores,
785
- hidden_states=outputs[0],
786
- attentions=None,
787
- )
788
-
789
- def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
790
- attention_mask: torch.Tensor,
791
- **model_kwargs):
792
- input_shape = input_ids.shape
793
- effective_batch_size = input_shape[0]
794
-
795
- # add a dummy token
796
- if self.config.pad_token_id is None:
797
- raise ValueError('The PAD token should be defined for generation')
798
-
799
- attention_mask = torch.cat([
800
- attention_mask,
801
- attention_mask.new_zeros((attention_mask.shape[0], 1))
802
- ],
803
- dim=-1)
804
- dummy_token = torch.full((effective_batch_size, 1),
805
- self.config.pad_token_id,
806
- dtype=torch.long,
807
- device=input_ids.device)
808
- input_ids = torch.cat([input_ids, dummy_token], dim=1)
809
-
810
- return {'input_ids': input_ids, 'attention_mask': attention_mask}
811
-
812
-
813
-
814
- class BertForSequenceClassification(BertPreTrainedModel):
815
- """Bert Model transformer with a sequence classification/regression head.
816
-
817
- This head is just a linear layer on top of the pooled output. Used for,
818
- e.g., GLUE tasks.
819
- """
820
-
821
- def __init__(self, config):
822
- super().__init__(config)
823
- self.num_labels = config.num_labels
824
- self.config = config
825
-
826
- self.bert = BertModel(config)
827
- classifier_dropout = (config.classifier_dropout
828
- if config.classifier_dropout is not None else
829
- config.hidden_dropout_prob)
830
- self.dropout = nn.Dropout(classifier_dropout)
831
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
832
-
833
- # Initialize weights and apply final processing
834
- self.post_init()
835
-
836
-
837
- def forward(
838
- self,
839
- input_ids: Optional[torch.Tensor] = None,
840
- attention_mask: Optional[torch.Tensor] = None,
841
- token_type_ids: Optional[torch.Tensor] = None,
842
- position_ids: Optional[torch.Tensor] = None,
843
- head_mask: Optional[torch.Tensor] = None,
844
- inputs_embeds: Optional[torch.Tensor] = None,
845
- labels: Optional[torch.Tensor] = None,
846
- output_attentions: Optional[bool] = None,
847
- output_hidden_states: Optional[bool] = None,
848
- return_dict: Optional[bool] = None,
849
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
850
- # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
851
- # Labels for computing the sequence classification/regression loss.
852
- # Indices should be in `[0, ..., config.num_labels - 1]`.
853
- # If `config.num_labels == 1` a regression loss is computed
854
- # (mean-square loss). If `config.num_labels > 1` a classification loss
855
- # is computed (cross-entropy).
856
-
857
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
858
-
859
- outputs = self.bert(
860
- input_ids,
861
- attention_mask=attention_mask,
862
- token_type_ids=token_type_ids,
863
- position_ids=position_ids,
864
- head_mask=head_mask,
865
- inputs_embeds=inputs_embeds,
866
- output_attentions=output_attentions,
867
- output_hidden_states=output_hidden_states,
868
- return_dict=return_dict,
869
- )
870
-
871
- pooled_output = outputs[1]
872
-
873
- pooled_output = self.dropout(pooled_output)
874
- logits = self.classifier(pooled_output)
875
-
876
- loss = None
877
- if labels is not None:
878
- # Compute loss
879
- if self.config.problem_type is None:
880
- if self.num_labels == 1:
881
- self.config.problem_type = 'regression'
882
- elif self.num_labels > 1 and (labels.dtype == torch.long or
883
- labels.dtype == torch.int):
884
- self.config.problem_type = 'single_label_classification'
885
- else:
886
- self.config.problem_type = 'multi_label_classification'
887
-
888
- if self.config.problem_type == 'regression':
889
- loss_fct = nn.MSELoss()
890
- if self.num_labels == 1:
891
- loss = loss_fct(logits.squeeze(), labels.squeeze())
892
- else:
893
- loss = loss_fct(logits, labels)
894
- elif self.config.problem_type == 'single_label_classification':
895
- loss_fct = nn.CrossEntropyLoss()
896
- loss = loss_fct(logits.view(-1, self.num_labels),
897
- labels.view(-1))
898
- elif self.config.problem_type == 'multi_label_classification':
899
- loss_fct = nn.BCEWithLogitsLoss()
900
- loss = loss_fct(logits, labels)
901
-
902
- if not return_dict:
903
- output = (logits,) + outputs[2:]
904
- return ((loss,) + output) if loss is not None else output
905
-
906
- return SequenceClassifierOutput(
907
- loss=loss,
908
- logits=logits,
909
- hidden_states=outputs[0],
910
- attentions=None,
911
- )
912
-
 
1
+ # Copyright 2022 MosaicML Examples authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
5
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
6
+ # Copyright (c) 2022, Tri Dao.
7
+
8
+ import copy
9
+ import logging
10
+ import math
11
+ import warnings
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange
17
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_outputs import (MaskedLMOutput,
20
+ SequenceClassifierOutput)
21
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel
22
+ from transformers.modeling_utils import PreTrainedModel
23
+
24
+ from .bert_padding import (index_first_axis,
25
+ index_put_first_axis, pad_input,
26
+ unpad_input, unpad_input_only)
27
+
28
+ try:
29
+ from .flash_attn_triton import flash_attn_qkvpacked_func
30
+ except ImportError as e:
31
+ flash_attn_qkvpacked_func = None
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class BertEmbeddings(nn.Module):
37
+
38
+ def __init__(self, config):
39
+ super().__init__()
40
+ self.word_embeddings = nn.Embedding(config.vocab_size,
41
+ config.hidden_size,
42
+ padding_idx=config.pad_token_id)
43
+ # ALiBi doesn't use position embeddings
44
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
45
+ config.hidden_size)
46
+
47
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model
48
+ # variable name and be able to load any TensorFlow checkpoint file
49
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
50
+ eps=config.layer_norm_eps)
51
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
52
+ self.register_buffer('token_type_ids',
53
+ torch.zeros(config.max_position_embeddings,
54
+ dtype=torch.long),
55
+ persistent=False)
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: Optional[torch.LongTensor] = None,
60
+ token_type_ids: Optional[torch.LongTensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ past_key_values_length: int = 0,
64
+ ) -> torch.Tensor:
65
+ if (input_ids is not None) == (inputs_embeds is not None):
66
+ raise ValueError('Must specify either input_ids or input_embeds!')
67
+ if input_ids is not None:
68
+ input_shape = input_ids.size()
69
+ else:
70
+ assert inputs_embeds is not None # just for type checking
71
+ input_shape = inputs_embeds.size()[:-1]
72
+
73
+ seq_length = input_shape[1]
74
+
75
+ if position_ids is None:
76
+ # great! ALiBi
77
+ pass
78
+
79
+ # Setting the token_type_ids to the registered buffer in constructor
80
+ # where it is all zeros, which usually occurs when it's auto-generated;
81
+ # registered buffer helps users when tracing the model without passing
82
+ # token_type_ids, solves issue #5664
83
+ if token_type_ids is None:
84
+ if hasattr(self, 'token_type_ids'):
85
+ assert isinstance(self.token_type_ids, torch.LongTensor)
86
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
87
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
88
+ input_shape[0], seq_length)
89
+ token_type_ids = buffered_token_type_ids_expanded # type: ignore
90
+ else:
91
+ token_type_ids = torch.zeros(input_shape, # type: ignore
92
+ dtype=torch.long,
93
+ device=self.word_embeddings.device) # type: ignore # yapf: disable
94
+
95
+ if inputs_embeds is None:
96
+ inputs_embeds = self.word_embeddings(input_ids)
97
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
98
+
99
+ embeddings = inputs_embeds + token_type_embeddings
100
+ # no position embeddings! ALiBi
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+
106
+ class BertUnpadSelfAttention(nn.Module):
107
+
108
+ def __init__(self, config):
109
+ super().__init__()
110
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
111
+ config, 'embedding_size'):
112
+ raise ValueError(
113
+ f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
114
+ f'heads ({config.num_attention_heads})')
115
+
116
+ self.num_attention_heads = config.num_attention_heads
117
+ self.attention_head_size = int(config.hidden_size /
118
+ config.num_attention_heads)
119
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
120
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
121
+ self.p_dropout = config.attention_probs_dropout_prob
122
+ self.Wqkv = nn.Linear(self.all_head_size, 3 * config.hidden_size)
123
+
124
+ # Warn if defaulting to pytorch because of import issues
125
+ if flash_attn_qkvpacked_func is None:
126
+ warnings.warn(
127
+ 'Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model).'
128
+ )
129
+
130
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
131
+ max_seqlen_in_batch: int, indices: torch.Tensor,
132
+ attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
133
+ """Perform self-attention.
134
+
135
+ If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
136
+ implementation of self-attention.
137
+
138
+ The arguments are unpadded, and our implementations of attention require padded arguments,
139
+ so we first call `pad_input`. Once we compute attention, we re-unpad our outputs for the other layers.
140
+ The pad/unpad operations add overhead, but not sending pad tokens through ffs saves compute.
141
+ It is possible to write an unpadded implementation of attention (in Triton and PyTorch), which we will eventually do.
142
+
143
+ Args:
144
+ hidden_states: (total_nnz, dim)
145
+ cu_seqlens: (batch + 1,)
146
+ max_seqlen_in_batch: int
147
+ indices: (total_nnz,)
148
+ attn_mask: (batch, max_seqlen_in_batch)
149
+ bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
150
+
151
+ Returns:
152
+ attention: (total_nnz, dim)
153
+ """
154
+ qkv = self.Wqkv(hidden_states)
155
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1,
156
+ max_seqlen_in_batch) # batch, max_seqlen_in_batch, thd
157
+ qkv = rearrange(qkv,
158
+ 'b s (t h d) -> b s t h d',
159
+ t=3,
160
+ h=self.num_attention_heads)
161
+ if self.p_dropout or flash_attn_qkvpacked_func is None:
162
+ # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
163
+ q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
164
+ k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
165
+ v = qkv[:, :, 2, :, :].permute(0, 2, 1, 3) # b h s d
166
+ attention_scores = torch.matmul(q, k) / math.sqrt(
167
+ self.attention_head_size)
168
+ attention_scores = attention_scores + bias
169
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
170
+ attention_probs = self.dropout(attention_probs)
171
+ attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
172
+ 3) # b s h d
173
+ else:
174
+ # Triton implementation only supports 0 attention dropout
175
+ convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
176
+ if convert_dtype:
177
+ # Triton implementation only supports fp16 and bf16
178
+ orig_dtype = qkv.dtype
179
+ qkv = qkv.to(torch.float16)
180
+ bias_dtype = bias.dtype
181
+ bias = bias.to(torch.float16)
182
+ attention = flash_attn_qkvpacked_func(qkv, bias)
183
+ attention = attention.to(orig_dtype)
184
+ bias = bias.to(bias_dtype)
185
+ else:
186
+ attention = flash_attn_qkvpacked_func(qkv, bias)
187
+
188
+ # attn_mask is 1 for attend and 0 for don't
189
+ attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
190
+ return rearrange(attention, 'nnz h d -> nnz (h d)')
191
+
192
+
193
+ # Copy of transformer's library BertSelfOutput that will not be caught by surgery methods looking for HF BERT modules.
194
+ class BertSelfOutput(nn.Module):
195
+
196
+ def __init__(self, config):
197
+ super().__init__()
198
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
199
+ self.LayerNorm = nn.LayerNorm(config.hidden_size,
200
+ eps=config.layer_norm_eps)
201
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
202
+
203
+ def forward(self, hidden_states: torch.Tensor,
204
+ input_tensor: torch.Tensor) -> torch.Tensor:
205
+ hidden_states = self.dense(hidden_states)
206
+ hidden_states = self.dropout(hidden_states)
207
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
208
+ return hidden_states
209
+
210
+
211
+ class BertUnpadAttention(nn.Module):
212
+ """Chains attention, Dropout, and LayerNorm for Mosaic BERT."""
213
+
214
+ def __init__(self, config):
215
+ super().__init__()
216
+ self.self = BertUnpadSelfAttention(config)
217
+ self.output = BertSelfOutput(config)
218
+
219
+ def forward(
220
+ self,
221
+ input_tensor: torch.Tensor,
222
+ cu_seqlens: torch.Tensor,
223
+ max_s: int,
224
+ subset_idx: Optional[torch.Tensor] = None,
225
+ indices: Optional[torch.Tensor] = None,
226
+ attn_mask: Optional[torch.Tensor] = None,
227
+ bias: Optional[torch.Tensor] = None,
228
+ ) -> torch.Tensor:
229
+ """Forward pass for scaled self-attention without padding.
230
+
231
+ Arguments:
232
+ input_tensor: (total_nnz, dim)
233
+ cu_seqlens: (batch + 1,)
234
+ max_s: int
235
+ subset_idx: () set of indices whose values we care about at the end of the layer
236
+ (e.g., the masked tokens, if this is the final layer).
237
+ indices: None or (total_nnz,)
238
+ attn_mask: None or (batch, max_seqlen_in_batch)
239
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
240
+ """
241
+ self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
242
+ attn_mask, bias)
243
+ if subset_idx is not None:
244
+ return self.output(index_first_axis(self_output, subset_idx),
245
+ index_first_axis(input_tensor, subset_idx))
246
+ else:
247
+ return self.output(self_output, input_tensor)
248
+
249
+
250
+ class BertGatedLinearUnitMLP(nn.Module):
251
+ """Applies the FFN at the end of each Mosaic BERT layer.
252
+
253
+ Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
254
+ and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but
255
+ introduces Gated Linear Units.
256
+
257
+ Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a
258
+ standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with
259
+ `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed
260
+ with the `config.intermediate_size=3072`.
261
+ However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased
262
+ parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
263
+ """
264
+
265
+ def __init__(self, config):
266
+ super().__init__()
267
+ self.config = config
268
+ self.gated_layers = nn.Linear(config.hidden_size,
269
+ config.intermediate_size * 2,
270
+ bias=False)
271
+ self.act = nn.GELU(approximate='none')
272
+ self.wo = nn.Linear(config.intermediate_size, config.hidden_size)
273
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
274
+ self.layernorm = nn.LayerNorm(config.hidden_size,
275
+ eps=config.layer_norm_eps)
276
+
277
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
278
+ """Compute new hidden states from current hidden states.
279
+
280
+ Args:
281
+ hidden_states (torch.Tensor): The (unpadded) hidden states from
282
+ the attention layer [nnz, dim].
283
+ """
284
+ residual_connection = hidden_states
285
+ # compute the activation
286
+ hidden_states = self.gated_layers(hidden_states)
287
+ gated = hidden_states[:, :self.config.intermediate_size]
288
+ non_gated = hidden_states[:, self.config.intermediate_size:]
289
+ hidden_states = self.act(gated) * non_gated
290
+ hidden_states = self.dropout(hidden_states)
291
+ # multiply by the second matrix
292
+ hidden_states = self.wo(hidden_states)
293
+ # add the residual connection and post-LN
294
+ hidden_states = self.layernorm(hidden_states + residual_connection)
295
+ return hidden_states
296
+
297
+
298
+ class BertLayer(nn.Module):
299
+ """Composes the Mosaic BERT attention and FFN blocks into a single layer."""
300
+
301
+ def __init__(self, config):
302
+ super(BertLayer, self).__init__()
303
+ self.attention = BertUnpadAttention(config)
304
+ self.mlp = BertGatedLinearUnitMLP(config)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ cu_seqlens: torch.Tensor,
310
+ seqlen: int,
311
+ subset_idx: Optional[torch.Tensor] = None,
312
+ indices: Optional[torch.Tensor] = None,
313
+ attn_mask: Optional[torch.Tensor] = None,
314
+ bias: Optional[torch.Tensor] = None,
315
+ ) -> torch.Tensor:
316
+ """Forward pass for a BERT layer, including both attention and MLP.
317
+
318
+ Args:
319
+ hidden_states: (total_nnz, dim)
320
+ cu_seqlens: (batch + 1,)
321
+ seqlen: int
322
+ subset_idx: () set of indices whose values we care about at the end of the layer
323
+ (e.g., the masked tokens, if this is the final layer).
324
+ indices: None or (total_nnz,)
325
+ attn_mask: None or (batch, max_seqlen_in_batch)
326
+ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
327
+ """
328
+ attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
329
+ subset_idx, indices, attn_mask, bias)
330
+ layer_output = self.mlp(attention_output)
331
+ return layer_output
332
+
333
+
334
+ class BertEncoder(nn.Module):
335
+ """A stack of BERT layers providing the backbone of Mosaic BERT.
336
+
337
+ This module is modeled after the Hugging Face BERT's :class:`~transformers.model.bert.modeling_bert.BertEncoder`,
338
+ but with substantial modifications to implement unpadding and ALiBi.
339
+
340
+ Compared to the analogous Hugging Face BERT module, this module handles unpadding to reduce unnecessary computation
341
+ at padded tokens, and pre-computes attention biases to implement ALiBi.
342
+ """
343
+
344
+ def __init__(self, config):
345
+ super().__init__()
346
+ layer = BertLayer(config)
347
+ self.layer = nn.ModuleList(
348
+ [copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
349
+
350
+ self.num_attention_heads = config.num_attention_heads
351
+
352
+ # The alibi mask will be dynamically expanded if it is too small for
353
+ # the input the model receives. But it generally helps to initialize it
354
+ # to a reasonably large size to help pre-allocate CUDA memory.
355
+ # The default `alibi_starting_size` is 512.
356
+ self._current_alibi_size = int(config.alibi_starting_size)
357
+ self.alibi = torch.zeros(
358
+ (1, self.num_attention_heads, self._current_alibi_size,
359
+ self._current_alibi_size))
360
+ self.rebuild_alibi_tensor(size=config.alibi_starting_size)
361
+
362
+ def rebuild_alibi_tensor(self,
363
+ size: int,
364
+ device: Optional[Union[torch.device, str]] = None):
365
+ # Alibi
366
+ # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
367
+ # In the causal case, you can exploit the fact that softmax is invariant to a uniform translation
368
+ # of the logits, which makes the math work out *after* applying causal masking. If no causal masking
369
+ # will be applied, it is necessary to construct the diagonal mask.
370
+ n_heads = self.num_attention_heads
371
+
372
+ def _get_alibi_head_slopes(n_heads: int) -> List[float]:
373
+
374
+ def get_slopes_power_of_2(n_heads: int) -> List[float]:
375
+ start = (2**(-2**-(math.log2(n_heads) - 3)))
376
+ ratio = start
377
+ return [start * ratio**i for i in range(n_heads)]
378
+
379
+ # In the paper, they only train models that have 2^a heads for some a. This function
380
+ # has some good properties that only occur when the input is a power of 2. To
381
+ # maintain that even when the number of heads is not a power of 2, we use a
382
+ # workaround.
383
+ if math.log2(n_heads).is_integer():
384
+ return get_slopes_power_of_2(n_heads)
385
+
386
+ closest_power_of_2 = 2**math.floor(math.log2(n_heads))
387
+ slopes_a = get_slopes_power_of_2(closest_power_of_2)
388
+ slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
389
+ slopes_b = slopes_b[0::2][:n_heads - closest_power_of_2]
390
+ return slopes_a + slopes_b
391
+
392
+ context_position = torch.arange(size, device=device)[:, None]
393
+ memory_position = torch.arange(size, device=device)[None, :]
394
+ relative_position = torch.abs(memory_position - context_position)
395
+ # [n_heads, max_token_length, max_token_length]
396
+ relative_position = relative_position.unsqueeze(0).expand(
397
+ n_heads, -1, -1)
398
+ slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
399
+ alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
400
+ # [1, n_heads, max_token_length, max_token_length]
401
+ alibi = alibi.unsqueeze(0)
402
+ assert alibi.shape == torch.Size([1, n_heads, size, size])
403
+
404
+ self._current_alibi_size = size
405
+ self.alibi = alibi
406
+
407
+ def forward(
408
+ self,
409
+ hidden_states: torch.Tensor,
410
+ attention_mask: torch.Tensor,
411
+ output_all_encoded_layers: Optional[bool] = True,
412
+ subset_mask: Optional[torch.Tensor] = None,
413
+ ) -> List[torch.Tensor]:
414
+
415
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
416
+ extended_attention_mask = extended_attention_mask.to(
417
+ dtype=torch.float32) # fp16 compatibility
418
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
419
+
420
+ attention_mask_bool = attention_mask.bool()
421
+ batch, seqlen = hidden_states.shape[:2]
422
+ # Unpad inputs and mask. It will remove tokens that are padded.
423
+ # Assume ntokens is total number of tokens (padded and non-padded)
424
+ # and ntokens_unpad is total number of non-padded tokens.
425
+ # Then unpadding performs the following compression of the inputs:
426
+ # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
427
+ hidden_states, indices, cu_seqlens, _ = unpad_input(
428
+ hidden_states, attention_mask_bool)
429
+
430
+ # Add alibi matrix to extended_attention_mask
431
+ if self._current_alibi_size < seqlen:
432
+ # Rebuild the alibi tensor when needed
433
+ warnings.warn(
434
+ f'Increasing alibi size from {self._current_alibi_size} to {seqlen}'
435
+ )
436
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
437
+ elif self.alibi.device != hidden_states.device:
438
+ # Device catch-up
439
+ self.alibi = self.alibi.to(hidden_states.device)
440
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
441
+ attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
442
+ alibi_attn_mask = attn_bias + alibi_bias
443
+
444
+ all_encoder_layers = []
445
+ if subset_mask is None:
446
+ for layer_module in self.layer:
447
+ hidden_states = layer_module(hidden_states,
448
+ cu_seqlens,
449
+ seqlen,
450
+ None,
451
+ indices,
452
+ attn_mask=attention_mask,
453
+ bias=alibi_attn_mask)
454
+ if output_all_encoded_layers:
455
+ all_encoder_layers.append(hidden_states)
456
+ # Pad inputs and mask. It will insert back zero-padded tokens.
457
+ # Assume ntokens is total number of tokens (padded and non-padded)
458
+ # and ntokens_unpad is total number of non-padded tokens.
459
+ # Then padding performs the following de-compression:
460
+ # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
461
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
462
+ else:
463
+ for i in range(len(self.layer) - 1):
464
+ layer_module = self.layer[i]
465
+ hidden_states = layer_module(hidden_states,
466
+ cu_seqlens,
467
+ seqlen,
468
+ None,
469
+ indices,
470
+ attn_mask=attention_mask,
471
+ bias=alibi_attn_mask)
472
+ if output_all_encoded_layers:
473
+ all_encoder_layers.append(hidden_states)
474
+ subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
475
+ as_tuple=False).flatten()
476
+ hidden_states = self.layer[-1](hidden_states,
477
+ cu_seqlens,
478
+ seqlen,
479
+ subset_idx=subset_idx,
480
+ indices=indices,
481
+ attn_mask=attention_mask,
482
+ bias=alibi_attn_mask)
483
+
484
+ if not output_all_encoded_layers:
485
+ all_encoder_layers.append(hidden_states)
486
+ return all_encoder_layers
487
+
488
+
489
+ class BertPooler(nn.Module):
490
+
491
+ def __init__(self, config):
492
+ super(BertPooler, self).__init__()
493
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
494
+ self.activation = nn.Tanh()
495
+
496
+ def forward(self,
497
+ hidden_states: torch.Tensor,
498
+ pool: Optional[bool] = True) -> torch.Tensor:
499
+ # We "pool" the model by simply taking the hidden state corresponding
500
+ # to the first token.
501
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
502
+ pooled_output = self.dense(first_token_tensor)
503
+ pooled_output = self.activation(pooled_output)
504
+ return pooled_output
505
+
506
+
507
+ class BertPredictionHeadTransform(nn.Module):
508
+
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
512
+ if isinstance(config.hidden_act, str):
513
+ self.transform_act_fn = ACT2FN[config.hidden_act]
514
+ else:
515
+ self.transform_act_fn = config.hidden_act
516
+ self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
517
+
518
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
519
+ hidden_states = self.dense(hidden_states)
520
+ hidden_states = self.transform_act_fn(hidden_states)
521
+ hidden_states = self.LayerNorm(hidden_states)
522
+ return hidden_states
523
+
524
+
525
+ class BertModel(BertPreTrainedModel):
526
+ """Overall BERT model.
527
+
528
+ Args:
529
+ config: a BertConfig class instance with the configuration to build a new model
530
+
531
+ Inputs:
532
+ `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
533
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
534
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
535
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
536
+ types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
537
+ a `sentence B` token (see BERT paper for more details).
538
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
539
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
540
+ input sequence length in the current batch. It's the mask that we typically use for attention when
541
+ a batch has varying length sentences.
542
+ `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
543
+
544
+ Outputs: Tuple of (encoded_layers, pooled_output)
545
+ `encoded_layers`: controlled by `output_all_encoded_layers` argument:
546
+ - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
547
+ of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
548
+ encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
549
+ - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
550
+ to the last attention block of shape [batch_size, sequence_length, hidden_size],
551
+ `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
552
+ classifier pretrained on top of the hidden state associated to the first character of the
553
+ input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
554
+
555
+ Example usage:
556
+ ```python
557
+ # Already been converted into WordPiece token ids
558
+ input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
559
+ input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
560
+ token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
561
+ config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
562
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
563
+ model = BertModel(config=config)
564
+ all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
+ ```
566
+ """
567
+
568
+ def __init__(self, config, add_pooling_layer=True):
569
+ super(BertModel, self).__init__(config)
570
+ self.embeddings = BertEmbeddings(config)
571
+ self.encoder = BertEncoder(config)
572
+ self.pooler = BertPooler(config) if add_pooling_layer else None
573
+ self.post_init()
574
+
575
+ def get_input_embeddings(self):
576
+ return self.embeddings.word_embeddings
577
+
578
+ def set_input_embeddings(self, value):
579
+ self.embeddings.word_embeddings = value
580
+
581
+ def forward(
582
+ self,
583
+ input_ids: torch.Tensor,
584
+ token_type_ids: Optional[torch.Tensor] = None,
585
+ attention_mask: Optional[torch.Tensor] = None,
586
+ position_ids: Optional[torch.Tensor] = None,
587
+ output_all_encoded_layers: Optional[bool] = False,
588
+ masked_tokens_mask: Optional[torch.Tensor] = None,
589
+ **kwargs
590
+ ) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
591
+ if attention_mask is None:
592
+ attention_mask = torch.ones_like(input_ids)
593
+ if token_type_ids is None:
594
+ token_type_ids = torch.zeros_like(input_ids)
595
+
596
+ embedding_output = self.embeddings(input_ids, token_type_ids,
597
+ position_ids)
598
+
599
+ subset_mask = []
600
+ first_col_mask = []
601
+
602
+ if masked_tokens_mask is None:
603
+ subset_mask = None
604
+ else:
605
+ first_col_mask = torch.zeros_like(masked_tokens_mask)
606
+ first_col_mask[:, 0] = True
607
+ subset_mask = masked_tokens_mask | first_col_mask
608
+
609
+ encoder_outputs = self.encoder(
610
+ embedding_output,
611
+ attention_mask,
612
+ output_all_encoded_layers=output_all_encoded_layers,
613
+ subset_mask=subset_mask)
614
+
615
+ if masked_tokens_mask is None:
616
+ sequence_output = encoder_outputs[-1]
617
+ pooled_output = self.pooler(
618
+ sequence_output) if self.pooler is not None else None
619
+ else:
620
+ # TD [2022-03-01]: the indexing here is very tricky.
621
+ attention_mask_bool = attention_mask.bool()
622
+ subset_idx = subset_mask[attention_mask_bool] # type: ignore
623
+ sequence_output = encoder_outputs[-1][
624
+ masked_tokens_mask[attention_mask_bool][subset_idx]]
625
+ if self.pooler is not None:
626
+ pool_input = encoder_outputs[-1][
627
+ first_col_mask[attention_mask_bool][subset_idx]]
628
+ pooled_output = self.pooler(pool_input, pool=False)
629
+ else:
630
+ pooled_output = None
631
+
632
+ if not output_all_encoded_layers:
633
+ encoder_outputs = sequence_output
634
+
635
+ if self.pooler is not None:
636
+ return encoder_outputs, pooled_output
637
+
638
+ return encoder_outputs, None
639
+
640
+
641
+ ###################
642
+ # Bert Heads
643
+ ###################
644
+ class BertLMPredictionHead(nn.Module):
645
+
646
+ def __init__(self, config, bert_model_embedding_weights):
647
+ super().__init__()
648
+ self.transform = BertPredictionHeadTransform(config)
649
+ # The output weights are the same as the input embeddings, but there is
650
+ # an output-only bias for each token.
651
+ self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
652
+ bert_model_embedding_weights.size(0))
653
+ self.decoder.weight = bert_model_embedding_weights
654
+
655
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
656
+ hidden_states = self.transform(hidden_states)
657
+ hidden_states = self.decoder(hidden_states)
658
+ return hidden_states
659
+
660
+
661
+ class BertOnlyMLMHead(nn.Module):
662
+
663
+ def __init__(self, config, bert_model_embedding_weights):
664
+ super().__init__()
665
+ self.predictions = BertLMPredictionHead(config,
666
+ bert_model_embedding_weights)
667
+
668
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
669
+ prediction_scores = self.predictions(sequence_output)
670
+ return prediction_scores
671
+
672
+
673
+ class BertOnlyNSPHead(nn.Module):
674
+
675
+ def __init__(self, config):
676
+ super().__init__()
677
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
678
+
679
+ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
680
+ seq_relationship_score = self.seq_relationship(pooled_output)
681
+ return seq_relationship_score
682
+
683
+
684
+
685
+ class BertForMaskedLM(BertPreTrainedModel):
686
+
687
+ def __init__(self, config):
688
+ super().__init__(config)
689
+
690
+ if config.is_decoder:
691
+ warnings.warn(
692
+ 'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
693
+ 'bi-directional self-attention.')
694
+
695
+ self.bert = BertModel(config, add_pooling_layer=False)
696
+ self.cls = BertOnlyMLMHead(config,
697
+ self.bert.embeddings.word_embeddings.weight)
698
+
699
+ # Initialize weights and apply final processing
700
+ self.post_init()
701
+
702
+ def get_output_embeddings(self):
703
+ return self.cls.predictions.decoder
704
+
705
+ def set_output_embeddings(self, new_embeddings):
706
+ self.cls.predictions.decoder = new_embeddings
707
+
708
+ def forward(
709
+ self,
710
+ input_ids: Optional[torch.Tensor] = None,
711
+ attention_mask: Optional[torch.Tensor] = None,
712
+ token_type_ids: Optional[torch.Tensor] = None,
713
+ position_ids: Optional[torch.Tensor] = None,
714
+ head_mask: Optional[torch.Tensor] = None,
715
+ inputs_embeds: Optional[torch.Tensor] = None,
716
+ encoder_hidden_states: Optional[torch.Tensor] = None,
717
+ encoder_attention_mask: Optional[torch.Tensor] = None,
718
+ labels: Optional[torch.Tensor] = None,
719
+ output_attentions: Optional[bool] = None,
720
+ output_hidden_states: Optional[bool] = None,
721
+ return_dict: Optional[bool] = None,
722
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
723
+ # labels should be a `torch.LongTensor` of shape
724
+ # `(batch_size, sequence_length)`. These are used for computing the
725
+ # masked language modeling loss.
726
+ #
727
+ # Indices should be in `[-100, 0, ..., config.vocab_size]` (see
728
+ # `input_ids` docstring) Tokens with indices set to `-100` are ignored
729
+ # (masked), the loss is only computed for the tokens with labels in `[0,
730
+ # ..., config.vocab_size]`
731
+ #
732
+ # Prediction scores are only computed for masked tokens and the (bs,
733
+ # seqlen) dimensions are flattened
734
+ if (input_ids is not None) == (inputs_embeds is not None):
735
+ raise ValueError('Must specify either input_ids or input_embeds!')
736
+
737
+ if labels is None:
738
+ masked_tokens_mask = None
739
+ else:
740
+ masked_tokens_mask = labels > 0
741
+
742
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
743
+
744
+ outputs = self.bert(
745
+ input_ids,
746
+ attention_mask=attention_mask,
747
+ token_type_ids=token_type_ids,
748
+ position_ids=position_ids,
749
+ head_mask=head_mask,
750
+ inputs_embeds=inputs_embeds,
751
+ encoder_hidden_states=encoder_hidden_states,
752
+ encoder_attention_mask=encoder_attention_mask,
753
+ output_attentions=output_attentions,
754
+ output_hidden_states=output_hidden_states,
755
+ return_dict=return_dict,
756
+ masked_tokens_mask=masked_tokens_mask,
757
+ )
758
+
759
+ sequence_output = outputs[0]
760
+ prediction_scores = self.cls(sequence_output)
761
+
762
+ loss = None
763
+ if labels is not None:
764
+ # Compute loss
765
+ loss_fct = nn.CrossEntropyLoss()
766
+ masked_token_idx = torch.nonzero(labels.flatten() > 0,
767
+ as_tuple=False).flatten()
768
+ loss = loss_fct(prediction_scores,
769
+ labels.flatten()[masked_token_idx])
770
+
771
+ assert input_ids is not None, 'Coding error; please open an issue'
772
+ batch, seqlen = input_ids.shape[:2]
773
+ prediction_scores = rearrange(index_put_first_axis(
774
+ prediction_scores, masked_token_idx, batch * seqlen),
775
+ '(b s) d -> b s d',
776
+ b=batch)
777
+
778
+ if not return_dict:
779
+ output = (prediction_scores,) + outputs[2:]
780
+ return ((loss,) + output) if loss is not None else output
781
+
782
+ return MaskedLMOutput(
783
+ loss=loss,
784
+ logits=prediction_scores,
785
+ hidden_states=outputs[0],
786
+ attentions=None,
787
+ )
788
+
789
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
790
+ attention_mask: torch.Tensor,
791
+ **model_kwargs):
792
+ input_shape = input_ids.shape
793
+ effective_batch_size = input_shape[0]
794
+
795
+ # add a dummy token
796
+ if self.config.pad_token_id is None:
797
+ raise ValueError('The PAD token should be defined for generation')
798
+
799
+ attention_mask = torch.cat([
800
+ attention_mask,
801
+ attention_mask.new_zeros((attention_mask.shape[0], 1))
802
+ ],
803
+ dim=-1)
804
+ dummy_token = torch.full((effective_batch_size, 1),
805
+ self.config.pad_token_id,
806
+ dtype=torch.long,
807
+ device=input_ids.device)
808
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
809
+
810
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
811
+
812
+
813
+
814
+ class BertForSequenceClassification(BertPreTrainedModel):
815
+ """Bert Model transformer with a sequence classification/regression head.
816
+
817
+ This head is just a linear layer on top of the pooled output. Used for,
818
+ e.g., GLUE tasks.
819
+ """
820
+
821
+ def __init__(self, config):
822
+ super().__init__(config)
823
+ self.num_labels = config.num_labels
824
+ self.config = config
825
+
826
+ self.bert = BertModel(config)
827
+ classifier_dropout = (config.classifier_dropout
828
+ if config.classifier_dropout is not None else
829
+ config.hidden_dropout_prob)
830
+ self.dropout = nn.Dropout(classifier_dropout)
831
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
832
+
833
+ # Initialize weights and apply final processing
834
+ self.post_init()
835
+
836
+
837
+ def forward(
838
+ self,
839
+ input_ids: Optional[torch.Tensor] = None,
840
+ attention_mask: Optional[torch.Tensor] = None,
841
+ token_type_ids: Optional[torch.Tensor] = None,
842
+ position_ids: Optional[torch.Tensor] = None,
843
+ head_mask: Optional[torch.Tensor] = None,
844
+ inputs_embeds: Optional[torch.Tensor] = None,
845
+ labels: Optional[torch.Tensor] = None,
846
+ output_attentions: Optional[bool] = None,
847
+ output_hidden_states: Optional[bool] = None,
848
+ return_dict: Optional[bool] = None,
849
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
850
+ # labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
851
+ # Labels for computing the sequence classification/regression loss.
852
+ # Indices should be in `[0, ..., config.num_labels - 1]`.
853
+ # If `config.num_labels == 1` a regression loss is computed
854
+ # (mean-square loss). If `config.num_labels > 1` a classification loss
855
+ # is computed (cross-entropy).
856
+
857
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
858
+
859
+ outputs = self.bert(
860
+ input_ids,
861
+ attention_mask=attention_mask,
862
+ token_type_ids=token_type_ids,
863
+ position_ids=position_ids,
864
+ head_mask=head_mask,
865
+ inputs_embeds=inputs_embeds,
866
+ output_attentions=output_attentions,
867
+ output_hidden_states=output_hidden_states,
868
+ return_dict=return_dict,
869
+ )
870
+
871
+ pooled_output = outputs[1]
872
+
873
+ pooled_output = self.dropout(pooled_output)
874
+ logits = self.classifier(pooled_output)
875
+
876
+ loss = None
877
+ if labels is not None:
878
+ # Compute loss
879
+ if self.config.problem_type is None:
880
+ if self.num_labels == 1:
881
+ self.config.problem_type = 'regression'
882
+ elif self.num_labels > 1 and (labels.dtype == torch.long or
883
+ labels.dtype == torch.int):
884
+ self.config.problem_type = 'single_label_classification'
885
+ else:
886
+ self.config.problem_type = 'multi_label_classification'
887
+
888
+ if self.config.problem_type == 'regression':
889
+ loss_fct = nn.MSELoss()
890
+ if self.num_labels == 1:
891
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
892
+ else:
893
+ loss = loss_fct(logits, labels)
894
+ elif self.config.problem_type == 'single_label_classification':
895
+ loss_fct = nn.CrossEntropyLoss()
896
+ loss = loss_fct(logits.view(-1, self.num_labels),
897
+ labels.view(-1))
898
+ elif self.config.problem_type == 'multi_label_classification':
899
+ loss_fct = nn.BCEWithLogitsLoss()
900
+ loss = loss_fct(logits, labels)
901
+
902
+ if not return_dict:
903
+ output = (logits,) + outputs[2:]
904
+ return ((loss,) + output) if loss is not None else output
905
+
906
+ return SequenceClassifierOutput(
907
+ loss=loss,
908
+ logits=logits,
909
+ hidden_states=outputs[0],
910
+ attentions="Not available",
911
+ )
912
+