ping yang commited on
Commit
2e08a92
1 Parent(s): 3fa8111

add albert and deberta

Browse files
Files changed (2) hide show
  1. modeling_albert.py +1363 -0
  2. modeling_deberta_v2.py +1617 -0
modeling_albert.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ALBERT model. """
16
+
17
+ import math
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
21
+
22
+ import torch
23
+ from packaging import version
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.file_utils import (
29
+ ModelOutput,
30
+ add_code_sample_docstrings,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ replace_return_docstrings,
34
+ )
35
+ from transformers.modeling_outputs import (
36
+ BaseModelOutput,
37
+ BaseModelOutputWithPooling,
38
+ MaskedLMOutput,
39
+ MultipleChoiceModelOutput,
40
+ QuestionAnsweringModelOutput,
41
+ SequenceClassifierOutput,
42
+ TokenClassifierOutput,
43
+ )
44
+ from transformers.modeling_utils import (
45
+ PreTrainedModel,
46
+ apply_chunking_to_forward,
47
+ find_pruneable_heads_and_indices,
48
+ prune_linear_layer,
49
+ )
50
+ from transformers.utils import logging
51
+ from transformers import AlbertConfig
52
+
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+ _CHECKPOINT_FOR_DOC = "albert-base-v2"
58
+ _CONFIG_FOR_DOC = "AlbertConfig"
59
+ _TOKENIZER_FOR_DOC = "AlbertTokenizer"
60
+
61
+
62
+ ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "albert-base-v1",
64
+ "albert-large-v1",
65
+ "albert-xlarge-v1",
66
+ "albert-xxlarge-v1",
67
+ "albert-base-v2",
68
+ "albert-large-v2",
69
+ "albert-xlarge-v2",
70
+ "albert-xxlarge-v2",
71
+ # See all ALBERT models at https://huggingface.co/models?filter=albert
72
+ ]
73
+
74
+
75
+ def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
76
+ """Load tf checkpoints in a pytorch model."""
77
+ try:
78
+ import re
79
+
80
+ import numpy as np
81
+ import tensorflow as tf
82
+ except ImportError:
83
+ logger.error(
84
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
85
+ "https://www.tensorflow.org/install/ for installation instructions."
86
+ )
87
+ raise
88
+ tf_path = os.path.abspath(tf_checkpoint_path)
89
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
90
+ # Load weights from TF model
91
+ init_vars = tf.train.list_variables(tf_path)
92
+ names = []
93
+ arrays = []
94
+ for name, shape in init_vars:
95
+ logger.info(f"Loading TF weight {name} with shape {shape}")
96
+ array = tf.train.load_variable(tf_path, name)
97
+ names.append(name)
98
+ arrays.append(array)
99
+
100
+ for name, array in zip(names, arrays):
101
+ print(name)
102
+
103
+ for name, array in zip(names, arrays):
104
+ original_name = name
105
+
106
+ # If saved from the TF HUB module
107
+ name = name.replace("module/", "")
108
+
109
+ # Renaming and simplifying
110
+ name = name.replace("ffn_1", "ffn")
111
+ name = name.replace("bert/", "albert/")
112
+ name = name.replace("attention_1", "attention")
113
+ name = name.replace("transform/", "")
114
+ name = name.replace("LayerNorm_1", "full_layer_layer_norm")
115
+ name = name.replace("LayerNorm", "attention/LayerNorm")
116
+ name = name.replace("transformer/", "")
117
+
118
+ # The feed forward layer had an 'intermediate' step which has been abstracted away
119
+ name = name.replace("intermediate/dense/", "")
120
+ name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
121
+
122
+ # ALBERT attention was split between self and output which have been abstracted away
123
+ name = name.replace("/output/", "/")
124
+ name = name.replace("/self/", "/")
125
+
126
+ # The pooler is a linear layer
127
+ name = name.replace("pooler/dense", "pooler")
128
+
129
+ # The classifier was simplified to predictions from cls/predictions
130
+ name = name.replace("cls/predictions", "predictions")
131
+ name = name.replace("predictions/attention", "predictions")
132
+
133
+ # Naming was changed to be more explicit
134
+ name = name.replace("embeddings/attention", "embeddings")
135
+ name = name.replace("inner_group_", "albert_layers/")
136
+ name = name.replace("group_", "albert_layer_groups/")
137
+
138
+ # Classifier
139
+ if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
140
+ name = "classifier/" + name
141
+
142
+ # No ALBERT model currently handles the next sentence prediction task
143
+ if "seq_relationship" in name:
144
+ name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
145
+ name = name.replace("weights", "weight")
146
+
147
+ name = name.split("/")
148
+
149
+ # Ignore the gradients applied by the LAMB/ADAM optimizers.
150
+ if (
151
+ "adam_m" in name
152
+ or "adam_v" in name
153
+ or "AdamWeightDecayOptimizer" in name
154
+ or "AdamWeightDecayOptimizer_1" in name
155
+ or "global_step" in name
156
+ ):
157
+ logger.info(f"Skipping {'/'.join(name)}")
158
+ continue
159
+
160
+ pointer = model
161
+ for m_name in name:
162
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
163
+ scope_names = re.split(r"_(\d+)", m_name)
164
+ else:
165
+ scope_names = [m_name]
166
+
167
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
168
+ pointer = getattr(pointer, "weight")
169
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
170
+ pointer = getattr(pointer, "bias")
171
+ elif scope_names[0] == "output_weights":
172
+ pointer = getattr(pointer, "weight")
173
+ elif scope_names[0] == "squad":
174
+ pointer = getattr(pointer, "classifier")
175
+ else:
176
+ try:
177
+ pointer = getattr(pointer, scope_names[0])
178
+ except AttributeError:
179
+ logger.info(f"Skipping {'/'.join(name)}")
180
+ continue
181
+ if len(scope_names) >= 2:
182
+ num = int(scope_names[1])
183
+ pointer = pointer[num]
184
+
185
+ if m_name[-11:] == "_embeddings":
186
+ pointer = getattr(pointer, "weight")
187
+ elif m_name == "kernel":
188
+ array = np.transpose(array)
189
+ try:
190
+ if pointer.shape != array.shape:
191
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
192
+ except AssertionError as e:
193
+ e.args += (pointer.shape, array.shape)
194
+ raise
195
+ print(f"Initialize PyTorch weight {name} from {original_name}")
196
+ pointer.data = torch.from_numpy(array)
197
+
198
+ return model
199
+
200
+
201
+ class AlbertEmbeddings(nn.Module):
202
+ """
203
+ Construct the embeddings from word, position and token_type embeddings.
204
+ """
205
+
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
209
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
210
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
211
+
212
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
213
+ # any TensorFlow checkpoint file
214
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
215
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
216
+
217
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
218
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
219
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
220
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
221
+ self.register_buffer(
222
+ "token_type_ids",
223
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
224
+ persistent=False,
225
+ )
226
+
227
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
228
+ def forward(
229
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
230
+ ):
231
+ if input_ids is not None:
232
+ input_shape = input_ids.size()
233
+ else:
234
+ input_shape = inputs_embeds.size()[:-1]
235
+
236
+ seq_length = input_shape[1]
237
+
238
+ if position_ids is None:
239
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
240
+
241
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
242
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
243
+ # issue #5664
244
+ if token_type_ids is None:
245
+ if hasattr(self, "token_type_ids"):
246
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
247
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
248
+ token_type_ids = buffered_token_type_ids_expanded
249
+ else:
250
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
251
+
252
+ if inputs_embeds is None:
253
+ inputs_embeds = self.word_embeddings(input_ids)
254
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
255
+
256
+ embeddings = inputs_embeds + token_type_embeddings
257
+ if self.position_embedding_type == "absolute":
258
+ position_embeddings = self.position_embeddings(position_ids)
259
+ embeddings += position_embeddings
260
+ embeddings = self.LayerNorm(embeddings)
261
+ embeddings = self.dropout(embeddings)
262
+ return embeddings
263
+
264
+
265
+ class AlbertAttention(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
269
+ raise ValueError(
270
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
271
+ f"heads ({config.num_attention_heads}"
272
+ )
273
+
274
+ self.num_attention_heads = config.num_attention_heads
275
+ self.hidden_size = config.hidden_size
276
+ self.attention_head_size = config.hidden_size // config.num_attention_heads
277
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
278
+
279
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
280
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
281
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
282
+
283
+ self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
284
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
285
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
286
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
287
+ self.pruned_heads = set()
288
+
289
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
290
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
291
+ self.max_position_embeddings = config.max_position_embeddings
292
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
293
+
294
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
295
+ def transpose_for_scores(self, x):
296
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
297
+ x = x.view(*new_x_shape)
298
+ return x.permute(0, 2, 1, 3)
299
+
300
+ def prune_heads(self, heads):
301
+ if len(heads) == 0:
302
+ return
303
+ heads, index = find_pruneable_heads_and_indices(
304
+ heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
305
+ )
306
+
307
+ # Prune linear layers
308
+ self.query = prune_linear_layer(self.query, index)
309
+ self.key = prune_linear_layer(self.key, index)
310
+ self.value = prune_linear_layer(self.value, index)
311
+ self.dense = prune_linear_layer(self.dense, index, dim=1)
312
+
313
+ # Update hyper params and store pruned heads
314
+ self.num_attention_heads = self.num_attention_heads - len(heads)
315
+ self.all_head_size = self.attention_head_size * self.num_attention_heads
316
+ self.pruned_heads = self.pruned_heads.union(heads)
317
+
318
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
319
+ mixed_query_layer = self.query(hidden_states)
320
+ mixed_key_layer = self.key(hidden_states)
321
+ mixed_value_layer = self.value(hidden_states)
322
+
323
+ query_layer = self.transpose_for_scores(mixed_query_layer)
324
+ key_layer = self.transpose_for_scores(mixed_key_layer)
325
+ value_layer = self.transpose_for_scores(mixed_value_layer)
326
+
327
+ # Take the dot product between "query" and "key" to get the raw attention scores.
328
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
329
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
330
+
331
+ if attention_mask is not None:
332
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
333
+ attention_scores = attention_scores + attention_mask
334
+
335
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
336
+ seq_length = hidden_states.size()[1]
337
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
338
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
339
+ distance = position_ids_l - position_ids_r
340
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
341
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
342
+
343
+ if self.position_embedding_type == "relative_key":
344
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
345
+ attention_scores = attention_scores + relative_position_scores
346
+ elif self.position_embedding_type == "relative_key_query":
347
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
348
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
349
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
350
+
351
+ # Normalize the attention scores to probabilities.
352
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
353
+
354
+ # This is actually dropping out entire tokens to attend to, which might
355
+ # seem a bit unusual, but is taken from the original Transformer paper.
356
+ attention_probs = self.attention_dropout(attention_probs)
357
+
358
+ # Mask heads if we want to
359
+ if head_mask is not None:
360
+ attention_probs = attention_probs * head_mask
361
+
362
+ context_layer = torch.matmul(attention_probs, value_layer)
363
+ context_layer = context_layer.transpose(2, 1).flatten(2)
364
+
365
+ projected_context_layer = self.dense(context_layer)
366
+ projected_context_layer_dropout = self.output_dropout(projected_context_layer)
367
+ layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
368
+ return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
369
+
370
+
371
+ class AlbertLayer(nn.Module):
372
+ def __init__(self, config):
373
+ super().__init__()
374
+
375
+ self.config = config
376
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
377
+ self.seq_len_dim = 1
378
+ self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
379
+ self.attention = AlbertAttention(config)
380
+ self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
381
+ self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
382
+ self.activation = ACT2FN[config.hidden_act]
383
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
384
+
385
+ def forward(
386
+ self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
387
+ ):
388
+ attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
389
+
390
+ ffn_output = apply_chunking_to_forward(
391
+ self.ff_chunk,
392
+ self.chunk_size_feed_forward,
393
+ self.seq_len_dim,
394
+ attention_output[0],
395
+ )
396
+ hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
397
+
398
+ return (hidden_states,) + attention_output[1:] # add attentions if we output them
399
+
400
+ def ff_chunk(self, attention_output):
401
+ ffn_output = self.ffn(attention_output)
402
+ ffn_output = self.activation(ffn_output)
403
+ ffn_output = self.ffn_output(ffn_output)
404
+ return ffn_output
405
+
406
+
407
+ class AlbertLayerGroup(nn.Module):
408
+ def __init__(self, config):
409
+ super().__init__()
410
+
411
+ self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
412
+
413
+ def forward(
414
+ self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, output_hidden_states=False
415
+ ):
416
+ layer_hidden_states = ()
417
+ layer_attentions = ()
418
+
419
+ for layer_index, albert_layer in enumerate(self.albert_layers):
420
+ layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
421
+ hidden_states = layer_output[0]
422
+
423
+ if output_attentions:
424
+ layer_attentions = layer_attentions + (layer_output[1],)
425
+
426
+ if output_hidden_states:
427
+ layer_hidden_states = layer_hidden_states + (hidden_states,)
428
+
429
+ outputs = (hidden_states,)
430
+ if output_hidden_states:
431
+ outputs = outputs + (layer_hidden_states,)
432
+ if output_attentions:
433
+ outputs = outputs + (layer_attentions,)
434
+ return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
435
+
436
+
437
+ class AlbertTransformer(nn.Module):
438
+ def __init__(self, config):
439
+ super().__init__()
440
+
441
+ self.config = config
442
+ self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
443
+ self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
444
+
445
+ def forward(
446
+ self,
447
+ hidden_states,
448
+ attention_mask=None,
449
+ head_mask=None,
450
+ output_attentions=False,
451
+ output_hidden_states=False,
452
+ return_dict=True,
453
+ ):
454
+ hidden_states = self.embedding_hidden_mapping_in(hidden_states)
455
+
456
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
457
+ all_attentions = () if output_attentions else None
458
+
459
+ head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
460
+
461
+ for i in range(self.config.num_hidden_layers):
462
+ # Number of layers in a hidden group
463
+ layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
464
+
465
+ # Index of the hidden group
466
+ group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
467
+
468
+ layer_group_output = self.albert_layer_groups[group_idx](
469
+ hidden_states,
470
+ attention_mask,
471
+ head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
472
+ output_attentions,
473
+ output_hidden_states,
474
+ )
475
+ hidden_states = layer_group_output[0]
476
+
477
+ if output_attentions:
478
+ all_attentions = all_attentions + layer_group_output[-1]
479
+
480
+ if output_hidden_states:
481
+ all_hidden_states = all_hidden_states + (hidden_states,)
482
+
483
+ if not return_dict:
484
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
485
+ return BaseModelOutput(
486
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
487
+ )
488
+
489
+
490
+ class AlbertPreTrainedModel(PreTrainedModel):
491
+ """
492
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
493
+ models.
494
+ """
495
+
496
+ config_class = AlbertConfig
497
+ load_tf_weights = load_tf_weights_in_albert
498
+ base_model_prefix = "albert"
499
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
500
+
501
+ def _init_weights(self, module):
502
+ """Initialize the weights."""
503
+ if isinstance(module, nn.Linear):
504
+ # Slightly different from the TF version which uses truncated_normal for initialization
505
+ # cf https://github.com/pytorch/pytorch/pull/5617
506
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
507
+ if module.bias is not None:
508
+ module.bias.data.zero_()
509
+ elif isinstance(module, nn.Embedding):
510
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
511
+ if module.padding_idx is not None:
512
+ module.weight.data[module.padding_idx].zero_()
513
+ elif isinstance(module, nn.LayerNorm):
514
+ module.bias.data.zero_()
515
+ module.weight.data.fill_(1.0)
516
+
517
+
518
+ @dataclass
519
+ class AlbertForPreTrainingOutput(ModelOutput):
520
+ """
521
+ Output type of :class:`~transformers.AlbertForPreTraining`.
522
+
523
+ Args:
524
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
525
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
526
+ (classification) loss.
527
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
528
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
529
+ sop_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
530
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
531
+ before SoftMax).
532
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
533
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
534
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
535
+
536
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
537
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
538
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
539
+ sequence_length, sequence_length)`.
540
+
541
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
542
+ heads.
543
+ """
544
+
545
+ loss: Optional[torch.FloatTensor] = None
546
+ prediction_logits: torch.FloatTensor = None
547
+ sop_logits: torch.FloatTensor = None
548
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
549
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
550
+
551
+
552
+ ALBERT_START_DOCSTRING = r"""
553
+
554
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
555
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
556
+ pruning heads etc.)
557
+
558
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
559
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
560
+ general usage and behavior.
561
+
562
+ Args:
563
+ config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
564
+ Initializing with a config file does not load the weights associated with the model, only the
565
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
566
+ weights.
567
+ """
568
+
569
+ ALBERT_INPUTS_DOCSTRING = r"""
570
+ Args:
571
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
572
+ Indices of input sequence tokens in the vocabulary.
573
+
574
+ Indices can be obtained using :class:`~transformers.AlbertTokenizer`. See
575
+ :meth:`transformers.PreTrainedTokenizer.__call__` and :meth:`transformers.PreTrainedTokenizer.encode` for
576
+ details.
577
+
578
+ `What are input IDs? <../glossary.html#input-ids>`__
579
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
580
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
581
+
582
+ - 1 for tokens that are **not masked**,
583
+ - 0 for tokens that are **masked**.
584
+
585
+ `What are attention masks? <../glossary.html#attention-mask>`__
586
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
587
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
588
+ 1]``:
589
+
590
+ - 0 corresponds to a `sentence A` token,
591
+ - 1 corresponds to a `sentence B` token.
592
+
593
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
594
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
595
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
596
+ config.max_position_embeddings - 1]``.
597
+
598
+ `What are position IDs? <../glossary.html#position-ids>`_
599
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
600
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
601
+
602
+ - 1 indicates the head is **not masked**,
603
+ - 0 indicates the head is **masked**.
604
+
605
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
606
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
607
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
608
+ vectors than the model's internal embedding lookup matrix.
609
+ output_attentions (:obj:`bool`, `optional`):
610
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
611
+ tensors for more detail.
612
+ output_hidden_states (:obj:`bool`, `optional`):
613
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
614
+ more detail.
615
+ return_dict (:obj:`bool`, `optional`):
616
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
617
+ """
618
+
619
+
620
+ @add_start_docstrings(
621
+ "The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
622
+ ALBERT_START_DOCSTRING,
623
+ )
624
+ class AlbertModel(AlbertPreTrainedModel):
625
+
626
+ config_class = AlbertConfig
627
+ base_model_prefix = "albert"
628
+
629
+ def __init__(self, config, add_pooling_layer=True):
630
+ super().__init__(config)
631
+
632
+ self.config = config
633
+ self.embeddings = AlbertEmbeddings(config)
634
+ self.encoder = AlbertTransformer(config)
635
+ if add_pooling_layer:
636
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
637
+ self.pooler_activation = nn.Tanh()
638
+ else:
639
+ self.pooler = None
640
+ self.pooler_activation = None
641
+
642
+ self.init_weights()
643
+
644
+ def get_input_embeddings(self):
645
+ return self.embeddings.word_embeddings
646
+
647
+ def set_input_embeddings(self, value):
648
+ self.embeddings.word_embeddings = value
649
+
650
+ def _prune_heads(self, heads_to_prune):
651
+ """
652
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
653
+ a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
654
+ model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
655
+
656
+ These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
657
+ while [2,3] correspond to the two inner groups of the second hidden layer.
658
+
659
+ Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
660
+ information about head pruning
661
+ """
662
+ for layer, heads in heads_to_prune.items():
663
+ group_idx = int(layer / self.config.inner_group_num)
664
+ inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
665
+ self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
666
+
667
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
668
+ @add_code_sample_docstrings(
669
+ processor_class=_TOKENIZER_FOR_DOC,
670
+ checkpoint=_CHECKPOINT_FOR_DOC,
671
+ output_type=BaseModelOutputWithPooling,
672
+ config_class=_CONFIG_FOR_DOC,
673
+ )
674
+ def forward(
675
+ self,
676
+ input_ids=None,
677
+ attention_mask=None,
678
+ token_type_ids=None,
679
+ position_ids=None,
680
+ head_mask=None,
681
+ inputs_embeds=None,
682
+ output_attentions=None,
683
+ output_hidden_states=None,
684
+ return_dict=None,
685
+ ):
686
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
687
+ output_hidden_states = (
688
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
689
+ )
690
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
691
+
692
+ if input_ids is not None and inputs_embeds is not None:
693
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
694
+ elif input_ids is not None:
695
+ input_shape = input_ids.size()
696
+ elif inputs_embeds is not None:
697
+ input_shape = inputs_embeds.size()[:-1]
698
+ else:
699
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
700
+
701
+ batch_size, seq_length = input_shape
702
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
703
+
704
+ if attention_mask is None:
705
+ attention_mask = torch.ones(input_shape, device=device)
706
+ if token_type_ids is None:
707
+ if hasattr(self.embeddings, "token_type_ids"):
708
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
709
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
710
+ token_type_ids = buffered_token_type_ids_expanded
711
+ else:
712
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
713
+
714
+ # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) #
715
+ extended_attention_mask = attention_mask[:, None, :, :]
716
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
717
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
718
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
719
+
720
+ embedding_output = self.embeddings(
721
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
722
+ )
723
+ encoder_outputs = self.encoder(
724
+ embedding_output,
725
+ extended_attention_mask,
726
+ head_mask=head_mask,
727
+ output_attentions=output_attentions,
728
+ output_hidden_states=output_hidden_states,
729
+ return_dict=return_dict,
730
+ )
731
+
732
+ sequence_output = encoder_outputs[0]
733
+
734
+ pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
735
+
736
+ if not return_dict:
737
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
738
+
739
+ return BaseModelOutputWithPooling(
740
+ last_hidden_state=sequence_output,
741
+ pooler_output=pooled_output,
742
+ hidden_states=encoder_outputs.hidden_states,
743
+ attentions=encoder_outputs.attentions,
744
+ )
745
+
746
+
747
+ @add_start_docstrings(
748
+ """
749
+ Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
750
+ `sentence order prediction (classification)` head.
751
+ """,
752
+ ALBERT_START_DOCSTRING,
753
+ )
754
+ class AlbertForPreTraining(AlbertPreTrainedModel):
755
+ def __init__(self, config):
756
+ super().__init__(config)
757
+
758
+ self.albert = AlbertModel(config)
759
+ self.predictions = AlbertMLMHead(config)
760
+ self.sop_classifier = AlbertSOPHead(config)
761
+
762
+ self.init_weights()
763
+
764
+ def get_output_embeddings(self):
765
+ return self.predictions.decoder
766
+
767
+ def set_output_embeddings(self, new_embeddings):
768
+ self.predictions.decoder = new_embeddings
769
+
770
+ def get_input_embeddings(self):
771
+ return self.albert.embeddings.word_embeddings
772
+
773
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
774
+ @replace_return_docstrings(output_type=AlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
775
+ def forward(
776
+ self,
777
+ input_ids=None,
778
+ attention_mask=None,
779
+ token_type_ids=None,
780
+ position_ids=None,
781
+ head_mask=None,
782
+ inputs_embeds=None,
783
+ labels=None,
784
+ sentence_order_label=None,
785
+ output_attentions=None,
786
+ output_hidden_states=None,
787
+ return_dict=None,
788
+ ):
789
+ r"""
790
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`):
791
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
792
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
793
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
794
+ sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
795
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
796
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``. ``0`` indicates original order (sequence
797
+ A, then sequence B), ``1`` indicates switched order (sequence B, then sequence A).
798
+
799
+ Returns:
800
+
801
+ Example::
802
+
803
+ >>> from transformers import AlbertTokenizer, AlbertForPreTraining
804
+ >>> import torch
805
+
806
+ >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
807
+ >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2')
808
+
809
+ >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
810
+ >>> outputs = model(input_ids)
811
+
812
+ >>> prediction_logits = outputs.prediction_logits
813
+ >>> sop_logits = outputs.sop_logits
814
+
815
+ """
816
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
817
+
818
+ outputs = self.albert(
819
+ input_ids,
820
+ attention_mask=attention_mask,
821
+ token_type_ids=token_type_ids,
822
+ position_ids=position_ids,
823
+ head_mask=head_mask,
824
+ inputs_embeds=inputs_embeds,
825
+ output_attentions=output_attentions,
826
+ output_hidden_states=output_hidden_states,
827
+ return_dict=return_dict,
828
+ )
829
+
830
+ sequence_output, pooled_output = outputs[:2]
831
+
832
+ prediction_scores = self.predictions(sequence_output)
833
+ sop_scores = self.sop_classifier(pooled_output)
834
+
835
+ total_loss = None
836
+ if labels is not None and sentence_order_label is not None:
837
+ loss_fct = CrossEntropyLoss()
838
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
839
+ sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
840
+ total_loss = masked_lm_loss + sentence_order_loss
841
+
842
+ if not return_dict:
843
+ output = (prediction_scores, sop_scores) + outputs[2:]
844
+ return ((total_loss,) + output) if total_loss is not None else output
845
+
846
+ return AlbertForPreTrainingOutput(
847
+ loss=total_loss,
848
+ prediction_logits=prediction_scores,
849
+ sop_logits=sop_scores,
850
+ hidden_states=outputs.hidden_states,
851
+ attentions=outputs.attentions,
852
+ )
853
+
854
+
855
+ class AlbertMLMHead(nn.Module):
856
+ def __init__(self, config):
857
+ super().__init__()
858
+
859
+ self.LayerNorm = nn.LayerNorm(config.embedding_size)
860
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
861
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
862
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
863
+ self.activation = ACT2FN[config.hidden_act]
864
+ self.decoder.bias = self.bias
865
+
866
+ def forward(self, hidden_states):
867
+ hidden_states = self.dense(hidden_states)
868
+ hidden_states = self.activation(hidden_states)
869
+ hidden_states = self.LayerNorm(hidden_states)
870
+ hidden_states = self.decoder(hidden_states)
871
+
872
+ prediction_scores = hidden_states
873
+
874
+ return prediction_scores
875
+
876
+ def _tie_weights(self):
877
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
878
+ self.bias = self.decoder.bias
879
+
880
+
881
+ class AlbertSOPHead(nn.Module):
882
+ def __init__(self, config):
883
+ super().__init__()
884
+
885
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
886
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
887
+
888
+ def forward(self, pooled_output):
889
+ dropout_pooled_output = self.dropout(pooled_output)
890
+ logits = self.classifier(dropout_pooled_output)
891
+ return logits
892
+
893
+
894
+ @add_start_docstrings(
895
+ "Albert Model with a `language modeling` head on top.",
896
+ ALBERT_START_DOCSTRING,
897
+ )
898
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
899
+
900
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
901
+
902
+ def __init__(self, config):
903
+ super().__init__(config)
904
+
905
+ self.albert = AlbertModel(config, add_pooling_layer=False)
906
+ self.predictions = AlbertMLMHead(config)
907
+
908
+ self.init_weights()
909
+
910
+ def get_output_embeddings(self):
911
+ return self.predictions.decoder
912
+
913
+ def set_output_embeddings(self, new_embeddings):
914
+ self.predictions.decoder = new_embeddings
915
+
916
+ def get_input_embeddings(self):
917
+ return self.albert.embeddings.word_embeddings
918
+
919
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
920
+ @add_code_sample_docstrings(
921
+ processor_class=_TOKENIZER_FOR_DOC,
922
+ checkpoint=_CHECKPOINT_FOR_DOC,
923
+ output_type=MaskedLMOutput,
924
+ config_class=_CONFIG_FOR_DOC,
925
+ )
926
+ def forward(
927
+ self,
928
+ input_ids=None,
929
+ attention_mask=None,
930
+ token_type_ids=None,
931
+ position_ids=None,
932
+ head_mask=None,
933
+ inputs_embeds=None,
934
+ labels=None,
935
+ output_attentions=None,
936
+ output_hidden_states=None,
937
+ return_dict=None,
938
+ ):
939
+ r"""
940
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
941
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
942
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
943
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
944
+ """
945
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
946
+
947
+ outputs = self.albert(
948
+ input_ids=input_ids,
949
+ attention_mask=attention_mask,
950
+ token_type_ids=token_type_ids,
951
+ position_ids=position_ids,
952
+ head_mask=head_mask,
953
+ inputs_embeds=inputs_embeds,
954
+ output_attentions=output_attentions,
955
+ output_hidden_states=output_hidden_states,
956
+ return_dict=return_dict,
957
+ )
958
+ sequence_outputs = outputs[0]
959
+
960
+ prediction_scores = self.predictions(sequence_outputs)
961
+
962
+ masked_lm_loss = None
963
+ if labels is not None:
964
+ loss_fct = CrossEntropyLoss()
965
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
966
+
967
+ if not return_dict:
968
+ output = (prediction_scores,) + outputs[2:]
969
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
970
+
971
+ return MaskedLMOutput(
972
+ loss=masked_lm_loss,
973
+ logits=prediction_scores,
974
+ hidden_states=outputs.hidden_states,
975
+ attentions=outputs.attentions,
976
+ )
977
+
978
+
979
+ @add_start_docstrings(
980
+ """
981
+ Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
982
+ output) e.g. for GLUE tasks.
983
+ """,
984
+ ALBERT_START_DOCSTRING,
985
+ )
986
+ class AlbertForSequenceClassification(AlbertPreTrainedModel):
987
+ def __init__(self, config):
988
+ super().__init__(config)
989
+ self.num_labels = config.num_labels
990
+ self.config = config
991
+
992
+ self.albert = AlbertModel(config)
993
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
994
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
995
+
996
+ self.init_weights()
997
+
998
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
999
+ @add_code_sample_docstrings(
1000
+ processor_class=_TOKENIZER_FOR_DOC,
1001
+ checkpoint=_CHECKPOINT_FOR_DOC,
1002
+ output_type=SequenceClassifierOutput,
1003
+ config_class=_CONFIG_FOR_DOC,
1004
+ )
1005
+ def forward(
1006
+ self,
1007
+ input_ids=None,
1008
+ attention_mask=None,
1009
+ token_type_ids=None,
1010
+ position_ids=None,
1011
+ head_mask=None,
1012
+ inputs_embeds=None,
1013
+ labels=None,
1014
+ output_attentions=None,
1015
+ output_hidden_states=None,
1016
+ return_dict=None,
1017
+ ):
1018
+ r"""
1019
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1020
+ Labels for computing the sequence classification/regression loss. Indices should be in ``[0, ...,
1021
+ config.num_labels - 1]``. If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1022
+ If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1023
+ """
1024
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1025
+
1026
+ outputs = self.albert(
1027
+ input_ids=input_ids,
1028
+ attention_mask=attention_mask,
1029
+ token_type_ids=token_type_ids,
1030
+ position_ids=position_ids,
1031
+ head_mask=head_mask,
1032
+ inputs_embeds=inputs_embeds,
1033
+ output_attentions=output_attentions,
1034
+ output_hidden_states=output_hidden_states,
1035
+ return_dict=return_dict,
1036
+ )
1037
+
1038
+ pooled_output = outputs[1]
1039
+
1040
+ pooled_output = self.dropout(pooled_output)
1041
+ logits = self.classifier(pooled_output)
1042
+
1043
+ loss = None
1044
+ if labels is not None:
1045
+ if self.config.problem_type is None:
1046
+ if self.num_labels == 1:
1047
+ self.config.problem_type = "regression"
1048
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1049
+ self.config.problem_type = "single_label_classification"
1050
+ else:
1051
+ self.config.problem_type = "multi_label_classification"
1052
+
1053
+ if self.config.problem_type == "regression":
1054
+ loss_fct = MSELoss()
1055
+ if self.num_labels == 1:
1056
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1057
+ else:
1058
+ loss = loss_fct(logits, labels)
1059
+ elif self.config.problem_type == "single_label_classification":
1060
+ loss_fct = CrossEntropyLoss()
1061
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1062
+ elif self.config.problem_type == "multi_label_classification":
1063
+ loss_fct = BCEWithLogitsLoss()
1064
+ loss = loss_fct(logits, labels)
1065
+
1066
+ if not return_dict:
1067
+ output = (logits,) + outputs[2:]
1068
+ return ((loss,) + output) if loss is not None else output
1069
+
1070
+ return SequenceClassifierOutput(
1071
+ loss=loss,
1072
+ logits=logits,
1073
+ hidden_states=outputs.hidden_states,
1074
+ attentions=outputs.attentions,
1075
+ )
1076
+
1077
+
1078
+ @add_start_docstrings(
1079
+ """
1080
+ Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1081
+ Named-Entity-Recognition (NER) tasks.
1082
+ """,
1083
+ ALBERT_START_DOCSTRING,
1084
+ )
1085
+ class AlbertForTokenClassification(AlbertPreTrainedModel):
1086
+
1087
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1088
+
1089
+ def __init__(self, config):
1090
+ super().__init__(config)
1091
+ self.num_labels = config.num_labels
1092
+
1093
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1094
+ classifier_dropout_prob = (
1095
+ config.classifier_dropout_prob
1096
+ if config.classifier_dropout_prob is not None
1097
+ else config.hidden_dropout_prob
1098
+ )
1099
+ self.dropout = nn.Dropout(classifier_dropout_prob)
1100
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
1101
+
1102
+ self.init_weights()
1103
+
1104
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1105
+ @add_code_sample_docstrings(
1106
+ processor_class=_TOKENIZER_FOR_DOC,
1107
+ checkpoint=_CHECKPOINT_FOR_DOC,
1108
+ output_type=TokenClassifierOutput,
1109
+ config_class=_CONFIG_FOR_DOC,
1110
+ )
1111
+ def forward(
1112
+ self,
1113
+ input_ids=None,
1114
+ attention_mask=None,
1115
+ token_type_ids=None,
1116
+ position_ids=None,
1117
+ head_mask=None,
1118
+ inputs_embeds=None,
1119
+ labels=None,
1120
+ output_attentions=None,
1121
+ output_hidden_states=None,
1122
+ return_dict=None,
1123
+ ):
1124
+ r"""
1125
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1126
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1127
+ 1]``.
1128
+ """
1129
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1130
+
1131
+ outputs = self.albert(
1132
+ input_ids,
1133
+ attention_mask=attention_mask,
1134
+ token_type_ids=token_type_ids,
1135
+ position_ids=position_ids,
1136
+ head_mask=head_mask,
1137
+ inputs_embeds=inputs_embeds,
1138
+ output_attentions=output_attentions,
1139
+ output_hidden_states=output_hidden_states,
1140
+ return_dict=return_dict,
1141
+ )
1142
+
1143
+ sequence_output = outputs[0]
1144
+
1145
+ sequence_output = self.dropout(sequence_output)
1146
+ logits = self.classifier(sequence_output)
1147
+
1148
+ loss = None
1149
+ if labels is not None:
1150
+ loss_fct = CrossEntropyLoss()
1151
+ # Only keep active parts of the loss
1152
+ if attention_mask is not None:
1153
+ active_loss = attention_mask.view(-1) == 1
1154
+ active_logits = logits.view(-1, self.num_labels)
1155
+ active_labels = torch.where(
1156
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1157
+ )
1158
+ loss = loss_fct(active_logits, active_labels)
1159
+ else:
1160
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1161
+
1162
+ if not return_dict:
1163
+ output = (logits,) + outputs[2:]
1164
+ return ((loss,) + output) if loss is not None else output
1165
+
1166
+ return TokenClassifierOutput(
1167
+ loss=loss,
1168
+ logits=logits,
1169
+ hidden_states=outputs.hidden_states,
1170
+ attentions=outputs.attentions,
1171
+ )
1172
+
1173
+
1174
+ @add_start_docstrings(
1175
+ """
1176
+ Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1177
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1178
+ """,
1179
+ ALBERT_START_DOCSTRING,
1180
+ )
1181
+ class AlbertForQuestionAnswering(AlbertPreTrainedModel):
1182
+
1183
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1184
+
1185
+ def __init__(self, config):
1186
+ super().__init__(config)
1187
+ self.num_labels = config.num_labels
1188
+
1189
+ self.albert = AlbertModel(config, add_pooling_layer=False)
1190
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1191
+
1192
+ self.init_weights()
1193
+
1194
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1195
+ @add_code_sample_docstrings(
1196
+ processor_class=_TOKENIZER_FOR_DOC,
1197
+ checkpoint=_CHECKPOINT_FOR_DOC,
1198
+ output_type=QuestionAnsweringModelOutput,
1199
+ config_class=_CONFIG_FOR_DOC,
1200
+ )
1201
+ def forward(
1202
+ self,
1203
+ input_ids=None,
1204
+ attention_mask=None,
1205
+ token_type_ids=None,
1206
+ position_ids=None,
1207
+ head_mask=None,
1208
+ inputs_embeds=None,
1209
+ start_positions=None,
1210
+ end_positions=None,
1211
+ output_attentions=None,
1212
+ output_hidden_states=None,
1213
+ return_dict=None,
1214
+ ):
1215
+ r"""
1216
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1217
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1218
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1219
+ sequence are not taken into account for computing the loss.
1220
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1221
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1222
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1223
+ sequence are not taken into account for computing the loss.
1224
+ """
1225
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1226
+
1227
+ outputs = self.albert(
1228
+ input_ids=input_ids,
1229
+ attention_mask=attention_mask,
1230
+ token_type_ids=token_type_ids,
1231
+ position_ids=position_ids,
1232
+ head_mask=head_mask,
1233
+ inputs_embeds=inputs_embeds,
1234
+ output_attentions=output_attentions,
1235
+ output_hidden_states=output_hidden_states,
1236
+ return_dict=return_dict,
1237
+ )
1238
+
1239
+ sequence_output = outputs[0]
1240
+
1241
+ logits = self.qa_outputs(sequence_output)
1242
+ start_logits, end_logits = logits.split(1, dim=-1)
1243
+ start_logits = start_logits.squeeze(-1).contiguous()
1244
+ end_logits = end_logits.squeeze(-1).contiguous()
1245
+
1246
+ total_loss = None
1247
+ if start_positions is not None and end_positions is not None:
1248
+ # If we are on multi-GPU, split add a dimension
1249
+ if len(start_positions.size()) > 1:
1250
+ start_positions = start_positions.squeeze(-1)
1251
+ if len(end_positions.size()) > 1:
1252
+ end_positions = end_positions.squeeze(-1)
1253
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1254
+ ignored_index = start_logits.size(1)
1255
+ start_positions = start_positions.clamp(0, ignored_index)
1256
+ end_positions = end_positions.clamp(0, ignored_index)
1257
+
1258
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1259
+ start_loss = loss_fct(start_logits, start_positions)
1260
+ end_loss = loss_fct(end_logits, end_positions)
1261
+ total_loss = (start_loss + end_loss) / 2
1262
+
1263
+ if not return_dict:
1264
+ output = (start_logits, end_logits) + outputs[2:]
1265
+ return ((total_loss,) + output) if total_loss is not None else output
1266
+
1267
+ return QuestionAnsweringModelOutput(
1268
+ loss=total_loss,
1269
+ start_logits=start_logits,
1270
+ end_logits=end_logits,
1271
+ hidden_states=outputs.hidden_states,
1272
+ attentions=outputs.attentions,
1273
+ )
1274
+
1275
+
1276
+ @add_start_docstrings(
1277
+ """
1278
+ Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1279
+ softmax) e.g. for RocStories/SWAG tasks.
1280
+ """,
1281
+ ALBERT_START_DOCSTRING,
1282
+ )
1283
+ class AlbertForMultipleChoice(AlbertPreTrainedModel):
1284
+ def __init__(self, config):
1285
+ super().__init__(config)
1286
+
1287
+ self.albert = AlbertModel(config)
1288
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
1289
+ self.classifier = nn.Linear(config.hidden_size, 1)
1290
+
1291
+ self.init_weights()
1292
+
1293
+ @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1294
+ @add_code_sample_docstrings(
1295
+ processor_class=_TOKENIZER_FOR_DOC,
1296
+ checkpoint=_CHECKPOINT_FOR_DOC,
1297
+ output_type=MultipleChoiceModelOutput,
1298
+ config_class=_CONFIG_FOR_DOC,
1299
+ )
1300
+ def forward(
1301
+ self,
1302
+ input_ids=None,
1303
+ attention_mask=None,
1304
+ token_type_ids=None,
1305
+ position_ids=None,
1306
+ head_mask=None,
1307
+ inputs_embeds=None,
1308
+ labels=None,
1309
+ output_attentions=None,
1310
+ output_hidden_states=None,
1311
+ return_dict=None,
1312
+ ):
1313
+ r"""
1314
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1315
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1316
+ num_choices-1]`` where `num_choices` is the size of the second dimension of the input tensors. (see
1317
+ `input_ids` above)
1318
+ """
1319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1320
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1321
+
1322
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1323
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1324
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1325
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1326
+ inputs_embeds = (
1327
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1328
+ if inputs_embeds is not None
1329
+ else None
1330
+ )
1331
+ outputs = self.albert(
1332
+ input_ids,
1333
+ attention_mask=attention_mask,
1334
+ token_type_ids=token_type_ids,
1335
+ position_ids=position_ids,
1336
+ head_mask=head_mask,
1337
+ inputs_embeds=inputs_embeds,
1338
+ output_attentions=output_attentions,
1339
+ output_hidden_states=output_hidden_states,
1340
+ return_dict=return_dict,
1341
+ )
1342
+
1343
+ pooled_output = outputs[1]
1344
+
1345
+ pooled_output = self.dropout(pooled_output)
1346
+ logits = self.classifier(pooled_output)
1347
+ reshaped_logits = logits.view(-1, num_choices)
1348
+
1349
+ loss = None
1350
+ if labels is not None:
1351
+ loss_fct = CrossEntropyLoss()
1352
+ loss = loss_fct(reshaped_logits, labels)
1353
+
1354
+ if not return_dict:
1355
+ output = (reshaped_logits,) + outputs[2:]
1356
+ return ((loss,) + output) if loss is not None else output
1357
+
1358
+ return MultipleChoiceModelOutput(
1359
+ loss=loss,
1360
+ logits=reshaped_logits,
1361
+ hidden_states=outputs.hidden_states,
1362
+ attentions=outputs.attentions,
1363
+ )
modeling_deberta_v2.py ADDED
@@ -0,0 +1,1617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch DeBERTa-v2 model."""
16
+
17
+ import math
18
+ from collections.abc import Sequence
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ MaskedLMOutput,
30
+ MultipleChoiceModelOutput,
31
+ QuestionAnsweringModelOutput,
32
+ SequenceClassifierOutput,
33
+ TokenClassifierOutput,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.pytorch_utils import softmax_backward_data
37
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from transformers import DebertaV2Config
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "DebertaV2Config"
44
+ _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
45
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
46
+
47
+ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
48
+ "microsoft/deberta-v2-xlarge",
49
+ "microsoft/deberta-v2-xxlarge",
50
+ "microsoft/deberta-v2-xlarge-mnli",
51
+ "microsoft/deberta-v2-xxlarge-mnli",
52
+ ]
53
+
54
+
55
+ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
56
+ class ContextPooler(nn.Module):
57
+ def __init__(self, config):
58
+ super().__init__()
59
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
60
+ self.dropout = StableDropout(config.pooler_dropout)
61
+ self.config = config
62
+
63
+ def forward(self, hidden_states):
64
+ # We "pool" the model by simply taking the hidden state corresponding
65
+ # to the first token.
66
+
67
+ context_token = hidden_states[:, 0]
68
+ context_token = self.dropout(context_token)
69
+ pooled_output = self.dense(context_token)
70
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
71
+ return pooled_output
72
+
73
+ @property
74
+ def output_dim(self):
75
+ return self.config.hidden_size
76
+
77
+
78
+ # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
79
+ class XSoftmax(torch.autograd.Function):
80
+ """
81
+ Masked Softmax which is optimized for saving memory
82
+
83
+ Args:
84
+ input (`torch.tensor`): The input tensor that will apply softmax.
85
+ mask (`torch.IntTensor`):
86
+ The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
87
+ dim (int): The dimension that will apply softmax
88
+
89
+ Example:
90
+
91
+ ```python
92
+ >>> import torch
93
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
94
+
95
+ >>> # Make a tensor
96
+ >>> x = torch.randn([4, 20, 100])
97
+
98
+ >>> # Create a mask
99
+ >>> mask = (x > 0).int()
100
+
101
+ >>> # Specify the dimension to apply softmax
102
+ >>> dim = -1
103
+
104
+ >>> y = XSoftmax.apply(x, mask, dim)
105
+ ```"""
106
+
107
+ @staticmethod
108
+ def forward(self, input, mask, dim):
109
+ self.dim = dim
110
+ rmask = ~(mask.to(torch.bool))
111
+
112
+ output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
113
+ output = torch.softmax(output, self.dim)
114
+ output.masked_fill_(rmask, 0)
115
+ self.save_for_backward(output)
116
+ return output
117
+
118
+ @staticmethod
119
+ def backward(self, grad_output):
120
+ (output,) = self.saved_tensors
121
+ inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
122
+ return inputGrad, None, None
123
+
124
+ @staticmethod
125
+ def symbolic(g, self, mask, dim):
126
+ import torch.onnx.symbolic_helper as sym_help
127
+ from torch.onnx.symbolic_opset9 import masked_fill, softmax
128
+
129
+ mask_cast_value = g.op("Cast", mask, to_i=sym_help.cast_pytorch_to_onnx["Long"])
130
+ r_mask = g.op(
131
+ "Cast",
132
+ g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
133
+ to_i=sym_help.cast_pytorch_to_onnx["Byte"],
134
+ )
135
+ output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
136
+ output = softmax(g, output, dim)
137
+ return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
138
+
139
+
140
+ # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
141
+ class DropoutContext(object):
142
+ def __init__(self):
143
+ self.dropout = 0
144
+ self.mask = None
145
+ self.scale = 1
146
+ self.reuse_mask = True
147
+
148
+
149
+ # Copied from transformers.models.deberta.modeling_deberta.get_mask
150
+ def get_mask(input, local_context):
151
+ if not isinstance(local_context, DropoutContext):
152
+ dropout = local_context
153
+ mask = None
154
+ else:
155
+ dropout = local_context.dropout
156
+ dropout *= local_context.scale
157
+ mask = local_context.mask if local_context.reuse_mask else None
158
+
159
+ if dropout > 0 and mask is None:
160
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).to(torch.bool)
161
+
162
+ if isinstance(local_context, DropoutContext):
163
+ if local_context.mask is None:
164
+ local_context.mask = mask
165
+
166
+ return mask, dropout
167
+
168
+
169
+ # Copied from transformers.models.deberta.modeling_deberta.XDropout
170
+ class XDropout(torch.autograd.Function):
171
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
172
+
173
+ @staticmethod
174
+ def forward(ctx, input, local_ctx):
175
+ mask, dropout = get_mask(input, local_ctx)
176
+ ctx.scale = 1.0 / (1 - dropout)
177
+ if dropout > 0:
178
+ ctx.save_for_backward(mask)
179
+ return input.masked_fill(mask, 0) * ctx.scale
180
+ else:
181
+ return input
182
+
183
+ @staticmethod
184
+ def backward(ctx, grad_output):
185
+ if ctx.scale > 1:
186
+ (mask,) = ctx.saved_tensors
187
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
188
+ else:
189
+ return grad_output, None
190
+
191
+
192
+ # Copied from transformers.models.deberta.modeling_deberta.StableDropout
193
+ class StableDropout(nn.Module):
194
+ """
195
+ Optimized dropout module for stabilizing the training
196
+
197
+ Args:
198
+ drop_prob (float): the dropout probabilities
199
+ """
200
+
201
+ def __init__(self, drop_prob):
202
+ super().__init__()
203
+ self.drop_prob = drop_prob
204
+ self.count = 0
205
+ self.context_stack = None
206
+
207
+ def forward(self, x):
208
+ """
209
+ Call the module
210
+
211
+ Args:
212
+ x (`torch.tensor`): The input tensor to apply dropout
213
+ """
214
+ if self.training and self.drop_prob > 0:
215
+ return XDropout.apply(x, self.get_context())
216
+ return x
217
+
218
+ def clear_context(self):
219
+ self.count = 0
220
+ self.context_stack = None
221
+
222
+ def init_context(self, reuse_mask=True, scale=1):
223
+ if self.context_stack is None:
224
+ self.context_stack = []
225
+ self.count = 0
226
+ for c in self.context_stack:
227
+ c.reuse_mask = reuse_mask
228
+ c.scale = scale
229
+
230
+ def get_context(self):
231
+ if self.context_stack is not None:
232
+ if self.count >= len(self.context_stack):
233
+ self.context_stack.append(DropoutContext())
234
+ ctx = self.context_stack[self.count]
235
+ ctx.dropout = self.drop_prob
236
+ self.count += 1
237
+ return ctx
238
+ else:
239
+ return self.drop_prob
240
+
241
+
242
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
243
+ class DebertaV2SelfOutput(nn.Module):
244
+ def __init__(self, config):
245
+ super().__init__()
246
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
247
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
248
+ self.dropout = StableDropout(config.hidden_dropout_prob)
249
+
250
+ def forward(self, hidden_states, input_tensor):
251
+ hidden_states = self.dense(hidden_states)
252
+ hidden_states = self.dropout(hidden_states)
253
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
254
+ return hidden_states
255
+
256
+
257
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
258
+ class DebertaV2Attention(nn.Module):
259
+ def __init__(self, config):
260
+ super().__init__()
261
+ self.self = DisentangledSelfAttention(config)
262
+ self.output = DebertaV2SelfOutput(config)
263
+ self.config = config
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states,
268
+ attention_mask,
269
+ output_attentions=False,
270
+ query_states=None,
271
+ relative_pos=None,
272
+ rel_embeddings=None,
273
+ ):
274
+ self_output = self.self(
275
+ hidden_states,
276
+ attention_mask,
277
+ output_attentions,
278
+ query_states=query_states,
279
+ relative_pos=relative_pos,
280
+ rel_embeddings=rel_embeddings,
281
+ )
282
+ if output_attentions:
283
+ self_output, att_matrix = self_output
284
+ if query_states is None:
285
+ query_states = hidden_states
286
+ attention_output = self.output(self_output, query_states)
287
+
288
+ if output_attentions:
289
+ return (attention_output, att_matrix)
290
+ else:
291
+ return attention_output
292
+
293
+
294
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
295
+ class DebertaV2Intermediate(nn.Module):
296
+ def __init__(self, config):
297
+ super().__init__()
298
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
299
+ if isinstance(config.hidden_act, str):
300
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
301
+ else:
302
+ self.intermediate_act_fn = config.hidden_act
303
+
304
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
305
+ hidden_states = self.dense(hidden_states)
306
+ hidden_states = self.intermediate_act_fn(hidden_states)
307
+ return hidden_states
308
+
309
+
310
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
311
+ class DebertaV2Output(nn.Module):
312
+ def __init__(self, config):
313
+ super().__init__()
314
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
315
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
316
+ self.dropout = StableDropout(config.hidden_dropout_prob)
317
+ self.config = config
318
+
319
+ def forward(self, hidden_states, input_tensor):
320
+ hidden_states = self.dense(hidden_states)
321
+ hidden_states = self.dropout(hidden_states)
322
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
323
+ return hidden_states
324
+
325
+
326
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
327
+ class DebertaV2Layer(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.attention = DebertaV2Attention(config)
331
+ self.intermediate = DebertaV2Intermediate(config)
332
+ self.output = DebertaV2Output(config)
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states,
337
+ attention_mask,
338
+ query_states=None,
339
+ relative_pos=None,
340
+ rel_embeddings=None,
341
+ output_attentions=False,
342
+ ):
343
+ attention_output = self.attention(
344
+ hidden_states,
345
+ attention_mask,
346
+ output_attentions=output_attentions,
347
+ query_states=query_states,
348
+ relative_pos=relative_pos,
349
+ rel_embeddings=rel_embeddings,
350
+ )
351
+ if output_attentions:
352
+ attention_output, att_matrix = attention_output
353
+ intermediate_output = self.intermediate(attention_output)
354
+ layer_output = self.output(intermediate_output, attention_output)
355
+ if output_attentions:
356
+ return (layer_output, att_matrix)
357
+ else:
358
+ return layer_output
359
+
360
+
361
+ class ConvLayer(nn.Module):
362
+ def __init__(self, config):
363
+ super().__init__()
364
+ kernel_size = getattr(config, "conv_kernel_size", 3)
365
+ groups = getattr(config, "conv_groups", 1)
366
+ self.conv_act = getattr(config, "conv_act", "tanh")
367
+ self.conv = nn.Conv1d(
368
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
369
+ )
370
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
371
+ self.dropout = StableDropout(config.hidden_dropout_prob)
372
+ self.config = config
373
+
374
+ def forward(self, hidden_states, residual_states, input_mask):
375
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
376
+ rmask = (1 - input_mask).bool()
377
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
378
+ out = ACT2FN[self.conv_act](self.dropout(out))
379
+
380
+ layer_norm_input = residual_states + out
381
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
382
+
383
+ if input_mask is None:
384
+ output_states = output
385
+ else:
386
+ if input_mask.dim() != layer_norm_input.dim():
387
+ if input_mask.dim() == 4:
388
+ input_mask = input_mask.squeeze(1).squeeze(1)
389
+ input_mask = input_mask.unsqueeze(2)
390
+
391
+ input_mask = input_mask.to(output.dtype)
392
+ output_states = output * input_mask
393
+
394
+ return output_states
395
+
396
+
397
+ class DebertaV2Encoder(nn.Module):
398
+ """Modified BertEncoder with relative position bias support"""
399
+
400
+ def __init__(self, config):
401
+ super().__init__()
402
+
403
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
404
+ self.relative_attention = getattr(config, "relative_attention", False)
405
+
406
+ if self.relative_attention:
407
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
408
+ if self.max_relative_positions < 1:
409
+ self.max_relative_positions = config.max_position_embeddings
410
+
411
+ self.position_buckets = getattr(config, "position_buckets", -1)
412
+ pos_ebd_size = self.max_relative_positions * 2
413
+
414
+ if self.position_buckets > 0:
415
+ pos_ebd_size = self.position_buckets * 2
416
+
417
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
418
+
419
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
420
+
421
+ if "layer_norm" in self.norm_rel_ebd:
422
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
423
+
424
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
425
+ self.gradient_checkpointing = False
426
+
427
+ def get_rel_embedding(self):
428
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
429
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
430
+ rel_embeddings = self.LayerNorm(rel_embeddings)
431
+ return rel_embeddings
432
+
433
+ def get_attention_mask(self, attention_mask):
434
+ if attention_mask.dim() <= 2:
435
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
436
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
437
+ attention_mask = attention_mask.byte()
438
+ elif attention_mask.dim() == 3:
439
+ attention_mask = attention_mask.unsqueeze(1)
440
+
441
+ return attention_mask
442
+
443
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
444
+ if self.relative_attention and relative_pos is None:
445
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
446
+ relative_pos = build_relative_position(
447
+ q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
448
+ )
449
+ return relative_pos
450
+
451
+ def forward(
452
+ self,
453
+ hidden_states,
454
+ attention_mask,
455
+ output_hidden_states=True,
456
+ output_attentions=False,
457
+ query_states=None,
458
+ relative_pos=None,
459
+ return_dict=True,
460
+ ):
461
+ if attention_mask.dim() <= 2:
462
+ input_mask = attention_mask
463
+ else:
464
+ input_mask = (attention_mask.sum(-2) > 0).byte()
465
+ attention_mask = self.get_attention_mask(attention_mask)
466
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
467
+
468
+ all_hidden_states = () if output_hidden_states else None
469
+ all_attentions = () if output_attentions else None
470
+
471
+ if isinstance(hidden_states, Sequence):
472
+ next_kv = hidden_states[0]
473
+ else:
474
+ next_kv = hidden_states
475
+ rel_embeddings = self.get_rel_embedding()
476
+ output_states = next_kv
477
+ for i, layer_module in enumerate(self.layer):
478
+
479
+ if output_hidden_states:
480
+ all_hidden_states = all_hidden_states + (output_states,)
481
+
482
+ if self.gradient_checkpointing and self.training:
483
+
484
+ def create_custom_forward(module):
485
+ def custom_forward(*inputs):
486
+ return module(*inputs, output_attentions)
487
+
488
+ return custom_forward
489
+
490
+ output_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(layer_module),
492
+ next_kv,
493
+ attention_mask,
494
+ query_states,
495
+ relative_pos,
496
+ rel_embeddings,
497
+ )
498
+ else:
499
+ output_states = layer_module(
500
+ next_kv,
501
+ attention_mask,
502
+ query_states=query_states,
503
+ relative_pos=relative_pos,
504
+ rel_embeddings=rel_embeddings,
505
+ output_attentions=output_attentions,
506
+ )
507
+
508
+ if output_attentions:
509
+ output_states, att_m = output_states
510
+
511
+ if i == 0 and self.conv is not None:
512
+ output_states = self.conv(hidden_states, output_states, input_mask)
513
+
514
+ if query_states is not None:
515
+ query_states = output_states
516
+ if isinstance(hidden_states, Sequence):
517
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
518
+ else:
519
+ next_kv = output_states
520
+
521
+ if output_attentions:
522
+ all_attentions = all_attentions + (att_m,)
523
+
524
+ if output_hidden_states:
525
+ all_hidden_states = all_hidden_states + (output_states,)
526
+
527
+ if not return_dict:
528
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
529
+ return BaseModelOutput(
530
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
531
+ )
532
+
533
+
534
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
535
+ sign = np.sign(relative_pos)
536
+ mid = bucket_size // 2
537
+ abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
538
+ log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
539
+ bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
540
+ return bucket_pos
541
+
542
+
543
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
544
+ """
545
+ Build relative position according to the query and key
546
+
547
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
548
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
549
+ P_k\\)
550
+
551
+ Args:
552
+ query_size (int): the length of query
553
+ key_size (int): the length of key
554
+ bucket_size (int): the size of position bucket
555
+ max_position (int): the maximum allowed absolute position
556
+
557
+ Return:
558
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
559
+
560
+ """
561
+ q_ids = np.arange(0, query_size)
562
+ k_ids = np.arange(0, key_size)
563
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
564
+ if bucket_size > 0 and max_position > 0:
565
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
566
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
567
+ rel_pos_ids = rel_pos_ids[:query_size, :]
568
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
569
+ return rel_pos_ids
570
+
571
+
572
+ @torch.jit.script
573
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
574
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
575
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
576
+
577
+
578
+ @torch.jit.script
579
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
580
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
581
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
582
+
583
+
584
+ @torch.jit.script
585
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
586
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
587
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
588
+
589
+
590
+ class DisentangledSelfAttention(nn.Module):
591
+ """
592
+ Disentangled self-attention module
593
+
594
+ Parameters:
595
+ config (`DebertaV2Config`):
596
+ A model config class instance with the configuration to build a new model. The schema is similar to
597
+ *BertConfig*, for more details, please refer [`DebertaV2Config`]
598
+
599
+ """
600
+
601
+ def __init__(self, config):
602
+ super().__init__()
603
+ if config.hidden_size % config.num_attention_heads != 0:
604
+ raise ValueError(
605
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
606
+ f"heads ({config.num_attention_heads})"
607
+ )
608
+ self.num_attention_heads = config.num_attention_heads
609
+ _attention_head_size = config.hidden_size // config.num_attention_heads
610
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
611
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
612
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
613
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
614
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
615
+
616
+ self.share_att_key = getattr(config, "share_att_key", False)
617
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
618
+ self.relative_attention = getattr(config, "relative_attention", False)
619
+
620
+ if self.relative_attention:
621
+ self.position_buckets = getattr(config, "position_buckets", -1)
622
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
623
+ if self.max_relative_positions < 1:
624
+ self.max_relative_positions = config.max_position_embeddings
625
+ self.pos_ebd_size = self.max_relative_positions
626
+ if self.position_buckets > 0:
627
+ self.pos_ebd_size = self.position_buckets
628
+
629
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
630
+
631
+ if not self.share_att_key:
632
+ if "c2p" in self.pos_att_type:
633
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
634
+ if "p2c" in self.pos_att_type:
635
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
636
+
637
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
638
+
639
+ def transpose_for_scores(self, x, attention_heads):
640
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
641
+ x = x.view(new_x_shape)
642
+ return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
643
+
644
+ def forward(
645
+ self,
646
+ hidden_states,
647
+ attention_mask,
648
+ output_attentions=False,
649
+ query_states=None,
650
+ relative_pos=None,
651
+ rel_embeddings=None,
652
+ ):
653
+ """
654
+ Call the module
655
+
656
+ Args:
657
+ hidden_states (`torch.FloatTensor`):
658
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
659
+ *Attention(Q,K,V)*
660
+
661
+ attention_mask (`torch.ByteTensor`):
662
+ An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
663
+ sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
664
+ th token.
665
+
666
+ output_attentions (`bool`, optional):
667
+ Whether return the attention matrix.
668
+
669
+ query_states (`torch.FloatTensor`, optional):
670
+ The *Q* state in *Attention(Q,K,V)*.
671
+
672
+ relative_pos (`torch.LongTensor`):
673
+ The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
674
+ values ranging in [*-max_relative_positions*, *max_relative_positions*].
675
+
676
+ rel_embeddings (`torch.FloatTensor`):
677
+ The embedding of relative distances. It's a tensor of shape [\\(2 \\times
678
+ \\text{max_relative_positions}\\), *hidden_size*].
679
+
680
+
681
+ """
682
+ if query_states is None:
683
+ query_states = hidden_states
684
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
685
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
686
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
687
+
688
+ rel_att = None
689
+ # Take the dot product between "query" and "key" to get the raw attention scores.
690
+ scale_factor = 1
691
+ if "c2p" in self.pos_att_type:
692
+ scale_factor += 1
693
+ if "p2c" in self.pos_att_type:
694
+ scale_factor += 1
695
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
696
+ attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
697
+ if self.relative_attention:
698
+ rel_embeddings = self.pos_dropout(rel_embeddings)
699
+ rel_att = self.disentangled_attention_bias(
700
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
701
+ )
702
+
703
+ if rel_att is not None:
704
+ attention_scores = attention_scores + rel_att
705
+ attention_scores = attention_scores
706
+ attention_scores = attention_scores.view(
707
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
708
+ )
709
+
710
+ # bsz x height x length x dimension
711
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
712
+ attention_probs = self.dropout(attention_probs)
713
+ context_layer = torch.bmm(
714
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
715
+ )
716
+ context_layer = (
717
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
718
+ .permute(0, 2, 1, 3)
719
+ .contiguous()
720
+ )
721
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
722
+ context_layer = context_layer.view(new_context_layer_shape)
723
+ if output_attentions:
724
+ return (context_layer, attention_probs)
725
+ else:
726
+ return context_layer
727
+
728
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
729
+ if relative_pos is None:
730
+ q = query_layer.size(-2)
731
+ relative_pos = build_relative_position(
732
+ q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
733
+ )
734
+ if relative_pos.dim() == 2:
735
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
736
+ elif relative_pos.dim() == 3:
737
+ relative_pos = relative_pos.unsqueeze(1)
738
+ # bsz x height x query x key
739
+ elif relative_pos.dim() != 4:
740
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
741
+
742
+ att_span = self.pos_ebd_size
743
+ relative_pos = relative_pos.long().to(query_layer.device)
744
+
745
+ rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0)
746
+ if self.share_att_key:
747
+ pos_query_layer = self.transpose_for_scores(
748
+ self.query_proj(rel_embeddings), self.num_attention_heads
749
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
750
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
751
+ query_layer.size(0) // self.num_attention_heads, 1, 1
752
+ )
753
+ else:
754
+ if "c2p" in self.pos_att_type:
755
+ pos_key_layer = self.transpose_for_scores(
756
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
757
+ ).repeat(
758
+ query_layer.size(0) // self.num_attention_heads, 1, 1
759
+ ) # .split(self.all_head_size, dim=-1)
760
+ if "p2c" in self.pos_att_type:
761
+ pos_query_layer = self.transpose_for_scores(
762
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
763
+ ).repeat(
764
+ query_layer.size(0) // self.num_attention_heads, 1, 1
765
+ ) # .split(self.all_head_size, dim=-1)
766
+
767
+ score = 0
768
+ # content->position
769
+ if "c2p" in self.pos_att_type:
770
+ scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
771
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
772
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
773
+ c2p_att = torch.gather(
774
+ c2p_att,
775
+ dim=-1,
776
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
777
+ )
778
+ score += c2p_att / scale
779
+
780
+ # position->content
781
+ if "p2c" in self.pos_att_type:
782
+ scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
783
+ if key_layer.size(-2) != query_layer.size(-2):
784
+ r_pos = build_relative_position(
785
+ key_layer.size(-2),
786
+ key_layer.size(-2),
787
+ bucket_size=self.position_buckets,
788
+ max_position=self.max_relative_positions,
789
+ ).to(query_layer.device)
790
+ r_pos = r_pos.unsqueeze(0)
791
+ else:
792
+ r_pos = relative_pos
793
+
794
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
795
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
796
+ p2c_att = torch.gather(
797
+ p2c_att,
798
+ dim=-1,
799
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
800
+ ).transpose(-1, -2)
801
+ score += p2c_att / scale
802
+
803
+ return score
804
+
805
+
806
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
807
+ class DebertaV2Embeddings(nn.Module):
808
+ """Construct the embeddings from word, position and token_type embeddings."""
809
+
810
+ def __init__(self, config):
811
+ super().__init__()
812
+ pad_token_id = getattr(config, "pad_token_id", 0)
813
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
814
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
815
+
816
+ self.position_biased_input = getattr(config, "position_biased_input", True)
817
+ if not self.position_biased_input:
818
+ self.position_embeddings = None
819
+ else:
820
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
821
+
822
+ if config.type_vocab_size > 0:
823
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
824
+
825
+ if self.embedding_size != config.hidden_size:
826
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
827
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
828
+ self.dropout = StableDropout(config.hidden_dropout_prob)
829
+ self.config = config
830
+
831
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
832
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
833
+
834
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
835
+ if input_ids is not None:
836
+ input_shape = input_ids.size()
837
+ else:
838
+ input_shape = inputs_embeds.size()[:-1]
839
+
840
+ seq_length = input_shape[1]
841
+
842
+ if position_ids is None:
843
+ position_ids = self.position_ids[:, :seq_length]
844
+
845
+ if token_type_ids is None:
846
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
847
+
848
+ if inputs_embeds is None:
849
+ inputs_embeds = self.word_embeddings(input_ids)
850
+
851
+ if self.position_embeddings is not None:
852
+ position_embeddings = self.position_embeddings(position_ids.long())
853
+ else:
854
+ position_embeddings = torch.zeros_like(inputs_embeds)
855
+
856
+ embeddings = inputs_embeds
857
+ if self.position_biased_input:
858
+ embeddings += position_embeddings
859
+ if self.config.type_vocab_size > 0:
860
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
861
+ embeddings += token_type_embeddings
862
+
863
+ if self.embedding_size != self.config.hidden_size:
864
+ embeddings = self.embed_proj(embeddings)
865
+
866
+ embeddings = self.LayerNorm(embeddings)
867
+
868
+ # if mask is not None:
869
+ # if mask.dim() != embeddings.dim():
870
+ # if mask.dim() == 4:
871
+ # mask = mask.squeeze(1).squeeze(1)
872
+ # mask = mask.unsqueeze(2)
873
+ # mask = mask.to(embeddings.dtype)
874
+
875
+ # embeddings = embeddings * mask
876
+
877
+ embeddings = self.dropout(embeddings)
878
+ return embeddings
879
+
880
+
881
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
882
+ class DebertaV2PreTrainedModel(PreTrainedModel):
883
+ """
884
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
885
+ models.
886
+ """
887
+
888
+ config_class = DebertaV2Config
889
+ base_model_prefix = "deberta"
890
+ _keys_to_ignore_on_load_missing = ["position_ids"]
891
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
892
+ supports_gradient_checkpointing = True
893
+
894
+ def _init_weights(self, module):
895
+ """Initialize the weights."""
896
+ if isinstance(module, nn.Linear):
897
+ # Slightly different from the TF version which uses truncated_normal for initialization
898
+ # cf https://github.com/pytorch/pytorch/pull/5617
899
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
900
+ if module.bias is not None:
901
+ module.bias.data.zero_()
902
+ elif isinstance(module, nn.Embedding):
903
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
904
+ if module.padding_idx is not None:
905
+ module.weight.data[module.padding_idx].zero_()
906
+
907
+ def _set_gradient_checkpointing(self, module, value=False):
908
+ if isinstance(module, DebertaV2Encoder):
909
+ module.gradient_checkpointing = value
910
+
911
+
912
+ DEBERTA_START_DOCSTRING = r"""
913
+ The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
914
+ Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build
915
+ on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
916
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
917
+
918
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
919
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
920
+ and behavior.```
921
+
922
+
923
+ Parameters:
924
+ config ([`DebertaV2Config`]): Model configuration class with all the parameters of the model.
925
+ Initializing with a config file does not load the weights associated with the model, only the
926
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
927
+ """
928
+
929
+ DEBERTA_INPUTS_DOCSTRING = r"""
930
+ Args:
931
+ input_ids (`torch.LongTensor` of shape `({0})`):
932
+ Indices of input sequence tokens in the vocabulary.
933
+
934
+ Indices can be obtained using [`DebertaV2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
935
+ [`PreTrainedTokenizer.__call__`] for details.
936
+
937
+ [What are input IDs?](../glossary#input-ids)
938
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
939
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
940
+
941
+ - 1 for tokens that are **not masked**,
942
+ - 0 for tokens that are **masked**.
943
+
944
+ [What are attention masks?](../glossary#attention-mask)
945
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
946
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
947
+ 1]`:
948
+
949
+ - 0 corresponds to a *sentence A* token,
950
+ - 1 corresponds to a *sentence B* token.
951
+
952
+ [What are token type IDs?](../glossary#token-type-ids)
953
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
954
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
955
+ config.max_position_embeddings - 1]`.
956
+
957
+ [What are position IDs?](../glossary#position-ids)
958
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
959
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
960
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
961
+ model's internal embedding lookup matrix.
962
+ output_attentions (`bool`, *optional*):
963
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
964
+ tensors for more detail.
965
+ output_hidden_states (`bool`, *optional*):
966
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
967
+ more detail.
968
+ return_dict (`bool`, *optional*):
969
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
970
+ """
971
+
972
+
973
+ @add_start_docstrings(
974
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
975
+ DEBERTA_START_DOCSTRING,
976
+ )
977
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
978
+ class DebertaV2Model(DebertaV2PreTrainedModel):
979
+ def __init__(self, config):
980
+ super().__init__(config)
981
+
982
+ self.embeddings = DebertaV2Embeddings(config)
983
+ self.encoder = DebertaV2Encoder(config)
984
+ self.z_steps = 2
985
+ self.config = config
986
+ # Initialize weights and apply final processing
987
+ self.post_init()
988
+
989
+ def get_input_embeddings(self):
990
+ return self.embeddings.word_embeddings
991
+
992
+ def set_input_embeddings(self, new_embeddings):
993
+ self.embeddings.word_embeddings = new_embeddings
994
+
995
+ def _prune_heads(self, heads_to_prune):
996
+ """
997
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
998
+ class PreTrainedModel
999
+ """
1000
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
1001
+
1002
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1003
+ @add_code_sample_docstrings(
1004
+ processor_class=_TOKENIZER_FOR_DOC,
1005
+ checkpoint=_CHECKPOINT_FOR_DOC,
1006
+ output_type=BaseModelOutput,
1007
+ config_class=_CONFIG_FOR_DOC,
1008
+ )
1009
+ def forward(
1010
+ self,
1011
+ input_ids: Optional[torch.Tensor] = None,
1012
+ attention_mask: Optional[torch.Tensor] = None,
1013
+ token_type_ids: Optional[torch.Tensor] = None,
1014
+ position_ids: Optional[torch.Tensor] = None,
1015
+ inputs_embeds: Optional[torch.Tensor] = None,
1016
+ output_attentions: Optional[bool] = None,
1017
+ output_hidden_states: Optional[bool] = None,
1018
+ return_dict: Optional[bool] = None,
1019
+ ) -> Union[Tuple, BaseModelOutput]:
1020
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1021
+ output_hidden_states = (
1022
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1023
+ )
1024
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1025
+
1026
+ if input_ids is not None and inputs_embeds is not None:
1027
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1028
+ elif input_ids is not None:
1029
+ input_shape = input_ids.size()
1030
+ elif inputs_embeds is not None:
1031
+ input_shape = inputs_embeds.size()[:-1]
1032
+ else:
1033
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1034
+
1035
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1036
+
1037
+ if attention_mask is None:
1038
+ attention_mask = torch.ones(input_shape, device=device)
1039
+ if token_type_ids is None:
1040
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1041
+
1042
+ embedding_output = self.embeddings(
1043
+ input_ids=input_ids,
1044
+ token_type_ids=token_type_ids,
1045
+ position_ids=position_ids,
1046
+ mask=attention_mask,
1047
+ inputs_embeds=inputs_embeds,
1048
+ )
1049
+
1050
+ encoder_outputs = self.encoder(
1051
+ embedding_output,
1052
+ attention_mask,
1053
+ output_hidden_states=True,
1054
+ output_attentions=output_attentions,
1055
+ return_dict=return_dict,
1056
+ )
1057
+ encoded_layers = encoder_outputs[1]
1058
+
1059
+ if self.z_steps > 1:
1060
+ hidden_states = encoded_layers[-2]
1061
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
1062
+ query_states = encoded_layers[-1]
1063
+ rel_embeddings = self.encoder.get_rel_embedding()
1064
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
1065
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
1066
+ for layer in layers[1:]:
1067
+ query_states = layer(
1068
+ hidden_states,
1069
+ attention_mask,
1070
+ output_attentions=False,
1071
+ query_states=query_states,
1072
+ relative_pos=rel_pos,
1073
+ rel_embeddings=rel_embeddings,
1074
+ )
1075
+ encoded_layers = encoded_layers + (query_states,)
1076
+
1077
+ sequence_output = encoded_layers[-1]
1078
+
1079
+ if not return_dict:
1080
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
1081
+
1082
+ return BaseModelOutput(
1083
+ last_hidden_state=sequence_output,
1084
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
1085
+ attentions=encoder_outputs.attentions,
1086
+ )
1087
+
1088
+
1089
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
1090
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
1091
+ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1092
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1093
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1094
+
1095
+ def __init__(self, config):
1096
+ super().__init__(config)
1097
+
1098
+ self.deberta = DebertaV2Model(config)
1099
+ self.cls = DebertaV2OnlyMLMHead(config)
1100
+
1101
+ # Initialize weights and apply final processing
1102
+ self.post_init()
1103
+
1104
+ def get_output_embeddings(self):
1105
+ return self.cls.predictions.decoder
1106
+
1107
+ def set_output_embeddings(self, new_embeddings):
1108
+ self.cls.predictions.decoder = new_embeddings
1109
+
1110
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1111
+ @add_code_sample_docstrings(
1112
+ processor_class=_TOKENIZER_FOR_DOC,
1113
+ checkpoint=_CHECKPOINT_FOR_DOC,
1114
+ output_type=MaskedLMOutput,
1115
+ config_class=_CONFIG_FOR_DOC,
1116
+ )
1117
+ def forward(
1118
+ self,
1119
+ input_ids: Optional[torch.Tensor] = None,
1120
+ attention_mask: Optional[torch.Tensor] = None,
1121
+ token_type_ids: Optional[torch.Tensor] = None,
1122
+ position_ids: Optional[torch.Tensor] = None,
1123
+ inputs_embeds: Optional[torch.Tensor] = None,
1124
+ labels: Optional[torch.Tensor] = None,
1125
+ output_attentions: Optional[bool] = None,
1126
+ output_hidden_states: Optional[bool] = None,
1127
+ return_dict: Optional[bool] = None,
1128
+ ) -> Union[Tuple, MaskedLMOutput]:
1129
+ r"""
1130
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1131
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1132
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1133
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1134
+ """
1135
+
1136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1137
+
1138
+ outputs = self.deberta(
1139
+ input_ids,
1140
+ attention_mask=attention_mask,
1141
+ token_type_ids=token_type_ids,
1142
+ position_ids=position_ids,
1143
+ inputs_embeds=inputs_embeds,
1144
+ output_attentions=output_attentions,
1145
+ output_hidden_states=output_hidden_states,
1146
+ return_dict=return_dict,
1147
+ )
1148
+
1149
+ sequence_output = outputs[0]
1150
+ prediction_scores = self.cls(sequence_output)
1151
+
1152
+ masked_lm_loss = None
1153
+ if labels is not None:
1154
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1155
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1156
+
1157
+ if not return_dict:
1158
+ output = (prediction_scores,) + outputs[1:]
1159
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1160
+
1161
+ return MaskedLMOutput(
1162
+ loss=masked_lm_loss,
1163
+ logits=prediction_scores,
1164
+ hidden_states=outputs.hidden_states,
1165
+ attentions=outputs.attentions,
1166
+ )
1167
+
1168
+
1169
+ # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
1170
+ class DebertaV2PredictionHeadTransform(nn.Module):
1171
+ def __init__(self, config):
1172
+ super().__init__()
1173
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1174
+ if isinstance(config.hidden_act, str):
1175
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1176
+ else:
1177
+ self.transform_act_fn = config.hidden_act
1178
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1179
+
1180
+ def forward(self, hidden_states):
1181
+ hidden_states = self.dense(hidden_states)
1182
+ hidden_states = self.transform_act_fn(hidden_states)
1183
+ hidden_states = self.LayerNorm(hidden_states)
1184
+ return hidden_states
1185
+
1186
+
1187
+ # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
1188
+ class DebertaV2LMPredictionHead(nn.Module):
1189
+ def __init__(self, config):
1190
+ super().__init__()
1191
+ self.transform = DebertaV2PredictionHeadTransform(config)
1192
+
1193
+ # The output weights are the same as the input embeddings, but there is
1194
+ # an output-only bias for each token.
1195
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1196
+
1197
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1198
+
1199
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1200
+ self.decoder.bias = self.bias
1201
+
1202
+ def forward(self, hidden_states):
1203
+ hidden_states = self.transform(hidden_states)
1204
+ hidden_states = self.decoder(hidden_states)
1205
+ return hidden_states
1206
+
1207
+
1208
+ # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
1209
+ class DebertaV2OnlyMLMHead(nn.Module):
1210
+ def __init__(self, config):
1211
+ super().__init__()
1212
+ self.predictions = DebertaV2LMPredictionHead(config)
1213
+
1214
+ def forward(self, sequence_output):
1215
+ prediction_scores = self.predictions(sequence_output)
1216
+ return prediction_scores
1217
+
1218
+
1219
+ @add_start_docstrings(
1220
+ """
1221
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1222
+ pooled output) e.g. for GLUE tasks.
1223
+ """,
1224
+ DEBERTA_START_DOCSTRING,
1225
+ )
1226
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
1227
+ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
1228
+ def __init__(self, config):
1229
+ super().__init__(config)
1230
+
1231
+ num_labels = getattr(config, "num_labels", 2)
1232
+ self.num_labels = num_labels
1233
+
1234
+ self.deberta = DebertaV2Model(config)
1235
+ self.pooler = ContextPooler(config)
1236
+ output_dim = self.pooler.output_dim
1237
+
1238
+ self.classifier = nn.Linear(output_dim, num_labels)
1239
+ drop_out = getattr(config, "cls_dropout", None)
1240
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1241
+ self.dropout = StableDropout(drop_out)
1242
+
1243
+ # Initialize weights and apply final processing
1244
+ self.post_init()
1245
+
1246
+ def get_input_embeddings(self):
1247
+ return self.deberta.get_input_embeddings()
1248
+
1249
+ def set_input_embeddings(self, new_embeddings):
1250
+ self.deberta.set_input_embeddings(new_embeddings)
1251
+
1252
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1253
+ @add_code_sample_docstrings(
1254
+ processor_class=_TOKENIZER_FOR_DOC,
1255
+ checkpoint=_CHECKPOINT_FOR_DOC,
1256
+ output_type=SequenceClassifierOutput,
1257
+ config_class=_CONFIG_FOR_DOC,
1258
+ )
1259
+ def forward(
1260
+ self,
1261
+ input_ids: Optional[torch.Tensor] = None,
1262
+ attention_mask: Optional[torch.Tensor] = None,
1263
+ token_type_ids: Optional[torch.Tensor] = None,
1264
+ position_ids: Optional[torch.Tensor] = None,
1265
+ inputs_embeds: Optional[torch.Tensor] = None,
1266
+ labels: Optional[torch.Tensor] = None,
1267
+ output_attentions: Optional[bool] = None,
1268
+ output_hidden_states: Optional[bool] = None,
1269
+ return_dict: Optional[bool] = None,
1270
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1271
+ r"""
1272
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1273
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1274
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1275
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1276
+ """
1277
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1278
+
1279
+ outputs = self.deberta(
1280
+ input_ids,
1281
+ token_type_ids=token_type_ids,
1282
+ attention_mask=attention_mask,
1283
+ position_ids=position_ids,
1284
+ inputs_embeds=inputs_embeds,
1285
+ output_attentions=output_attentions,
1286
+ output_hidden_states=output_hidden_states,
1287
+ return_dict=return_dict,
1288
+ )
1289
+
1290
+ encoder_layer = outputs[0]
1291
+ pooled_output = self.pooler(encoder_layer)
1292
+ pooled_output = self.dropout(pooled_output)
1293
+ logits = self.classifier(pooled_output)
1294
+
1295
+ loss = None
1296
+ if labels is not None:
1297
+ if self.config.problem_type is None:
1298
+ if self.num_labels == 1:
1299
+ # regression task
1300
+ loss_fn = nn.MSELoss()
1301
+ logits = logits.view(-1).to(labels.dtype)
1302
+ loss = loss_fn(logits, labels.view(-1))
1303
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1304
+ label_index = (labels >= 0).nonzero()
1305
+ labels = labels.long()
1306
+ if label_index.size(0) > 0:
1307
+ labeled_logits = torch.gather(
1308
+ logits, 0, label_index.expand(label_index.size(0), logits.size(1))
1309
+ )
1310
+ labels = torch.gather(labels, 0, label_index.view(-1))
1311
+ loss_fct = CrossEntropyLoss()
1312
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1313
+ else:
1314
+ loss = torch.tensor(0).to(logits)
1315
+ else:
1316
+ log_softmax = nn.LogSoftmax(-1)
1317
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1318
+ elif self.config.problem_type == "regression":
1319
+ loss_fct = MSELoss()
1320
+ if self.num_labels == 1:
1321
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1322
+ else:
1323
+ loss = loss_fct(logits, labels)
1324
+ elif self.config.problem_type == "single_label_classification":
1325
+ loss_fct = CrossEntropyLoss()
1326
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1327
+ elif self.config.problem_type == "multi_label_classification":
1328
+ loss_fct = BCEWithLogitsLoss()
1329
+ loss = loss_fct(logits, labels)
1330
+ if not return_dict:
1331
+ output = (logits,) + outputs[1:]
1332
+ return ((loss,) + output) if loss is not None else output
1333
+
1334
+ return SequenceClassifierOutput(
1335
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1336
+ )
1337
+
1338
+
1339
+ @add_start_docstrings(
1340
+ """
1341
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1342
+ Named-Entity-Recognition (NER) tasks.
1343
+ """,
1344
+ DEBERTA_START_DOCSTRING,
1345
+ )
1346
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
1347
+ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
1348
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1349
+
1350
+ def __init__(self, config):
1351
+ super().__init__(config)
1352
+ self.num_labels = config.num_labels
1353
+
1354
+ self.deberta = DebertaV2Model(config)
1355
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1356
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1357
+
1358
+ # Initialize weights and apply final processing
1359
+ self.post_init()
1360
+
1361
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1362
+ @add_code_sample_docstrings(
1363
+ processor_class=_TOKENIZER_FOR_DOC,
1364
+ checkpoint=_CHECKPOINT_FOR_DOC,
1365
+ output_type=TokenClassifierOutput,
1366
+ config_class=_CONFIG_FOR_DOC,
1367
+ )
1368
+ def forward(
1369
+ self,
1370
+ input_ids: Optional[torch.Tensor] = None,
1371
+ attention_mask: Optional[torch.Tensor] = None,
1372
+ token_type_ids: Optional[torch.Tensor] = None,
1373
+ position_ids: Optional[torch.Tensor] = None,
1374
+ inputs_embeds: Optional[torch.Tensor] = None,
1375
+ labels: Optional[torch.Tensor] = None,
1376
+ output_attentions: Optional[bool] = None,
1377
+ output_hidden_states: Optional[bool] = None,
1378
+ return_dict: Optional[bool] = None,
1379
+ ) -> Union[Tuple, TokenClassifierOutput]:
1380
+ r"""
1381
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1382
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1383
+ """
1384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1385
+
1386
+ outputs = self.deberta(
1387
+ input_ids,
1388
+ attention_mask=attention_mask,
1389
+ token_type_ids=token_type_ids,
1390
+ position_ids=position_ids,
1391
+ inputs_embeds=inputs_embeds,
1392
+ output_attentions=output_attentions,
1393
+ output_hidden_states=output_hidden_states,
1394
+ return_dict=return_dict,
1395
+ )
1396
+
1397
+ sequence_output = outputs[0]
1398
+
1399
+ sequence_output = self.dropout(sequence_output)
1400
+ logits = self.classifier(sequence_output)
1401
+
1402
+ loss = None
1403
+ if labels is not None:
1404
+ loss_fct = CrossEntropyLoss()
1405
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1406
+
1407
+ if not return_dict:
1408
+ output = (logits,) + outputs[1:]
1409
+ return ((loss,) + output) if loss is not None else output
1410
+
1411
+ return TokenClassifierOutput(
1412
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
1413
+ )
1414
+
1415
+
1416
+ @add_start_docstrings(
1417
+ """
1418
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1419
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1420
+ """,
1421
+ DEBERTA_START_DOCSTRING,
1422
+ )
1423
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2
1424
+ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
1425
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1426
+
1427
+ def __init__(self, config):
1428
+ super().__init__(config)
1429
+ self.num_labels = config.num_labels
1430
+
1431
+ self.deberta = DebertaV2Model(config)
1432
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1433
+
1434
+ # Initialize weights and apply final processing
1435
+ self.post_init()
1436
+
1437
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1438
+ @add_code_sample_docstrings(
1439
+ processor_class=_TOKENIZER_FOR_DOC,
1440
+ checkpoint=_CHECKPOINT_FOR_DOC,
1441
+ output_type=QuestionAnsweringModelOutput,
1442
+ config_class=_CONFIG_FOR_DOC,
1443
+ )
1444
+ def forward(
1445
+ self,
1446
+ input_ids: Optional[torch.Tensor] = None,
1447
+ attention_mask: Optional[torch.Tensor] = None,
1448
+ token_type_ids: Optional[torch.Tensor] = None,
1449
+ position_ids: Optional[torch.Tensor] = None,
1450
+ inputs_embeds: Optional[torch.Tensor] = None,
1451
+ start_positions: Optional[torch.Tensor] = None,
1452
+ end_positions: Optional[torch.Tensor] = None,
1453
+ output_attentions: Optional[bool] = None,
1454
+ output_hidden_states: Optional[bool] = None,
1455
+ return_dict: Optional[bool] = None,
1456
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1457
+ r"""
1458
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1459
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1460
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1461
+ are not taken into account for computing the loss.
1462
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1463
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1464
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1465
+ are not taken into account for computing the loss.
1466
+ """
1467
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1468
+
1469
+ outputs = self.deberta(
1470
+ input_ids,
1471
+ attention_mask=attention_mask,
1472
+ token_type_ids=token_type_ids,
1473
+ position_ids=position_ids,
1474
+ inputs_embeds=inputs_embeds,
1475
+ output_attentions=output_attentions,
1476
+ output_hidden_states=output_hidden_states,
1477
+ return_dict=return_dict,
1478
+ )
1479
+
1480
+ sequence_output = outputs[0]
1481
+
1482
+ logits = self.qa_outputs(sequence_output)
1483
+ start_logits, end_logits = logits.split(1, dim=-1)
1484
+ start_logits = start_logits.squeeze(-1).contiguous()
1485
+ end_logits = end_logits.squeeze(-1).contiguous()
1486
+
1487
+ total_loss = None
1488
+ if start_positions is not None and end_positions is not None:
1489
+ # If we are on multi-GPU, split add a dimension
1490
+ if len(start_positions.size()) > 1:
1491
+ start_positions = start_positions.squeeze(-1)
1492
+ if len(end_positions.size()) > 1:
1493
+ end_positions = end_positions.squeeze(-1)
1494
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1495
+ ignored_index = start_logits.size(1)
1496
+ start_positions = start_positions.clamp(0, ignored_index)
1497
+ end_positions = end_positions.clamp(0, ignored_index)
1498
+
1499
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1500
+ start_loss = loss_fct(start_logits, start_positions)
1501
+ end_loss = loss_fct(end_logits, end_positions)
1502
+ total_loss = (start_loss + end_loss) / 2
1503
+
1504
+ if not return_dict:
1505
+ output = (start_logits, end_logits) + outputs[1:]
1506
+ return ((total_loss,) + output) if total_loss is not None else output
1507
+
1508
+ return QuestionAnsweringModelOutput(
1509
+ loss=total_loss,
1510
+ start_logits=start_logits,
1511
+ end_logits=end_logits,
1512
+ hidden_states=outputs.hidden_states,
1513
+ attentions=outputs.attentions,
1514
+ )
1515
+
1516
+
1517
+ @add_start_docstrings(
1518
+ """
1519
+ DeBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1520
+ softmax) e.g. for RocStories/SWAG tasks.
1521
+ """,
1522
+ DEBERTA_START_DOCSTRING,
1523
+ )
1524
+ class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel):
1525
+ def __init__(self, config):
1526
+ super().__init__(config)
1527
+
1528
+ num_labels = getattr(config, "num_labels", 2)
1529
+ self.num_labels = num_labels
1530
+
1531
+ self.deberta = DebertaV2Model(config)
1532
+ self.pooler = ContextPooler(config)
1533
+ output_dim = self.pooler.output_dim
1534
+
1535
+ self.classifier = nn.Linear(output_dim, 1)
1536
+ drop_out = getattr(config, "cls_dropout", None)
1537
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1538
+ self.dropout = StableDropout(drop_out)
1539
+
1540
+ self.init_weights()
1541
+
1542
+ def get_input_embeddings(self):
1543
+ return self.deberta.get_input_embeddings()
1544
+
1545
+ def set_input_embeddings(self, new_embeddings):
1546
+ self.deberta.set_input_embeddings(new_embeddings)
1547
+
1548
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1549
+ @add_code_sample_docstrings(
1550
+ processor_class=_TOKENIZER_FOR_DOC,
1551
+ checkpoint=_CHECKPOINT_FOR_DOC,
1552
+ output_type=MultipleChoiceModelOutput,
1553
+ config_class=_CONFIG_FOR_DOC,
1554
+ )
1555
+ def forward(
1556
+ self,
1557
+ input_ids=None,
1558
+ attention_mask=None,
1559
+ token_type_ids=None,
1560
+ position_ids=None,
1561
+ inputs_embeds=None,
1562
+ labels=None,
1563
+ output_attentions=None,
1564
+ output_hidden_states=None,
1565
+ return_dict=None,
1566
+ ):
1567
+ r"""
1568
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1569
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1570
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1571
+ `input_ids` above)
1572
+ """
1573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1574
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1575
+
1576
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1577
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1578
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1579
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1580
+ flat_inputs_embeds = (
1581
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1582
+ if inputs_embeds is not None
1583
+ else None
1584
+ )
1585
+
1586
+ outputs = self.deberta(
1587
+ flat_input_ids,
1588
+ position_ids=flat_position_ids,
1589
+ token_type_ids=flat_token_type_ids,
1590
+ attention_mask=flat_attention_mask,
1591
+ inputs_embeds=flat_inputs_embeds,
1592
+ output_attentions=output_attentions,
1593
+ output_hidden_states=output_hidden_states,
1594
+ return_dict=return_dict,
1595
+ )
1596
+
1597
+ encoder_layer = outputs[0]
1598
+ pooled_output = self.pooler(encoder_layer)
1599
+ pooled_output = self.dropout(pooled_output)
1600
+ logits = self.classifier(pooled_output)
1601
+ reshaped_logits = logits.view(-1, num_choices)
1602
+
1603
+ loss = None
1604
+ if labels is not None:
1605
+ loss_fct = CrossEntropyLoss()
1606
+ loss = loss_fct(reshaped_logits, labels)
1607
+
1608
+ if not return_dict:
1609
+ output = (reshaped_logits,) + outputs[1:]
1610
+ return ((loss,) + output) if loss is not None else output
1611
+
1612
+ return MultipleChoiceModelOutput(
1613
+ loss=loss,
1614
+ logits=reshaped_logits,
1615
+ hidden_states=outputs.hidden_states,
1616
+ attentions=outputs.attentions,
1617
+ )