zpn commited on
Commit
d9ac587
·
1 Parent(s): 212555e

Upload RoFormerForMaskedLM

Browse files
Files changed (3) hide show
  1. config.json +34 -0
  2. modeling_nt_roformer.py +1587 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/fsx/home-zanussbaum/nucleotide-transformer/nt_transformer/saved_models/roformer/mamalian_cds_2k_window_step84999",
3
+ "architectures": [
4
+ "RoFormerForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoModelForMaskedLM": "modeling_nt_roformer.RoFormerForMaskedLM"
9
+ },
10
+ "classifier_dropout": "None",
11
+ "embedding_size": 768,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-12,
18
+ "max_position_embeddings": 512,
19
+ "model_type": "roformer",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "pad_token_id": 0,
23
+ "position_embedding_type": "absolute",
24
+ "rotary_value": true,
25
+ "summary_activation": "gelu",
26
+ "summary_last_dropout": 0.1,
27
+ "summary_type": "first",
28
+ "summary_use_proj": true,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.25.1",
31
+ "type_vocab_size": 2,
32
+ "use_cache": true,
33
+ "vocab_size": 9
34
+ }
modeling_nt_roformer.py ADDED
@@ -0,0 +1,1587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch RoFormer model."""
16
+
17
+
18
+ import math
19
+ import os
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutput,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
39
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers import RoFormerConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "junnyu/roformer_chinese_base"
53
+ _CONFIG_FOR_DOC = "RoFormerConfig"
54
+ _TOKENIZER_FOR_DOC = "RoFormerTokenizer"
55
+
56
+ ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "junnyu/roformer_chinese_small",
58
+ "junnyu/roformer_chinese_base",
59
+ "junnyu/roformer_chinese_char_small",
60
+ "junnyu/roformer_chinese_char_base",
61
+ "junnyu/roformer_small_discriminator",
62
+ "junnyu/roformer_small_generator"
63
+ # See all RoFormer models at https://huggingface.co/models?filter=roformer
64
+ ]
65
+
66
+
67
+ # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->RoFormer
68
+ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
69
+ """This module produces sinusoidal positional embeddings of any length."""
70
+
71
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
72
+ super().__init__(num_positions, embedding_dim)
73
+ self.weight = self._init_weight(self.weight)
74
+
75
+ @staticmethod
76
+ def _init_weight(out: nn.Parameter) -> nn.Parameter:
77
+ """
78
+ Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
79
+ the 2nd half of the vector. [dim // 2:]
80
+ """
81
+ n_pos, dim = out.shape
82
+ position_enc = np.array(
83
+ [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
84
+ )
85
+ out.requires_grad = False # set early to avoid an error in pytorch-1.8+
86
+ sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
87
+ out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
88
+ out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
89
+ out.detach_()
90
+ return out
91
+
92
+ @torch.no_grad()
93
+ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
94
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
95
+ bsz, seq_len = input_ids_shape[:2]
96
+ positions = torch.arange(
97
+ past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
98
+ )
99
+ return super().forward(positions)
100
+
101
+
102
+ def load_tf_weights_in_roformer(model, config, tf_checkpoint_path):
103
+ """Load tf checkpoints in a pytorch model."""
104
+ try:
105
+ import re
106
+
107
+ import numpy as np
108
+ import tensorflow as tf
109
+ except ImportError:
110
+ logger.error(
111
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
112
+ "https://www.tensorflow.org/install/ for installation instructions."
113
+ )
114
+ raise
115
+ tf_path = os.path.abspath(tf_checkpoint_path)
116
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
117
+ # Load weights from TF model
118
+ init_vars = tf.train.list_variables(tf_path)
119
+ names = []
120
+ arrays = []
121
+ for name, shape in init_vars:
122
+ logger.info(f"Loading TF weight {name} with shape {shape}")
123
+ array = tf.train.load_variable(tf_path, name)
124
+ names.append(name.replace("bert", "roformer"))
125
+ arrays.append(array)
126
+
127
+ for name, array in zip(names, arrays):
128
+ name = name.split("/")
129
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
130
+ # which are not required for using pretrained model
131
+ if any(
132
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
133
+ for n in name
134
+ ):
135
+ logger.info(f"Skipping {'/'.join(name)}")
136
+ continue
137
+ pointer = model
138
+ for m_name in name:
139
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
140
+ scope_names = re.split(r"_(\d+)", m_name)
141
+ else:
142
+ scope_names = [m_name]
143
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
144
+ pointer = getattr(pointer, "weight")
145
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
146
+ pointer = getattr(pointer, "bias")
147
+ elif scope_names[0] == "output_weights":
148
+ pointer = getattr(pointer, "weight")
149
+ elif scope_names[0] == "squad":
150
+ pointer = getattr(pointer, "classifier")
151
+ else:
152
+ try:
153
+ pointer = getattr(pointer, scope_names[0])
154
+ except AttributeError:
155
+ logger.info(f"Skipping {'/'.join(name)}")
156
+ continue
157
+ if len(scope_names) >= 2:
158
+ num = int(scope_names[1])
159
+ pointer = pointer[num]
160
+ if m_name[-11:] == "_embeddings":
161
+ pointer = getattr(pointer, "weight")
162
+ elif m_name == "kernel":
163
+ array = np.transpose(array)
164
+ try:
165
+ if not pointer.shape == array.shape:
166
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
167
+ except AssertionError as e:
168
+ e.args += (pointer.shape, array.shape)
169
+ raise
170
+ logger.info(f"Initialize PyTorch weight {name}")
171
+ pointer.data = torch.from_numpy(array)
172
+ return model
173
+
174
+
175
+ class RoFormerEmbeddings(nn.Module):
176
+ """Construct the embeddings from word and token_type embeddings."""
177
+
178
+ def __init__(self, config):
179
+ super().__init__()
180
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
181
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
182
+
183
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
184
+ # any TensorFlow checkpoint file
185
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
186
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
187
+
188
+ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None):
189
+ if input_ids is not None:
190
+ input_shape = input_ids.size()
191
+ else:
192
+ input_shape = inputs_embeds.size()[:-1]
193
+
194
+ if inputs_embeds is None:
195
+ inputs_embeds = self.word_embeddings(input_ids)
196
+
197
+ if token_type_ids is None:
198
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
199
+
200
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
201
+
202
+ embeddings = inputs_embeds + token_type_embeddings
203
+
204
+ embeddings = self.LayerNorm(embeddings)
205
+ embeddings = self.dropout(embeddings)
206
+ return embeddings
207
+
208
+
209
+ class RoFormerSelfAttention(nn.Module):
210
+ def __init__(self, config):
211
+ super().__init__()
212
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
213
+ raise ValueError(
214
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
215
+ f"heads ({config.num_attention_heads})"
216
+ )
217
+
218
+ self.num_attention_heads = config.num_attention_heads
219
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
220
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
221
+
222
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
223
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
224
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
225
+
226
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
+
228
+ self.is_decoder = config.is_decoder
229
+ self.rotary_value = config.rotary_value
230
+
231
+ def transpose_for_scores(self, x):
232
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
233
+ x = x.view(*new_x_shape)
234
+ return x.permute(0, 2, 1, 3)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ attention_mask=None,
240
+ sinusoidal_pos=None,
241
+ head_mask=None,
242
+ encoder_hidden_states=None,
243
+ encoder_attention_mask=None,
244
+ past_key_value=None,
245
+ output_attentions=False,
246
+ ):
247
+ mixed_query_layer = self.query(hidden_states)
248
+ query_layer = self.transpose_for_scores(mixed_query_layer)
249
+ # If this is instantiated as a cross-attention module, the keys
250
+ # and values come from an encoder; the attention mask needs to be
251
+ # such that the encoder's padding tokens are not attended to.
252
+ is_cross_attention = encoder_hidden_states is not None
253
+
254
+ if is_cross_attention and past_key_value is not None:
255
+ # reuse k,v, cross_attentions
256
+ key_layer = past_key_value[0]
257
+ value_layer = past_key_value[1]
258
+ attention_mask = encoder_attention_mask
259
+ elif is_cross_attention:
260
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
261
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
262
+ attention_mask = encoder_attention_mask
263
+ elif past_key_value is not None:
264
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
265
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
266
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
267
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
268
+ else:
269
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
270
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
271
+ if sinusoidal_pos is not None:
272
+ if self.rotary_value:
273
+ query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
274
+ sinusoidal_pos, query_layer, key_layer, value_layer
275
+ )
276
+ else:
277
+ query_layer, key_layer = self.apply_rotary_position_embeddings(
278
+ sinusoidal_pos, query_layer, key_layer
279
+ )
280
+ if self.is_decoder:
281
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
282
+ # Further calls to cross_attention layer can then reuse all cross-attention
283
+ # key/value_states (first "if" case)
284
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
285
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
286
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
287
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
288
+ past_key_value = (key_layer, value_layer)
289
+
290
+ # Take the dot product between "query" and "key" to get the raw attention scores.
291
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
292
+
293
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
294
+ if attention_mask is not None:
295
+ # Apply the attention mask is (precomputed for all layers in RoFormerModel forward() function)
296
+ attention_scores = attention_scores + attention_mask
297
+
298
+ # Normalize the attention scores to probabilities.
299
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
300
+
301
+ # This is actually dropping out entire tokens to attend to, which might
302
+ # seem a bit unusual, but is taken from the original Transformer paper.
303
+ attention_probs = self.dropout(attention_probs)
304
+
305
+ # Mask heads if we want to
306
+ if head_mask is not None:
307
+ attention_probs = attention_probs * head_mask
308
+
309
+ context_layer = torch.matmul(attention_probs, value_layer)
310
+
311
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
312
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
313
+ context_layer = context_layer.view(*new_context_layer_shape)
314
+
315
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
316
+
317
+ if self.is_decoder:
318
+ outputs = outputs + (past_key_value,)
319
+ return outputs
320
+
321
+ @staticmethod
322
+ def apply_rotary_position_embeddings(sinusoidal_pos, query_layer, key_layer, value_layer=None):
323
+ # https://kexue.fm/archives/8265
324
+ # sin [batch_size, num_heads, sequence_length, embed_size_per_head//2]
325
+ # cos [batch_size, num_heads, sequence_length, embed_size_per_head//2]
326
+ sin, cos = sinusoidal_pos.chunk(2, dim=-1)
327
+ # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
328
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape_as(sinusoidal_pos)
329
+ # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
330
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape_as(sinusoidal_pos)
331
+ # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
332
+ rotate_half_query_layer = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], dim=-1).reshape_as(
333
+ query_layer
334
+ )
335
+ query_layer = query_layer * cos_pos + rotate_half_query_layer * sin_pos
336
+ # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
337
+ rotate_half_key_layer = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], dim=-1).reshape_as(key_layer)
338
+ key_layer = key_layer * cos_pos + rotate_half_key_layer * sin_pos
339
+ if value_layer is not None:
340
+ # rotate_half_value_layer [-v1,v0,-v3,v2......,-vd-1,vd-2]
341
+ rotate_half_value_layer = torch.stack([-value_layer[..., 1::2], value_layer[..., ::2]], dim=-1).reshape_as(
342
+ value_layer
343
+ )
344
+ value_layer = value_layer * cos_pos + rotate_half_value_layer * sin_pos
345
+ return query_layer, key_layer, value_layer
346
+ return query_layer, key_layer
347
+
348
+
349
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->RoFormer
350
+ class RoFormerSelfOutput(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
354
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
355
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
356
+
357
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
358
+ hidden_states = self.dense(hidden_states)
359
+ hidden_states = self.dropout(hidden_states)
360
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
361
+ return hidden_states
362
+
363
+
364
+ class RoFormerAttention(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.self = RoFormerSelfAttention(config)
368
+ self.output = RoFormerSelfOutput(config)
369
+ self.pruned_heads = set()
370
+
371
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
372
+ def prune_heads(self, heads):
373
+ if len(heads) == 0:
374
+ return
375
+ heads, index = find_pruneable_heads_and_indices(
376
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
377
+ )
378
+
379
+ # Prune linear layers
380
+ self.self.query = prune_linear_layer(self.self.query, index)
381
+ self.self.key = prune_linear_layer(self.self.key, index)
382
+ self.self.value = prune_linear_layer(self.self.value, index)
383
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
384
+
385
+ # Update hyper params and store pruned heads
386
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
387
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
388
+ self.pruned_heads = self.pruned_heads.union(heads)
389
+
390
+ # End Copy
391
+ def forward(
392
+ self,
393
+ hidden_states,
394
+ attention_mask=None,
395
+ sinusoidal_pos=None,
396
+ head_mask=None,
397
+ encoder_hidden_states=None,
398
+ encoder_attention_mask=None,
399
+ past_key_value=None,
400
+ output_attentions=False,
401
+ ):
402
+ self_outputs = self.self(
403
+ hidden_states,
404
+ attention_mask,
405
+ sinusoidal_pos,
406
+ head_mask,
407
+ encoder_hidden_states,
408
+ encoder_attention_mask,
409
+ past_key_value,
410
+ output_attentions,
411
+ )
412
+ attention_output = self.output(self_outputs[0], hidden_states)
413
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
414
+ return outputs
415
+
416
+
417
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->RoFormer
418
+ class RoFormerIntermediate(nn.Module):
419
+ def __init__(self, config):
420
+ super().__init__()
421
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
422
+ if isinstance(config.hidden_act, str):
423
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
424
+ else:
425
+ self.intermediate_act_fn = config.hidden_act
426
+
427
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
428
+ hidden_states = self.dense(hidden_states)
429
+ hidden_states = self.intermediate_act_fn(hidden_states)
430
+ return hidden_states
431
+
432
+
433
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->RoFormer
434
+ class RoFormerOutput(nn.Module):
435
+ def __init__(self, config):
436
+ super().__init__()
437
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
438
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
439
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
440
+
441
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
442
+ hidden_states = self.dense(hidden_states)
443
+ hidden_states = self.dropout(hidden_states)
444
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
445
+ return hidden_states
446
+
447
+
448
+ class RoFormerLayer(nn.Module):
449
+ def __init__(self, config):
450
+ super().__init__()
451
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
452
+ self.seq_len_dim = 1
453
+ self.attention = RoFormerAttention(config)
454
+ self.is_decoder = config.is_decoder
455
+ self.add_cross_attention = config.add_cross_attention
456
+ if self.add_cross_attention:
457
+ if not self.is_decoder:
458
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
459
+ self.crossattention = RoFormerAttention(config)
460
+ self.intermediate = RoFormerIntermediate(config)
461
+ self.output = RoFormerOutput(config)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states,
466
+ attention_mask=None,
467
+ sinusoidal_pos=None,
468
+ head_mask=None,
469
+ encoder_hidden_states=None,
470
+ encoder_attention_mask=None,
471
+ past_key_value=None,
472
+ output_attentions=False,
473
+ ):
474
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
475
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
476
+ self_attention_outputs = self.attention(
477
+ hidden_states,
478
+ attention_mask,
479
+ sinusoidal_pos,
480
+ head_mask,
481
+ output_attentions=output_attentions,
482
+ past_key_value=self_attn_past_key_value,
483
+ )
484
+ attention_output = self_attention_outputs[0]
485
+
486
+ # if decoder, the last output is tuple of self-attn cache
487
+ if self.is_decoder:
488
+ outputs = self_attention_outputs[1:-1]
489
+ present_key_value = self_attention_outputs[-1]
490
+ else:
491
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
492
+
493
+ cross_attn_present_key_value = None
494
+ if self.is_decoder and encoder_hidden_states is not None:
495
+ if not hasattr(self, "crossattention"):
496
+ raise ValueError(
497
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
498
+ "layers by setting `config.add_cross_attention=True`"
499
+ )
500
+
501
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
502
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
503
+ cross_attention_outputs = self.crossattention(
504
+ attention_output,
505
+ attention_mask,
506
+ sinusoidal_pos,
507
+ head_mask,
508
+ encoder_hidden_states,
509
+ encoder_attention_mask,
510
+ cross_attn_past_key_value,
511
+ output_attentions,
512
+ )
513
+ attention_output = cross_attention_outputs[0]
514
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
515
+
516
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
517
+ cross_attn_present_key_value = cross_attention_outputs[-1]
518
+ present_key_value = present_key_value + cross_attn_present_key_value
519
+
520
+ layer_output = apply_chunking_to_forward(
521
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
522
+ )
523
+ outputs = (layer_output,) + outputs
524
+
525
+ # if decoder, return the attn key/values as the last output
526
+ if self.is_decoder:
527
+ outputs = outputs + (present_key_value,)
528
+
529
+ return outputs
530
+
531
+ def feed_forward_chunk(self, attention_output):
532
+ intermediate_output = self.intermediate(attention_output)
533
+ layer_output = self.output(intermediate_output, attention_output)
534
+ return layer_output
535
+
536
+
537
+ class RoFormerEncoder(nn.Module):
538
+ def __init__(self, config):
539
+ super().__init__()
540
+ self.config = config
541
+ self.embed_positions = RoFormerSinusoidalPositionalEmbedding(
542
+ config.max_position_embeddings, config.hidden_size // config.num_attention_heads
543
+ )
544
+ self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)])
545
+ self.gradient_checkpointing = False
546
+
547
+ def forward(
548
+ self,
549
+ hidden_states,
550
+ attention_mask=None,
551
+ head_mask=None,
552
+ encoder_hidden_states=None,
553
+ encoder_attention_mask=None,
554
+ past_key_values=None,
555
+ use_cache=None,
556
+ output_attentions=False,
557
+ output_hidden_states=False,
558
+ return_dict=True,
559
+ ):
560
+ all_hidden_states = () if output_hidden_states else None
561
+ all_self_attentions = () if output_attentions else None
562
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
563
+
564
+ # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
565
+ sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1])[None, None, :, :]
566
+
567
+ next_decoder_cache = () if use_cache else None
568
+ for i, layer_module in enumerate(self.layer):
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ layer_head_mask = head_mask[i] if head_mask is not None else None
573
+ past_key_value = past_key_values[i] if past_key_values is not None else None
574
+
575
+ if self.gradient_checkpointing and self.training:
576
+
577
+ if use_cache:
578
+ logger.warning(
579
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
580
+ )
581
+ use_cache = False
582
+
583
+ def create_custom_forward(module):
584
+ def custom_forward(*inputs):
585
+ return module(*inputs, past_key_value, output_attentions)
586
+
587
+ return custom_forward
588
+
589
+ layer_outputs = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(layer_module),
591
+ hidden_states,
592
+ attention_mask,
593
+ sinusoidal_pos,
594
+ layer_head_mask,
595
+ encoder_hidden_states,
596
+ encoder_attention_mask,
597
+ )
598
+ else:
599
+ layer_outputs = layer_module(
600
+ hidden_states,
601
+ attention_mask,
602
+ sinusoidal_pos,
603
+ layer_head_mask,
604
+ encoder_hidden_states,
605
+ encoder_attention_mask,
606
+ past_key_value,
607
+ output_attentions,
608
+ )
609
+
610
+ hidden_states = layer_outputs[0]
611
+ if use_cache:
612
+ next_decoder_cache += (layer_outputs[-1],)
613
+ if output_attentions:
614
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
615
+ if self.config.add_cross_attention:
616
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
617
+
618
+ if output_hidden_states:
619
+ all_hidden_states = all_hidden_states + (hidden_states,)
620
+
621
+ if not return_dict:
622
+ return tuple(
623
+ v
624
+ for v in [
625
+ hidden_states,
626
+ next_decoder_cache,
627
+ all_hidden_states,
628
+ all_self_attentions,
629
+ all_cross_attentions,
630
+ ]
631
+ if v is not None
632
+ )
633
+ return BaseModelOutputWithPastAndCrossAttentions(
634
+ last_hidden_state=hidden_states,
635
+ past_key_values=next_decoder_cache,
636
+ hidden_states=all_hidden_states,
637
+ attentions=all_self_attentions,
638
+ cross_attentions=all_cross_attentions,
639
+ )
640
+
641
+
642
+ class RoFormerPredictionHeadTransform(nn.Module):
643
+ def __init__(self, config):
644
+ super().__init__()
645
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
646
+ if isinstance(config.hidden_act, str):
647
+ self.transform_act_fn = ACT2FN[config.hidden_act]
648
+ else:
649
+ self.transform_act_fn = config.hidden_act
650
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
651
+
652
+ def forward(self, hidden_states):
653
+ hidden_states = self.dense(hidden_states)
654
+ hidden_states = self.transform_act_fn(hidden_states)
655
+ hidden_states = self.LayerNorm(hidden_states)
656
+ return hidden_states
657
+
658
+
659
+ class RoFormerLMPredictionHead(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.transform = RoFormerPredictionHeadTransform(config)
663
+
664
+ # The output weights are the same as the input embeddings, but there is
665
+ # an output-only bias for each token.
666
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
667
+
668
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
669
+
670
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
671
+ self.decoder.bias = self.bias
672
+
673
+ def forward(self, hidden_states):
674
+ hidden_states = self.transform(hidden_states)
675
+ hidden_states = self.decoder(hidden_states)
676
+ return hidden_states
677
+
678
+
679
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->RoFormer
680
+ class RoFormerOnlyMLMHead(nn.Module):
681
+ def __init__(self, config):
682
+ super().__init__()
683
+ self.predictions = RoFormerLMPredictionHead(config)
684
+
685
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
686
+ prediction_scores = self.predictions(sequence_output)
687
+ return prediction_scores
688
+
689
+
690
+ class RoFormerPreTrainedModel(PreTrainedModel):
691
+ """
692
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
693
+ models.
694
+ """
695
+
696
+ config_class = RoFormerConfig
697
+ load_tf_weights = load_tf_weights_in_roformer
698
+ base_model_prefix = "roformer"
699
+ supports_gradient_checkpointing = True
700
+ _keys_to_ignore_on_load_missing = []
701
+ _keys_to_ignore_on_load_unexpected = [
702
+ r"roformer.embeddings_project.weight",
703
+ r"roformer.embeddings_project.bias",
704
+ ]
705
+
706
+ def _init_weights(self, module):
707
+ """Initialize the weights"""
708
+ if isinstance(module, nn.Linear):
709
+ # Slightly different from the TF version which uses truncated_normal for initialization
710
+ # cf https://github.com/pytorch/pytorch/pull/5617
711
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
712
+ if module.bias is not None:
713
+ module.bias.data.zero_()
714
+ elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
715
+ pass
716
+ elif isinstance(module, nn.Embedding):
717
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
718
+ if module.padding_idx is not None:
719
+ module.weight.data[module.padding_idx].zero_()
720
+ elif isinstance(module, nn.LayerNorm):
721
+ module.bias.data.zero_()
722
+ module.weight.data.fill_(1.0)
723
+
724
+ def _set_gradient_checkpointing(self, module, value=False):
725
+ if isinstance(module, RoFormerEncoder):
726
+ module.gradient_checkpointing = value
727
+
728
+
729
+ ROFORMER_START_DOCSTRING = r"""
730
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
731
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
732
+ behavior.
733
+
734
+ Parameters:
735
+ config ([`RoFormerConfig`]): Model configuration class with all the parameters of the model.
736
+ Initializing with a config file does not load the weights associated with the model, only the
737
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
738
+ """
739
+
740
+ ROFORMER_INPUTS_DOCSTRING = r"""
741
+ Args:
742
+ input_ids (`torch.LongTensor` of shape `({0})`):
743
+ Indices of input sequence tokens in the vocabulary.
744
+
745
+ Indices can be obtained using [`RoFormerTokenizer`]. See [`PreTrainedTokenizer.encode`] and
746
+ [`PreTrainedTokenizer.__call__`] for details.
747
+
748
+ [What are input IDs?](../glossary#input-ids)
749
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
750
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
751
+
752
+ - 1 for tokens that are **not masked**,
753
+ - 0 for tokens that are **masked**.
754
+
755
+ [What are attention masks?](../glossary#attention-mask)
756
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
757
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
758
+ 1]`:
759
+
760
+ - 0 corresponds to a *sentence A* token,
761
+ - 1 corresponds to a *sentence B* token.
762
+
763
+ [What are token type IDs?](../glossary#token-type-ids)
764
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
765
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
766
+
767
+ - 1 indicates the head is **not masked**,
768
+ - 0 indicates the head is **masked**.
769
+
770
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
771
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
772
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
773
+ model's internal embedding lookup matrix.
774
+ output_attentions (`bool`, *optional*):
775
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
776
+ tensors for more detail.
777
+ output_hidden_states (`bool`, *optional*):
778
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
779
+ more detail.
780
+ return_dict (`bool`, *optional*):
781
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
782
+ """
783
+
784
+
785
+ @add_start_docstrings(
786
+ "The bare RoFormer Model transformer outputting raw hidden-states without any specific head on top.",
787
+ ROFORMER_START_DOCSTRING,
788
+ )
789
+ class RoFormerModel(RoFormerPreTrainedModel):
790
+ """
791
+
792
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
793
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
794
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
795
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
796
+
797
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
798
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
799
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
800
+ """
801
+
802
+ def __init__(self, config):
803
+ super().__init__(config)
804
+ self.config = config
805
+ self.embeddings = RoFormerEmbeddings(config)
806
+
807
+ if config.embedding_size != config.hidden_size:
808
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
809
+
810
+ self.encoder = RoFormerEncoder(config)
811
+
812
+ # Initialize weights and apply final processing
813
+ self.post_init()
814
+
815
+ def get_input_embeddings(self):
816
+ return self.embeddings.word_embeddings
817
+
818
+ def set_input_embeddings(self, value):
819
+ self.embeddings.word_embeddings = value
820
+
821
+ def _prune_heads(self, heads_to_prune):
822
+ """
823
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
824
+ class PreTrainedModel
825
+ """
826
+ for layer, heads in heads_to_prune.items():
827
+ self.encoder.layer[layer].attention.prune_heads(heads)
828
+
829
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
830
+ @add_code_sample_docstrings(
831
+ processor_class=_TOKENIZER_FOR_DOC,
832
+ checkpoint=_CHECKPOINT_FOR_DOC,
833
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
834
+ config_class=_CONFIG_FOR_DOC,
835
+ )
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.LongTensor] = None,
839
+ attention_mask: Optional[torch.FloatTensor] = None,
840
+ token_type_ids: Optional[torch.LongTensor] = None,
841
+ head_mask: Optional[torch.FloatTensor] = None,
842
+ inputs_embeds: Optional[torch.FloatTensor] = None,
843
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
844
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
845
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
846
+ use_cache: Optional[bool] = None,
847
+ output_attentions: Optional[bool] = None,
848
+ output_hidden_states: Optional[bool] = None,
849
+ return_dict: Optional[bool] = None,
850
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]:
851
+ r"""
852
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
853
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
854
+ the model is configured as a decoder.
855
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
856
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
857
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
858
+
859
+ - 1 for tokens that are **not masked**,
860
+ - 0 for tokens that are **masked**.
861
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
862
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
863
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
864
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
865
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
866
+ use_cache (`bool`, *optional*):
867
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
868
+ `past_key_values`).
869
+ """
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ if self.config.is_decoder:
877
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
878
+ else:
879
+ use_cache = False
880
+
881
+ if input_ids is not None and inputs_embeds is not None:
882
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
883
+ elif input_ids is not None:
884
+ input_shape = input_ids.size()
885
+ elif inputs_embeds is not None:
886
+ input_shape = inputs_embeds.size()[:-1]
887
+ else:
888
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
889
+
890
+ batch_size, seq_length = input_shape
891
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
892
+
893
+ # past_key_values_length
894
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
895
+
896
+ if attention_mask is None:
897
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
898
+ if token_type_ids is None:
899
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
900
+
901
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
902
+ # ourselves in which case we just need to make it broadcastable to all heads.
903
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
904
+
905
+ # If a 2D or 3D attention mask is provided for the cross-attention
906
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
907
+ if self.config.is_decoder and encoder_hidden_states is not None:
908
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
909
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
910
+ if encoder_attention_mask is None:
911
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
912
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
913
+ else:
914
+ encoder_extended_attention_mask = None
915
+
916
+ # Prepare head mask if needed
917
+ # 1.0 in head_mask indicate we keep the head
918
+ # attention_probs has shape bsz x n_heads x N x N
919
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
920
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
921
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
922
+
923
+ embedding_output = self.embeddings(
924
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
925
+ )
926
+ if hasattr(self, "embeddings_project"):
927
+ embedding_output = self.embeddings_project(embedding_output)
928
+
929
+ encoder_outputs = self.encoder(
930
+ embedding_output,
931
+ attention_mask=extended_attention_mask,
932
+ head_mask=head_mask,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ encoder_attention_mask=encoder_extended_attention_mask,
935
+ past_key_values=past_key_values,
936
+ use_cache=use_cache,
937
+ output_attentions=output_attentions,
938
+ output_hidden_states=output_hidden_states,
939
+ return_dict=return_dict,
940
+ )
941
+ sequence_output = encoder_outputs[0]
942
+
943
+ if not return_dict:
944
+ return (sequence_output,) + encoder_outputs[1:]
945
+
946
+ return BaseModelOutputWithPastAndCrossAttentions(
947
+ last_hidden_state=sequence_output,
948
+ past_key_values=encoder_outputs.past_key_values,
949
+ hidden_states=encoder_outputs.hidden_states,
950
+ attentions=encoder_outputs.attentions,
951
+ cross_attentions=encoder_outputs.cross_attentions,
952
+ )
953
+
954
+
955
+ @add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
956
+ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
957
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
958
+
959
+ def __init__(self, config):
960
+ super().__init__(config)
961
+
962
+ if config.is_decoder:
963
+ logger.warning(
964
+ "If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for "
965
+ "bi-directional self-attention."
966
+ )
967
+
968
+ self.roformer = RoFormerModel(config)
969
+ self.cls = RoFormerOnlyMLMHead(config)
970
+
971
+ # Initialize weights and apply final processing
972
+ self.post_init()
973
+
974
+ def get_output_embeddings(self):
975
+ return self.cls.predictions.decoder
976
+
977
+ def set_output_embeddings(self, new_embeddings):
978
+ self.cls.predictions.decoder = new_embeddings
979
+
980
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
981
+ @add_code_sample_docstrings(
982
+ processor_class=_TOKENIZER_FOR_DOC,
983
+ checkpoint=_CHECKPOINT_FOR_DOC,
984
+ output_type=MaskedLMOutput,
985
+ config_class=_CONFIG_FOR_DOC,
986
+ )
987
+ def forward(
988
+ self,
989
+ input_ids: Optional[torch.LongTensor] = None,
990
+ attention_mask: Optional[torch.FloatTensor] = None,
991
+ token_type_ids: Optional[torch.LongTensor] = None,
992
+ head_mask: Optional[torch.FloatTensor] = None,
993
+ inputs_embeds: Optional[torch.FloatTensor] = None,
994
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
995
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
996
+ labels: Optional[torch.LongTensor] = None,
997
+ output_attentions: Optional[bool] = None,
998
+ output_hidden_states: Optional[bool] = None,
999
+ return_dict: Optional[bool] = None,
1000
+ loss_weight: Optional[torch.FloatTensor] = None,
1001
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
1002
+ r"""
1003
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1004
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1005
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1006
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1007
+ """
1008
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1009
+
1010
+ outputs = self.roformer(
1011
+ input_ids,
1012
+ attention_mask=attention_mask,
1013
+ token_type_ids=token_type_ids,
1014
+ head_mask=head_mask,
1015
+ inputs_embeds=inputs_embeds,
1016
+ encoder_hidden_states=encoder_hidden_states,
1017
+ encoder_attention_mask=encoder_attention_mask,
1018
+ output_attentions=output_attentions,
1019
+ output_hidden_states=output_hidden_states,
1020
+ return_dict=return_dict,
1021
+ )
1022
+
1023
+ sequence_output = outputs[0]
1024
+ prediction_scores = self.cls(sequence_output)
1025
+
1026
+ masked_lm_loss = None
1027
+ if labels is not None:
1028
+ loss_fct = CrossEntropyLoss(reduction="none") # -100 index = padding token
1029
+ labels = labels.view(-1)
1030
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels)
1031
+ loss_weight = loss_weight.view(-1)
1032
+ loss_weight[labels==-100] = 0.0
1033
+ masked_lm_loss = (masked_lm_loss * loss_weight / loss_weight.sum()).sum()
1034
+
1035
+ if not return_dict:
1036
+ output = (prediction_scores,) + outputs[1:]
1037
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1038
+
1039
+ return MaskedLMOutput(
1040
+ loss=masked_lm_loss,
1041
+ logits=prediction_scores,
1042
+ hidden_states=outputs.hidden_states,
1043
+ attentions=outputs.attentions,
1044
+ )
1045
+
1046
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1047
+ input_shape = input_ids.shape
1048
+ effective_batch_size = input_shape[0]
1049
+
1050
+ # add a dummy token
1051
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1052
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1053
+ dummy_token = torch.full(
1054
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1055
+ )
1056
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1057
+
1058
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1059
+
1060
+
1061
+ @add_start_docstrings(
1062
+ """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
1063
+ )
1064
+ class RoFormerForCausalLM(RoFormerPreTrainedModel):
1065
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1066
+
1067
+ def __init__(self, config):
1068
+ super().__init__(config)
1069
+
1070
+ if not config.is_decoder:
1071
+ logger.warning("If you want to use `RoFormerForCausalLM` as a standalone, add `is_decoder=True.`")
1072
+
1073
+ self.roformer = RoFormerModel(config)
1074
+ self.cls = RoFormerOnlyMLMHead(config)
1075
+
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ def get_output_embeddings(self):
1080
+ return self.cls.predictions.decoder
1081
+
1082
+ def set_output_embeddings(self, new_embeddings):
1083
+ self.cls.predictions.decoder = new_embeddings
1084
+
1085
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1086
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1087
+ def forward(
1088
+ self,
1089
+ input_ids: Optional[torch.LongTensor] = None,
1090
+ attention_mask: Optional[torch.FloatTensor] = None,
1091
+ token_type_ids: Optional[torch.LongTensor] = None,
1092
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1093
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1094
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1095
+ head_mask: Optional[torch.FloatTensor] = None,
1096
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1097
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1098
+ labels: Optional[torch.LongTensor] = None,
1099
+ use_cache: Optional[bool] = None,
1100
+ output_attentions: Optional[bool] = None,
1101
+ output_hidden_states: Optional[bool] = None,
1102
+ return_dict: Optional[bool] = None,
1103
+ ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]:
1104
+ r"""
1105
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1106
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1107
+ the model is configured as a decoder.
1108
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1109
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1110
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1111
+
1112
+ - 1 for tokens that are **not masked**,
1113
+ - 0 for tokens that are **masked**.
1114
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1115
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1116
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1117
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1118
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1119
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1120
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1121
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1122
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
1123
+ use_cache (`bool`, *optional*):
1124
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1125
+ `past_key_values`).
1126
+
1127
+ Returns:
1128
+
1129
+ Example:
1130
+
1131
+ ```python
1132
+ >>> from transformers import RoFormerTokenizer, RoFormerForCausalLM, RoFormerConfig
1133
+ >>> import torch
1134
+
1135
+ >>> tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
1136
+ >>> config = RoFormerConfig.from_pretrained("junnyu/roformer_chinese_base")
1137
+ >>> config.is_decoder = True
1138
+ >>> model = RoFormerForCausalLM.from_pretrained("junnyu/roformer_chinese_base", config=config)
1139
+
1140
+ >>> inputs = tokenizer("今天天气非常好。", return_tensors="pt")
1141
+ >>> outputs = model(**inputs)
1142
+
1143
+ >>> prediction_logits = outputs.logits
1144
+ ```"""
1145
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1146
+
1147
+ outputs = self.roformer(
1148
+ input_ids,
1149
+ attention_mask=attention_mask,
1150
+ token_type_ids=token_type_ids,
1151
+ head_mask=head_mask,
1152
+ inputs_embeds=inputs_embeds,
1153
+ encoder_hidden_states=encoder_hidden_states,
1154
+ encoder_attention_mask=encoder_attention_mask,
1155
+ past_key_values=past_key_values,
1156
+ use_cache=use_cache,
1157
+ output_attentions=output_attentions,
1158
+ output_hidden_states=output_hidden_states,
1159
+ return_dict=return_dict,
1160
+ )
1161
+
1162
+ sequence_output = outputs[0]
1163
+ prediction_scores = self.cls(sequence_output)
1164
+
1165
+ lm_loss = None
1166
+ if labels is not None:
1167
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1168
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1169
+ labels = labels[:, 1:].contiguous()
1170
+ loss_fct = CrossEntropyLoss()
1171
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1172
+
1173
+ if not return_dict:
1174
+ output = (prediction_scores,) + outputs[1:]
1175
+ return ((lm_loss,) + output) if lm_loss is not None else output
1176
+
1177
+ return CausalLMOutputWithCrossAttentions(
1178
+ loss=lm_loss,
1179
+ logits=prediction_scores,
1180
+ past_key_values=outputs.past_key_values,
1181
+ hidden_states=outputs.hidden_states,
1182
+ attentions=outputs.attentions,
1183
+ cross_attentions=outputs.cross_attentions,
1184
+ )
1185
+
1186
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1187
+ input_shape = input_ids.shape
1188
+
1189
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1190
+ if attention_mask is None:
1191
+ attention_mask = input_ids.new_ones(input_shape)
1192
+
1193
+ # cut decoder_input_ids if past is used
1194
+ if past is not None:
1195
+ input_ids = input_ids[:, -1:]
1196
+
1197
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1198
+
1199
+ def _reorder_cache(self, past, beam_idx):
1200
+ reordered_past = ()
1201
+ for layer_past in past:
1202
+ reordered_past += (
1203
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1204
+ )
1205
+ return reordered_past
1206
+
1207
+
1208
+ class RoFormerClassificationHead(nn.Module):
1209
+ """Head for sentence-level classification tasks."""
1210
+
1211
+ def __init__(self, config):
1212
+ super().__init__()
1213
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1215
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1216
+
1217
+ self.config = config
1218
+
1219
+ def forward(self, features, **kwargs):
1220
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1221
+ x = self.dropout(x)
1222
+ x = self.dense(x)
1223
+ x = ACT2FN[self.config.hidden_act](x)
1224
+ x = self.dropout(x)
1225
+ x = self.out_proj(x)
1226
+ return x
1227
+
1228
+
1229
+ @add_start_docstrings(
1230
+ """
1231
+ RoFormer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1232
+ pooled output) e.g. for GLUE tasks.
1233
+ """,
1234
+ ROFORMER_START_DOCSTRING,
1235
+ )
1236
+ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
1237
+ def __init__(self, config):
1238
+ super().__init__(config)
1239
+ self.num_labels = config.num_labels
1240
+ self.roformer = RoFormerModel(config)
1241
+ self.classifier = RoFormerClassificationHead(config)
1242
+
1243
+ # Initialize weights and apply final processing
1244
+ self.post_init()
1245
+
1246
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1247
+ @add_code_sample_docstrings(
1248
+ processor_class=_TOKENIZER_FOR_DOC,
1249
+ checkpoint=_CHECKPOINT_FOR_DOC,
1250
+ output_type=SequenceClassifierOutput,
1251
+ config_class=_CONFIG_FOR_DOC,
1252
+ )
1253
+ def forward(
1254
+ self,
1255
+ input_ids: Optional[torch.LongTensor] = None,
1256
+ attention_mask: Optional[torch.FloatTensor] = None,
1257
+ token_type_ids: Optional[torch.LongTensor] = None,
1258
+ head_mask: Optional[torch.FloatTensor] = None,
1259
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1260
+ labels: Optional[torch.LongTensor] = None,
1261
+ output_attentions: Optional[bool] = None,
1262
+ output_hidden_states: Optional[bool] = None,
1263
+ return_dict: Optional[bool] = None,
1264
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor]]:
1265
+ r"""
1266
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1267
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1268
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1269
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1270
+ """
1271
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1272
+
1273
+ outputs = self.roformer(
1274
+ input_ids,
1275
+ attention_mask=attention_mask,
1276
+ token_type_ids=token_type_ids,
1277
+ head_mask=head_mask,
1278
+ inputs_embeds=inputs_embeds,
1279
+ output_attentions=output_attentions,
1280
+ output_hidden_states=output_hidden_states,
1281
+ return_dict=return_dict,
1282
+ )
1283
+
1284
+ sequence_output = outputs[0]
1285
+ logits = self.classifier(sequence_output)
1286
+
1287
+ loss = None
1288
+ if labels is not None:
1289
+ if self.config.problem_type is None:
1290
+ if self.num_labels == 1:
1291
+ self.config.problem_type = "regression"
1292
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1293
+ self.config.problem_type = "single_label_classification"
1294
+ else:
1295
+ self.config.problem_type = "multi_label_classification"
1296
+
1297
+ if self.config.problem_type == "regression":
1298
+ loss_fct = MSELoss()
1299
+ if self.num_labels == 1:
1300
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1301
+ else:
1302
+ loss = loss_fct(logits, labels)
1303
+ elif self.config.problem_type == "single_label_classification":
1304
+ loss_fct = CrossEntropyLoss()
1305
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1306
+ elif self.config.problem_type == "multi_label_classification":
1307
+ loss_fct = BCEWithLogitsLoss()
1308
+ loss = loss_fct(logits, labels)
1309
+ if not return_dict:
1310
+ output = (logits,) + outputs[1:]
1311
+ return ((loss,) + output) if loss is not None else output
1312
+
1313
+ return SequenceClassifierOutput(
1314
+ loss=loss,
1315
+ logits=logits,
1316
+ hidden_states=outputs.hidden_states,
1317
+ attentions=outputs.attentions,
1318
+ )
1319
+
1320
+
1321
+ @add_start_docstrings(
1322
+ """
1323
+ RoFormer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1324
+ softmax) e.g. for RocStories/SWAG tasks.
1325
+ """,
1326
+ ROFORMER_START_DOCSTRING,
1327
+ )
1328
+ class RoFormerForMultipleChoice(RoFormerPreTrainedModel):
1329
+ def __init__(self, config):
1330
+ super().__init__(config)
1331
+
1332
+ self.roformer = RoFormerModel(config)
1333
+ self.sequence_summary = SequenceSummary(config)
1334
+ self.classifier = nn.Linear(config.hidden_size, 1)
1335
+
1336
+ # Initialize weights and apply final processing
1337
+ self.post_init()
1338
+
1339
+ @add_start_docstrings_to_model_forward(
1340
+ ROFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1341
+ )
1342
+ @add_code_sample_docstrings(
1343
+ processor_class=_TOKENIZER_FOR_DOC,
1344
+ checkpoint=_CHECKPOINT_FOR_DOC,
1345
+ output_type=MultipleChoiceModelOutput,
1346
+ config_class=_CONFIG_FOR_DOC,
1347
+ )
1348
+ def forward(
1349
+ self,
1350
+ input_ids: Optional[torch.LongTensor] = None,
1351
+ attention_mask: Optional[torch.FloatTensor] = None,
1352
+ token_type_ids: Optional[torch.LongTensor] = None,
1353
+ head_mask: Optional[torch.FloatTensor] = None,
1354
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1355
+ labels: Optional[torch.LongTensor] = None,
1356
+ output_attentions: Optional[bool] = None,
1357
+ output_hidden_states: Optional[bool] = None,
1358
+ return_dict: Optional[bool] = None,
1359
+ ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor]]:
1360
+ r"""
1361
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1362
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1363
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1364
+ `input_ids` above)
1365
+ """
1366
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1367
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1368
+
1369
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1370
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1371
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1372
+
1373
+ inputs_embeds = (
1374
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1375
+ if inputs_embeds is not None
1376
+ else None
1377
+ )
1378
+
1379
+ outputs = self.roformer(
1380
+ input_ids,
1381
+ attention_mask=attention_mask,
1382
+ token_type_ids=token_type_ids,
1383
+ head_mask=head_mask,
1384
+ inputs_embeds=inputs_embeds,
1385
+ output_attentions=output_attentions,
1386
+ output_hidden_states=output_hidden_states,
1387
+ return_dict=return_dict,
1388
+ )
1389
+
1390
+ sequence_output = outputs[0]
1391
+
1392
+ pooled_output = self.sequence_summary(sequence_output)
1393
+ logits = self.classifier(pooled_output)
1394
+ reshaped_logits = logits.view(-1, num_choices)
1395
+
1396
+ loss = None
1397
+ if labels is not None:
1398
+ loss_fct = CrossEntropyLoss()
1399
+ loss = loss_fct(reshaped_logits, labels)
1400
+
1401
+ if not return_dict:
1402
+ output = (reshaped_logits,) + outputs[1:]
1403
+ return ((loss,) + output) if loss is not None else output
1404
+
1405
+ return MultipleChoiceModelOutput(
1406
+ loss=loss,
1407
+ logits=reshaped_logits,
1408
+ hidden_states=outputs.hidden_states,
1409
+ attentions=outputs.attentions,
1410
+ )
1411
+
1412
+
1413
+ @add_start_docstrings(
1414
+ """
1415
+ RoFormer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1416
+ Named-Entity-Recognition (NER) tasks.
1417
+ """,
1418
+ ROFORMER_START_DOCSTRING,
1419
+ )
1420
+ class RoFormerForTokenClassification(RoFormerPreTrainedModel):
1421
+ def __init__(self, config):
1422
+ super().__init__(config)
1423
+ self.num_labels = config.num_labels
1424
+
1425
+ self.roformer = RoFormerModel(config)
1426
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1427
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1428
+
1429
+ # Initialize weights and apply final processing
1430
+ self.post_init()
1431
+
1432
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1433
+ @add_code_sample_docstrings(
1434
+ processor_class=_TOKENIZER_FOR_DOC,
1435
+ checkpoint=_CHECKPOINT_FOR_DOC,
1436
+ output_type=TokenClassifierOutput,
1437
+ config_class=_CONFIG_FOR_DOC,
1438
+ )
1439
+ def forward(
1440
+ self,
1441
+ input_ids: Optional[torch.LongTensor] = None,
1442
+ attention_mask: Optional[torch.FloatTensor] = None,
1443
+ token_type_ids: Optional[torch.LongTensor] = None,
1444
+ head_mask: Optional[torch.FloatTensor] = None,
1445
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1446
+ labels: Optional[torch.LongTensor] = None,
1447
+ output_attentions: Optional[bool] = None,
1448
+ output_hidden_states: Optional[bool] = None,
1449
+ return_dict: Optional[bool] = None,
1450
+ ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor]]:
1451
+ r"""
1452
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1453
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1454
+ """
1455
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1456
+
1457
+ outputs = self.roformer(
1458
+ input_ids,
1459
+ attention_mask=attention_mask,
1460
+ token_type_ids=token_type_ids,
1461
+ head_mask=head_mask,
1462
+ inputs_embeds=inputs_embeds,
1463
+ output_attentions=output_attentions,
1464
+ output_hidden_states=output_hidden_states,
1465
+ return_dict=return_dict,
1466
+ )
1467
+
1468
+ sequence_output = outputs[0]
1469
+
1470
+ sequence_output = self.dropout(sequence_output)
1471
+ logits = self.classifier(sequence_output)
1472
+
1473
+ loss = None
1474
+ if labels is not None:
1475
+ loss_fct = CrossEntropyLoss()
1476
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1477
+
1478
+ if not return_dict:
1479
+ output = (logits,) + outputs[1:]
1480
+ return ((loss,) + output) if loss is not None else output
1481
+
1482
+ return TokenClassifierOutput(
1483
+ loss=loss,
1484
+ logits=logits,
1485
+ hidden_states=outputs.hidden_states,
1486
+ attentions=outputs.attentions,
1487
+ )
1488
+
1489
+
1490
+ @add_start_docstrings(
1491
+ """
1492
+ RoFormer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1493
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1494
+ """,
1495
+ ROFORMER_START_DOCSTRING,
1496
+ )
1497
+ class RoFormerForQuestionAnswering(RoFormerPreTrainedModel):
1498
+ def __init__(self, config):
1499
+ super().__init__(config)
1500
+
1501
+ config.num_labels = 2
1502
+ self.num_labels = config.num_labels
1503
+
1504
+ self.roformer = RoFormerModel(config)
1505
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1506
+
1507
+ # Initialize weights and apply final processing
1508
+ self.post_init()
1509
+
1510
+ @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1511
+ @add_code_sample_docstrings(
1512
+ processor_class=_TOKENIZER_FOR_DOC,
1513
+ checkpoint=_CHECKPOINT_FOR_DOC,
1514
+ output_type=QuestionAnsweringModelOutput,
1515
+ config_class=_CONFIG_FOR_DOC,
1516
+ )
1517
+ def forward(
1518
+ self,
1519
+ input_ids: Optional[torch.LongTensor] = None,
1520
+ attention_mask: Optional[torch.FloatTensor] = None,
1521
+ token_type_ids: Optional[torch.LongTensor] = None,
1522
+ head_mask: Optional[torch.FloatTensor] = None,
1523
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1524
+ start_positions: Optional[torch.LongTensor] = None,
1525
+ end_positions: Optional[torch.LongTensor] = None,
1526
+ output_attentions: Optional[bool] = None,
1527
+ output_hidden_states: Optional[bool] = None,
1528
+ return_dict: Optional[bool] = None,
1529
+ ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor]]:
1530
+ r"""
1531
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1532
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1533
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1534
+ are not taken into account for computing the loss.
1535
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1536
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1537
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1538
+ are not taken into account for computing the loss.
1539
+ """
1540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1541
+
1542
+ outputs = self.roformer(
1543
+ input_ids,
1544
+ attention_mask=attention_mask,
1545
+ token_type_ids=token_type_ids,
1546
+ head_mask=head_mask,
1547
+ inputs_embeds=inputs_embeds,
1548
+ output_attentions=output_attentions,
1549
+ output_hidden_states=output_hidden_states,
1550
+ return_dict=return_dict,
1551
+ )
1552
+
1553
+ sequence_output = outputs[0]
1554
+
1555
+ logits = self.qa_outputs(sequence_output)
1556
+ start_logits, end_logits = logits.split(1, dim=-1)
1557
+ start_logits = start_logits.squeeze(-1)
1558
+ end_logits = end_logits.squeeze(-1)
1559
+
1560
+ total_loss = None
1561
+ if start_positions is not None and end_positions is not None:
1562
+ # If we are on multi-GPU, split add a dimension
1563
+ if len(start_positions.size()) > 1:
1564
+ start_positions = start_positions.squeeze(-1)
1565
+ if len(end_positions.size()) > 1:
1566
+ end_positions = end_positions.squeeze(-1)
1567
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1568
+ ignored_index = start_logits.size(1)
1569
+ start_positions = start_positions.clamp(0, ignored_index)
1570
+ end_positions = end_positions.clamp(0, ignored_index)
1571
+
1572
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1573
+ start_loss = loss_fct(start_logits, start_positions)
1574
+ end_loss = loss_fct(end_logits, end_positions)
1575
+ total_loss = (start_loss + end_loss) / 2
1576
+
1577
+ if not return_dict:
1578
+ output = (start_logits, end_logits) + outputs[1:]
1579
+ return ((total_loss,) + output) if total_loss is not None else output
1580
+
1581
+ return QuestionAnsweringModelOutput(
1582
+ loss=total_loss,
1583
+ start_logits=start_logits,
1584
+ end_logits=end_logits,
1585
+ hidden_states=outputs.hidden_states,
1586
+ attentions=outputs.attentions,
1587
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae885a57f9cc3d9491bb414c0471026a7695381cb83c719a76f8f0094ad7b44a
3
+ size 342824309