florentgbelidji HF staff commited on
Commit
d188349
1 Parent(s): 72ddfac

Create med.py

Browse files
Files changed (1) hide show
  1. med.py +953 -0
med.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ if position_ids is None:
82
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ if self.position_embedding_type == "absolute":
90
+ position_embeddings = self.position_embeddings(position_ids)
91
+ embeddings += position_embeddings
92
+ embeddings = self.LayerNorm(embeddings)
93
+ embeddings = self.dropout(embeddings)
94
+ return embeddings
95
+
96
+
97
+ class BertSelfAttention(nn.Module):
98
+ def __init__(self, config, is_cross_attention):
99
+ super().__init__()
100
+ self.config = config
101
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
102
+ raise ValueError(
103
+ "The hidden size (%d) is not a multiple of the number of attention "
104
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
105
+ )
106
+
107
+ self.num_attention_heads = config.num_attention_heads
108
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
109
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
110
+
111
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
112
+ if is_cross_attention:
113
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
114
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
115
+ else:
116
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
117
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
118
+
119
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
120
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
121
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
122
+ self.max_position_embeddings = config.max_position_embeddings
123
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
124
+ self.save_attention = False
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def save_attention_map(self, attention_map):
133
+ self.attention_map = attention_map
134
+
135
+ def get_attention_map(self):
136
+ return self.attention_map
137
+
138
+ def transpose_for_scores(self, x):
139
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
140
+ x = x.view(*new_x_shape)
141
+ return x.permute(0, 2, 1, 3)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states,
146
+ attention_mask=None,
147
+ head_mask=None,
148
+ encoder_hidden_states=None,
149
+ encoder_attention_mask=None,
150
+ past_key_value=None,
151
+ output_attentions=False,
152
+ ):
153
+ mixed_query_layer = self.query(hidden_states)
154
+
155
+ # If this is instantiated as a cross-attention module, the keys
156
+ # and values come from an encoder; the attention mask needs to be
157
+ # such that the encoder's padding tokens are not attended to.
158
+ is_cross_attention = encoder_hidden_states is not None
159
+
160
+ if is_cross_attention:
161
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
162
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
163
+ attention_mask = encoder_attention_mask
164
+ elif past_key_value is not None:
165
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
166
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
167
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
168
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
169
+ else:
170
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
171
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
172
+
173
+ query_layer = self.transpose_for_scores(mixed_query_layer)
174
+
175
+ past_key_value = (key_layer, value_layer)
176
+
177
+ # Take the dot product between "query" and "key" to get the raw attention scores.
178
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
179
+
180
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
181
+ seq_length = hidden_states.size()[1]
182
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
183
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
184
+ distance = position_ids_l - position_ids_r
185
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
186
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
187
+
188
+ if self.position_embedding_type == "relative_key":
189
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
190
+ attention_scores = attention_scores + relative_position_scores
191
+ elif self.position_embedding_type == "relative_key_query":
192
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
193
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
194
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
195
+
196
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
197
+ if attention_mask is not None:
198
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
199
+ attention_scores = attention_scores + attention_mask
200
+
201
+ # Normalize the attention scores to probabilities.
202
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
203
+
204
+ if is_cross_attention and self.save_attention:
205
+ self.save_attention_map(attention_probs)
206
+ attention_probs.register_hook(self.save_attn_gradients)
207
+
208
+ # This is actually dropping out entire tokens to attend to, which might
209
+ # seem a bit unusual, but is taken from the original Transformer paper.
210
+ attention_probs_dropped = self.dropout(attention_probs)
211
+
212
+ # Mask heads if we want to
213
+ if head_mask is not None:
214
+ attention_probs_dropped = attention_probs_dropped * head_mask
215
+
216
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
217
+
218
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
219
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
220
+ context_layer = context_layer.view(*new_context_layer_shape)
221
+
222
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
223
+
224
+ outputs = outputs + (past_key_value,)
225
+ return outputs
226
+
227
+
228
+ class BertSelfOutput(nn.Module):
229
+ def __init__(self, config):
230
+ super().__init__()
231
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
232
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
233
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ hidden_states = self.dense(hidden_states)
237
+ hidden_states = self.dropout(hidden_states)
238
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
239
+ return hidden_states
240
+
241
+
242
+ class BertAttention(nn.Module):
243
+ def __init__(self, config, is_cross_attention=False):
244
+ super().__init__()
245
+ self.self = BertSelfAttention(config, is_cross_attention)
246
+ self.output = BertSelfOutput(config)
247
+ self.pruned_heads = set()
248
+
249
+ def prune_heads(self, heads):
250
+ if len(heads) == 0:
251
+ return
252
+ heads, index = find_pruneable_heads_and_indices(
253
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
254
+ )
255
+
256
+ # Prune linear layers
257
+ self.self.query = prune_linear_layer(self.self.query, index)
258
+ self.self.key = prune_linear_layer(self.self.key, index)
259
+ self.self.value = prune_linear_layer(self.self.value, index)
260
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
261
+
262
+ # Update hyper params and store pruned heads
263
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
264
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
265
+ self.pruned_heads = self.pruned_heads.union(heads)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states,
270
+ attention_mask=None,
271
+ head_mask=None,
272
+ encoder_hidden_states=None,
273
+ encoder_attention_mask=None,
274
+ past_key_value=None,
275
+ output_attentions=False,
276
+ ):
277
+ self_outputs = self.self(
278
+ hidden_states,
279
+ attention_mask,
280
+ head_mask,
281
+ encoder_hidden_states,
282
+ encoder_attention_mask,
283
+ past_key_value,
284
+ output_attentions,
285
+ )
286
+ attention_output = self.output(self_outputs[0], hidden_states)
287
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
288
+ return outputs
289
+
290
+
291
+ class BertIntermediate(nn.Module):
292
+ def __init__(self, config):
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
295
+ if isinstance(config.hidden_act, str):
296
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
297
+ else:
298
+ self.intermediate_act_fn = config.hidden_act
299
+
300
+ def forward(self, hidden_states):
301
+ hidden_states = self.dense(hidden_states)
302
+ hidden_states = self.intermediate_act_fn(hidden_states)
303
+ return hidden_states
304
+
305
+
306
+ class BertOutput(nn.Module):
307
+ def __init__(self, config):
308
+ super().__init__()
309
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
310
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+
313
+ def forward(self, hidden_states, input_tensor):
314
+ hidden_states = self.dense(hidden_states)
315
+ hidden_states = self.dropout(hidden_states)
316
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
317
+ return hidden_states
318
+
319
+
320
+ class BertLayer(nn.Module):
321
+ def __init__(self, config, layer_num):
322
+ super().__init__()
323
+ self.config = config
324
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
325
+ self.seq_len_dim = 1
326
+ self.attention = BertAttention(config)
327
+ self.layer_num = layer_num
328
+ if self.config.add_cross_attention:
329
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
330
+ self.intermediate = BertIntermediate(config)
331
+ self.output = BertOutput(config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states,
336
+ attention_mask=None,
337
+ head_mask=None,
338
+ encoder_hidden_states=None,
339
+ encoder_attention_mask=None,
340
+ past_key_value=None,
341
+ output_attentions=False,
342
+ mode=None,
343
+ ):
344
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
345
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
346
+ self_attention_outputs = self.attention(
347
+ hidden_states,
348
+ attention_mask,
349
+ head_mask,
350
+ output_attentions=output_attentions,
351
+ past_key_value=self_attn_past_key_value,
352
+ )
353
+ attention_output = self_attention_outputs[0]
354
+
355
+ outputs = self_attention_outputs[1:-1]
356
+ present_key_value = self_attention_outputs[-1]
357
+
358
+ if mode=='multimodal':
359
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
360
+
361
+ cross_attention_outputs = self.crossattention(
362
+ attention_output,
363
+ attention_mask,
364
+ head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ output_attentions=output_attentions,
368
+ )
369
+ attention_output = cross_attention_outputs[0]
370
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
371
+ layer_output = apply_chunking_to_forward(
372
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
373
+ )
374
+ outputs = (layer_output,) + outputs
375
+
376
+ outputs = outputs + (present_key_value,)
377
+
378
+ return outputs
379
+
380
+ def feed_forward_chunk(self, attention_output):
381
+ intermediate_output = self.intermediate(attention_output)
382
+ layer_output = self.output(intermediate_output, attention_output)
383
+ return layer_output
384
+
385
+
386
+ class BertEncoder(nn.Module):
387
+ def __init__(self, config):
388
+ super().__init__()
389
+ self.config = config
390
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
391
+ self.gradient_checkpointing = False
392
+
393
+ def forward(
394
+ self,
395
+ hidden_states,
396
+ attention_mask=None,
397
+ head_mask=None,
398
+ encoder_hidden_states=None,
399
+ encoder_attention_mask=None,
400
+ past_key_values=None,
401
+ use_cache=None,
402
+ output_attentions=False,
403
+ output_hidden_states=False,
404
+ return_dict=True,
405
+ mode='multimodal',
406
+ ):
407
+ all_hidden_states = () if output_hidden_states else None
408
+ all_self_attentions = () if output_attentions else None
409
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
410
+
411
+ next_decoder_cache = () if use_cache else None
412
+
413
+ for i in range(self.config.num_hidden_layers):
414
+ layer_module = self.layer[i]
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ layer_head_mask = head_mask[i] if head_mask is not None else None
419
+ past_key_value = past_key_values[i] if past_key_values is not None else None
420
+
421
+ if self.gradient_checkpointing and self.training:
422
+
423
+ if use_cache:
424
+ logger.warn(
425
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
426
+ )
427
+ use_cache = False
428
+
429
+ def create_custom_forward(module):
430
+ def custom_forward(*inputs):
431
+ return module(*inputs, past_key_value, output_attentions)
432
+
433
+ return custom_forward
434
+
435
+ layer_outputs = torch.utils.checkpoint.checkpoint(
436
+ create_custom_forward(layer_module),
437
+ hidden_states,
438
+ attention_mask,
439
+ layer_head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ mode=mode,
443
+ )
444
+ else:
445
+ layer_outputs = layer_module(
446
+ hidden_states,
447
+ attention_mask,
448
+ layer_head_mask,
449
+ encoder_hidden_states,
450
+ encoder_attention_mask,
451
+ past_key_value,
452
+ output_attentions,
453
+ mode=mode,
454
+ )
455
+
456
+ hidden_states = layer_outputs[0]
457
+ if use_cache:
458
+ next_decoder_cache += (layer_outputs[-1],)
459
+ if output_attentions:
460
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
461
+
462
+ if output_hidden_states:
463
+ all_hidden_states = all_hidden_states + (hidden_states,)
464
+
465
+ if not return_dict:
466
+ return tuple(
467
+ v
468
+ for v in [
469
+ hidden_states,
470
+ next_decoder_cache,
471
+ all_hidden_states,
472
+ all_self_attentions,
473
+ all_cross_attentions,
474
+ ]
475
+ if v is not None
476
+ )
477
+ return BaseModelOutputWithPastAndCrossAttentions(
478
+ last_hidden_state=hidden_states,
479
+ past_key_values=next_decoder_cache,
480
+ hidden_states=all_hidden_states,
481
+ attentions=all_self_attentions,
482
+ cross_attentions=all_cross_attentions,
483
+ )
484
+
485
+
486
+ class BertPooler(nn.Module):
487
+ def __init__(self, config):
488
+ super().__init__()
489
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
490
+ self.activation = nn.Tanh()
491
+
492
+ def forward(self, hidden_states):
493
+ # We "pool" the model by simply taking the hidden state corresponding
494
+ # to the first token.
495
+ first_token_tensor = hidden_states[:, 0]
496
+ pooled_output = self.dense(first_token_tensor)
497
+ pooled_output = self.activation(pooled_output)
498
+ return pooled_output
499
+
500
+
501
+ class BertPredictionHeadTransform(nn.Module):
502
+ def __init__(self, config):
503
+ super().__init__()
504
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
505
+ if isinstance(config.hidden_act, str):
506
+ self.transform_act_fn = ACT2FN[config.hidden_act]
507
+ else:
508
+ self.transform_act_fn = config.hidden_act
509
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
510
+
511
+ def forward(self, hidden_states):
512
+ hidden_states = self.dense(hidden_states)
513
+ hidden_states = self.transform_act_fn(hidden_states)
514
+ hidden_states = self.LayerNorm(hidden_states)
515
+ return hidden_states
516
+
517
+
518
+ class BertLMPredictionHead(nn.Module):
519
+ def __init__(self, config):
520
+ super().__init__()
521
+ self.transform = BertPredictionHeadTransform(config)
522
+
523
+ # The output weights are the same as the input embeddings, but there is
524
+ # an output-only bias for each token.
525
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
526
+
527
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
528
+
529
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
530
+ self.decoder.bias = self.bias
531
+
532
+ def forward(self, hidden_states):
533
+ hidden_states = self.transform(hidden_states)
534
+ hidden_states = self.decoder(hidden_states)
535
+ return hidden_states
536
+
537
+
538
+ class BertOnlyMLMHead(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.predictions = BertLMPredictionHead(config)
542
+
543
+ def forward(self, sequence_output):
544
+ prediction_scores = self.predictions(sequence_output)
545
+ return prediction_scores
546
+
547
+
548
+ class BertPreTrainedModel(PreTrainedModel):
549
+ """
550
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
551
+ models.
552
+ """
553
+
554
+ config_class = BertConfig
555
+ base_model_prefix = "bert"
556
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
557
+
558
+ def _init_weights(self, module):
559
+ """ Initialize the weights """
560
+ if isinstance(module, (nn.Linear, nn.Embedding)):
561
+ # Slightly different from the TF version which uses truncated_normal for initialization
562
+ # cf https://github.com/pytorch/pytorch/pull/5617
563
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
564
+ elif isinstance(module, nn.LayerNorm):
565
+ module.bias.data.zero_()
566
+ module.weight.data.fill_(1.0)
567
+ if isinstance(module, nn.Linear) and module.bias is not None:
568
+ module.bias.data.zero_()
569
+
570
+
571
+ class BertModel(BertPreTrainedModel):
572
+ """
573
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
574
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
575
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
576
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
577
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
578
+ input to the forward pass.
579
+ """
580
+
581
+ def __init__(self, config, add_pooling_layer=True):
582
+ super().__init__(config)
583
+ self.config = config
584
+
585
+ self.embeddings = BertEmbeddings(config)
586
+
587
+ self.encoder = BertEncoder(config)
588
+
589
+ self.pooler = BertPooler(config) if add_pooling_layer else None
590
+
591
+ self.init_weights()
592
+
593
+
594
+ def get_input_embeddings(self):
595
+ return self.embeddings.word_embeddings
596
+
597
+ def set_input_embeddings(self, value):
598
+ self.embeddings.word_embeddings = value
599
+
600
+ def _prune_heads(self, heads_to_prune):
601
+ """
602
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
603
+ class PreTrainedModel
604
+ """
605
+ for layer, heads in heads_to_prune.items():
606
+ self.encoder.layer[layer].attention.prune_heads(heads)
607
+
608
+
609
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
610
+ """
611
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
612
+ Arguments:
613
+ attention_mask (:obj:`torch.Tensor`):
614
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
615
+ input_shape (:obj:`Tuple[int]`):
616
+ The shape of the input to the model.
617
+ device: (:obj:`torch.device`):
618
+ The device of the input to the model.
619
+ Returns:
620
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
621
+ """
622
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
623
+ # ourselves in which case we just need to make it broadcastable to all heads.
624
+ if attention_mask.dim() == 3:
625
+ extended_attention_mask = attention_mask[:, None, :, :]
626
+ elif attention_mask.dim() == 2:
627
+ # Provided a padding mask of dimensions [batch_size, seq_length]
628
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
629
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
630
+ if is_decoder:
631
+ batch_size, seq_length = input_shape
632
+
633
+ seq_ids = torch.arange(seq_length, device=device)
634
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
635
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
636
+ # causal and attention masks must have same type with pytorch version < 1.3
637
+ causal_mask = causal_mask.to(attention_mask.dtype)
638
+
639
+ if causal_mask.shape[1] < attention_mask.shape[1]:
640
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
641
+ causal_mask = torch.cat(
642
+ [
643
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
644
+ causal_mask,
645
+ ],
646
+ axis=-1,
647
+ )
648
+
649
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
650
+ else:
651
+ extended_attention_mask = attention_mask[:, None, None, :]
652
+ else:
653
+ raise ValueError(
654
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
655
+ input_shape, attention_mask.shape
656
+ )
657
+ )
658
+
659
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
660
+ # masked positions, this operation will create a tensor which is 0.0 for
661
+ # positions we want to attend and -10000.0 for masked positions.
662
+ # Since we are adding it to the raw scores before the softmax, this is
663
+ # effectively the same as removing these entirely.
664
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
665
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
666
+ return extended_attention_mask
667
+
668
+ def forward(
669
+ self,
670
+ input_ids=None,
671
+ attention_mask=None,
672
+ position_ids=None,
673
+ head_mask=None,
674
+ inputs_embeds=None,
675
+ encoder_embeds=None,
676
+ encoder_hidden_states=None,
677
+ encoder_attention_mask=None,
678
+ past_key_values=None,
679
+ use_cache=None,
680
+ output_attentions=None,
681
+ output_hidden_states=None,
682
+ return_dict=None,
683
+ is_decoder=False,
684
+ mode='multimodal',
685
+ ):
686
+ r"""
687
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
688
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
689
+ the model is configured as a decoder.
690
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
691
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
692
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
693
+ - 1 for tokens that are **not masked**,
694
+ - 0 for tokens that are **masked**.
695
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
696
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
697
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
698
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
699
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
700
+ use_cache (:obj:`bool`, `optional`):
701
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
702
+ decoding (see :obj:`past_key_values`).
703
+ """
704
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
705
+ output_hidden_states = (
706
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
707
+ )
708
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
709
+
710
+ if is_decoder:
711
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
712
+ else:
713
+ use_cache = False
714
+
715
+ if input_ids is not None and inputs_embeds is not None:
716
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
717
+ elif input_ids is not None:
718
+ input_shape = input_ids.size()
719
+ batch_size, seq_length = input_shape
720
+ device = input_ids.device
721
+ elif inputs_embeds is not None:
722
+ input_shape = inputs_embeds.size()[:-1]
723
+ batch_size, seq_length = input_shape
724
+ device = inputs_embeds.device
725
+ elif encoder_embeds is not None:
726
+ input_shape = encoder_embeds.size()[:-1]
727
+ batch_size, seq_length = input_shape
728
+ device = encoder_embeds.device
729
+ else:
730
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
731
+
732
+ # past_key_values_length
733
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
734
+
735
+ if attention_mask is None:
736
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
737
+
738
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
739
+ # ourselves in which case we just need to make it broadcastable to all heads.
740
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
741
+ device, is_decoder)
742
+
743
+ # If a 2D or 3D attention mask is provided for the cross-attention
744
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
745
+ if encoder_hidden_states is not None:
746
+ if type(encoder_hidden_states) == list:
747
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
748
+ else:
749
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
750
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
751
+
752
+ if type(encoder_attention_mask) == list:
753
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
754
+ elif encoder_attention_mask is None:
755
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
756
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
757
+ else:
758
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
759
+ else:
760
+ encoder_extended_attention_mask = None
761
+
762
+ # Prepare head mask if needed
763
+ # 1.0 in head_mask indicate we keep the head
764
+ # attention_probs has shape bsz x n_heads x N x N
765
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
766
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
767
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
768
+
769
+ if encoder_embeds is None:
770
+ embedding_output = self.embeddings(
771
+ input_ids=input_ids,
772
+ position_ids=position_ids,
773
+ inputs_embeds=inputs_embeds,
774
+ past_key_values_length=past_key_values_length,
775
+ )
776
+ else:
777
+ embedding_output = encoder_embeds
778
+
779
+ encoder_outputs = self.encoder(
780
+ embedding_output,
781
+ attention_mask=extended_attention_mask,
782
+ head_mask=head_mask,
783
+ encoder_hidden_states=encoder_hidden_states,
784
+ encoder_attention_mask=encoder_extended_attention_mask,
785
+ past_key_values=past_key_values,
786
+ use_cache=use_cache,
787
+ output_attentions=output_attentions,
788
+ output_hidden_states=output_hidden_states,
789
+ return_dict=return_dict,
790
+ mode=mode,
791
+ )
792
+ sequence_output = encoder_outputs[0]
793
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
794
+
795
+ if not return_dict:
796
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
797
+
798
+ return BaseModelOutputWithPoolingAndCrossAttentions(
799
+ last_hidden_state=sequence_output,
800
+ pooler_output=pooled_output,
801
+ past_key_values=encoder_outputs.past_key_values,
802
+ hidden_states=encoder_outputs.hidden_states,
803
+ attentions=encoder_outputs.attentions,
804
+ cross_attentions=encoder_outputs.cross_attentions,
805
+ )
806
+
807
+
808
+
809
+ class BertLMHeadModel(BertPreTrainedModel):
810
+
811
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
812
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
813
+
814
+ def __init__(self, config):
815
+ super().__init__(config)
816
+
817
+ self.bert = BertModel(config, add_pooling_layer=False)
818
+ self.cls = BertOnlyMLMHead(config)
819
+
820
+ self.init_weights()
821
+
822
+ def get_output_embeddings(self):
823
+ return self.cls.predictions.decoder
824
+
825
+ def set_output_embeddings(self, new_embeddings):
826
+ self.cls.predictions.decoder = new_embeddings
827
+
828
+ def forward(
829
+ self,
830
+ input_ids=None,
831
+ attention_mask=None,
832
+ position_ids=None,
833
+ head_mask=None,
834
+ inputs_embeds=None,
835
+ encoder_hidden_states=None,
836
+ encoder_attention_mask=None,
837
+ labels=None,
838
+ past_key_values=None,
839
+ use_cache=None,
840
+ output_attentions=None,
841
+ output_hidden_states=None,
842
+ return_dict=None,
843
+ return_logits=False,
844
+ is_decoder=True,
845
+ reduction='mean',
846
+ mode='multimodal',
847
+ ):
848
+ r"""
849
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
850
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
851
+ the model is configured as a decoder.
852
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
853
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
854
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
855
+ - 1 for tokens that are **not masked**,
856
+ - 0 for tokens that are **masked**.
857
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
858
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
859
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
860
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
861
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
862
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
863
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
864
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
865
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
866
+ use_cache (:obj:`bool`, `optional`):
867
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
868
+ decoding (see :obj:`past_key_values`).
869
+ Returns:
870
+ Example::
871
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
872
+ >>> import torch
873
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
874
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
875
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
876
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
877
+ >>> outputs = model(**inputs)
878
+ >>> prediction_logits = outputs.logits
879
+ """
880
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
881
+ if labels is not None:
882
+ use_cache = False
883
+
884
+ outputs = self.bert(
885
+ input_ids,
886
+ attention_mask=attention_mask,
887
+ position_ids=position_ids,
888
+ head_mask=head_mask,
889
+ inputs_embeds=inputs_embeds,
890
+ encoder_hidden_states=encoder_hidden_states,
891
+ encoder_attention_mask=encoder_attention_mask,
892
+ past_key_values=past_key_values,
893
+ use_cache=use_cache,
894
+ output_attentions=output_attentions,
895
+ output_hidden_states=output_hidden_states,
896
+ return_dict=return_dict,
897
+ is_decoder=is_decoder,
898
+ mode=mode,
899
+ )
900
+
901
+ sequence_output = outputs[0]
902
+ prediction_scores = self.cls(sequence_output)
903
+
904
+ if return_logits:
905
+ return prediction_scores[:, :-1, :].contiguous()
906
+
907
+ lm_loss = None
908
+ if labels is not None:
909
+ # we are doing next-token prediction; shift prediction scores and input ids by one
910
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
911
+ labels = labels[:, 1:].contiguous()
912
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
913
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
914
+ if reduction=='none':
915
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
916
+
917
+ if not return_dict:
918
+ output = (prediction_scores,) + outputs[2:]
919
+ return ((lm_loss,) + output) if lm_loss is not None else output
920
+
921
+ return CausalLMOutputWithCrossAttentions(
922
+ loss=lm_loss,
923
+ logits=prediction_scores,
924
+ past_key_values=outputs.past_key_values,
925
+ hidden_states=outputs.hidden_states,
926
+ attentions=outputs.attentions,
927
+ cross_attentions=outputs.cross_attentions,
928
+ )
929
+
930
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
931
+ input_shape = input_ids.shape
932
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
933
+ if attention_mask is None:
934
+ attention_mask = input_ids.new_ones(input_shape)
935
+
936
+ # cut decoder_input_ids if past is used
937
+ if past is not None:
938
+ input_ids = input_ids[:, -1:]
939
+
940
+ return {
941
+ "input_ids": input_ids,
942
+ "attention_mask": attention_mask,
943
+ "past_key_values": past,
944
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
945
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
946
+ "is_decoder": True,
947
+ }
948
+
949
+ def _reorder_cache(self, past, beam_idx):
950
+ reordered_past = ()
951
+ for layer_past in past:
952
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
953
+ return reordered_past