doevent commited on
Commit
21522d5
1 Parent(s): 6ca9770

Upload models/nlvr_encoder.py

Browse files
Files changed (1) hide show
  1. models/nlvr_encoder.py +843 -0
models/nlvr_encoder.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ from torch import Tensor, device, dtype, nn
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ import torch.nn.functional as F
13
+
14
+ from transformers.activations import ACT2FN
15
+ from transformers.file_utils import (
16
+ ModelOutput,
17
+ )
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPastAndCrossAttentions,
20
+ BaseModelOutputWithPoolingAndCrossAttentions,
21
+ CausalLMOutputWithCrossAttentions,
22
+ MaskedLMOutput,
23
+ MultipleChoiceModelOutput,
24
+ NextSentencePredictorOutput,
25
+ QuestionAnsweringModelOutput,
26
+ SequenceClassifierOutput,
27
+ TokenClassifierOutput,
28
+ )
29
+ from transformers.modeling_utils import (
30
+ PreTrainedModel,
31
+ apply_chunking_to_forward,
32
+ find_pruneable_heads_and_indices,
33
+ prune_linear_layer,
34
+ )
35
+ from transformers.utils import logging
36
+ from transformers.models.bert.configuration_bert import BertConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class BertEmbeddings(nn.Module):
43
+ """Construct the embeddings from word and position embeddings."""
44
+
45
+ def __init__(self, config):
46
+ super().__init__()
47
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
48
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
49
+
50
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
51
+ # any TensorFlow checkpoint file
52
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
53
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
54
+
55
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
56
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
57
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
58
+
59
+ self.config = config
60
+
61
+ def forward(
62
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
63
+ ):
64
+ if input_ids is not None:
65
+ input_shape = input_ids.size()
66
+ else:
67
+ input_shape = inputs_embeds.size()[:-1]
68
+
69
+ seq_length = input_shape[1]
70
+
71
+ if position_ids is None:
72
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
73
+
74
+ if inputs_embeds is None:
75
+ inputs_embeds = self.word_embeddings(input_ids)
76
+
77
+ embeddings = inputs_embeds
78
+
79
+ if self.position_embedding_type == "absolute":
80
+ position_embeddings = self.position_embeddings(position_ids)
81
+ embeddings += position_embeddings
82
+ embeddings = self.LayerNorm(embeddings)
83
+ embeddings = self.dropout(embeddings)
84
+ return embeddings
85
+
86
+
87
+ class BertSelfAttention(nn.Module):
88
+ def __init__(self, config, is_cross_attention):
89
+ super().__init__()
90
+ self.config = config
91
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
92
+ raise ValueError(
93
+ "The hidden size (%d) is not a multiple of the number of attention "
94
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
95
+ )
96
+
97
+ self.num_attention_heads = config.num_attention_heads
98
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
99
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
100
+
101
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
102
+ if is_cross_attention:
103
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
104
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
105
+ else:
106
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
107
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
108
+
109
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
110
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
111
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
112
+ self.max_position_embeddings = config.max_position_embeddings
113
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
114
+ self.save_attention = False
115
+
116
+ def save_attn_gradients(self, attn_gradients):
117
+ self.attn_gradients = attn_gradients
118
+
119
+ def get_attn_gradients(self):
120
+ return self.attn_gradients
121
+
122
+ def save_attention_map(self, attention_map):
123
+ self.attention_map = attention_map
124
+
125
+ def get_attention_map(self):
126
+ return self.attention_map
127
+
128
+ def transpose_for_scores(self, x):
129
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
130
+ x = x.view(*new_x_shape)
131
+ return x.permute(0, 2, 1, 3)
132
+
133
+ def forward(
134
+ self,
135
+ hidden_states,
136
+ attention_mask=None,
137
+ head_mask=None,
138
+ encoder_hidden_states=None,
139
+ encoder_attention_mask=None,
140
+ past_key_value=None,
141
+ output_attentions=False,
142
+ ):
143
+ mixed_query_layer = self.query(hidden_states)
144
+
145
+ # If this is instantiated as a cross-attention module, the keys
146
+ # and values come from an encoder; the attention mask needs to be
147
+ # such that the encoder's padding tokens are not attended to.
148
+ is_cross_attention = encoder_hidden_states is not None
149
+
150
+ if is_cross_attention:
151
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
152
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
153
+ attention_mask = encoder_attention_mask
154
+ elif past_key_value is not None:
155
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
156
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
157
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
158
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
159
+ else:
160
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
161
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
162
+
163
+ query_layer = self.transpose_for_scores(mixed_query_layer)
164
+
165
+ past_key_value = (key_layer, value_layer)
166
+
167
+ # Take the dot product between "query" and "key" to get the raw attention scores.
168
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
169
+
170
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
171
+ seq_length = hidden_states.size()[1]
172
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
173
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
174
+ distance = position_ids_l - position_ids_r
175
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
176
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
177
+
178
+ if self.position_embedding_type == "relative_key":
179
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
180
+ attention_scores = attention_scores + relative_position_scores
181
+ elif self.position_embedding_type == "relative_key_query":
182
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
183
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
184
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
185
+
186
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
187
+ if attention_mask is not None:
188
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
189
+ attention_scores = attention_scores + attention_mask
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
193
+
194
+ if is_cross_attention and self.save_attention:
195
+ self.save_attention_map(attention_probs)
196
+ attention_probs.register_hook(self.save_attn_gradients)
197
+
198
+ # This is actually dropping out entire tokens to attend to, which might
199
+ # seem a bit unusual, but is taken from the original Transformer paper.
200
+ attention_probs_dropped = self.dropout(attention_probs)
201
+
202
+ # Mask heads if we want to
203
+ if head_mask is not None:
204
+ attention_probs_dropped = attention_probs_dropped * head_mask
205
+
206
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
207
+
208
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
209
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
210
+ context_layer = context_layer.view(*new_context_layer_shape)
211
+
212
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
213
+
214
+ outputs = outputs + (past_key_value,)
215
+ return outputs
216
+
217
+
218
+ class BertSelfOutput(nn.Module):
219
+ def __init__(self, config, twin=False, merge=False):
220
+ super().__init__()
221
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+ if twin:
224
+ self.dense0 = nn.Linear(config.hidden_size, config.hidden_size)
225
+ self.dense1 = nn.Linear(config.hidden_size, config.hidden_size)
226
+ else:
227
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
228
+ if merge:
229
+ self.act = ACT2FN[config.hidden_act]
230
+ self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size)
231
+ self.merge = True
232
+ else:
233
+ self.merge = False
234
+
235
+ def forward(self, hidden_states, input_tensor):
236
+ if type(hidden_states) == list:
237
+ hidden_states0 = self.dense0(hidden_states[0])
238
+ hidden_states1 = self.dense1(hidden_states[1])
239
+ if self.merge:
240
+ #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1)))
241
+ hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1))
242
+ else:
243
+ hidden_states = (hidden_states0+hidden_states1)/2
244
+ else:
245
+ hidden_states = self.dense(hidden_states)
246
+ hidden_states = self.dropout(hidden_states)
247
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
248
+ return hidden_states
249
+
250
+
251
+ class BertAttention(nn.Module):
252
+ def __init__(self, config, is_cross_attention=False, layer_num=-1):
253
+ super().__init__()
254
+ if is_cross_attention:
255
+ self.self0 = BertSelfAttention(config, is_cross_attention)
256
+ self.self1 = BertSelfAttention(config, is_cross_attention)
257
+ else:
258
+ self.self = BertSelfAttention(config, is_cross_attention)
259
+ self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6))
260
+ self.pruned_heads = set()
261
+
262
+ def prune_heads(self, heads):
263
+ if len(heads) == 0:
264
+ return
265
+ heads, index = find_pruneable_heads_and_indices(
266
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
267
+ )
268
+
269
+ # Prune linear layers
270
+ self.self.query = prune_linear_layer(self.self.query, index)
271
+ self.self.key = prune_linear_layer(self.self.key, index)
272
+ self.self.value = prune_linear_layer(self.self.value, index)
273
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
274
+
275
+ # Update hyper params and store pruned heads
276
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
277
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
278
+ self.pruned_heads = self.pruned_heads.union(heads)
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states,
283
+ attention_mask=None,
284
+ head_mask=None,
285
+ encoder_hidden_states=None,
286
+ encoder_attention_mask=None,
287
+ past_key_value=None,
288
+ output_attentions=False,
289
+ ):
290
+ if type(encoder_hidden_states)==list:
291
+ self_outputs0 = self.self0(
292
+ hidden_states,
293
+ attention_mask,
294
+ head_mask,
295
+ encoder_hidden_states[0],
296
+ encoder_attention_mask[0],
297
+ past_key_value,
298
+ output_attentions,
299
+ )
300
+ self_outputs1 = self.self1(
301
+ hidden_states,
302
+ attention_mask,
303
+ head_mask,
304
+ encoder_hidden_states[1],
305
+ encoder_attention_mask[1],
306
+ past_key_value,
307
+ output_attentions,
308
+ )
309
+ attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states)
310
+
311
+ outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them
312
+ else:
313
+ self_outputs = self.self(
314
+ hidden_states,
315
+ attention_mask,
316
+ head_mask,
317
+ encoder_hidden_states,
318
+ encoder_attention_mask,
319
+ past_key_value,
320
+ output_attentions,
321
+ )
322
+ attention_output = self.output(self_outputs[0], hidden_states)
323
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
324
+ return outputs
325
+
326
+
327
+ class BertIntermediate(nn.Module):
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
331
+ if isinstance(config.hidden_act, str):
332
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
333
+ else:
334
+ self.intermediate_act_fn = config.hidden_act
335
+
336
+ def forward(self, hidden_states):
337
+ hidden_states = self.dense(hidden_states)
338
+ hidden_states = self.intermediate_act_fn(hidden_states)
339
+ return hidden_states
340
+
341
+
342
+ class BertOutput(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
346
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
347
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
348
+
349
+ def forward(self, hidden_states, input_tensor):
350
+ hidden_states = self.dense(hidden_states)
351
+ hidden_states = self.dropout(hidden_states)
352
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
353
+ return hidden_states
354
+
355
+
356
+ class BertLayer(nn.Module):
357
+ def __init__(self, config, layer_num):
358
+ super().__init__()
359
+ self.config = config
360
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
361
+ self.seq_len_dim = 1
362
+ self.attention = BertAttention(config)
363
+ self.layer_num = layer_num
364
+ if self.config.add_cross_attention:
365
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num)
366
+ self.intermediate = BertIntermediate(config)
367
+ self.output = BertOutput(config)
368
+
369
+ def forward(
370
+ self,
371
+ hidden_states,
372
+ attention_mask=None,
373
+ head_mask=None,
374
+ encoder_hidden_states=None,
375
+ encoder_attention_mask=None,
376
+ past_key_value=None,
377
+ output_attentions=False,
378
+ mode=None,
379
+ ):
380
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
381
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
382
+ self_attention_outputs = self.attention(
383
+ hidden_states,
384
+ attention_mask,
385
+ head_mask,
386
+ output_attentions=output_attentions,
387
+ past_key_value=self_attn_past_key_value,
388
+ )
389
+ attention_output = self_attention_outputs[0]
390
+
391
+ outputs = self_attention_outputs[1:-1]
392
+ present_key_value = self_attention_outputs[-1]
393
+
394
+ if mode=='multimodal':
395
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
396
+ cross_attention_outputs = self.crossattention(
397
+ attention_output,
398
+ attention_mask,
399
+ head_mask,
400
+ encoder_hidden_states,
401
+ encoder_attention_mask,
402
+ output_attentions=output_attentions,
403
+ )
404
+ attention_output = cross_attention_outputs[0]
405
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
406
+ layer_output = apply_chunking_to_forward(
407
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
408
+ )
409
+ outputs = (layer_output,) + outputs
410
+
411
+ outputs = outputs + (present_key_value,)
412
+
413
+ return outputs
414
+
415
+ def feed_forward_chunk(self, attention_output):
416
+ intermediate_output = self.intermediate(attention_output)
417
+ layer_output = self.output(intermediate_output, attention_output)
418
+ return layer_output
419
+
420
+
421
+ class BertEncoder(nn.Module):
422
+ def __init__(self, config):
423
+ super().__init__()
424
+ self.config = config
425
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
426
+ self.gradient_checkpointing = False
427
+
428
+ def forward(
429
+ self,
430
+ hidden_states,
431
+ attention_mask=None,
432
+ head_mask=None,
433
+ encoder_hidden_states=None,
434
+ encoder_attention_mask=None,
435
+ past_key_values=None,
436
+ use_cache=None,
437
+ output_attentions=False,
438
+ output_hidden_states=False,
439
+ return_dict=True,
440
+ mode='multimodal',
441
+ ):
442
+ all_hidden_states = () if output_hidden_states else None
443
+ all_self_attentions = () if output_attentions else None
444
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
445
+
446
+ next_decoder_cache = () if use_cache else None
447
+
448
+ for i in range(self.config.num_hidden_layers):
449
+ layer_module = self.layer[i]
450
+ if output_hidden_states:
451
+ all_hidden_states = all_hidden_states + (hidden_states,)
452
+
453
+ layer_head_mask = head_mask[i] if head_mask is not None else None
454
+ past_key_value = past_key_values[i] if past_key_values is not None else None
455
+
456
+ if self.gradient_checkpointing and self.training:
457
+
458
+ if use_cache:
459
+ logger.warn(
460
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
461
+ )
462
+ use_cache = False
463
+
464
+ def create_custom_forward(module):
465
+ def custom_forward(*inputs):
466
+ return module(*inputs, past_key_value, output_attentions)
467
+
468
+ return custom_forward
469
+
470
+ layer_outputs = torch.utils.checkpoint.checkpoint(
471
+ create_custom_forward(layer_module),
472
+ hidden_states,
473
+ attention_mask,
474
+ layer_head_mask,
475
+ encoder_hidden_states,
476
+ encoder_attention_mask,
477
+ mode=mode,
478
+ )
479
+ else:
480
+ layer_outputs = layer_module(
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ past_key_value,
487
+ output_attentions,
488
+ mode=mode,
489
+ )
490
+
491
+ hidden_states = layer_outputs[0]
492
+ if use_cache:
493
+ next_decoder_cache += (layer_outputs[-1],)
494
+ if output_attentions:
495
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
496
+
497
+ if output_hidden_states:
498
+ all_hidden_states = all_hidden_states + (hidden_states,)
499
+
500
+ if not return_dict:
501
+ return tuple(
502
+ v
503
+ for v in [
504
+ hidden_states,
505
+ next_decoder_cache,
506
+ all_hidden_states,
507
+ all_self_attentions,
508
+ all_cross_attentions,
509
+ ]
510
+ if v is not None
511
+ )
512
+ return BaseModelOutputWithPastAndCrossAttentions(
513
+ last_hidden_state=hidden_states,
514
+ past_key_values=next_decoder_cache,
515
+ hidden_states=all_hidden_states,
516
+ attentions=all_self_attentions,
517
+ cross_attentions=all_cross_attentions,
518
+ )
519
+
520
+
521
+ class BertPooler(nn.Module):
522
+ def __init__(self, config):
523
+ super().__init__()
524
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
525
+ self.activation = nn.Tanh()
526
+
527
+ def forward(self, hidden_states):
528
+ # We "pool" the model by simply taking the hidden state corresponding
529
+ # to the first token.
530
+ first_token_tensor = hidden_states[:, 0]
531
+ pooled_output = self.dense(first_token_tensor)
532
+ pooled_output = self.activation(pooled_output)
533
+ return pooled_output
534
+
535
+
536
+ class BertPredictionHeadTransform(nn.Module):
537
+ def __init__(self, config):
538
+ super().__init__()
539
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
540
+ if isinstance(config.hidden_act, str):
541
+ self.transform_act_fn = ACT2FN[config.hidden_act]
542
+ else:
543
+ self.transform_act_fn = config.hidden_act
544
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
545
+
546
+ def forward(self, hidden_states):
547
+ hidden_states = self.dense(hidden_states)
548
+ hidden_states = self.transform_act_fn(hidden_states)
549
+ hidden_states = self.LayerNorm(hidden_states)
550
+ return hidden_states
551
+
552
+
553
+ class BertLMPredictionHead(nn.Module):
554
+ def __init__(self, config):
555
+ super().__init__()
556
+ self.transform = BertPredictionHeadTransform(config)
557
+
558
+ # The output weights are the same as the input embeddings, but there is
559
+ # an output-only bias for each token.
560
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
561
+
562
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
563
+
564
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
565
+ self.decoder.bias = self.bias
566
+
567
+ def forward(self, hidden_states):
568
+ hidden_states = self.transform(hidden_states)
569
+ hidden_states = self.decoder(hidden_states)
570
+ return hidden_states
571
+
572
+
573
+ class BertOnlyMLMHead(nn.Module):
574
+ def __init__(self, config):
575
+ super().__init__()
576
+ self.predictions = BertLMPredictionHead(config)
577
+
578
+ def forward(self, sequence_output):
579
+ prediction_scores = self.predictions(sequence_output)
580
+ return prediction_scores
581
+
582
+
583
+ class BertPreTrainedModel(PreTrainedModel):
584
+ """
585
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
586
+ models.
587
+ """
588
+
589
+ config_class = BertConfig
590
+ base_model_prefix = "bert"
591
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
592
+
593
+ def _init_weights(self, module):
594
+ """ Initialize the weights """
595
+ if isinstance(module, (nn.Linear, nn.Embedding)):
596
+ # Slightly different from the TF version which uses truncated_normal for initialization
597
+ # cf https://github.com/pytorch/pytorch/pull/5617
598
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
599
+ elif isinstance(module, nn.LayerNorm):
600
+ module.bias.data.zero_()
601
+ module.weight.data.fill_(1.0)
602
+ if isinstance(module, nn.Linear) and module.bias is not None:
603
+ module.bias.data.zero_()
604
+
605
+
606
+ class BertModel(BertPreTrainedModel):
607
+ """
608
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
609
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
610
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
611
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
612
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
613
+ input to the forward pass.
614
+ """
615
+
616
+ def __init__(self, config, add_pooling_layer=True):
617
+ super().__init__(config)
618
+ self.config = config
619
+
620
+ self.embeddings = BertEmbeddings(config)
621
+
622
+ self.encoder = BertEncoder(config)
623
+
624
+ self.pooler = BertPooler(config) if add_pooling_layer else None
625
+
626
+ self.init_weights()
627
+
628
+
629
+ def get_input_embeddings(self):
630
+ return self.embeddings.word_embeddings
631
+
632
+ def set_input_embeddings(self, value):
633
+ self.embeddings.word_embeddings = value
634
+
635
+ def _prune_heads(self, heads_to_prune):
636
+ """
637
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
638
+ class PreTrainedModel
639
+ """
640
+ for layer, heads in heads_to_prune.items():
641
+ self.encoder.layer[layer].attention.prune_heads(heads)
642
+
643
+
644
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
645
+ """
646
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
647
+
648
+ Arguments:
649
+ attention_mask (:obj:`torch.Tensor`):
650
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
651
+ input_shape (:obj:`Tuple[int]`):
652
+ The shape of the input to the model.
653
+ device: (:obj:`torch.device`):
654
+ The device of the input to the model.
655
+
656
+ Returns:
657
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
658
+ """
659
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
660
+ # ourselves in which case we just need to make it broadcastable to all heads.
661
+ if attention_mask.dim() == 3:
662
+ extended_attention_mask = attention_mask[:, None, :, :]
663
+ elif attention_mask.dim() == 2:
664
+ # Provided a padding mask of dimensions [batch_size, seq_length]
665
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
666
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
667
+ if is_decoder:
668
+ batch_size, seq_length = input_shape
669
+
670
+ seq_ids = torch.arange(seq_length, device=device)
671
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
672
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
673
+ # causal and attention masks must have same type with pytorch version < 1.3
674
+ causal_mask = causal_mask.to(attention_mask.dtype)
675
+
676
+ if causal_mask.shape[1] < attention_mask.shape[1]:
677
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
678
+ causal_mask = torch.cat(
679
+ [
680
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
681
+ causal_mask,
682
+ ],
683
+ axis=-1,
684
+ )
685
+
686
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
687
+ else:
688
+ extended_attention_mask = attention_mask[:, None, None, :]
689
+ else:
690
+ raise ValueError(
691
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
692
+ input_shape, attention_mask.shape
693
+ )
694
+ )
695
+
696
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
697
+ # masked positions, this operation will create a tensor which is 0.0 for
698
+ # positions we want to attend and -10000.0 for masked positions.
699
+ # Since we are adding it to the raw scores before the softmax, this is
700
+ # effectively the same as removing these entirely.
701
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
702
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
703
+ return extended_attention_mask
704
+
705
+ def forward(
706
+ self,
707
+ input_ids=None,
708
+ attention_mask=None,
709
+ position_ids=None,
710
+ head_mask=None,
711
+ inputs_embeds=None,
712
+ encoder_embeds=None,
713
+ encoder_hidden_states=None,
714
+ encoder_attention_mask=None,
715
+ past_key_values=None,
716
+ use_cache=None,
717
+ output_attentions=None,
718
+ output_hidden_states=None,
719
+ return_dict=None,
720
+ is_decoder=False,
721
+ mode='multimodal',
722
+ ):
723
+ r"""
724
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
725
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
726
+ the model is configured as a decoder.
727
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
728
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
729
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
730
+ - 1 for tokens that are **not masked**,
731
+ - 0 for tokens that are **masked**.
732
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
733
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
734
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
735
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
736
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
737
+ use_cache (:obj:`bool`, `optional`):
738
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
739
+ decoding (see :obj:`past_key_values`).
740
+ """
741
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
742
+ output_hidden_states = (
743
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
744
+ )
745
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
746
+
747
+ if is_decoder:
748
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
749
+ else:
750
+ use_cache = False
751
+
752
+ if input_ids is not None and inputs_embeds is not None:
753
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
754
+ elif input_ids is not None:
755
+ input_shape = input_ids.size()
756
+ batch_size, seq_length = input_shape
757
+ device = input_ids.device
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ batch_size, seq_length = input_shape
761
+ device = inputs_embeds.device
762
+ elif encoder_embeds is not None:
763
+ input_shape = encoder_embeds.size()[:-1]
764
+ batch_size, seq_length = input_shape
765
+ device = encoder_embeds.device
766
+ else:
767
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
768
+
769
+ # past_key_values_length
770
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
771
+
772
+ if attention_mask is None:
773
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
774
+
775
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
776
+ # ourselves in which case we just need to make it broadcastable to all heads.
777
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
778
+ device, is_decoder)
779
+
780
+ # If a 2D or 3D attention mask is provided for the cross-attention
781
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
782
+ if encoder_hidden_states is not None:
783
+ if type(encoder_hidden_states) == list:
784
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
785
+ else:
786
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
787
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
788
+
789
+ if type(encoder_attention_mask) == list:
790
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
791
+ elif encoder_attention_mask is None:
792
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
793
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
794
+ else:
795
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
796
+ else:
797
+ encoder_extended_attention_mask = None
798
+
799
+ # Prepare head mask if needed
800
+ # 1.0 in head_mask indicate we keep the head
801
+ # attention_probs has shape bsz x n_heads x N x N
802
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
803
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
804
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
805
+
806
+ if encoder_embeds is None:
807
+ embedding_output = self.embeddings(
808
+ input_ids=input_ids,
809
+ position_ids=position_ids,
810
+ inputs_embeds=inputs_embeds,
811
+ past_key_values_length=past_key_values_length,
812
+ )
813
+ else:
814
+ embedding_output = encoder_embeds
815
+
816
+ encoder_outputs = self.encoder(
817
+ embedding_output,
818
+ attention_mask=extended_attention_mask,
819
+ head_mask=head_mask,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ encoder_attention_mask=encoder_extended_attention_mask,
822
+ past_key_values=past_key_values,
823
+ use_cache=use_cache,
824
+ output_attentions=output_attentions,
825
+ output_hidden_states=output_hidden_states,
826
+ return_dict=return_dict,
827
+ mode=mode,
828
+ )
829
+ sequence_output = encoder_outputs[0]
830
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
831
+
832
+ if not return_dict:
833
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
834
+
835
+ return BaseModelOutputWithPoolingAndCrossAttentions(
836
+ last_hidden_state=sequence_output,
837
+ pooler_output=pooled_output,
838
+ past_key_values=encoder_outputs.past_key_values,
839
+ hidden_states=encoder_outputs.hidden_states,
840
+ attentions=encoder_outputs.attentions,
841
+ cross_attentions=encoder_outputs.cross_attentions,
842
+ )
843
+