valhalla commited on
Commit
3470b2b
1 Parent(s): 1ec58a5

Create modeling_ldmbert.py

Browse files
Files changed (1) hide show
  1. modeling_ldmbert.py +705 -0
modeling_ldmbert.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch LDMBERT model."""
16
+ import copy
17
+ import math
18
+ import random
19
+ import warnings
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ Seq2SeqLMOutput,
33
+ Seq2SeqModelOutput,
34
+ Seq2SeqQuestionAnsweringModelOutput,
35
+ Seq2SeqSequenceClassifierOutput,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_end_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ replace_return_docstrings,
45
+ )
46
+ from .configuration_ldmbert import LDMBertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "ldm-bert"
52
+ _CONFIG_FOR_DOC = "LDMBertConfig"
53
+ _TOKENIZER_FOR_DOC = "BartTokenizer"
54
+
55
+ # Base model docstring
56
+ _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
57
+
58
+ # SequenceClassification docstring
59
+ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/ldmbert-large-sst2"
60
+ _SEQ_CLASS_EXPECTED_LOSS = 0.0
61
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
62
+
63
+ # QuestionAsnwering docstring
64
+ _CHECKPOINT_FOR_QA = "valhalla/ldmbert-large-finetuned-squadv1"
65
+ _QA_EXPECTED_LOSS = 0.59
66
+ _QA_EXPECTED_OUTPUT = "' nice puppet'"
67
+
68
+
69
+ LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
70
+ "ldm-bert",
71
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
72
+ ]
73
+
74
+
75
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
76
+ """
77
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
78
+ """
79
+ bsz, src_len = mask.size()
80
+ tgt_len = tgt_len if tgt_len is not None else src_len
81
+
82
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
83
+
84
+ inverted_mask = 1.0 - expanded_mask
85
+
86
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
+
88
+
89
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
90
+ class LDMBertAttention(nn.Module):
91
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
92
+
93
+ def __init__(
94
+ self,
95
+ embed_dim: int,
96
+ num_heads: int,
97
+ head_dim: int,
98
+ dropout: float = 0.0,
99
+ is_decoder: bool = False,
100
+ bias: bool = False,
101
+ ):
102
+ super().__init__()
103
+ self.embed_dim = embed_dim
104
+ self.num_heads = num_heads
105
+ self.dropout = dropout
106
+ self.head_dim = head_dim
107
+ self.inner_dim = head_dim * num_heads
108
+
109
+ self.scaling = self.head_dim**-0.5
110
+ self.is_decoder = is_decoder
111
+
112
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
113
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
114
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
115
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
116
+
117
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
118
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ key_value_states: Optional[torch.Tensor] = None,
124
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ layer_head_mask: Optional[torch.Tensor] = None,
127
+ output_attentions: bool = False,
128
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
129
+ """Input shape: Batch x Time x Channel"""
130
+
131
+ # if key_value_states are provided this layer is used as a cross-attention layer
132
+ # for the decoder
133
+ is_cross_attention = key_value_states is not None
134
+
135
+ bsz, tgt_len, _ = hidden_states.size()
136
+
137
+ # get query proj
138
+ query_states = self.q_proj(hidden_states) * self.scaling
139
+ # get key, value proj
140
+ if is_cross_attention and past_key_value is not None:
141
+ # reuse k,v, cross_attentions
142
+ key_states = past_key_value[0]
143
+ value_states = past_key_value[1]
144
+ elif is_cross_attention:
145
+ # cross_attentions
146
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
147
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
148
+ elif past_key_value is not None:
149
+ # reuse k, v, self_attention
150
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
151
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
152
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
153
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
154
+ else:
155
+ # self_attention
156
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
157
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
158
+
159
+ if self.is_decoder:
160
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
161
+ # Further calls to cross_attention layer can then reuse all cross-attention
162
+ # key/value_states (first "if" case)
163
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
164
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
165
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
166
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
167
+ past_key_value = (key_states, value_states)
168
+
169
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
170
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
171
+ key_states = key_states.view(*proj_shape)
172
+ value_states = value_states.view(*proj_shape)
173
+
174
+ src_len = key_states.size(1)
175
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
176
+
177
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
178
+ raise ValueError(
179
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
180
+ f" {attn_weights.size()}"
181
+ )
182
+
183
+ if attention_mask is not None:
184
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
185
+ raise ValueError(
186
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
187
+ )
188
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
189
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
190
+
191
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
192
+
193
+ if layer_head_mask is not None:
194
+ if layer_head_mask.size() != (self.num_heads,):
195
+ raise ValueError(
196
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
197
+ f" {layer_head_mask.size()}"
198
+ )
199
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
200
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
201
+
202
+ if output_attentions:
203
+ # this operation is a bit awkward, but it's required to
204
+ # make sure that attn_weights keeps its gradient.
205
+ # In order to do so, attn_weights have to be reshaped
206
+ # twice and have to be reused in the following
207
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
208
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
209
+ else:
210
+ attn_weights_reshaped = None
211
+
212
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
213
+
214
+ attn_output = torch.bmm(attn_probs, value_states)
215
+
216
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
217
+ raise ValueError(
218
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
219
+ f" {attn_output.size()}"
220
+ )
221
+
222
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
223
+ attn_output = attn_output.transpose(1, 2)
224
+
225
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
226
+ # partitioned aross GPUs when using tensor-parallelism.
227
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
228
+
229
+ attn_output = self.out_proj(attn_output)
230
+
231
+ return attn_output, attn_weights_reshaped, past_key_value
232
+
233
+
234
+ class LDMBertEncoderLayer(nn.Module):
235
+ def __init__(self, config: LDMBertConfig):
236
+ super().__init__()
237
+ self.embed_dim = config.d_model
238
+ self.self_attn = LDMBertAttention(
239
+ embed_dim=self.embed_dim,
240
+ num_heads=config.encoder_attention_heads,
241
+ head_dim=config.head_dim,
242
+ dropout=config.attention_dropout,
243
+ )
244
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
245
+ self.dropout = config.dropout
246
+ self.activation_fn = ACT2FN[config.activation_function]
247
+ self.activation_dropout = config.activation_dropout
248
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
249
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
250
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
251
+
252
+ def forward(
253
+ self,
254
+ hidden_states: torch.FloatTensor,
255
+ attention_mask: torch.FloatTensor,
256
+ layer_head_mask: torch.FloatTensor,
257
+ output_attentions: Optional[bool] = False,
258
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
259
+ """
260
+ Args:
261
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
262
+ attention_mask (`torch.FloatTensor`): attention mask of size
263
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
264
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
265
+ `(encoder_attention_heads,)`.
266
+ output_attentions (`bool`, *optional*):
267
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
268
+ returned tensors for more detail.
269
+ """
270
+ residual = hidden_states
271
+ hidden_states = self.self_attn_layer_norm(hidden_states)
272
+ hidden_states, attn_weights, _ = self.self_attn(
273
+ hidden_states=hidden_states,
274
+ attention_mask=attention_mask,
275
+ layer_head_mask=layer_head_mask,
276
+ output_attentions=output_attentions,
277
+ )
278
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
279
+ hidden_states = residual + hidden_states
280
+
281
+ residual = hidden_states
282
+ hidden_states = self.final_layer_norm(hidden_states)
283
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
284
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
285
+ hidden_states = self.fc2(hidden_states)
286
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
287
+ hidden_states = residual + hidden_states
288
+
289
+ if hidden_states.dtype == torch.float16 and (
290
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
291
+ ):
292
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
293
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
294
+
295
+ outputs = (hidden_states,)
296
+
297
+ if output_attentions:
298
+ outputs += (attn_weights,)
299
+
300
+ return outputs
301
+
302
+
303
+ # Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
304
+ class LDMBertPreTrainedModel(PreTrainedModel):
305
+ config_class = LDMBertConfig
306
+ base_model_prefix = "model"
307
+ supports_gradient_checkpointing = True
308
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
309
+
310
+ def _init_weights(self, module):
311
+ std = self.config.init_std
312
+ if isinstance(module, nn.Linear):
313
+ module.weight.data.normal_(mean=0.0, std=std)
314
+ if module.bias is not None:
315
+ module.bias.data.zero_()
316
+ elif isinstance(module, nn.Embedding):
317
+ module.weight.data.normal_(mean=0.0, std=std)
318
+ if module.padding_idx is not None:
319
+ module.weight.data[module.padding_idx].zero_()
320
+
321
+ def _set_gradient_checkpointing(self, module, value=False):
322
+ if isinstance(module, (LDMBertDecoder, LDMBertEncoder)):
323
+ module.gradient_checkpointing = value
324
+
325
+ @property
326
+ def dummy_inputs(self):
327
+ pad_token = self.config.pad_token_id
328
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
329
+ dummy_inputs = {
330
+ "attention_mask": input_ids.ne(pad_token),
331
+ "input_ids": input_ids,
332
+ }
333
+ return dummy_inputs
334
+
335
+
336
+ LDMBERT_START_DOCSTRING = r"""
337
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
338
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
339
+ etc.)
340
+
341
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
342
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
343
+ and behavior.
344
+
345
+ Parameters:
346
+ config ([`LDMBertConfig`]):
347
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
348
+ load the weights associated with the model, only the configuration. Check out the
349
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
350
+ """
351
+
352
+ LDMBERT_GENERATION_EXAMPLE = r"""
353
+ Summarization example:
354
+
355
+ ```python
356
+ >>> from transformers import BartTokenizer, LDMBertForConditionalGeneration
357
+
358
+ >>> model = LDMBertForConditionalGeneration.from_pretrained("facebook/ldmbert-large-cnn")
359
+ >>> tokenizer = BartTokenizer.from_pretrained("facebook/ldmbert-large-cnn")
360
+
361
+ >>> ARTICLE_TO_SUMMARIZE = (
362
+ ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
363
+ ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
364
+ ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
365
+ ... )
366
+ >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
367
+
368
+ >>> # Generate Summary
369
+ >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
370
+ >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
371
+ 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
372
+ ```
373
+
374
+ Mask filling example:
375
+
376
+ ```python
377
+ >>> from transformers import BartTokenizer, LDMBertForConditionalGeneration
378
+
379
+ >>> tokenizer = BartTokenizer.from_pretrained("ldm-bert")
380
+ >>> model = LDMBertForConditionalGeneration.from_pretrained("ldm-bert")
381
+
382
+ >>> TXT = "My friends are <mask> but they eat too many carbs."
383
+ >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
384
+ >>> logits = model(input_ids).logits
385
+
386
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
387
+ >>> probs = logits[0, masked_index].softmax(dim=0)
388
+ >>> values, predictions = probs.topk(5)
389
+
390
+ >>> tokenizer.decode(predictions).split()
391
+ ['not', 'good', 'healthy', 'great', 'very']
392
+ ```
393
+ """
394
+
395
+ LDMBERT_INPUTS_DOCSTRING = r"""
396
+ Args:
397
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
398
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
399
+ it.
400
+
401
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
402
+ [`PreTrainedTokenizer.__call__`] for details.
403
+
404
+ [What are input IDs?](../glossary#input-ids)
405
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
406
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
407
+
408
+ - 1 for tokens that are **not masked**,
409
+ - 0 for tokens that are **masked**.
410
+
411
+ [What are attention masks?](../glossary#attention-mask)
412
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
413
+ Indices of decoder input sequence tokens in the vocabulary.
414
+
415
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
416
+ [`PreTrainedTokenizer.__call__`] for details.
417
+
418
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
419
+
420
+ LDMBert uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
421
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
422
+ `past_key_values`).
423
+
424
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
425
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
426
+ for denoising pre-training following the paper.
427
+ decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
428
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
429
+ be used by default.
430
+
431
+ If you want to change padding behavior, you should read
432
+ [`modeling_ldmbert._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the
433
+ paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
434
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
435
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
436
+
437
+ - 1 indicates the head is **not masked**,
438
+ - 0 indicates the head is **masked**.
439
+
440
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
441
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
442
+
443
+ - 1 indicates the head is **not masked**,
444
+ - 0 indicates the head is **masked**.
445
+
446
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
447
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
448
+ 1]`:
449
+
450
+ - 1 indicates the head is **not masked**,
451
+ - 0 indicates the head is **masked**.
452
+
453
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
454
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
455
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
456
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
457
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
458
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
459
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
460
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
461
+
462
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
463
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
464
+
465
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
466
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
467
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
468
+ `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
469
+ can choose to directly pass an embedded representation. This is useful if you want more control over how to
470
+ convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
471
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
472
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
473
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
474
+ input (see `past_key_values`). This is useful if you want more control over how to convert
475
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
476
+
477
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
478
+ of `inputs_embeds`.
479
+ use_cache (`bool`, *optional*):
480
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
481
+ `past_key_values`).
482
+ output_attentions (`bool`, *optional*):
483
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
484
+ tensors for more detail.
485
+ output_hidden_states (`bool`, *optional*):
486
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
487
+ more detail.
488
+ return_dict (`bool`, *optional*):
489
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
490
+ """
491
+
492
+
493
+ class LDMBertEncoder(LDMBertPreTrainedModel):
494
+ """
495
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
496
+ [`LDMBertEncoderLayer`].
497
+
498
+ Args:
499
+ config: LDMBertConfig
500
+ embed_tokens (nn.Embedding): output embedding
501
+ """
502
+
503
+ def __init__(self, config: LDMBertConfig):
504
+ super().__init__(config)
505
+
506
+ self.dropout = config.dropout
507
+
508
+ embed_dim = config.d_model
509
+ self.padding_idx = config.pad_token_id
510
+ self.max_source_positions = config.max_position_embeddings
511
+
512
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
513
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
514
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
515
+ self.layer_norm = nn.LayerNorm(embed_dim)
516
+
517
+ self.gradient_checkpointing = False
518
+ # Initialize weights and apply final processing
519
+ self.post_init()
520
+
521
+ def get_input_embeddings(self):
522
+ return self.embed_tokens
523
+
524
+ def set_input_embeddings(self, value):
525
+ self.embed_tokens = value
526
+
527
+ def forward(
528
+ self,
529
+ input_ids: torch.LongTensor = None,
530
+ attention_mask: Optional[torch.Tensor] = None,
531
+ position_ids: Optional[torch.LongTensor] = None,
532
+ head_mask: Optional[torch.Tensor] = None,
533
+ inputs_embeds: Optional[torch.FloatTensor] = None,
534
+ output_attentions: Optional[bool] = None,
535
+ output_hidden_states: Optional[bool] = None,
536
+ return_dict: Optional[bool] = None,
537
+ ) -> Union[Tuple, BaseModelOutput]:
538
+ r"""
539
+ Args:
540
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
541
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
542
+ provide it.
543
+
544
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
545
+ [`PreTrainedTokenizer.__call__`] for details.
546
+
547
+ [What are input IDs?](../glossary#input-ids)
548
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
549
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
550
+
551
+ - 1 for tokens that are **not masked**,
552
+ - 0 for tokens that are **masked**.
553
+
554
+ [What are attention masks?](../glossary#attention-mask)
555
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
556
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
557
+
558
+ - 1 indicates the head is **not masked**,
559
+ - 0 indicates the head is **masked**.
560
+
561
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
562
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
563
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
564
+ than the model's internal embedding lookup matrix.
565
+ output_attentions (`bool`, *optional*):
566
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
567
+ returned tensors for more detail.
568
+ output_hidden_states (`bool`, *optional*):
569
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
570
+ for more detail.
571
+ return_dict (`bool`, *optional*):
572
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ # retrieve input_ids and inputs_embeds
581
+ if input_ids is not None and inputs_embeds is not None:
582
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
583
+ elif input_ids is not None:
584
+ input_shape = input_ids.size()
585
+ input_ids = input_ids.view(-1, input_shape[-1])
586
+ elif inputs_embeds is not None:
587
+ input_shape = inputs_embeds.size()[:-1]
588
+ else:
589
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
590
+
591
+ if inputs_embeds is None:
592
+ inputs_embeds = self.embed_tokens(input_ids)
593
+
594
+ seq_len = input_shape[1]
595
+ if position_ids is None:
596
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
597
+ embed_pos = self.embed_positions(position_ids)
598
+
599
+ hidden_states = inputs_embeds + embed_pos
600
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
601
+
602
+ # expand attention_mask
603
+ if attention_mask is not None:
604
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
605
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
606
+
607
+ encoder_states = () if output_hidden_states else None
608
+ all_attentions = () if output_attentions else None
609
+
610
+ # check if head_mask has a correct number of layers specified if desired
611
+ if head_mask is not None:
612
+ if head_mask.size()[0] != (len(self.layers)):
613
+ raise ValueError(
614
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
615
+ f" {head_mask.size()[0]}."
616
+ )
617
+
618
+ for idx, encoder_layer in enumerate(self.layers):
619
+ if output_hidden_states:
620
+ encoder_states = encoder_states + (hidden_states,)
621
+ if self.gradient_checkpointing and self.training:
622
+
623
+ def create_custom_forward(module):
624
+ def custom_forward(*inputs):
625
+ return module(*inputs, output_attentions)
626
+
627
+ return custom_forward
628
+
629
+ layer_outputs = torch.utils.checkpoint.checkpoint(
630
+ create_custom_forward(encoder_layer),
631
+ hidden_states,
632
+ attention_mask,
633
+ (head_mask[idx] if head_mask is not None else None),
634
+ )
635
+ else:
636
+ layer_outputs = encoder_layer(
637
+ hidden_states,
638
+ attention_mask,
639
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
640
+ output_attentions=output_attentions,
641
+ )
642
+
643
+ hidden_states = layer_outputs[0]
644
+
645
+ if output_attentions:
646
+ all_attentions = all_attentions + (layer_outputs[1],)
647
+
648
+ hidden_states = self.layer_norm(hidden_states)
649
+
650
+ if output_hidden_states:
651
+ encoder_states = encoder_states + (hidden_states,)
652
+
653
+ if not return_dict:
654
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
655
+ return BaseModelOutput(
656
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
657
+ )
658
+
659
+
660
+ class LDMBertModel(LDMBertPreTrainedModel):
661
+ def __init__(self, config):
662
+ super().__init__(config)
663
+ self.model = LDMBertEncoder(config)
664
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
665
+
666
+ def forward(
667
+ self,
668
+ input_ids=None,
669
+ attention_mask=None,
670
+ position_ids=None,
671
+ head_mask=None,
672
+ inputs_embeds=None,
673
+ labels=None,
674
+ output_attentions=None,
675
+ output_hidden_states=None,
676
+ return_dict=None,
677
+ ):
678
+
679
+ outputs = self.model(
680
+ input_ids,
681
+ attention_mask=attention_mask,
682
+ position_ids=position_ids,
683
+ head_mask=head_mask,
684
+ inputs_embeds=inputs_embeds,
685
+ output_attentions=output_attentions,
686
+ output_hidden_states=output_hidden_states,
687
+ return_dict=return_dict,
688
+ )
689
+ sequence_output = outputs[0]
690
+ # logits = self.to_logits(sequence_output)
691
+ # outputs = (logits,) + outputs[1:]
692
+
693
+ # if labels is not None:
694
+ # loss_fct = CrossEntropyLoss()
695
+ # loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
696
+ # outputs = (loss,) + outputs
697
+
698
+ # if not return_dict:
699
+ # return outputs
700
+
701
+ return BaseModelOutput(
702
+ last_hidden_state=sequence_output,
703
+ # hidden_states=outputs[1],
704
+ # attentions=outputs[2],
705
+ )