haowei commited on
Commit
932302a
1 Parent(s): 1e8a310
Files changed (1) hide show
  1. roberta_modeling.py +2195 -0
roberta_modeling.py ADDED
@@ -0,0 +1,2195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa model. Modify the transformers implementation to accept **kwargs."""
17
+
18
+ import math
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from packaging import version
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+ from transformers import RobertaConfig
26
+ from transformers.activations import ACT2FN, gelu
27
+ from transformers.adapters.model_mixin import ModelWithHeadsAdaptersMixin
28
+ from transformers.adapters.models.bert import (
29
+ BertEncoderAdaptersMixin,
30
+ BertLayerAdaptersMixin,
31
+ BertModelAdaptersMixin,
32
+ BertModelHeadsMixin,
33
+ BertOutputAdaptersMixin,
34
+ BertSelfOutputAdaptersMixin,
35
+ )
36
+ from transformers.file_utils import (
37
+ ModelOutput,
38
+ add_code_sample_docstrings,
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ replace_return_docstrings,
42
+ )
43
+ from transformers.modeling_outputs import (
44
+ BaseModelOutputWithPastAndCrossAttentions,
45
+ BaseModelOutputWithPoolingAndCrossAttentions,
46
+ CausalLMOutputWithCrossAttentions,
47
+ MaskedLMOutput,
48
+ MultipleChoiceModelOutput,
49
+ QuestionAnsweringModelOutput,
50
+ SequenceClassifierOutput,
51
+ TokenClassifierOutput,
52
+ )
53
+ from transformers.modeling_utils import (
54
+ PreTrainedModel,
55
+ apply_chunking_to_forward,
56
+ find_pruneable_heads_and_indices,
57
+ prune_linear_layer,
58
+ )
59
+ from transformers.utils import logging
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+ _CHECKPOINT_FOR_DOC = "roberta-base"
64
+ _CONFIG_FOR_DOC = "RobertaConfig"
65
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
66
+
67
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
68
+ "roberta-base",
69
+ "roberta-large",
70
+ "roberta-large-mnli",
71
+ "distilroberta-base",
72
+ "roberta-base-openai-detector",
73
+ "roberta-large-openai-detector",
74
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
75
+ ]
76
+
77
+
78
+ class RobertaEmbeddings(nn.Module):
79
+ """
80
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
81
+ """
82
+
83
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
87
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
88
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
89
+
90
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
91
+ # any TensorFlow checkpoint file
92
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
93
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
94
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
95
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
96
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
97
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
98
+ self.register_buffer(
99
+ "token_type_ids",
100
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
101
+ persistent=False,
102
+ )
103
+
104
+ # End copy
105
+ self.padding_idx = config.pad_token_id
106
+ self.position_embeddings = nn.Embedding(
107
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
108
+ )
109
+
110
+ def forward(
111
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
112
+ ):
113
+ if position_ids is None:
114
+ if input_ids is not None:
115
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
116
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
117
+ else:
118
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
119
+
120
+ if input_ids is not None:
121
+ input_shape = input_ids.size()
122
+ else:
123
+ input_shape = inputs_embeds.size()[:-1]
124
+
125
+ seq_length = input_shape[1]
126
+
127
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
128
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
129
+ # issue #5664
130
+ if token_type_ids is None:
131
+ if hasattr(self, "token_type_ids"):
132
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
133
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
134
+ token_type_ids = buffered_token_type_ids_expanded
135
+ else:
136
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
137
+
138
+ if inputs_embeds is None:
139
+ inputs_embeds = self.word_embeddings(input_ids)
140
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
141
+
142
+ embeddings = inputs_embeds + token_type_embeddings
143
+ if self.position_embedding_type == "absolute":
144
+ position_embeddings = self.position_embeddings(position_ids)
145
+ embeddings += position_embeddings
146
+ embeddings = self.LayerNorm(embeddings)
147
+ embeddings = self.dropout(embeddings)
148
+ return embeddings
149
+
150
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
151
+ """
152
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
153
+
154
+ Args:
155
+ inputs_embeds: torch.Tensor
156
+
157
+ Returns: torch.Tensor
158
+ """
159
+ input_shape = inputs_embeds.size()[:-1]
160
+ sequence_length = input_shape[1]
161
+
162
+ position_ids = torch.arange(
163
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
164
+ )
165
+ return position_ids.unsqueeze(0).expand(input_shape)
166
+
167
+
168
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
169
+ class RobertaSelfAttention(nn.Module):
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
173
+ raise ValueError(
174
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
175
+ f"heads ({config.num_attention_heads})"
176
+ )
177
+
178
+ self.num_attention_heads = config.num_attention_heads
179
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
180
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
181
+
182
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
183
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
184
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
185
+
186
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
187
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
188
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
189
+ self.max_position_embeddings = config.max_position_embeddings
190
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
191
+
192
+ self.is_decoder = config.is_decoder
193
+
194
+ def transpose_for_scores(self, x):
195
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
196
+ x = x.view(*new_x_shape)
197
+ return x.permute(0, 2, 1, 3)
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states,
202
+ attention_mask=None,
203
+ head_mask=None,
204
+ encoder_hidden_states=None,
205
+ encoder_attention_mask=None,
206
+ past_key_value=None,
207
+ output_attentions=False,
208
+ **kwargs,
209
+ ):
210
+ mixed_query_layer = self.query(hidden_states)
211
+
212
+ # If this is instantiated as a cross-attention module, the keys
213
+ # and values come from an encoder; the attention mask needs to be
214
+ # such that the encoder's padding tokens are not attended to.
215
+ is_cross_attention = encoder_hidden_states is not None
216
+
217
+ if is_cross_attention and past_key_value is not None:
218
+ # reuse k,v, cross_attentions
219
+ key_layer = past_key_value[0]
220
+ value_layer = past_key_value[1]
221
+ attention_mask = encoder_attention_mask
222
+ elif is_cross_attention:
223
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
224
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
225
+ attention_mask = encoder_attention_mask
226
+ elif past_key_value is not None:
227
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
228
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
229
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
230
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
231
+ else:
232
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
233
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
234
+
235
+ query_layer = self.transpose_for_scores(mixed_query_layer)
236
+
237
+ if self.is_decoder:
238
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
239
+ # Further calls to cross_attention layer can then reuse all cross-attention
240
+ # key/value_states (first "if" case)
241
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
242
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
243
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
244
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
245
+ past_key_value = (key_layer, value_layer)
246
+
247
+ # Take the dot product between "query" and "key" to get the raw attention scores.
248
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
249
+
250
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
251
+ seq_length = hidden_states.size()[1]
252
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
253
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
254
+ distance = position_ids_l - position_ids_r
255
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
256
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
257
+
258
+ if self.position_embedding_type == "relative_key":
259
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
260
+ attention_scores = attention_scores + relative_position_scores
261
+ elif self.position_embedding_type == "relative_key_query":
262
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
263
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
264
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
265
+
266
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
267
+ if attention_mask is not None:
268
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
269
+ attention_scores = attention_scores + attention_mask
270
+
271
+ # Normalize the attention scores to probabilities.
272
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
273
+
274
+ # This is actually dropping out entire tokens to attend to, which might
275
+ # seem a bit unusual, but is taken from the original Transformer paper.
276
+ attention_probs = self.dropout(attention_probs)
277
+
278
+ # Mask heads if we want to
279
+ if head_mask is not None:
280
+ attention_probs = attention_probs * head_mask
281
+
282
+ context_layer = torch.matmul(attention_probs, value_layer)
283
+
284
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
285
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
286
+ context_layer = context_layer.view(*new_context_layer_shape)
287
+
288
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
289
+
290
+ if self.is_decoder:
291
+ outputs = outputs + (past_key_value,)
292
+ return outputs
293
+
294
+
295
+ # Copied from transformers.models.modeling_bert.BertSelfOutput
296
+ class RobertaSelfOutput(BertSelfOutputAdaptersMixin, nn.Module):
297
+ def __init__(self, config):
298
+ super().__init__()
299
+ self.config = config
300
+
301
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
302
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
303
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
304
+ self._init_adapter_modules()
305
+
306
+ def forward(self, hidden_states, input_tensor, **kwargs):
307
+ hidden_states = self.dense(hidden_states)
308
+ hidden_states = self.dropout(hidden_states)
309
+ hidden_states = self.adapters_forward(hidden_states, input_tensor, **kwargs)
310
+ return hidden_states
311
+
312
+
313
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
314
+ class RobertaAttention(nn.Module):
315
+ def __init__(self, config):
316
+ super().__init__()
317
+ self.self = RobertaSelfAttention(config)
318
+ self.output = RobertaSelfOutput(config)
319
+ self.pruned_heads = set()
320
+
321
+ def prune_heads(self, heads):
322
+ if len(heads) == 0:
323
+ return
324
+ heads, index = find_pruneable_heads_and_indices(
325
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
326
+ )
327
+
328
+ # Prune linear layers
329
+ self.self.query = prune_linear_layer(self.self.query, index)
330
+ self.self.key = prune_linear_layer(self.self.key, index)
331
+ self.self.value = prune_linear_layer(self.self.value, index)
332
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
333
+
334
+ # Update hyper params and store pruned heads
335
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
336
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
337
+ self.pruned_heads = self.pruned_heads.union(heads)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states,
342
+ attention_mask=None,
343
+ head_mask=None,
344
+ encoder_hidden_states=None,
345
+ encoder_attention_mask=None,
346
+ past_key_value=None,
347
+ output_attentions=False,
348
+ **kwargs
349
+ ):
350
+ self_outputs = self.self(
351
+ hidden_states,
352
+ attention_mask,
353
+ head_mask,
354
+ encoder_hidden_states,
355
+ encoder_attention_mask,
356
+ past_key_value,
357
+ output_attentions,
358
+ **kwargs,
359
+ )
360
+ attention_output = self.output(self_outputs[0], hidden_states, **kwargs)
361
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
362
+ return outputs
363
+
364
+
365
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
366
+ class RobertaIntermediate(nn.Module):
367
+ def __init__(self, config):
368
+ super().__init__()
369
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
370
+ if isinstance(config.hidden_act, str):
371
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
372
+ else:
373
+ self.intermediate_act_fn = config.hidden_act
374
+
375
+ def forward(self, hidden_states):
376
+ hidden_states = self.dense(hidden_states)
377
+ hidden_states = self.intermediate_act_fn(hidden_states)
378
+ return hidden_states
379
+
380
+
381
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
382
+ class RobertaOutput(BertOutputAdaptersMixin, nn.Module):
383
+ def __init__(self, config):
384
+ super().__init__()
385
+ self.config = config
386
+
387
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
388
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
389
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
390
+ self._init_adapter_modules()
391
+
392
+ def forward(self, hidden_states, input_tensor, **kwargs):
393
+ hidden_states = self.dense(hidden_states)
394
+ hidden_states = self.dropout(hidden_states)
395
+ hidden_states = self.adapters_forward(hidden_states, input_tensor, **kwargs)
396
+ return hidden_states
397
+
398
+
399
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
400
+ class RobertaLayer(BertLayerAdaptersMixin, nn.Module):
401
+ def __init__(self, config):
402
+ super().__init__()
403
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
404
+ self.seq_len_dim = 1
405
+ self.attention = RobertaAttention(config)
406
+ self.is_decoder = config.is_decoder
407
+ self.add_cross_attention = config.add_cross_attention
408
+ if self.add_cross_attention:
409
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
410
+ self.crossattention = RobertaAttention(config)
411
+ self.intermediate = RobertaIntermediate(config)
412
+ self.output = RobertaOutput(config)
413
+
414
+ def forward(
415
+ self,
416
+ hidden_states,
417
+ attention_mask=None,
418
+ head_mask=None,
419
+ encoder_hidden_states=None,
420
+ encoder_attention_mask=None,
421
+ past_key_value=None,
422
+ output_attentions=False,
423
+ **kwargs
424
+ ):
425
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
426
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
427
+ self_attention_outputs = self.attention(
428
+ hidden_states,
429
+ attention_mask,
430
+ head_mask,
431
+ output_attentions=output_attentions,
432
+ past_key_value=self_attn_past_key_value,
433
+ **kwargs,
434
+ )
435
+ attention_output = self_attention_outputs[0]
436
+
437
+ # if decoder, the last output is tuple of self-attn cache
438
+ if self.is_decoder:
439
+ outputs = self_attention_outputs[1:-1]
440
+ present_key_value = self_attention_outputs[-1]
441
+ else:
442
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
443
+
444
+ cross_attn_present_key_value = None
445
+ if self.is_decoder and encoder_hidden_states is not None:
446
+ assert hasattr(
447
+ self, "crossattention"
448
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
449
+
450
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
451
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
452
+ cross_attention_outputs = self.crossattention(
453
+ attention_output,
454
+ attention_mask,
455
+ head_mask,
456
+ encoder_hidden_states,
457
+ encoder_attention_mask,
458
+ cross_attn_past_key_value,
459
+ output_attentions,
460
+ )
461
+ attention_output = cross_attention_outputs[0]
462
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
463
+
464
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
465
+ cross_attn_present_key_value = cross_attention_outputs[-1]
466
+ present_key_value = present_key_value + cross_attn_present_key_value
467
+
468
+ layer_output = apply_chunking_to_forward(
469
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, **kwargs
470
+ )
471
+ outputs = (layer_output,) + outputs
472
+
473
+ # if decoder, return the attn key/values as the last output
474
+ if self.is_decoder:
475
+ outputs = outputs + (present_key_value,)
476
+
477
+ return outputs
478
+
479
+ def feed_forward_chunk(self, attention_output, **kwargs):
480
+ intermediate_output = self.intermediate(attention_output)
481
+ layer_output = self.output(intermediate_output, attention_output, **kwargs)
482
+ return layer_output
483
+
484
+
485
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
486
+ class RobertaEncoder(BertEncoderAdaptersMixin, nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.config = config
490
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
491
+ self.gradient_checkpointing = False
492
+
493
+ def forward(
494
+ self,
495
+ hidden_states,
496
+ attention_mask=None,
497
+ head_mask=None,
498
+ encoder_hidden_states=None,
499
+ encoder_attention_mask=None,
500
+ past_key_values=None,
501
+ use_cache=None,
502
+ output_attentions=False,
503
+ output_hidden_states=False,
504
+ return_dict=True,
505
+ **kwargs
506
+ ):
507
+ all_hidden_states = () if output_hidden_states else None
508
+ all_self_attentions = () if output_attentions else None
509
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
510
+
511
+ next_decoder_cache = () if use_cache else None
512
+ for i, layer_module in enumerate(self.layer):
513
+ if output_hidden_states:
514
+ all_hidden_states = all_hidden_states + (hidden_states,)
515
+
516
+ layer_head_mask = head_mask[i] if head_mask is not None else None
517
+ past_key_value = past_key_values[i] if past_key_values is not None else None
518
+
519
+ if self.gradient_checkpointing and self.training:
520
+
521
+ if use_cache:
522
+ logger.warning(
523
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
524
+ )
525
+ use_cache = False
526
+
527
+ def create_custom_forward(module):
528
+ def custom_forward(*inputs):
529
+ return module(*inputs, past_key_value, output_attentions)
530
+
531
+ return custom_forward
532
+
533
+ layer_outputs = torch.utils.checkpoint.checkpoint(
534
+ create_custom_forward(layer_module),
535
+ hidden_states,
536
+ attention_mask,
537
+ layer_head_mask,
538
+ encoder_hidden_states,
539
+ encoder_attention_mask,
540
+ )
541
+ else:
542
+ layer_outputs = layer_module(
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ past_key_value,
549
+ output_attentions,
550
+ **kwargs,
551
+ )
552
+
553
+ hidden_states = layer_outputs[0]
554
+ attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask)
555
+
556
+ if use_cache:
557
+ next_decoder_cache += (layer_outputs[-1],)
558
+ if output_attentions:
559
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
560
+ if self.config.add_cross_attention:
561
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
562
+
563
+ if output_hidden_states:
564
+ all_hidden_states = all_hidden_states + (hidden_states,)
565
+
566
+ if not return_dict:
567
+ return tuple(
568
+ v
569
+ for v in [
570
+ hidden_states,
571
+ next_decoder_cache,
572
+ all_hidden_states,
573
+ all_self_attentions,
574
+ all_cross_attentions,
575
+ ]
576
+ if v is not None
577
+ )
578
+ return BaseModelOutputWithPastAndCrossAttentions(
579
+ last_hidden_state=hidden_states,
580
+ past_key_values=next_decoder_cache,
581
+ hidden_states=all_hidden_states,
582
+ attentions=all_self_attentions,
583
+ cross_attentions=all_cross_attentions,
584
+ )
585
+
586
+
587
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
588
+ class RobertaPooler(nn.Module):
589
+ def __init__(self, config):
590
+ super().__init__()
591
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
592
+ self.activation = nn.Tanh()
593
+
594
+ def forward(self, hidden_states):
595
+ # We "pool" the model by simply taking the hidden state corresponding
596
+ # to the first token.
597
+ first_token_tensor = hidden_states[:, 0]
598
+ pooled_output = self.dense(first_token_tensor)
599
+ pooled_output = self.activation(pooled_output)
600
+ return pooled_output
601
+
602
+
603
+ class RobertaPreTrainedModel(PreTrainedModel):
604
+ """
605
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
606
+ models.
607
+ """
608
+
609
+ config_class = RobertaConfig
610
+ base_model_prefix = "roberta"
611
+ supports_gradient_checkpointing = True
612
+
613
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
614
+ def _init_weights(self, module):
615
+ """Initialize the weights"""
616
+ if isinstance(module, nn.Linear):
617
+ # Slightly different from the TF version which uses truncated_normal for initialization
618
+ # cf https://github.com/pytorch/pytorch/pull/5617
619
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
620
+ if module.bias is not None:
621
+ module.bias.data.zero_()
622
+ elif isinstance(module, nn.Embedding):
623
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
624
+ if module.padding_idx is not None:
625
+ module.weight.data[module.padding_idx].zero_()
626
+ elif isinstance(module, nn.LayerNorm):
627
+ module.bias.data.zero_()
628
+ module.weight.data.fill_(1.0)
629
+
630
+ def _set_gradient_checkpointing(self, module, value=False):
631
+ if isinstance(module, RobertaEncoder):
632
+ module.gradient_checkpointing = value
633
+
634
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
635
+ """Remove some keys from ignore list"""
636
+ if not config.tie_word_embeddings:
637
+ # must make a new list, or the class variable gets modified!
638
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
639
+ self._keys_to_ignore_on_load_missing = [
640
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
641
+ ]
642
+
643
+
644
+ ROBERTA_START_DOCSTRING = r"""
645
+
646
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
647
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
648
+ pruning heads etc.)
649
+
650
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
651
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
652
+ general usage and behavior.
653
+
654
+ Parameters:
655
+ config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
656
+ model. Initializing with a config file does not load the weights associated with the model, only the
657
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
658
+ weights.
659
+ """
660
+
661
+ ROBERTA_INPUTS_DOCSTRING = r"""
662
+ Args:
663
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
664
+ Indices of input sequence tokens in the vocabulary.
665
+
666
+ Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See
667
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
668
+ details.
669
+
670
+ `What are input IDs? <../glossary.html#input-ids>`__
671
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
672
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
673
+
674
+ - 1 for tokens that are **not masked**,
675
+ - 0 for tokens that are **masked**.
676
+
677
+ `What are attention masks? <../glossary.html#attention-mask>`__
678
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
679
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
680
+ 1]``:
681
+
682
+ - 0 corresponds to a `sentence A` token,
683
+ - 1 corresponds to a `sentence B` token.
684
+
685
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
686
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
687
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
688
+ config.max_position_embeddings - 1]``.
689
+
690
+ `What are position IDs? <../glossary.html#position-ids>`_
691
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
692
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
693
+
694
+ - 1 indicates the head is **not masked**,
695
+ - 0 indicates the head is **masked**.
696
+
697
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
698
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
699
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
700
+ vectors than the model's internal embedding lookup matrix.
701
+ output_attentions (:obj:`bool`, `optional`):
702
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
703
+ tensors for more detail.
704
+ output_hidden_states (:obj:`bool`, `optional`):
705
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
706
+ more detail.
707
+ return_dict (:obj:`bool`, `optional`):
708
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
709
+ """
710
+
711
+
712
+ @add_start_docstrings(
713
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
714
+ ROBERTA_START_DOCSTRING,
715
+ )
716
+ class RobertaModel(BertModelAdaptersMixin, RobertaPreTrainedModel):
717
+ """
718
+
719
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
720
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
721
+ all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
722
+ Kaiser and Illia Polosukhin.
723
+
724
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
725
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
726
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
727
+ input to the forward pass.
728
+
729
+ .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762
730
+
731
+ """
732
+
733
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
734
+
735
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
736
+ def __init__(self, config, add_pooling_layer=True):
737
+ super().__init__(config)
738
+ self.config = config
739
+
740
+ self.embeddings = RobertaEmbeddings(config)
741
+ self.encoder = RobertaEncoder(config)
742
+
743
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
744
+
745
+ self._init_adapter_modules()
746
+
747
+ self.init_weights()
748
+
749
+ def get_input_embeddings(self):
750
+ return self.embeddings.word_embeddings
751
+
752
+ def set_input_embeddings(self, value):
753
+ self.embeddings.word_embeddings = value
754
+
755
+ def _prune_heads(self, heads_to_prune):
756
+ """
757
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
758
+ class PreTrainedModel
759
+ """
760
+ for layer, heads in heads_to_prune.items():
761
+ self.encoder.layer[layer].attention.prune_heads(heads)
762
+
763
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
764
+ @add_code_sample_docstrings(
765
+ tokenizer_class=_TOKENIZER_FOR_DOC,
766
+ checkpoint=_CHECKPOINT_FOR_DOC,
767
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
768
+ config_class=_CONFIG_FOR_DOC,
769
+ )
770
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
771
+ def forward(
772
+ self,
773
+ input_ids=None,
774
+ attention_mask=None,
775
+ token_type_ids=None,
776
+ position_ids=None,
777
+ head_mask=None,
778
+ inputs_embeds=None,
779
+ encoder_hidden_states=None,
780
+ encoder_attention_mask=None,
781
+ past_key_values=None,
782
+ use_cache=None,
783
+ output_attentions=None,
784
+ output_hidden_states=None,
785
+ return_dict=None,
786
+ **kwargs
787
+ ):
788
+ r"""
789
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
790
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
791
+ the model is configured as a decoder.
792
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
793
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
794
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
795
+
796
+ - 1 for tokens that are **not masked**,
797
+ - 0 for tokens that are **masked**.
798
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
799
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
800
+
801
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
802
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
803
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
804
+ use_cache (:obj:`bool`, `optional`):
805
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
806
+ decoding (see :obj:`past_key_values`).
807
+ """
808
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
809
+ output_hidden_states = (
810
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
811
+ )
812
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
813
+ self.pre_transformer_forward(**kwargs)
814
+
815
+ if self.config.is_decoder:
816
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
817
+ else:
818
+ use_cache = False
819
+
820
+ if input_ids is not None and inputs_embeds is not None:
821
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
822
+ elif input_ids is not None:
823
+ input_shape = input_ids.size()
824
+ elif inputs_embeds is not None:
825
+ input_shape = inputs_embeds.size()[:-1]
826
+ else:
827
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
828
+
829
+ batch_size, seq_length = input_shape
830
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
831
+
832
+ # past_key_values_length
833
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
834
+
835
+ if attention_mask is None:
836
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
837
+
838
+ if token_type_ids is None:
839
+ if hasattr(self.embeddings, "token_type_ids"):
840
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
841
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
842
+ token_type_ids = buffered_token_type_ids_expanded
843
+ else:
844
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
845
+
846
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
847
+ # ourselves in which case we just need to make it broadcastable to all heads.
848
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
849
+
850
+ # If a 2D or 3D attention mask is provided for the cross-attention
851
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
852
+ if self.config.is_decoder and encoder_hidden_states is not None:
853
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
854
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
855
+ if encoder_attention_mask is None:
856
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
857
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
858
+ else:
859
+ encoder_extended_attention_mask = None
860
+
861
+ # Prepare head mask if needed
862
+ # 1.0 in head_mask indicate we keep the head
863
+ # attention_probs has shape bsz x n_heads x N x N
864
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
865
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
866
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ token_type_ids=token_type_ids,
872
+ inputs_embeds=inputs_embeds,
873
+ past_key_values_length=past_key_values_length,
874
+ )
875
+ embedding_output = self.invertible_adapters_forward(embedding_output)
876
+
877
+ encoder_outputs = self.encoder(
878
+ embedding_output,
879
+ attention_mask=extended_attention_mask,
880
+ head_mask=head_mask,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ encoder_attention_mask=encoder_extended_attention_mask,
883
+ past_key_values=past_key_values,
884
+ use_cache=use_cache,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ **kwargs,
889
+ )
890
+ sequence_output = encoder_outputs[0]
891
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
892
+
893
+ if not return_dict:
894
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
895
+
896
+ return BaseModelOutputWithPoolingAndCrossAttentions(
897
+ last_hidden_state=sequence_output,
898
+ pooler_output=pooled_output,
899
+ past_key_values=encoder_outputs.past_key_values,
900
+ hidden_states=encoder_outputs.hidden_states,
901
+ attentions=encoder_outputs.attentions,
902
+ cross_attentions=encoder_outputs.cross_attentions,
903
+ )
904
+
905
+
906
+ @add_start_docstrings(
907
+ """Roberta Model transformer with the option to add multiple flexible heads on top.""",
908
+ ROBERTA_START_DOCSTRING,
909
+ )
910
+ class RobertaModelWithHeads(BertModelHeadsMixin, RobertaPreTrainedModel):
911
+ def __init__(self, config):
912
+ super().__init__(config)
913
+
914
+ self.roberta = RobertaModel(config)
915
+
916
+ self._init_head_modules()
917
+
918
+ self.init_weights()
919
+
920
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
921
+ @add_code_sample_docstrings(
922
+ tokenizer_class=_TOKENIZER_FOR_DOC,
923
+ checkpoint="roberta-base",
924
+ output_type=ModelOutput,
925
+ config_class=_CONFIG_FOR_DOC,
926
+ )
927
+ def forward(
928
+ self,
929
+ input_ids=None,
930
+ attention_mask=None,
931
+ token_type_ids=None,
932
+ position_ids=None,
933
+ head_mask=None,
934
+ inputs_embeds=None,
935
+ output_attentions=None,
936
+ output_hidden_states=None,
937
+ return_dict=None,
938
+ adapter_names=None,
939
+ head=None,
940
+ **kwargs
941
+ ):
942
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
943
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
944
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
945
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
946
+ inputs_embeds = (
947
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
948
+ if inputs_embeds is not None
949
+ else None
950
+ )
951
+
952
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
953
+
954
+ outputs = self.roberta(
955
+ input_ids,
956
+ attention_mask=attention_mask,
957
+ token_type_ids=token_type_ids,
958
+ position_ids=position_ids,
959
+ head_mask=head_mask,
960
+ inputs_embeds=inputs_embeds,
961
+ output_attentions=output_attentions,
962
+ output_hidden_states=output_hidden_states,
963
+ return_dict=return_dict,
964
+ adapter_names=adapter_names,
965
+ )
966
+ # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
967
+ if not return_dict:
968
+ head_inputs = (outputs[0],) + outputs[2:]
969
+ else:
970
+ head_inputs = outputs
971
+ pooled_output = outputs[1]
972
+
973
+ if head or self.active_head:
974
+ head_outputs = self.forward_head(
975
+ head_inputs,
976
+ head_name=head,
977
+ attention_mask=attention_mask,
978
+ return_dict=return_dict,
979
+ pooled_output=pooled_output,
980
+ **kwargs,
981
+ )
982
+ return head_outputs
983
+ else:
984
+ # in case no head is used just return the output of the base model (including pooler output)
985
+ return outputs
986
+
987
+
988
+ @add_start_docstrings(
989
+ """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
990
+ )
991
+ class RobertaForCausalLM(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
992
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
993
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
994
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
995
+
996
+ def __init__(self, config):
997
+ super().__init__(config)
998
+
999
+ if not config.is_decoder:
1000
+ logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
1001
+
1002
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1003
+ self.lm_head = RobertaLMHead(config)
1004
+
1005
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1006
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1007
+
1008
+ self.init_weights()
1009
+
1010
+ def get_output_embeddings(self):
1011
+ return self.lm_head.decoder
1012
+
1013
+ def set_output_embeddings(self, new_embeddings):
1014
+ self.lm_head.decoder = new_embeddings
1015
+
1016
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1017
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1018
+ def forward(
1019
+ self,
1020
+ input_ids=None,
1021
+ attention_mask=None,
1022
+ token_type_ids=None,
1023
+ position_ids=None,
1024
+ head_mask=None,
1025
+ inputs_embeds=None,
1026
+ encoder_hidden_states=None,
1027
+ encoder_attention_mask=None,
1028
+ labels=None,
1029
+ past_key_values=None,
1030
+ use_cache=None,
1031
+ output_attentions=None,
1032
+ output_hidden_states=None,
1033
+ return_dict=None,
1034
+ adapter_names=None,
1035
+ ):
1036
+ r"""
1037
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1038
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1039
+ the model is configured as a decoder.
1040
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1041
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1042
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1043
+
1044
+ - 1 for tokens that are **not masked**,
1045
+ - 0 for tokens that are **masked**.
1046
+
1047
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1048
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1049
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1050
+ ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1051
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1052
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1053
+
1054
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1055
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1056
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1057
+ use_cache (:obj:`bool`, `optional`):
1058
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1059
+ decoding (see :obj:`past_key_values`).
1060
+
1061
+ Returns:
1062
+
1063
+ Example::
1064
+
1065
+ >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig
1066
+ >>> import torch
1067
+
1068
+ >>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
1069
+ >>> config = RobertaConfig.from_pretrained("roberta-base")
1070
+ >>> config.is_decoder = True
1071
+ >>> model = RobertaForCausalLM.from_pretrained('roberta-base', config=config)
1072
+
1073
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1074
+ >>> outputs = model(**inputs)
1075
+
1076
+ >>> prediction_logits = outputs.logits
1077
+ """
1078
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1079
+ if labels is not None:
1080
+ use_cache = False
1081
+
1082
+ outputs = self.roberta(
1083
+ input_ids,
1084
+ attention_mask=attention_mask,
1085
+ token_type_ids=token_type_ids,
1086
+ position_ids=position_ids,
1087
+ head_mask=head_mask,
1088
+ inputs_embeds=inputs_embeds,
1089
+ encoder_hidden_states=encoder_hidden_states,
1090
+ encoder_attention_mask=encoder_attention_mask,
1091
+ past_key_values=past_key_values,
1092
+ use_cache=use_cache,
1093
+ output_attentions=output_attentions,
1094
+ output_hidden_states=output_hidden_states,
1095
+ return_dict=return_dict,
1096
+ adapter_names=adapter_names,
1097
+ )
1098
+
1099
+ sequence_output = outputs[0]
1100
+ prediction_scores = self.lm_head(
1101
+ sequence_output,
1102
+ inv_lang_adapter=self.roberta.get_invertible_adapter(),
1103
+ )
1104
+
1105
+ lm_loss = None
1106
+ if labels is not None:
1107
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1108
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1109
+ labels = labels[:, 1:].contiguous()
1110
+ loss_fct = CrossEntropyLoss()
1111
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1112
+
1113
+ if not return_dict:
1114
+ output = (prediction_scores,) + outputs[2:]
1115
+ return ((lm_loss,) + output) if lm_loss is not None else output
1116
+
1117
+ return CausalLMOutputWithCrossAttentions(
1118
+ loss=lm_loss,
1119
+ logits=prediction_scores,
1120
+ past_key_values=outputs.past_key_values,
1121
+ hidden_states=outputs.hidden_states,
1122
+ attentions=outputs.attentions,
1123
+ cross_attentions=outputs.cross_attentions,
1124
+ )
1125
+
1126
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1127
+ input_shape = input_ids.shape
1128
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1129
+ if attention_mask is None:
1130
+ attention_mask = input_ids.new_ones(input_shape)
1131
+
1132
+ # cut decoder_input_ids if past is used
1133
+ if past is not None:
1134
+ input_ids = input_ids[:, -1:]
1135
+
1136
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1137
+
1138
+ def _reorder_cache(self, past, beam_idx):
1139
+ reordered_past = ()
1140
+ for layer_past in past:
1141
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1142
+ return reordered_past
1143
+
1144
+
1145
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
1146
+ class RobertaForMaskedLM(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
1147
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1148
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1149
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1150
+
1151
+ def __init__(self, config):
1152
+ super().__init__(config)
1153
+
1154
+ if config.is_decoder:
1155
+ logger.warning(
1156
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
1157
+ "bi-directional self-attention."
1158
+ )
1159
+
1160
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1161
+ self.lm_head = RobertaLMHead(config)
1162
+
1163
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1164
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1165
+
1166
+ self.init_weights()
1167
+
1168
+ def get_output_embeddings(self):
1169
+ return self.lm_head.decoder
1170
+
1171
+ def set_output_embeddings(self, new_embeddings):
1172
+ self.lm_head.decoder = new_embeddings
1173
+
1174
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1175
+ @add_code_sample_docstrings(
1176
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1177
+ checkpoint=_CHECKPOINT_FOR_DOC,
1178
+ output_type=MaskedLMOutput,
1179
+ config_class=_CONFIG_FOR_DOC,
1180
+ mask="<mask>",
1181
+ )
1182
+ def forward(
1183
+ self,
1184
+ input_ids=None,
1185
+ attention_mask=None,
1186
+ token_type_ids=None,
1187
+ position_ids=None,
1188
+ head_mask=None,
1189
+ inputs_embeds=None,
1190
+ encoder_hidden_states=None,
1191
+ encoder_attention_mask=None,
1192
+ labels=None,
1193
+ output_attentions=None,
1194
+ output_hidden_states=None,
1195
+ return_dict=None,
1196
+ adapter_names=None,
1197
+ **kwargs,
1198
+ ):
1199
+ r"""
1200
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1201
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1202
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1203
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1204
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1205
+ Used to hide legacy arguments that have been deprecated.
1206
+ """
1207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1208
+
1209
+ outputs = self.roberta(
1210
+ input_ids,
1211
+ attention_mask=attention_mask,
1212
+ token_type_ids=token_type_ids,
1213
+ position_ids=position_ids,
1214
+ head_mask=head_mask,
1215
+ inputs_embeds=inputs_embeds,
1216
+ encoder_hidden_states=encoder_hidden_states,
1217
+ encoder_attention_mask=encoder_attention_mask,
1218
+ output_attentions=output_attentions,
1219
+ output_hidden_states=output_hidden_states,
1220
+ return_dict=return_dict,
1221
+ **kwargs,
1222
+ )
1223
+ sequence_output = outputs[0]
1224
+ prediction_scores = self.lm_head(
1225
+ sequence_output,
1226
+ inv_lang_adapter=self.roberta.get_invertible_adapter(),
1227
+ )
1228
+
1229
+ masked_lm_loss = None
1230
+ if labels is not None:
1231
+ loss_fct = CrossEntropyLoss()
1232
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1233
+
1234
+ if not return_dict:
1235
+ output = (prediction_scores,) + outputs[2:]
1236
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1237
+
1238
+ return MaskedLMOutput(
1239
+ loss=masked_lm_loss,
1240
+ logits=prediction_scores,
1241
+ hidden_states=outputs.hidden_states,
1242
+ attentions=outputs.attentions,
1243
+ )
1244
+
1245
+
1246
+ class RobertaLMHead(nn.Module):
1247
+ """Roberta Head for masked language modeling."""
1248
+
1249
+ def __init__(self, config):
1250
+ super().__init__()
1251
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1252
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1253
+
1254
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1255
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1256
+ self.decoder.bias = self.bias
1257
+
1258
+ def forward(self, features, inv_lang_adapter=None, **kwargs):
1259
+ x = self.dense(features)
1260
+ x = gelu(x)
1261
+ x = self.layer_norm(x)
1262
+
1263
+ if inv_lang_adapter:
1264
+ x = inv_lang_adapter(x, rev=True)
1265
+
1266
+ # project back to size of vocabulary with bias
1267
+ x = self.decoder(x)
1268
+
1269
+ return x
1270
+
1271
+ def _tie_weights(self):
1272
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1273
+ self.bias = self.decoder.bias
1274
+
1275
+
1276
+ @add_start_docstrings(
1277
+ """
1278
+ RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1279
+ pooled output) e.g. for GLUE tasks.
1280
+ """,
1281
+ ROBERTA_START_DOCSTRING,
1282
+ )
1283
+ class RobertaForSequenceClassification(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
1284
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1285
+
1286
+ def __init__(self, config):
1287
+ super().__init__(config)
1288
+ self.num_labels = config.num_labels
1289
+ self.config = config
1290
+
1291
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1292
+ self.classifier = RobertaClassificationHead(config)
1293
+
1294
+ self.init_weights()
1295
+
1296
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1297
+ @add_code_sample_docstrings(
1298
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1299
+ checkpoint=_CHECKPOINT_FOR_DOC,
1300
+ output_type=SequenceClassifierOutput,
1301
+ config_class=_CONFIG_FOR_DOC,
1302
+ )
1303
+ def forward(
1304
+ self,
1305
+ input_ids=None,
1306
+ attention_mask=None,
1307
+ token_type_ids=None,
1308
+ position_ids=None,
1309
+ head_mask=None,
1310
+ inputs_embeds=None,
1311
+ labels=None,
1312
+ output_attentions=None,
1313
+ output_hidden_states=None,
1314
+ return_dict=None,
1315
+ adapter_names=None,
1316
+ **kwargs,
1317
+ ):
1318
+ r"""
1319
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1320
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1321
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1322
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1323
+ """
1324
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1325
+
1326
+ outputs = self.roberta(
1327
+ input_ids,
1328
+ attention_mask=attention_mask,
1329
+ token_type_ids=token_type_ids,
1330
+ position_ids=position_ids,
1331
+ head_mask=head_mask,
1332
+ inputs_embeds=inputs_embeds,
1333
+ output_attentions=output_attentions,
1334
+ output_hidden_states=output_hidden_states,
1335
+ return_dict=return_dict,
1336
+ adapter_names=adapter_names,
1337
+ **kwargs,
1338
+ )
1339
+ sequence_output = outputs[0]
1340
+ logits = self.classifier(sequence_output)
1341
+
1342
+ loss = None
1343
+ if labels is not None:
1344
+ if self.config.problem_type is None:
1345
+ if self.num_labels == 1:
1346
+ self.config.problem_type = "regression"
1347
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1348
+ self.config.problem_type = "single_label_classification"
1349
+ else:
1350
+ self.config.problem_type = "multi_label_classification"
1351
+
1352
+ if self.config.problem_type == "regression":
1353
+ loss_fct = MSELoss()
1354
+ if self.num_labels == 1:
1355
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1356
+ else:
1357
+ loss = loss_fct(logits, labels)
1358
+ elif self.config.problem_type == "single_label_classification":
1359
+ loss_fct = CrossEntropyLoss()
1360
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1361
+ elif self.config.problem_type == "multi_label_classification":
1362
+ loss_fct = BCEWithLogitsLoss()
1363
+ loss = loss_fct(logits, labels)
1364
+
1365
+ if not return_dict:
1366
+ output = (logits,) + outputs[2:]
1367
+ return ((loss,) + output) if loss is not None else output
1368
+
1369
+ return SequenceClassifierOutput(
1370
+ loss=loss,
1371
+ logits=logits,
1372
+ hidden_states=outputs.hidden_states,
1373
+ attentions=outputs.attentions,
1374
+ )
1375
+
1376
+
1377
+ @add_start_docstrings(
1378
+ """
1379
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1380
+ softmax) e.g. for RocStories/SWAG tasks.
1381
+ """,
1382
+ ROBERTA_START_DOCSTRING,
1383
+ )
1384
+ class RobertaForMultipleChoice(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
1385
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1386
+
1387
+ def __init__(self, config):
1388
+ super().__init__(config)
1389
+
1390
+ self.roberta = RobertaModel(config)
1391
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1392
+ self.classifier = nn.Linear(config.hidden_size, 1)
1393
+
1394
+ self.init_weights()
1395
+
1396
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1397
+ @add_code_sample_docstrings(
1398
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1399
+ checkpoint=_CHECKPOINT_FOR_DOC,
1400
+ output_type=MultipleChoiceModelOutput,
1401
+ config_class=_CONFIG_FOR_DOC,
1402
+ )
1403
+ def forward(
1404
+ self,
1405
+ input_ids=None,
1406
+ token_type_ids=None,
1407
+ attention_mask=None,
1408
+ labels=None,
1409
+ position_ids=None,
1410
+ head_mask=None,
1411
+ inputs_embeds=None,
1412
+ output_attentions=None,
1413
+ output_hidden_states=None,
1414
+ return_dict=None,
1415
+ adapter_names=None,
1416
+ ):
1417
+ r"""
1418
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1419
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1420
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1421
+ :obj:`input_ids` above)
1422
+ """
1423
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1424
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1425
+
1426
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1427
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1428
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1429
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1430
+ flat_inputs_embeds = (
1431
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1432
+ if inputs_embeds is not None
1433
+ else None
1434
+ )
1435
+
1436
+ outputs = self.roberta(
1437
+ flat_input_ids,
1438
+ position_ids=flat_position_ids,
1439
+ token_type_ids=flat_token_type_ids,
1440
+ attention_mask=flat_attention_mask,
1441
+ head_mask=head_mask,
1442
+ inputs_embeds=flat_inputs_embeds,
1443
+ output_attentions=output_attentions,
1444
+ output_hidden_states=output_hidden_states,
1445
+ return_dict=return_dict,
1446
+ adapter_names=adapter_names,
1447
+ )
1448
+ pooled_output = outputs[1]
1449
+
1450
+ pooled_output = self.dropout(pooled_output)
1451
+ logits = self.classifier(pooled_output)
1452
+ reshaped_logits = logits.view(-1, num_choices)
1453
+
1454
+ loss = None
1455
+ if labels is not None:
1456
+ loss_fct = CrossEntropyLoss()
1457
+ loss = loss_fct(reshaped_logits, labels)
1458
+
1459
+ if not return_dict:
1460
+ output = (reshaped_logits,) + outputs[2:]
1461
+ return ((loss,) + output) if loss is not None else output
1462
+
1463
+ return MultipleChoiceModelOutput(
1464
+ loss=loss,
1465
+ logits=reshaped_logits,
1466
+ hidden_states=outputs.hidden_states,
1467
+ attentions=outputs.attentions,
1468
+ )
1469
+
1470
+
1471
+ @add_start_docstrings(
1472
+ """
1473
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1474
+ Named-Entity-Recognition (NER) tasks.
1475
+ """,
1476
+ ROBERTA_START_DOCSTRING,
1477
+ )
1478
+ class RobertaForTokenClassification(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
1479
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1480
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1481
+
1482
+ def __init__(self, config):
1483
+ super().__init__(config)
1484
+ self.num_labels = config.num_labels
1485
+
1486
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1487
+ classifier_dropout = (
1488
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1489
+ )
1490
+ self.dropout = nn.Dropout(classifier_dropout)
1491
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1492
+
1493
+ self.init_weights()
1494
+
1495
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1496
+ @add_code_sample_docstrings(
1497
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1498
+ checkpoint=_CHECKPOINT_FOR_DOC,
1499
+ output_type=TokenClassifierOutput,
1500
+ config_class=_CONFIG_FOR_DOC,
1501
+ )
1502
+ def forward(
1503
+ self,
1504
+ input_ids=None,
1505
+ attention_mask=None,
1506
+ token_type_ids=None,
1507
+ position_ids=None,
1508
+ head_mask=None,
1509
+ inputs_embeds=None,
1510
+ labels=None,
1511
+ output_attentions=None,
1512
+ output_hidden_states=None,
1513
+ return_dict=None,
1514
+ adapter_names=None,
1515
+ ):
1516
+ r"""
1517
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1518
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1519
+ 1]``.
1520
+ """
1521
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1522
+
1523
+ outputs = self.roberta(
1524
+ input_ids,
1525
+ attention_mask=attention_mask,
1526
+ token_type_ids=token_type_ids,
1527
+ position_ids=position_ids,
1528
+ head_mask=head_mask,
1529
+ inputs_embeds=inputs_embeds,
1530
+ output_attentions=output_attentions,
1531
+ output_hidden_states=output_hidden_states,
1532
+ return_dict=return_dict,
1533
+ adapter_names=adapter_names,
1534
+ )
1535
+
1536
+ sequence_output = outputs[0]
1537
+
1538
+ sequence_output = self.dropout(sequence_output)
1539
+ logits = self.classifier(sequence_output)
1540
+
1541
+ loss = None
1542
+ if labels is not None:
1543
+ loss_fct = CrossEntropyLoss()
1544
+ # Only keep active parts of the loss
1545
+ if attention_mask is not None:
1546
+ active_loss = attention_mask.view(-1) == 1
1547
+ active_logits = logits.view(-1, self.num_labels)
1548
+ active_labels = torch.where(
1549
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1550
+ )
1551
+ loss = loss_fct(active_logits, active_labels)
1552
+ else:
1553
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1554
+
1555
+ if not return_dict:
1556
+ output = (logits,) + outputs[2:]
1557
+ return ((loss,) + output) if loss is not None else output
1558
+
1559
+ return TokenClassifierOutput(
1560
+ loss=loss,
1561
+ logits=logits,
1562
+ hidden_states=outputs.hidden_states,
1563
+ attentions=outputs.attentions,
1564
+ )
1565
+
1566
+
1567
+ class RobertaClassificationHead(nn.Module):
1568
+ """Head for sentence-level classification tasks."""
1569
+
1570
+ def __init__(self, config):
1571
+ super().__init__()
1572
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1573
+ classifier_dropout = (
1574
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1575
+ )
1576
+ self.dropout = nn.Dropout(classifier_dropout)
1577
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1578
+
1579
+ def forward(self, features, **kwargs):
1580
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1581
+ x = self.dropout(x)
1582
+ x = self.dense(x)
1583
+ x = torch.tanh(x)
1584
+ x = self.dropout(x)
1585
+ x = self.out_proj(x)
1586
+ return x
1587
+
1588
+
1589
+ @add_start_docstrings(
1590
+ """
1591
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1592
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1593
+ """,
1594
+ ROBERTA_START_DOCSTRING,
1595
+ )
1596
+ class RobertaForQuestionAnswering(ModelWithHeadsAdaptersMixin, RobertaPreTrainedModel):
1597
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1598
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1599
+
1600
+ def __init__(self, config):
1601
+ super().__init__(config)
1602
+ self.num_labels = config.num_labels
1603
+
1604
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1605
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1606
+
1607
+ self.init_weights()
1608
+
1609
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1610
+ @add_code_sample_docstrings(
1611
+ tokenizer_class=_TOKENIZER_FOR_DOC,
1612
+ checkpoint=_CHECKPOINT_FOR_DOC,
1613
+ output_type=QuestionAnsweringModelOutput,
1614
+ config_class=_CONFIG_FOR_DOC,
1615
+ )
1616
+ def forward(
1617
+ self,
1618
+ input_ids=None,
1619
+ attention_mask=None,
1620
+ token_type_ids=None,
1621
+ position_ids=None,
1622
+ head_mask=None,
1623
+ inputs_embeds=None,
1624
+ start_positions=None,
1625
+ end_positions=None,
1626
+ output_attentions=None,
1627
+ output_hidden_states=None,
1628
+ return_dict=None,
1629
+ adapter_names=None,
1630
+ ):
1631
+ r"""
1632
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1633
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1634
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1635
+ sequence are not taken into account for computing the loss.
1636
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1637
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1638
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1639
+ sequence are not taken into account for computing the loss.
1640
+ """
1641
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1642
+
1643
+ outputs = self.roberta(
1644
+ input_ids,
1645
+ attention_mask=attention_mask,
1646
+ token_type_ids=token_type_ids,
1647
+ position_ids=position_ids,
1648
+ head_mask=head_mask,
1649
+ inputs_embeds=inputs_embeds,
1650
+ output_attentions=output_attentions,
1651
+ output_hidden_states=output_hidden_states,
1652
+ return_dict=return_dict,
1653
+ )
1654
+
1655
+ sequence_output = outputs[0]
1656
+
1657
+ logits = self.qa_outputs(sequence_output)
1658
+ start_logits, end_logits = logits.split(1, dim=-1)
1659
+ start_logits = start_logits.squeeze(-1).contiguous()
1660
+ end_logits = end_logits.squeeze(-1).contiguous()
1661
+
1662
+ total_loss = None
1663
+ if start_positions is not None and end_positions is not None:
1664
+ # If we are on multi-GPU, split add a dimension
1665
+ if len(start_positions.size()) > 1:
1666
+ start_positions = start_positions.squeeze(-1)
1667
+ if len(end_positions.size()) > 1:
1668
+ end_positions = end_positions.squeeze(-1)
1669
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1670
+ ignored_index = start_logits.size(1)
1671
+ start_positions = start_positions.clamp(0, ignored_index)
1672
+ end_positions = end_positions.clamp(0, ignored_index)
1673
+
1674
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1675
+ start_loss = loss_fct(start_logits, start_positions)
1676
+ end_loss = loss_fct(end_logits, end_positions)
1677
+ total_loss = (start_loss + end_loss) / 2
1678
+
1679
+ if not return_dict:
1680
+ output = (start_logits, end_logits) + outputs[2:]
1681
+ return ((total_loss,) + output) if total_loss is not None else output
1682
+
1683
+ return QuestionAnsweringModelOutput(
1684
+ loss=total_loss,
1685
+ start_logits=start_logits,
1686
+ end_logits=end_logits,
1687
+ hidden_states=outputs.hidden_states,
1688
+ attentions=outputs.attentions,
1689
+ )
1690
+
1691
+
1692
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1693
+ """
1694
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1695
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1696
+
1697
+ Args:
1698
+ x: torch.Tensor x:
1699
+
1700
+ Returns: torch.Tensor
1701
+ """
1702
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1703
+ mask = input_ids.ne(padding_idx).int()
1704
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1705
+ return incremental_indices.long() + padding_idx
1706
+
1707
+ from dataclasses import dataclass
1708
+ from typing import Union, Callable
1709
+
1710
+ import torch.nn as nn
1711
+
1712
+
1713
+ @dataclass
1714
+ class AdapterMaskConfig:
1715
+ hidden_size: int
1716
+ adapter_size: int
1717
+ ffn_adapter_size: int
1718
+ attn_adapter_size: int
1719
+ adapter_act: Union[str, Callable]
1720
+ adapter_initializer_range: float
1721
+ ntasks: int
1722
+ smax: int
1723
+ mode: str = "sequential" # "sequential" / "parallel"
1724
+
1725
+ def __post_init__(self):
1726
+ if self.mode not in ("sequential", "parallel"):
1727
+ raise NotImplementedError(f"The current mode {self.mode} is not supported!")
1728
+
1729
+
1730
+ def freeze_all_parameters(model: nn.Module) -> nn.Module:
1731
+ for param in model.parameters():
1732
+ param.requires_grad = False
1733
+ return model
1734
+
1735
+ """Roberta model with CPT CL-plugins."""
1736
+ import math
1737
+ from copy import deepcopy
1738
+
1739
+ import torch
1740
+ import torch.nn as nn
1741
+ from transformers import BertModel
1742
+ from transformers.models.bert.modeling_bert import BertSelfOutput
1743
+ from transformers.models.roberta.modeling_roberta import RobertaSelfAttention
1744
+
1745
+
1746
+ class RobertaAdapter(nn.Module):
1747
+ def __init__(self, config: AdapterMaskConfig):
1748
+ super().__init__()
1749
+ self.fc1 = torch.nn.Linear(config.hidden_size, config.adapter_size)
1750
+ self.fc2 = torch.nn.Linear(config.adapter_size, config.hidden_size)
1751
+ self.activation = torch.nn.ReLU()
1752
+
1753
+ def forward(self, x):
1754
+ h = self.activation(self.fc1(x))
1755
+ h = self.activation(self.fc2(h))
1756
+ return x + h
1757
+ # return h
1758
+
1759
+
1760
+ class RobertaAdapterMask(RobertaAdapter):
1761
+ def __init__(self, config: AdapterMaskConfig):
1762
+ super().__init__(config)
1763
+ self.efc1 = torch.nn.Embedding(config.ntasks, config.adapter_size)
1764
+ self.efc2 = torch.nn.Embedding(config.ntasks, config.hidden_size)
1765
+ self.gate = torch.nn.Sigmoid()
1766
+ self.config = config
1767
+ self.smax = config.smax
1768
+
1769
+ def forward(self, x, t, s, smax=400, add_residual=True, residual=None):
1770
+ residual = x if residual is None else residual
1771
+ gfc1, gfc2 = self.mask(t=t, s=s)
1772
+ h = self.get_feature(gfc1, gfc2, x)
1773
+ if add_residual:
1774
+ output = residual + h
1775
+ else:
1776
+ output = h
1777
+
1778
+ return output
1779
+
1780
+ def get_feature(self, gfc1, gfc2, x):
1781
+ h = self.activation(self.fc1(x))
1782
+ h = h * gfc1.expand_as(h)
1783
+
1784
+ h = self.activation(self.fc2(h))
1785
+ h = h * gfc2.expand_as(h)
1786
+
1787
+ return h
1788
+
1789
+ def mask(self, t: torch.LongTensor, s: int = None):
1790
+
1791
+ efc1 = self.efc1(t)
1792
+ efc2 = self.efc2(t)
1793
+
1794
+ gfc1 = self.gate(s * efc1)
1795
+ gfc2 = self.gate(s * efc2)
1796
+
1797
+ if s == self.smax: # you want to use it for evluation
1798
+ gfc1 = (gfc1 > 0.5).float()
1799
+ gfc2 = (gfc2 > 0.5).float()
1800
+
1801
+ return [gfc1, gfc2]
1802
+
1803
+
1804
+ class RobertaAdaptedSelfOutput(nn.Module):
1805
+ def __init__(self,
1806
+ self_output: BertSelfOutput,
1807
+ config: AdapterMaskConfig):
1808
+ super(RobertaAdaptedSelfOutput, self).__init__()
1809
+ self.self_output = self_output
1810
+ self.adapter_mask = RobertaAdapterMask(config)
1811
+ self.mode = config.mode
1812
+
1813
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, t, s, **kwargs):
1814
+ if self.mode == "sequential":
1815
+ hidden_states = self.self_output.dense(hidden_states)
1816
+ hidden_states = self.self_output.dropout(hidden_states)
1817
+ hidden_states = self.adapter_mask(hidden_states, t, s)
1818
+ elif self.mode == "parallel":
1819
+ adapter_change = self.adapter_mask(input_tensor, t, s)
1820
+ hidden_states = self.self_output.dense(hidden_states)
1821
+ hidden_states = self.self_output.dropout(hidden_states)
1822
+ hidden_states = hidden_states + adapter_change
1823
+ hidden_states = self.self_output.LayerNorm(hidden_states + input_tensor)
1824
+ return hidden_states
1825
+
1826
+
1827
+ class RobertaAdaptedSelfAttention(nn.Module):
1828
+ """For parallel adapter."""
1829
+
1830
+ def __init__(self,
1831
+ self_attn: RobertaSelfAttention,
1832
+ config: AdapterMaskConfig):
1833
+ super(RobertaAdaptedSelfAttention, self).__init__()
1834
+ if config.mode != "parallel":
1835
+ raise ValueError("This class is tailored for parallel adapter!")
1836
+ self.self_attn = self_attn
1837
+ self.adapter_mask = RobertaAdapterMask(config)
1838
+
1839
+ def forward(
1840
+ self,
1841
+ hidden_states,
1842
+ attention_mask=None,
1843
+ head_mask=None,
1844
+ encoder_hidden_states=None,
1845
+ encoder_attention_mask=None,
1846
+ past_key_value=None,
1847
+ output_attentions=False,
1848
+ t=None,
1849
+ s=None,
1850
+ **kwargs,
1851
+ ):
1852
+ mixed_query_layer = self.self_attn.query(hidden_states)
1853
+
1854
+ # If this is instantiated as a cross-attention module, the keys
1855
+ # and values come from an encoder; the attention mask needs to be
1856
+ # such that the encoder's padding tokens are not attended to.
1857
+ is_cross_attention = encoder_hidden_states is not None
1858
+
1859
+ if is_cross_attention and past_key_value is not None:
1860
+ # reuse k,v, cross_attentions
1861
+ key_layer = past_key_value[0]
1862
+ value_layer = past_key_value[1]
1863
+ attention_mask = encoder_attention_mask
1864
+ elif is_cross_attention:
1865
+ key_layer = self.self_attn.transpose_for_scores(self.self_attn.key(encoder_hidden_states))
1866
+ value_layer = self.self_attn.transpose_for_scores(self.self_attn.value(encoder_hidden_states))
1867
+ attention_mask = encoder_attention_mask
1868
+ elif past_key_value is not None:
1869
+ key_layer = self.self_attn.transpose_for_scores(self.self_attn.key(hidden_states))
1870
+ value_layer = self.self_attn.transpose_for_scores(self.self_attn.value(hidden_states))
1871
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
1872
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
1873
+ else:
1874
+ key_layer = self.self_attn.transpose_for_scores(self.self_attn.key(hidden_states))
1875
+ value_layer = self.self_attn.transpose_for_scores(self.self_attn.value(hidden_states))
1876
+
1877
+ query_layer = self.self_attn.transpose_for_scores(mixed_query_layer)
1878
+
1879
+ if self.self_attn.is_decoder:
1880
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
1881
+ # Further calls to cross_attention layer can then reuse all cross-attention
1882
+ # key/value_states (first "if" case)
1883
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
1884
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
1885
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
1886
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
1887
+ past_key_value = (key_layer, value_layer)
1888
+
1889
+ cross_attn_output = self.adapter_mask(hidden_states, t=t, s=s, add_residual=False) # For parallel adapter.
1890
+
1891
+ # Take the dot product between "query" and "key" to get the raw attention scores.
1892
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
1893
+
1894
+ if self.self_attn.position_embedding_type == "relative_key" or self.self_attn.position_embedding_type == "relative_key_query":
1895
+ seq_length = hidden_states.size()[1]
1896
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
1897
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
1898
+ distance = position_ids_l - position_ids_r
1899
+ positional_embedding = self.self_attn.distance_embedding(
1900
+ distance + self.self_attn.max_position_embeddings - 1)
1901
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
1902
+
1903
+ if self.self_attn.position_embedding_type == "relative_key":
1904
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
1905
+ attention_scores = attention_scores + relative_position_scores
1906
+ elif self.self_attn.position_embedding_type == "relative_key_query":
1907
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
1908
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
1909
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
1910
+
1911
+ attention_scores = attention_scores / math.sqrt(self.self_attn.attention_head_size)
1912
+ if attention_mask is not None:
1913
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
1914
+ attention_scores = attention_scores + attention_mask
1915
+
1916
+ # Normalize the attention scores to probabilities.
1917
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
1918
+
1919
+ # This is actually dropping out entire tokens to attend to, which might
1920
+ # seem a bit unusual, but is taken from the original Transformer paper.
1921
+ attention_probs = self.self_attn.dropout(attention_probs)
1922
+
1923
+ # Mask heads if we want to
1924
+ if head_mask is not None:
1925
+ attention_probs = attention_probs * head_mask
1926
+
1927
+ context_layer = torch.matmul(attention_probs, value_layer)
1928
+
1929
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
1930
+ new_context_layer_shape = context_layer.size()[:-2] + (self.self_attn.all_head_size,)
1931
+ context_layer = context_layer.view(*new_context_layer_shape)
1932
+
1933
+ context_layer = context_layer + cross_attn_output # For parallel adapter.
1934
+
1935
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
1936
+
1937
+ if self.self_attn.is_decoder:
1938
+ outputs = outputs + (past_key_value,)
1939
+ return outputs
1940
+
1941
+
1942
+ def adapt_roberta_self_output(config: AdapterMaskConfig):
1943
+ return lambda self_output: RobertaAdaptedSelfOutput(self_output, config=config)
1944
+
1945
+
1946
+ def adapt_roberta_self_attn(config: AdapterMaskConfig):
1947
+ return lambda self_attn: RobertaAdaptedSelfAttention(self_attn, config=config)
1948
+
1949
+
1950
+ def add_roberta_adapters(roberta_model: BertModel, config: AdapterMaskConfig) -> BertModel:
1951
+ attn_config, ffn_config = deepcopy(config), deepcopy(config)
1952
+ attn_config.adapter_size = attn_config.attn_adapter_size
1953
+ ffn_config.adapter_size = ffn_config.ffn_adapter_size
1954
+
1955
+ if config.mode == "sequential":
1956
+ for layer in roberta_model.encoder.layer:
1957
+ layer.attention.output = adapt_roberta_self_output(
1958
+ attn_config)(layer.attention.output)
1959
+ layer.output = adapt_roberta_self_output(ffn_config)(layer.output)
1960
+ elif config.mode == "parallel":
1961
+ for layer in roberta_model.encoder.layer:
1962
+ layer.attention.self = adapt_roberta_self_attn(attn_config)(layer.attention.self)
1963
+ layer.output = adapt_roberta_self_output(ffn_config)(layer.output)
1964
+ return roberta_model
1965
+
1966
+
1967
+ def unfreeze_roberta_adapters(roberta_model: nn.Module) -> nn.Module:
1968
+ # Unfreeze trainable parts — layer norms and adapters
1969
+ for name, sub_module in roberta_model.named_modules():
1970
+ if isinstance(sub_module, (RobertaAdapter, nn.LayerNorm)):
1971
+ for param_name, param in sub_module.named_parameters():
1972
+ param.requires_grad = True
1973
+ return roberta_model
1974
+
1975
+
1976
+ def load_roberta_adapter_model(
1977
+ roberta_model: nn.Module,
1978
+ checkpoint: str = None,
1979
+ mode: str = "sequential",
1980
+ attn_adapter_size: int = 200,
1981
+ ffn_adapter_size: int = 512,
1982
+ ntasks: int = 5):
1983
+ adapter_config = AdapterMaskConfig(
1984
+ hidden_size=768,
1985
+ adapter_size=-1, # Deprecated.
1986
+ adapter_act='relu',
1987
+ adapter_initializer_range=1e-2,
1988
+ ntasks=ntasks,
1989
+ smax=400,
1990
+ mode=mode,
1991
+ attn_adapter_size=attn_adapter_size,
1992
+ ffn_adapter_size=ffn_adapter_size,
1993
+ )
1994
+ roberta_model.roberta = add_roberta_adapters(
1995
+ roberta_model.roberta, adapter_config)
1996
+
1997
+ # freeze the bert model, unfreeze adapter
1998
+ roberta_model.roberta = freeze_all_parameters(roberta_model.roberta)
1999
+ roberta_model.roberta = unfreeze_roberta_adapters(roberta_model.roberta)
2000
+
2001
+ if checkpoint is not None and checkpoint != 'None':
2002
+ print("loading checkpoint...")
2003
+ model_dict = roberta_model.state_dict()
2004
+ pretrained_dict = torch.load(checkpoint, map_location='cpu')
2005
+ new_dict = {k: v for k, v in pretrained_dict.items()
2006
+ if k in model_dict.keys()}
2007
+ model_dict.update(new_dict)
2008
+ print('Total : {} params are loaded.'.format(len(pretrained_dict)))
2009
+ roberta_model.load_state_dict(model_dict)
2010
+ print("loaded finished!")
2011
+ else:
2012
+ print('No checkpoint is included')
2013
+ return roberta_model
2014
+
2015
+
2016
+ def save_roberta_adapter_model(roberta_model: nn.Module, save_path: str, accelerator=None):
2017
+ model_dict = {k: v for k, v in roberta_model.state_dict().items()
2018
+ if 'adapter' in k}
2019
+ if accelerator is not None:
2020
+ accelerator.save(model_dict, save_path)
2021
+ else:
2022
+ torch.save(model_dict, save_path)
2023
+
2024
+
2025
+ def forward(self, t, input_ids, segment_ids, input_mask, s=None):
2026
+ output_dict = {}
2027
+
2028
+ sequence_output, pooled_output = \
2029
+ self.bert(input_ids=input_ids, token_type_ids=segment_ids,
2030
+ attention_mask=input_mask, t=t, s=s)
2031
+ masks = self.mask(t, s)
2032
+ pooled_output = self.dropout(pooled_output)
2033
+
2034
+ y = self.last(sequence_output)
2035
+ output_dict['y'] = y
2036
+ output_dict['masks'] = masks
2037
+ return output_dict
2038
+
2039
+
2040
+ def forward_cls(self, t, input_ids, segment_ids, input_mask, start_mixup=None, s=None, l=None, idx=None, mix_type=None):
2041
+ output_dict = {}
2042
+
2043
+ sequence_output, pooled_output = \
2044
+ self.bert(input_ids=input_ids, token_type_ids=segment_ids,
2045
+ attention_mask=input_mask, t=t, s=s)
2046
+ masks = self.mask(t, s)
2047
+ pooled_output = self.dropout(pooled_output)
2048
+
2049
+ y = self.last_cls(pooled_output)
2050
+ output_dict['y'] = y
2051
+ output_dict['masks'] = masks
2052
+ return output_dict
2053
+
2054
+
2055
+ def mask(roberta_model, t, s, adapter_type="sequential"):
2056
+ masks = {}
2057
+ for layer_id in range(len(roberta_model.roberta.encoder.layer)):
2058
+ if adapter_type == "sequential":
2059
+ fc1_key = 'roberta.encoder.layer.' + \
2060
+ str(layer_id) + '.attention.output.adapter_mask.fc1' # gfc1
2061
+ fc2_key = 'roberta.encoder.layer.' + \
2062
+ str(layer_id) + '.attention.output.adapter_mask.fc2' # gfc2
2063
+
2064
+ masks[fc1_key], masks[fc2_key] = roberta_model.roberta.encoder.layer[
2065
+ layer_id].attention.output.adapter_mask.mask(
2066
+ t, s)
2067
+ else:
2068
+ fc1_key = 'roberta.encoder.layer.' + \
2069
+ str(layer_id) + '.attention.self.adapter_mask.fc1' # gfc1
2070
+ fc2_key = 'roberta.encoder.layer.' + \
2071
+ str(layer_id) + '.attention.self.adapter_mask.fc2' # gfc2
2072
+
2073
+ masks[fc1_key], masks[fc2_key] = roberta_model.roberta.encoder.layer[
2074
+ layer_id].attention.self.adapter_mask.mask(
2075
+ t, s)
2076
+
2077
+ fc1_key = 'roberta.encoder.layer.' + \
2078
+ str(layer_id) + '.output.adapter_mask.fc1' # gfc1
2079
+ fc2_key = 'roberta.encoder.layer.' + \
2080
+ str(layer_id) + '.output.adapter_mask.fc2' # gfc2
2081
+
2082
+ masks[fc1_key], masks[fc2_key] = roberta_model.roberta.encoder.layer[layer_id].output.adapter_mask.mask(
2083
+ t, s)
2084
+
2085
+ return masks
2086
+
2087
+
2088
+ def get_view_for(model, n, p, masks):
2089
+ for layer_id in range(12):
2090
+ if n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.output.adapter_mask.fc1.weight':
2091
+ return masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2092
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.output.adapter_mask.fc1.bias':
2093
+ return masks[n.replace('.bias', '')].data.view(-1)
2094
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.output.adapter_mask.fc2.weight':
2095
+ post = masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2096
+ pre = masks[n.replace('.weight', '').replace('fc2', 'fc1')].data.view(1, -1).expand_as(p)
2097
+ return torch.min(post, pre)
2098
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.output.adapter_mask.fc2.bias':
2099
+ return masks[n.replace('.bias', '')].data.view(-1)
2100
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.output.adapter_mask.fc1.weight':
2101
+ # print('not nont')
2102
+ return masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2103
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.output.adapter_mask.fc1.bias':
2104
+ return masks[n.replace('.bias', '')].data.view(-1)
2105
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.output.adapter_mask.fc2.weight':
2106
+ post = masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2107
+ pre = masks[n.replace('.weight', '').replace('fc2', 'fc1')].data.view(1, -1).expand_as(p)
2108
+ return torch.min(post, pre)
2109
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.output.adapter_mask.fc2.bias':
2110
+ return masks[n.replace('.bias', '')].data.view(-1)
2111
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.self.adapter_mask.fc1.weight':
2112
+ return masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2113
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.self.adapter_mask.fc1.bias':
2114
+ return masks[n.replace('.bias', '')].data.view(-1)
2115
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.self.adapter_mask.fc2.weight':
2116
+ post = masks[n.replace('.weight', '')].data.view(-1, 1).expand_as(p)
2117
+ pre = masks[n.replace('.weight', '').replace('fc2', 'fc1')].data.view(1, -1).expand_as(p)
2118
+ return torch.min(post, pre)
2119
+ elif n == 'roberta.encoder.layer.' + str(layer_id) + '.attention.self.adapter_mask.fc2.bias':
2120
+ return masks[n.replace('.bias', '')].data.view(-1)
2121
+ return None
2122
+
2123
+ import os
2124
+ import pdb
2125
+ from pathlib import Path
2126
+
2127
+ import torch
2128
+ import torch.nn as nn
2129
+ import sys
2130
+
2131
+ class RobertaMaskBasedModel:
2132
+
2133
+ def forward(
2134
+ self,
2135
+ input_ids=None,
2136
+ past_key_values=None,
2137
+ attention_mask=None,
2138
+ token_type_ids=None,
2139
+ position_ids=None,
2140
+ head_mask=None,
2141
+ inputs_embeds=None,
2142
+ encoder_hidden_states=None,
2143
+ encoder_attention_mask=None,
2144
+ labels=None,
2145
+ use_cache=None,
2146
+ output_attentions=None,
2147
+ output_hidden_states=None,
2148
+ return_dict=None,
2149
+ for_end_task=False,
2150
+ use_prompt=True,
2151
+ **kwargs
2152
+ ):
2153
+ # Drop most of the args for now
2154
+ outputs = super().forward(
2155
+ attention_mask=attention_mask,
2156
+ input_ids=input_ids,
2157
+ labels=labels,
2158
+ return_dict=return_dict,
2159
+ **kwargs
2160
+ )
2161
+ return outputs
2162
+
2163
+
2164
+ class RobertaMaskForMaskedLM(RobertaMaskBasedModel, RobertaForMaskedLM):
2165
+ def __init__(self, config):
2166
+ super().__init__(config)
2167
+ adapter_config = AdapterMaskConfig(
2168
+ hidden_size=768,
2169
+ adapter_size=-1, # Deprecated.
2170
+ adapter_act='relu',
2171
+ adapter_initializer_range=1e-2,
2172
+ ntasks=config.adapter_task,
2173
+ smax=config.smax,
2174
+ mode=config.adapter_mode,
2175
+ attn_adapter_size=config.attn_adapter_size,
2176
+ ffn_adapter_size=config.ffn_adapter_size,
2177
+ )
2178
+ self.roberta = add_roberta_adapters(self.roberta, adapter_config)
2179
+
2180
+ class RobertaMaskForSequenceClassification(RobertaMaskBasedModel, RobertaForSequenceClassification):
2181
+ def __init__(self, config):
2182
+ super().__init__(config)
2183
+ adapter_config = AdapterMaskConfig(
2184
+ hidden_size=768,
2185
+ adapter_size=-1, # Deprecated.
2186
+ adapter_act='relu',
2187
+ adapter_initializer_range=1e-2,
2188
+ ntasks=config.adapter_task,
2189
+ smax=config.smax,
2190
+ mode=config.adapter_mode,
2191
+ attn_adapter_size=config.attn_adapter_size,
2192
+ ffn_adapter_size=config.ffn_adapter_size,
2193
+ )
2194
+ self.roberta = add_roberta_adapters(self.roberta, adapter_config)
2195
+