zpn commited on
Commit
d9ac587
1 Parent(s): 212555e

Upload RoFormerForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +34 -0
  2. modeling_nt_roformer.py +1587 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/fsx/home-zanussbaum/nucleotide-transformer/nt_transformer/saved_models/roformer/mamalian_cds_2k_window_step84999",
3
+ "architectures": [
4
+ "RoFormerForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModelForMaskedLM": "modeling_nt_roformer.RoFormerForMaskedLM"
9
+ },
10
+ "classifier_dropout": "None",
11
+ "embedding_size": 768,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 512,
19
+ "model_type": "roformer",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "pad_token_id": 0,
23
+ "position_embedding_type": "absolute",
24
+ "rotary_value": true,
25
+ "summary_activation": "gelu",
26
+ "summary_last_dropout": 0.1,
27
+ "summary_type": "first",
28
+ "summary_use_proj": true,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.25.1",
31
+ "type_vocab_size": 2,
32
+ "use_cache": true,
33
+ "vocab_size": 9
34
+ }
modeling_nt_roformer.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch RoFormer model."""
16
+
17
+
18
+ import math
19
+ import os
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
39
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers import RoFormerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
53
+ _CONFIG_FOR_DOC = "RoFormerConfig"
54
+ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
55
+
56
+ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "junnyu/roformer_chinese_small",
58
+ "junnyu/roformer_chinese_base",
59
+ "junnyu/roformer_chinese_char_small",
60
+ "junnyu/roformer_chinese_char_base",
61
+ "junnyu/roformer_small_discriminator",
62
+ "junnyu/roformer_small_generator"
63
+ # See all RoFormer models at https://huggingface.co/models?filter=roformer
64
+ ]
65
+
66
+
67
+ # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
68
+ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
69
+ """This module produces sinusoidal positional embeddings of any length."""
70
+
71
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
72
+ super().__init__(num_positions, embedding_dim)
73
+ self.weight = self._init_weight(self.weight)
74
+
75
+ @staticmethod
76
+ def _init_weight(out: nn.Parameter) -> nn.Parameter:
77
+ """
78
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
79
+ the 2nd half of the vector. [dim // 2:]
80
+ """
81
+ n_pos, dim = out.shape
82
+ position_enc = np.array(
83
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
84
+ )
85
+ out.requires_grad = False # set early to avoid an error in pytorch-1.8+
86
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
87
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
88
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
89
+ out.detach_()
90
+ return out
91
+
92
+ @torch.no_grad()
93
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
94
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
95
+ bsz, seq_len = input_ids_shape[:2]
96
+ positions = torch.arange(
97
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
98
+ )
99
+ return super().forward(positions)
100
+
101
+
102
+ def load_tf_weights_in_roformer(model, config, tf_checkpoint_path):
103
+ """Load tf checkpoints in a pytorch model."""
104
+ try:
105
+ import re
106
+
107
+ import numpy as np
108
+ import tensorflow as tf
109
+ except ImportError:
110
+ logger.error(
111
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
112
+ "https://www.tensorflow.org/install/ for installation instructions."
113
+ )
114
+ raise
115
+ tf_path = os.path.abspath(tf_checkpoint_path)
116
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
117
+ # Load weights from TF model
118
+ init_vars = tf.train.list_variables(tf_path)
119
+ names = []
120
+ arrays = []
121
+ for name, shape in init_vars:
122
+ logger.info(f"Loading TF weight {name} with shape {shape}")
123
+ array = tf.train.load_variable(tf_path, name)
124
+ names.append(name.replace("bert", "roformer"))
125
+ arrays.append(array)
126
+
127
+ for name, array in zip(names, arrays):
128
+ name = name.split("/")
129
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
130
+ # which are not required for using pretrained model
131
+ if any(
132
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
133
+ for n in name
134
+ ):
135
+ logger.info(f"Skipping {'/'.join(name)}")
136
+ continue
137
+ pointer = model
138
+ for m_name in name:
139
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
140
+ scope_names = re.split(r"_(\d+)", m_name)
141
+ else:
142
+ scope_names = [m_name]
143
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
144
+ pointer = getattr(pointer, "weight")
145
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
146
+ pointer = getattr(pointer, "bias")
147
+ elif scope_names[0] == "output_weights":
148
+ pointer = getattr(pointer, "weight")
149
+ elif scope_names[0] == "squad":
150
+ pointer = getattr(pointer, "classifier")
151
+ else:
152
+ try:
153
+ pointer = getattr(pointer, scope_names[0])
154
+ except AttributeError:
155
+ logger.info(f"Skipping {'/'.join(name)}")
156
+ continue
157
+ if len(scope_names) >= 2:
158
+ num = int(scope_names[1])
159
+ pointer = pointer[num]
160
+ if m_name[-11:] == "_embeddings":
161
+ pointer = getattr(pointer, "weight")
162
+ elif m_name == "kernel":
163
+ array = np.transpose(array)
164
+ try:
165
+ if not pointer.shape == array.shape:
166
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
167
+ except AssertionError as e:
168
+ e.args += (pointer.shape, array.shape)
169
+ raise
170
+ logger.info(f"Initialize PyTorch weight {name}")
171
+ pointer.data = torch.from_numpy(array)
172
+ return model
173
+
174
+
175
+ class RoFormerEmbeddings(nn.Module):
176
+ """Construct the embeddings from word and token_type embeddings."""
177
+
178
+ def __init__(self, config):
179
+ super().__init__()
180
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
181
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
182
+
183
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
184
+ # any TensorFlow checkpoint file
185
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
186
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
187
+
188
+ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None):
189
+ if input_ids is not None:
190
+ input_shape = input_ids.size()
191
+ else:
192
+ input_shape = inputs_embeds.size()[:-1]
193
+
194
+ if inputs_embeds is None:
195
+ inputs_embeds = self.word_embeddings(input_ids)
196
+
197
+ if token_type_ids is None:
198
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
199
+
200
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
201
+
202
+ embeddings = inputs_embeds + token_type_embeddings
203
+
204
+ embeddings = self.LayerNorm(embeddings)
205
+ embeddings = self.dropout(embeddings)
206
+ return embeddings
207
+
208
+
209
+ class RoFormerSelfAttention(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
213
+ raise ValueError(
214
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
215
+ f"heads ({config.num_attention_heads})"
216
+ )
217
+
218
+ self.num_attention_heads = config.num_attention_heads
219
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
220
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
221
+
222
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
223
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
224
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
225
+
226
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
+
228
+ self.is_decoder = config.is_decoder
229
+ self.rotary_value = config.rotary_value
230
+
231
+ def transpose_for_scores(self, x):
232
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
233
+ x = x.view(*new_x_shape)
234
+ return x.permute(0, 2, 1, 3)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ attention_mask=None,
240
+ sinusoidal_pos=None,
241
+ head_mask=None,
242
+ encoder_hidden_states=None,
243
+ encoder_attention_mask=None,
244
+ past_key_value=None,
245
+ output_attentions=False,
246
+ ):
247
+ mixed_query_layer = self.query(hidden_states)
248
+ query_layer = self.transpose_for_scores(mixed_query_layer)
249
+ # If this is instantiated as a cross-attention module, the keys
250
+ # and values come from an encoder; the attention mask needs to be
251
+ # such that the encoder's padding tokens are not attended to.
252
+ is_cross_attention = encoder_hidden_states is not None
253
+
254
+ if is_cross_attention and past_key_value is not None:
255
+ # reuse k,v, cross_attentions
256
+ key_layer = past_key_value[0]
257
+ value_layer = past_key_value[1]
258
+ attention_mask = encoder_attention_mask
259
+ elif is_cross_attention:
260
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
261
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
262
+ attention_mask = encoder_attention_mask
263
+ elif past_key_value is not None:
264
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
265
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
266
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
267
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
268
+ else:
269
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
270
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
271
+ if sinusoidal_pos is not None:
272
+ if self.rotary_value:
273
+ query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
274
+ sinusoidal_pos, query_layer, key_layer, value_layer
275
+ )
276
+ else:
277
+ query_layer, key_layer = self.apply_rotary_position_embeddings(
278
+ sinusoidal_pos, query_layer, key_layer
279
+ )
280
+ if self.is_decoder:
281
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
282
+ # Further calls to cross_attention layer can then reuse all cross-attention
283
+ # key/value_states (first "if" case)
284
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
285
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
286
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
287
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
288
+ past_key_value = (key_layer, value_layer)
289
+
290
+ # Take the dot product between "query" and "key" to get the raw attention scores.
291
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
292
+
293
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
294
+ if attention_mask is not None:
295
+ # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function)
296
+ attention_scores = attention_scores + attention_mask
297
+
298
+ # Normalize the attention scores to probabilities.
299
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
300
+
301
+ # This is actually dropping out entire tokens to attend to, which might
302
+ # seem a bit unusual, but is taken from the original Transformer paper.
303
+ attention_probs = self.dropout(attention_probs)
304
+
305
+ # Mask heads if we want to
306
+ if head_mask is not None:
307
+ attention_probs = attention_probs * head_mask
308
+
309
+ context_layer = torch.matmul(attention_probs, value_layer)
310
+
311
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
312
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
313
+ context_layer = context_layer.view(*new_context_layer_shape)
314
+
315
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
316
+
317
+ if self.is_decoder:
318
+ outputs = outputs + (past_key_value,)
319
+ return outputs
320
+
321
+ @staticmethod
322
+ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
323
+ # https://kexue.fm/archives/8265
324
+ # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
325
+ # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
326
+ sin, cos = sinusoidal_pos.chunk(2, dim=-1)
327
+ # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
328
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
329
+ # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
330
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
331
+ # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
332
+ rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
333
+ query_layer
334
+ )
335
+ query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
336
+ # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
337
+ rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
338
+ key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
339
+ if value_layer is not None:
340
+ # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
341
+ rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
342
+ value_layer
343
+ )
344
+ value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
345
+ return query_layer, key_layer, value_layer
346
+ return query_layer, key_layer
347
+
348
+
349
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoFormer
350
+ class RoFormerSelfOutput(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
354
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
355
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
356
+
357
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
358
+ hidden_states = self.dense(hidden_states)
359
+ hidden_states = self.dropout(hidden_states)
360
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
361
+ return hidden_states
362
+
363
+
364
+ class RoFormerAttention(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.self = RoFormerSelfAttention(config)
368
+ self.output = RoFormerSelfOutput(config)
369
+ self.pruned_heads = set()
370
+
371
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
372
+ def prune_heads(self, heads):
373
+ if len(heads) == 0:
374
+ return
375
+ heads, index = find_pruneable_heads_and_indices(
376
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
377
+ )
378
+
379
+ # Prune linear layers
380
+ self.self.query = prune_linear_layer(self.self.query, index)
381
+ self.self.key = prune_linear_layer(self.self.key, index)
382
+ self.self.value = prune_linear_layer(self.self.value, index)
383
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
384
+
385
+ # Update hyper params and store pruned heads
386
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
387
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
388
+ self.pruned_heads = self.pruned_heads.union(heads)
389
+
390
+ # End Copy
391
+ def forward(
392
+ self,
393
+ hidden_states,
394
+ attention_mask=None,
395
+ sinusoidal_pos=None,
396
+ head_mask=None,
397
+ encoder_hidden_states=None,
398
+ encoder_attention_mask=None,
399
+ past_key_value=None,
400
+ output_attentions=False,
401
+ ):
402
+ self_outputs = self.self(
403
+ hidden_states,
404
+ attention_mask,
405
+ sinusoidal_pos,
406
+ head_mask,
407
+ encoder_hidden_states,
408
+ encoder_attention_mask,
409
+ past_key_value,
410
+ output_attentions,
411
+ )
412
+ attention_output = self.output(self_outputs[0], hidden_states)
413
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
414
+ return outputs
415
+
416
+
417
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoFormer
418
+ class RoFormerIntermediate(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
422
+ if isinstance(config.hidden_act, str):
423
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
424
+ else:
425
+ self.intermediate_act_fn = config.hidden_act
426
+
427
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
+ hidden_states = self.dense(hidden_states)
429
+ hidden_states = self.intermediate_act_fn(hidden_states)
430
+ return hidden_states
431
+
432
+
433
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer
434
+ class RoFormerOutput(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
438
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
439
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
440
+
441
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
442
+ hidden_states = self.dense(hidden_states)
443
+ hidden_states = self.dropout(hidden_states)
444
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
445
+ return hidden_states
446
+
447
+
448
+ class RoFormerLayer(nn.Module):
449
+ def __init__(self, config):
450
+ super().__init__()
451
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
452
+ self.seq_len_dim = 1
453
+ self.attention = RoFormerAttention(config)
454
+ self.is_decoder = config.is_decoder
455
+ self.add_cross_attention = config.add_cross_attention
456
+ if self.add_cross_attention:
457
+ if not self.is_decoder:
458
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
459
+ self.crossattention = RoFormerAttention(config)
460
+ self.intermediate = RoFormerIntermediate(config)
461
+ self.output = RoFormerOutput(config)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states,
466
+ attention_mask=None,
467
+ sinusoidal_pos=None,
468
+ head_mask=None,
469
+ encoder_hidden_states=None,
470
+ encoder_attention_mask=None,
471
+ past_key_value=None,
472
+ output_attentions=False,
473
+ ):
474
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
475
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
476
+ self_attention_outputs = self.attention(
477
+ hidden_states,
478
+ attention_mask,
479
+ sinusoidal_pos,
480
+ head_mask,
481
+ output_attentions=output_attentions,
482
+ past_key_value=self_attn_past_key_value,
483
+ )
484
+ attention_output = self_attention_outputs[0]
485
+
486
+ # if decoder, the last output is tuple of self-attn cache
487
+ if self.is_decoder:
488
+ outputs = self_attention_outputs[1:-1]
489
+ present_key_value = self_attention_outputs[-1]
490
+ else:
491
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
492
+
493
+ cross_attn_present_key_value = None
494
+ if self.is_decoder and encoder_hidden_states is not None:
495
+ if not hasattr(self, "crossattention"):
496
+ raise ValueError(
497
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
498
+ "layers by setting `config.add_cross_attention=True`"
499
+ )
500
+
501
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
502
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
503
+ cross_attention_outputs = self.crossattention(
504
+ attention_output,
505
+ attention_mask,
506
+ sinusoidal_pos,
507
+ head_mask,
508
+ encoder_hidden_states,
509
+ encoder_attention_mask,
510
+ cross_attn_past_key_value,
511
+ output_attentions,
512
+ )
513
+ attention_output = cross_attention_outputs[0]
514
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
515
+
516
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
517
+ cross_attn_present_key_value = cross_attention_outputs[-1]
518
+ present_key_value = present_key_value + cross_attn_present_key_value
519
+
520
+ layer_output = apply_chunking_to_forward(
521
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
522
+ )
523
+ outputs = (layer_output,) + outputs
524
+
525
+ # if decoder, return the attn key/values as the last output
526
+ if self.is_decoder:
527
+ outputs = outputs + (present_key_value,)
528
+
529
+ return outputs
530
+
531
+ def feed_forward_chunk(self, attention_output):
532
+ intermediate_output = self.intermediate(attention_output)
533
+ layer_output = self.output(intermediate_output, attention_output)
534
+ return layer_output
535
+
536
+
537
+ class RoFormerEncoder(nn.Module):
538
+ def __init__(self, config):
539
+ super().__init__()
540
+ self.config = config
541
+ self.embed_positions = RoFormerSinusoidalPositionalEmbedding(
542
+ config.max_position_embeddings, config.hidden_size // config.num_attention_heads
543
+ )
544
+ self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)])
545
+ self.gradient_checkpointing = False
546
+
547
+ def forward(
548
+ self,
549
+ hidden_states,
550
+ attention_mask=None,
551
+ head_mask=None,
552
+ encoder_hidden_states=None,
553
+ encoder_attention_mask=None,
554
+ past_key_values=None,
555
+ use_cache=None,
556
+ output_attentions=False,
557
+ output_hidden_states=False,
558
+ return_dict=True,
559
+ ):
560
+ all_hidden_states = () if output_hidden_states else None
561
+ all_self_attentions = () if output_attentions else None
562
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
563
+
564
+ # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
565
+ sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[None, None, :, :]
566
+
567
+ next_decoder_cache = () if use_cache else None
568
+ for i, layer_module in enumerate(self.layer):
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ layer_head_mask = head_mask[i] if head_mask is not None else None
573
+ past_key_value = past_key_values[i] if past_key_values is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ if use_cache:
578
+ logger.warning(
579
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
580
+ )
581
+ use_cache = False
582
+
583
+ def create_custom_forward(module):
584
+ def custom_forward(*inputs):
585
+ return module(*inputs, past_key_value, output_attentions)
586
+
587
+ return custom_forward
588
+
589
+ layer_outputs = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(layer_module),
591
+ hidden_states,
592
+ attention_mask,
593
+ sinusoidal_pos,
594
+ layer_head_mask,
595
+ encoder_hidden_states,
596
+ encoder_attention_mask,
597
+ )
598
+ else:
599
+ layer_outputs = layer_module(
600
+ hidden_states,
601
+ attention_mask,
602
+ sinusoidal_pos,
603
+ layer_head_mask,
604
+ encoder_hidden_states,
605
+ encoder_attention_mask,
606
+ past_key_value,
607
+ output_attentions,
608
+ )
609
+
610
+ hidden_states = layer_outputs[0]
611
+ if use_cache:
612
+ next_decoder_cache += (layer_outputs[-1],)
613
+ if output_attentions:
614
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
615
+ if self.config.add_cross_attention:
616
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
617
+
618
+ if output_hidden_states:
619
+ all_hidden_states = all_hidden_states + (hidden_states,)
620
+
621
+ if not return_dict:
622
+ return tuple(
623
+ v
624
+ for v in [
625
+ hidden_states,
626
+ next_decoder_cache,
627
+ all_hidden_states,
628
+ all_self_attentions,
629
+ all_cross_attentions,
630
+ ]
631
+ if v is not None
632
+ )
633
+ return BaseModelOutputWithPastAndCrossAttentions(
634
+ last_hidden_state=hidden_states,
635
+ past_key_values=next_decoder_cache,
636
+ hidden_states=all_hidden_states,
637
+ attentions=all_self_attentions,
638
+ cross_attentions=all_cross_attentions,
639
+ )
640
+
641
+
642
+ class RoFormerPredictionHeadTransform(nn.Module):
643
+ def __init__(self, config):
644
+ super().__init__()
645
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
646
+ if isinstance(config.hidden_act, str):
647
+ self.transform_act_fn = ACT2FN[config.hidden_act]
648
+ else:
649
+ self.transform_act_fn = config.hidden_act
650
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
651
+
652
+ def forward(self, hidden_states):
653
+ hidden_states = self.dense(hidden_states)
654
+ hidden_states = self.transform_act_fn(hidden_states)
655
+ hidden_states = self.LayerNorm(hidden_states)
656
+ return hidden_states
657
+
658
+
659
+ class RoFormerLMPredictionHead(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.transform = RoFormerPredictionHeadTransform(config)
663
+
664
+ # The output weights are the same as the input embeddings, but there is
665
+ # an output-only bias for each token.
666
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
667
+
668
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
669
+
670
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
671
+ self.decoder.bias = self.bias
672
+
673
+ def forward(self, hidden_states):
674
+ hidden_states = self.transform(hidden_states)
675
+ hidden_states = self.decoder(hidden_states)
676
+ return hidden_states
677
+
678
+
679
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoFormer
680
+ class RoFormerOnlyMLMHead(nn.Module):
681
+ def __init__(self, config):
682
+ super().__init__()
683
+ self.predictions = RoFormerLMPredictionHead(config)
684
+
685
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
686
+ prediction_scores = self.predictions(sequence_output)
687
+ return prediction_scores
688
+
689
+
690
+ class RoFormerPreTrainedModel(PreTrainedModel):
691
+ """
692
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
693
+ models.
694
+ """
695
+
696
+ config_class = RoFormerConfig
697
+ load_tf_weights = load_tf_weights_in_roformer
698
+ base_model_prefix = "roformer"
699
+ supports_gradient_checkpointing = True
700
+ _keys_to_ignore_on_load_missing = []
701
+ _keys_to_ignore_on_load_unexpected = [
702
+ r"roformer.embeddings_project.weight",
703
+ r"roformer.embeddings_project.bias",
704
+ ]
705
+
706
+ def _init_weights(self, module):
707
+ """Initialize the weights"""
708
+ if isinstance(module, nn.Linear):
709
+ # Slightly different from the TF version which uses truncated_normal for initialization
710
+ # cf https://github.com/pytorch/pytorch/pull/5617
711
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
712
+ if module.bias is not None:
713
+ module.bias.data.zero_()
714
+ elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
715
+ pass
716
+ elif isinstance(module, nn.Embedding):
717
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
718
+ if module.padding_idx is not None:
719
+ module.weight.data[module.padding_idx].zero_()
720
+ elif isinstance(module, nn.LayerNorm):
721
+ module.bias.data.zero_()
722
+ module.weight.data.fill_(1.0)
723
+
724
+ def _set_gradient_checkpointing(self, module, value=False):
725
+ if isinstance(module, RoFormerEncoder):
726
+ module.gradient_checkpointing = value
727
+
728
+
729
+ ROFORMER_START_DOCSTRING = r"""
730
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
731
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
732
+ behavior.
733
+
734
+ Parameters:
735
+ config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model.
736
+ Initializing with a config file does not load the weights associated with the model, only the
737
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
738
+ """
739
+
740
+ ROFORMER_INPUTS_DOCSTRING = r"""
741
+ Args:
742
+ input_ids (`torch.LongTensor` of shape `({0})`):
743
+ Indices of input sequence tokens in the vocabulary.
744
+
745
+ Indices can be obtained using [`RoFormerTokenizer`]. See [`PreTrainedTokenizer.encode`] and
746
+ [`PreTrainedTokenizer.__call__`] for details.
747
+
748
+ [What are input IDs?](../glossary#input-ids)
749
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
750
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
751
+
752
+ - 1 for tokens that are **not masked**,
753
+ - 0 for tokens that are **masked**.
754
+
755
+ [What are attention masks?](../glossary#attention-mask)
756
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
757
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
758
+ 1]`:
759
+
760
+ - 0 corresponds to a *sentence A* token,
761
+ - 1 corresponds to a *sentence B* token.
762
+
763
+ [What are token type IDs?](../glossary#token-type-ids)
764
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
765
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
766
+
767
+ - 1 indicates the head is **not masked**,
768
+ - 0 indicates the head is **masked**.
769
+
770
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
771
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
772
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
773
+ model's internal embedding lookup matrix.
774
+ output_attentions (`bool`, *optional*):
775
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
776
+ tensors for more detail.
777
+ output_hidden_states (`bool`, *optional*):
778
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
779
+ more detail.
780
+ return_dict (`bool`, *optional*):
781
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
782
+ """
783
+
784
+
785
+ @add_start_docstrings(
786
+ "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.",
787
+ ROFORMER_START_DOCSTRING,
788
+ )
789
+ class RoFormerModel(RoFormerPreTrainedModel):
790
+ """
791
+
792
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
793
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
794
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
795
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
796
+
797
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
798
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
799
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
800
+ """
801
+
802
+ def __init__(self, config):
803
+ super().__init__(config)
804
+ self.config = config
805
+ self.embeddings = RoFormerEmbeddings(config)
806
+
807
+ if config.embedding_size != config.hidden_size:
808
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
809
+
810
+ self.encoder = RoFormerEncoder(config)
811
+
812
+ # Initialize weights and apply final processing
813
+ self.post_init()
814
+
815
+ def get_input_embeddings(self):
816
+ return self.embeddings.word_embeddings
817
+
818
+ def set_input_embeddings(self, value):
819
+ self.embeddings.word_embeddings = value
820
+
821
+ def _prune_heads(self, heads_to_prune):
822
+ """
823
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
824
+ class PreTrainedModel
825
+ """
826
+ for layer, heads in heads_to_prune.items():
827
+ self.encoder.layer[layer].attention.prune_heads(heads)
828
+
829
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
830
+ @add_code_sample_docstrings(
831
+ processor_class=_TOKENIZER_FOR_DOC,
832
+ checkpoint=_CHECKPOINT_FOR_DOC,
833
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
834
+ config_class=_CONFIG_FOR_DOC,
835
+ )
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.LongTensor] = None,
839
+ attention_mask: Optional[torch.FloatTensor] = None,
840
+ token_type_ids: Optional[torch.LongTensor] = None,
841
+ head_mask: Optional[torch.FloatTensor] = None,
842
+ inputs_embeds: Optional[torch.FloatTensor] = None,
843
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
844
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
845
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
846
+ use_cache: Optional[bool] = None,
847
+ output_attentions: Optional[bool] = None,
848
+ output_hidden_states: Optional[bool] = None,
849
+ return_dict: Optional[bool] = None,
850
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]:
851
+ r"""
852
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
853
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
854
+ the model is configured as a decoder.
855
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
856
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
857
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
858
+
859
+ - 1 for tokens that are **not masked**,
860
+ - 0 for tokens that are **masked**.
861
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
862
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
863
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
864
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
865
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
866
+ use_cache (`bool`, *optional*):
867
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
868
+ `past_key_values`).
869
+ """
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ if self.config.is_decoder:
877
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
878
+ else:
879
+ use_cache = False
880
+
881
+ if input_ids is not None and inputs_embeds is not None:
882
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
883
+ elif input_ids is not None:
884
+ input_shape = input_ids.size()
885
+ elif inputs_embeds is not None:
886
+ input_shape = inputs_embeds.size()[:-1]
887
+ else:
888
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
889
+
890
+ batch_size, seq_length = input_shape
891
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
892
+
893
+ # past_key_values_length
894
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
895
+
896
+ if attention_mask is None:
897
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
898
+ if token_type_ids is None:
899
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
900
+
901
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
902
+ # ourselves in which case we just need to make it broadcastable to all heads.
903
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
904
+
905
+ # If a 2D or 3D attention mask is provided for the cross-attention
906
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
907
+ if self.config.is_decoder and encoder_hidden_states is not None:
908
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
909
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
910
+ if encoder_attention_mask is None:
911
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
912
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
913
+ else:
914
+ encoder_extended_attention_mask = None
915
+
916
+ # Prepare head mask if needed
917
+ # 1.0 in head_mask indicate we keep the head
918
+ # attention_probs has shape bsz x n_heads x N x N
919
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
920
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
921
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
922
+
923
+ embedding_output = self.embeddings(
924
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
925
+ )
926
+ if hasattr(self, "embeddings_project"):
927
+ embedding_output = self.embeddings_project(embedding_output)
928
+
929
+ encoder_outputs = self.encoder(
930
+ embedding_output,
931
+ attention_mask=extended_attention_mask,
932
+ head_mask=head_mask,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ encoder_attention_mask=encoder_extended_attention_mask,
935
+ past_key_values=past_key_values,
936
+ use_cache=use_cache,
937
+ output_attentions=output_attentions,
938
+ output_hidden_states=output_hidden_states,
939
+ return_dict=return_dict,
940
+ )
941
+ sequence_output = encoder_outputs[0]
942
+
943
+ if not return_dict:
944
+ return (sequence_output,) + encoder_outputs[1:]
945
+
946
+ return BaseModelOutputWithPastAndCrossAttentions(
947
+ last_hidden_state=sequence_output,
948
+ past_key_values=encoder_outputs.past_key_values,
949
+ hidden_states=encoder_outputs.hidden_states,
950
+ attentions=encoder_outputs.attentions,
951
+ cross_attentions=encoder_outputs.cross_attentions,
952
+ )
953
+
954
+
955
+ @add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
956
+ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
957
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
958
+
959
+ def __init__(self, config):
960
+ super().__init__(config)
961
+
962
+ if config.is_decoder:
963
+ logger.warning(
964
+ "If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for "
965
+ "bi-directional self-attention."
966
+ )
967
+
968
+ self.roformer = RoFormerModel(config)
969
+ self.cls = RoFormerOnlyMLMHead(config)
970
+
971
+ # Initialize weights and apply final processing
972
+ self.post_init()
973
+
974
+ def get_output_embeddings(self):
975
+ return self.cls.predictions.decoder
976
+
977
+ def set_output_embeddings(self, new_embeddings):
978
+ self.cls.predictions.decoder = new_embeddings
979
+
980
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
981
+ @add_code_sample_docstrings(
982
+ processor_class=_TOKENIZER_FOR_DOC,
983
+ checkpoint=_CHECKPOINT_FOR_DOC,
984
+ output_type=MaskedLMOutput,
985
+ config_class=_CONFIG_FOR_DOC,
986
+ )
987
+ def forward(
988
+ self,
989
+ input_ids: Optional[torch.LongTensor] = None,
990
+ attention_mask: Optional[torch.FloatTensor] = None,
991
+ token_type_ids: Optional[torch.LongTensor] = None,
992
+ head_mask: Optional[torch.FloatTensor] = None,
993
+ inputs_embeds: Optional[torch.FloatTensor] = None,
994
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
995
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
996
+ labels: Optional[torch.LongTensor] = None,
997
+ output_attentions: Optional[bool] = None,
998
+ output_hidden_states: Optional[bool] = None,
999
+ return_dict: Optional[bool] = None,
1000
+ loss_weight: Optional[torch.FloatTensor] = None,
1001
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
1002
+ r"""
1003
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1004
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1005
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1006
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1007
+ """
1008
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1009
+
1010
+ outputs = self.roformer(
1011
+ input_ids,
1012
+ attention_mask=attention_mask,
1013
+ token_type_ids=token_type_ids,
1014
+ head_mask=head_mask,
1015
+ inputs_embeds=inputs_embeds,
1016
+ encoder_hidden_states=encoder_hidden_states,
1017
+ encoder_attention_mask=encoder_attention_mask,
1018
+ output_attentions=output_attentions,
1019
+ output_hidden_states=output_hidden_states,
1020
+ return_dict=return_dict,
1021
+ )
1022
+
1023
+ sequence_output = outputs[0]
1024
+ prediction_scores = self.cls(sequence_output)
1025
+
1026
+ masked_lm_loss = None
1027
+ if labels is not None:
1028
+ loss_fct = CrossEntropyLoss(reduction="none") # -100 index = padding token
1029
+ labels = labels.view(-1)
1030
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels)
1031
+ loss_weight = loss_weight.view(-1)
1032
+ loss_weight[labels==-100] = 0.0
1033
+ masked_lm_loss = (masked_lm_loss * loss_weight / loss_weight.sum()).sum()
1034
+
1035
+ if not return_dict:
1036
+ output = (prediction_scores,) + outputs[1:]
1037
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1038
+
1039
+ return MaskedLMOutput(
1040
+ loss=masked_lm_loss,
1041
+ logits=prediction_scores,
1042
+ hidden_states=outputs.hidden_states,
1043
+ attentions=outputs.attentions,
1044
+ )
1045
+
1046
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1047
+ input_shape = input_ids.shape
1048
+ effective_batch_size = input_shape[0]
1049
+
1050
+ # add a dummy token
1051
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1052
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1053
+ dummy_token = torch.full(
1054
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1055
+ )
1056
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1057
+
1058
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1059
+
1060
+
1061
+ @add_start_docstrings(
1062
+ """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
1063
+ )
1064
+ class RoFormerForCausalLM(RoFormerPreTrainedModel):
1065
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1066
+
1067
+ def __init__(self, config):
1068
+ super().__init__(config)
1069
+
1070
+ if not config.is_decoder:
1071
+ logger.warning("If you want to use `RoFormerForCausalLM` as a standalone, add `is_decoder=True.`")
1072
+
1073
+ self.roformer = RoFormerModel(config)
1074
+ self.cls = RoFormerOnlyMLMHead(config)
1075
+
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ def get_output_embeddings(self):
1080
+ return self.cls.predictions.decoder
1081
+
1082
+ def set_output_embeddings(self, new_embeddings):
1083
+ self.cls.predictions.decoder = new_embeddings
1084
+
1085
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1086
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1087
+ def forward(
1088
+ self,
1089
+ input_ids: Optional[torch.LongTensor] = None,
1090
+ attention_mask: Optional[torch.FloatTensor] = None,
1091
+ token_type_ids: Optional[torch.LongTensor] = None,
1092
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1093
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1094
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1095
+ head_mask: Optional[torch.FloatTensor] = None,
1096
+