simpleParadox commited on
Commit
3d7108b
·
verified ·
1 Parent(s): 3c87b62

Upload 10 files

Browse files
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/opt-125m",
3
+ "_remove_final_layer_norm": false,
4
+ "activation_dropout": 0.0,
5
+ "activation_function": "relu",
6
+ "architectures": [
7
+ "FlamingoForCausalLM"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "auto_map": {
11
+ "AutoConfig": "configuration_flamingo.FlamingoConfig",
12
+ "AutoModelForCausalLM": "modeling_flamingo.FlamingoForCausalLM",
13
+ "AutoModelForSequenceClassification": "modeling_flamingo.FlamingoForSequenceClassification"
14
+ },
15
+ "bos_token_id": 2,
16
+ "cross_attn_every": 2,
17
+ "do_layer_norm_before": true,
18
+ "dropout": 0.1,
19
+ "enable_bias": true,
20
+ "eos_token_id": 2,
21
+ "ffn_dim": 3072,
22
+ "finetune_LM": true,
23
+ "hidden_size": 768,
24
+ "id_perceiver": false,
25
+ "init_std": 0.02,
26
+ "inp_dim": 768,
27
+ "layer_norm_elementwise_affine": true,
28
+ "layerdrop": 0.0,
29
+ "manual_seed": 0,
30
+ "max_position_embeddings": 2048,
31
+ "media_token_id": 32768,
32
+ "model_type": "flamingo",
33
+ "num_attention_heads": 12,
34
+ "num_hidden_layers": 12,
35
+ "only_attend_immediate_media": true,
36
+ "pad_token_id": 1,
37
+ "perceiver_depth": 2,
38
+ "perceiver_num_latents": 64,
39
+ "prefix": "</s>",
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.38.2",
42
+ "use_cache": true,
43
+ "vocab_size": 32778,
44
+ "word_embed_proj_dim": 768
45
+ }
configuration_flamingo.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from typing import Union
18
+
19
+ import transformers.models.opt.configuration_opt as configuration_opt
20
+
21
+
22
+ class FlamingoConfig(configuration_opt.OPTConfig, dict):
23
+ model_type = "flamingo"
24
+ def __init__(
25
+ self,
26
+ cross_attn_every=2,
27
+ vocab_size=32778,
28
+ media_token_id=32768,
29
+ **kwargs,
30
+ ):
31
+ configuration_opt.OPTConfig.__init__(
32
+ self, vocab_size=vocab_size, **kwargs)
33
+ self.media_token_id = media_token_id
34
+ self.cross_attn_every = cross_attn_every
35
+ dict.__init__(self, **self.__dict__)
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.38.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a64aa594817a36b3c148659a4cf764fc08efb9227e7fa98e0f7da7787131213b
3
+ size 1022011440
modeling_flamingo.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import pdb
3
+ from einops import rearrange
4
+ from typing import List, Optional, Tuple, Union
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
11
+
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
13
+ import transformers.models.opt.modeling_opt as modeling_opt
14
+ from transformers.models.opt.modeling_opt\
15
+ import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
16
+ from transformers import ViTModel
17
+
18
+ try:
19
+ from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask
20
+ except:
21
+ _prepare_4d_causal_attention_mask = None
22
+
23
+ from utils import exists, freeze_all_layers_, unfreeze_all_layers_
24
+ from flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
25
+ from .configuration_flamingo import FlamingoConfig
26
+
27
+
28
+ class OPTLearnedPositionalEmbedding(nn.Embedding):
29
+ """
30
+ This module learns positional embeddings up to a fixed maximum size.
31
+ """
32
+
33
+ def __init__(self, num_embeddings: int, embedding_dim: int):
34
+ # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
35
+ # and adjust num_embeddings appropriately. Other models don't have this hack
36
+ self.offset = 2
37
+ super().__init__(num_embeddings + self.offset, embedding_dim)
38
+
39
+ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
40
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
41
+ attention_mask = attention_mask.long()
42
+
43
+ # create positions depending on attention_mask
44
+ positions = torch.cumsum(attention_mask, dim=1)
45
+ positions = (positions.type_as(attention_mask) * attention_mask).long() - 1
46
+
47
+ # cut positions if `past_key_values_length` is > 0
48
+ positions = positions[:, past_key_values_length:]
49
+
50
+ return super().forward(positions + self.offset)
51
+
52
+
53
+ class OPTDecoder(modeling_opt.OPTDecoder):
54
+ """
55
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
56
+ Args:
57
+ config: OPTConfig
58
+ embed_tokens (nn.Embedding): output embedding
59
+ """
60
+
61
+ def __init__(self, config: OPTConfig):
62
+ OPTPreTrainedModel.__init__(self, config)
63
+ self.dropout = config.dropout
64
+ self.layerdrop = config.layerdrop
65
+ self.padding_idx = config.pad_token_id
66
+ self.max_target_positions = config.max_position_embeddings
67
+ self.vocab_size = config.vocab_size
68
+ self.media_token_id = config.media_token_id
69
+
70
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
71
+ self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
72
+
73
+ if config.word_embed_proj_dim != config.hidden_size:
74
+ self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
75
+ else:
76
+ self.project_out = None
77
+
78
+ if config.word_embed_proj_dim != config.hidden_size:
79
+ self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
80
+ else:
81
+ self.project_in = None
82
+
83
+ # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
84
+ # with checkpoints that have been fine-tuned before transformers v4.20.1
85
+ # see https://github.com/facebookresearch/metaseq/pull/164
86
+ if config.do_layer_norm_before and not config._remove_final_layer_norm:
87
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size)
88
+ else:
89
+ self.final_layer_norm = None
90
+
91
+ dim_head = config.hidden_size // config.num_attention_heads
92
+ if not config.id_perceiver:
93
+ self.perceiver_resampler = PerceiverResampler(
94
+ dim=config.hidden_size,
95
+ depth=config.perceiver_depth,
96
+ dim_head=dim_head,
97
+ heads=config.num_attention_heads,
98
+ num_latents=config.perceiver_num_latents,
99
+ inp_dim=config.inp_dim,
100
+ )
101
+ else:
102
+ if config.inp_dim is None:
103
+ self.perceiver_resampler = nn.Identity()
104
+ else:
105
+ self.perceiver_resampler = nn.Linear(
106
+ config.inp_dim, config.hidden_size,
107
+ bias=False)
108
+ self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
109
+ self.gated_attn_layers = nn.ModuleList(
110
+ [GatedCrossAttentionBlock(
111
+ dim=config.hidden_size, dim_head=dim_head, heads=config.num_attention_heads,
112
+ only_attend_immediate_media=config.only_attend_immediate_media)\
113
+ if not (ind % config.cross_attn_every) else None \
114
+ for ind in range(config.num_hidden_layers)])
115
+
116
+ self.gradient_checkpointing = False
117
+ # Initialize weights and apply final processing
118
+ self.post_init()
119
+
120
+ # in flamingo mode, freeze everything but perceiver and gated cross attention
121
+ if not config.finetune_LM:
122
+ freeze_all_layers_(self)
123
+ unfreeze_all_layers_(self.perceiver_resampler)
124
+ [unfreeze_all_layers_(cross_attn) for cross_attn in self.gated_attn_layers if exists(cross_attn)]
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: torch.LongTensor = None,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ head_mask: Optional[torch.Tensor] = None,
131
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
132
+ inputs_embeds: Optional[torch.FloatTensor] = None,
133
+ use_cache: Optional[bool] = None,
134
+ output_attentions: Optional[bool] = None,
135
+ output_hidden_states: Optional[bool] = None,
136
+ return_dict: Optional[bool] = None,
137
+ pixel_values=None,
138
+ image_embeds=None
139
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
140
+ r"""
141
+ Args:
142
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
143
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
144
+ provide it.
145
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
146
+ [`PreTrainedTokenizer.__call__`] for details.
147
+ [What are input IDs?](../glossary#input-ids)
148
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
149
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
150
+ - 1 for tokens that are **not masked**,
151
+ - 0 for tokens that are **masked**.
152
+ [What are attention masks?](../glossary#attention-mask)
153
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
154
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
155
+ - 1 indicates the head is **not masked**,
156
+ - 0 indicates the head is **masked**.
157
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
158
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
159
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
160
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
161
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
162
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
163
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
164
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
167
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
168
+ than the model's internal embedding lookup matrix.
169
+ output_attentions (`bool`, *optional*):
170
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
171
+ returned tensors for more detail.
172
+ output_hidden_states (`bool`, *optional*):
173
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
174
+ for more detail.
175
+ return_dict (`bool`, *optional*):
176
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
177
+ """
178
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
179
+ output_hidden_states = (
180
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
181
+ )
182
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
183
+
184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
185
+
186
+ # retrieve input_ids and inputs_embeds
187
+ if input_ids is not None and inputs_embeds is not None:
188
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
189
+ elif input_ids is not None:
190
+ input_shape = input_ids.size()
191
+ input_ids = input_ids.view(-1, input_shape[-1])
192
+ elif inputs_embeds is not None:
193
+ input_shape = inputs_embeds.size()[:-1]
194
+ else:
195
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
196
+ batch, device = input_ids.shape[0], input_ids.device
197
+
198
+ flamingo_mode = exists(pixel_values) or exists(image_embeds)
199
+
200
+ # derive the media token ids (as a boolean tensor), for calculating the masked cross attention
201
+ if flamingo_mode:
202
+ media_locations = input_ids == self.media_token_id
203
+
204
+ assert not (exists(pixel_values) and exists(image_embeds))
205
+ # encode images into embeddings
206
+ # with the img_encoder passed in at init
207
+ # it can also accept precomputed image embeddings
208
+
209
+ if exists(pixel_values):
210
+ assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
211
+ if len(pixel_values.shape) == 4:
212
+ pixel_values = torch.unsqueeze(pixel_values, 1)
213
+ pixel_values = rearrange(pixel_values, 'b t ... -> (b t) ...')
214
+
215
+ with torch.no_grad():
216
+ if getattr(self.img_encoder, 'vision_model', None) is not None:
217
+ image_outputs = self.img_encoder.vision_model(
218
+ pixel_values=pixel_values,
219
+ output_hidden_states=True, return_dict=True)
220
+ else:
221
+ image_outputs = self.img_encoder(
222
+ pixel_values=pixel_values,
223
+ output_hidden_states=True, return_dict=True)
224
+
225
+ image_embeds = image_outputs['last_hidden_state']
226
+ image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch)
227
+
228
+ if exists(image_embeds):
229
+ image_embeds = self.perceiver_resampler(image_embeds)
230
+
231
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
232
+
233
+ if inputs_embeds is None:
234
+ inputs_embeds = self.embed_tokens(input_ids)
235
+
236
+ # embed positions
237
+ if attention_mask is None:
238
+ attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
239
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
240
+
241
+ if _prepare_4d_causal_attention_mask is None:
242
+ attention_mask = self._prepare_decoder_attention_mask(
243
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
244
+ )
245
+ else:
246
+ attention_mask = _prepare_4d_causal_attention_mask(
247
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
248
+ )
249
+
250
+ if self.project_in is not None:
251
+ inputs_embeds = self.project_in(inputs_embeds)
252
+
253
+ hidden_states = inputs_embeds + pos_embeds
254
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
255
+
256
+ # decoder layers
257
+ all_hidden_states = () if output_hidden_states else None
258
+ all_self_attns = () if output_attentions else None
259
+ next_decoder_cache = () if use_cache else None
260
+
261
+ # check if head_mask has a correct number of layers specified if desired
262
+ for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
263
+ if attn_mask is not None:
264
+ if attn_mask.size()[0] != (len(self.layers)):
265
+ raise ValueError(
266
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
267
+ f" {head_mask.size()[0]}."
268
+ )
269
+
270
+ for idx, decoder_layer in enumerate(self.layers):
271
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
272
+ if output_hidden_states:
273
+ all_hidden_states += (hidden_states,)
274
+
275
+ dropout_probability = random.uniform(0, 1)
276
+ if self.training and (dropout_probability < self.layerdrop):
277
+ continue
278
+
279
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
280
+
281
+ flamingo_cross_attn = self.gated_attn_layers[idx]
282
+ if exists(flamingo_cross_attn) and exists(image_embeds):
283
+ hidden_states = flamingo_cross_attn(
284
+ hidden_states,
285
+ image_embeds,
286
+ media_locations = media_locations
287
+ )
288
+
289
+ layer_outputs = decoder_layer(
290
+ hidden_states,
291
+ attention_mask=attention_mask,
292
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
293
+ past_key_value=past_key_value,
294
+ output_attentions=output_attentions,
295
+ use_cache=use_cache,
296
+ )
297
+
298
+ hidden_states = layer_outputs[0]
299
+
300
+ if use_cache:
301
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
302
+
303
+ if output_attentions:
304
+ all_self_attns += (layer_outputs[1],)
305
+
306
+ if self.final_layer_norm is not None:
307
+ hidden_states = self.final_layer_norm(hidden_states)
308
+
309
+ if self.project_out is not None:
310
+ hidden_states = self.project_out(hidden_states)
311
+
312
+ # add hidden states from the last decoder layer
313
+ if output_hidden_states:
314
+ all_hidden_states += (hidden_states,)
315
+
316
+ next_cache = next_decoder_cache if use_cache else None
317
+ if not return_dict:
318
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
319
+ return BaseModelOutputWithPast(
320
+ last_hidden_state=hidden_states,
321
+ past_key_values=next_cache,
322
+ hidden_states=all_hidden_states,
323
+ attentions=all_self_attns,
324
+ )
325
+
326
+
327
+ class OPTModel(modeling_opt.OPTModel):
328
+ def __init__(self, config: OPTConfig):
329
+ OPTPreTrainedModel.__init__(self, config)
330
+ self.decoder = OPTDecoder(config)
331
+
332
+ # Initialize weights and apply final processing
333
+ self.post_init()
334
+
335
+
336
+ class OPTForCausalLM(modeling_opt.OPTForCausalLM):
337
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
338
+
339
+ def __init__(self, config):
340
+ OPTPreTrainedModel.__init__(self, config)
341
+ self.model = OPTModel(config)
342
+
343
+ # the lm_head weight is automatically tied to the embed tokens weight
344
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
345
+
346
+ # Initialize weights and apply final processing
347
+ self.post_init()
348
+
349
+
350
+ def set_default_if_nonexist(config, key, value):
351
+ if getattr(config, key, None) is None:
352
+ setattr(config, key, value)
353
+ return config
354
+
355
+
356
+ def setup_default_flamingo_configs(config):
357
+ set_default_if_nonexist(config, 'perceiver_depth', 2)
358
+ set_default_if_nonexist(config, 'perceiver_num_latents', 64)
359
+ set_default_if_nonexist(config, 'cross_attn_every', 3)
360
+ set_default_if_nonexist(config, 'only_attend_immediate_media', True)
361
+ set_default_if_nonexist(config, 'media_token_id', 50265)
362
+ set_default_if_nonexist(config, 'inp_dim', 768)
363
+ set_default_if_nonexist(config, 'finetune_LM', True)
364
+ set_default_if_nonexist(config, 'id_perceiver', False)
365
+ return config
366
+
367
+
368
+ class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
369
+ _keys_to_ignore_on_load_missing = [
370
+ r"lm_head.weight",
371
+ ]
372
+ config_class = FlamingoConfig
373
+
374
+ def __init__(self, config):
375
+ OPTPreTrainedModel.__init__(self, config)
376
+ config = setup_default_flamingo_configs(config)
377
+ self.model = OPTModel(config)
378
+
379
+
380
+ random.seed(self.config.manual_seed)
381
+ np.random.seed(self.config.manual_seed)
382
+ torch.manual_seed(self.config.manual_seed)
383
+ torch.cuda.manual_seed_all(self.config.manual_seed)
384
+
385
+ # Also setting deterministic behaviour for cudnn and mixed precision.
386
+ torch.backends.cudnn.deterministic = True
387
+ torch.use_deterministic_algorithms(True)
388
+ torch.backends.cuda.matmul.allow_tf32 = True
389
+ torch.backends.cudnn.allow_tf32 = True
390
+
391
+ # the lm_head weight is automatically tied to the embed tokens weight
392
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
393
+
394
+ # Initialize weights and apply final processing
395
+ self.post_init()
396
+ self.model.decoder.img_encoder = None
397
+ self.loss_fct = CrossEntropyLoss()
398
+ dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
399
+ self.setup_vis_encoder(dino_model)
400
+
401
+ def setup_vis_encoder(self, img_encoder):
402
+ self.model.decoder.img_encoder = img_encoder
403
+ freeze_all_layers_(img_encoder)
404
+
405
+ def forward(
406
+ self,
407
+ input_ids: torch.LongTensor = None,
408
+ attention_mask: Optional[torch.Tensor] = None,
409
+ head_mask: Optional[torch.Tensor] = None,
410
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
411
+ inputs_embeds: Optional[torch.FloatTensor] = None,
412
+ labels: Optional[torch.LongTensor] = None,
413
+ use_cache: Optional[bool] = None,
414
+ output_attentions: Optional[bool] = None,
415
+ output_hidden_states: Optional[bool] = None,
416
+ return_dict: Optional[bool] = None,
417
+ *args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
418
+ r"""
419
+ Args:
420
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
421
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
422
+ provide it.
423
+ Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
424
+ [`PreTrainedTokenizer.__call__`] for details.
425
+ [What are input IDs?](../glossary#input-ids)
426
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
427
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
428
+ - 1 for tokens that are **not masked**,
429
+ - 0 for tokens that are **masked**.
430
+ [What are attention masks?](../glossary#attention-mask)
431
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
432
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
433
+ - 1 indicates the head is **not masked**,
434
+ - 0 indicates the head is **masked**.
435
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
436
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
437
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
438
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
439
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
440
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
441
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
442
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
443
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
444
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
445
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
446
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
447
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
448
+ than the model's internal embedding lookup matrix.
449
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
450
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
451
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
452
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
453
+ use_cache (`bool`, *optional*):
454
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
455
+ (see `past_key_values`).
456
+ output_attentions (`bool`, *optional*):
457
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
458
+ returned tensors for more detail.
459
+ output_hidden_states (`bool`, *optional*):
460
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
461
+ for more detail.
462
+ return_dict (`bool`, *optional*):
463
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
464
+ Returns:
465
+ Example:
466
+ ```python
467
+ >>> from transformers import GPT2Tokenizer, OPTForCausalLM
468
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
469
+ >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
470
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
471
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
472
+ >>> # Generate
473
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
474
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
475
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
476
+ ```"""
477
+
478
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
479
+ output_hidden_states = (
480
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
481
+ )
482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
483
+
484
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
485
+ outputs = self.model.decoder(
486
+ input_ids=input_ids,
487
+ attention_mask=attention_mask,
488
+ head_mask=head_mask,
489
+ past_key_values=past_key_values,
490
+ inputs_embeds=inputs_embeds,
491
+ use_cache=use_cache,
492
+ output_attentions=output_attentions,
493
+ output_hidden_states=output_hidden_states,
494
+ return_dict=return_dict,
495
+ *args, **kwargs)
496
+
497
+ logits = self.lm_head(outputs[0]).contiguous()
498
+
499
+ loss = None
500
+ if labels is not None:
501
+ # Shift so that tokens < n predict n
502
+ shift_logits = logits[..., :-1, :].contiguous()
503
+ shift_labels = labels[..., 1:].contiguous()
504
+ # Flatten the tokens
505
+ loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
506
+
507
+ if not return_dict:
508
+ output = (logits,) + outputs[1:]
509
+ return (loss,) + output if loss is not None else output
510
+
511
+ return CausalLMOutputWithPast(
512
+ loss=loss,
513
+ logits=logits,
514
+ past_key_values=outputs.past_key_values,
515
+ hidden_states=outputs.hidden_states,
516
+ attentions=outputs.attentions,
517
+ )
518
+
519
+
520
+ class FlamingoForSequenceClassification(OPTPreTrainedModel):
521
+ _keys_to_ignore_on_load_missing = [
522
+ r"score.weight",
523
+ ]
524
+
525
+ def __init__(self, config: OPTConfig):
526
+ OPTPreTrainedModel.__init__(self, config)
527
+ config = setup_default_flamingo_configs(config)
528
+ self.num_labels = config.num_labels
529
+
530
+ random.seed(self.config.manual_seed)
531
+ np.random.seed(self.config.manual_seed)
532
+ torch.manual_seed(self.config.manual_seed)
533
+ torch.cuda.manual_seed_all(self.config.manual_seed)
534
+
535
+ # Also setting deterministic behaviour for cudnn and mixed precision.
536
+ torch.backends.cudnn.deterministic = True
537
+ torch.use_deterministic_algorithms(True)
538
+ torch.backends.cuda.matmul.allow_tf32 = True
539
+ torch.backends.cudnn.allow_tf32 = True
540
+
541
+
542
+ self.model = OPTModel(config)
543
+
544
+
545
+
546
+ # the lm_head weight is automatically tied to the embed tokens weight
547
+ self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
548
+
549
+ # Initialize weights and apply final processing
550
+ self.post_init()
551
+ self.model.decoder.img_encoder = None
552
+ self.loss_fct = CrossEntropyLoss()
553
+ dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
554
+ self.setup_vis_encoder(dino_model)
555
+
556
+
557
+
558
+
559
+ def setup_vis_encoder(self, img_encoder):
560
+ self.model.decoder.img_encoder = img_encoder
561
+ freeze_all_layers_(img_encoder)
562
+
563
+ def forward(
564
+ self,
565
+ input_ids: Optional[torch.LongTensor] = None,
566
+ attention_mask: Optional[torch.FloatTensor] = None,
567
+ head_mask: Optional[torch.FloatTensor] = None,
568
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
569
+ inputs_embeds: Optional[torch.FloatTensor] = None,
570
+ labels: Optional[torch.LongTensor] = None,
571
+ use_cache: Optional[bool] = None,
572
+ output_attentions: Optional[bool] = None,
573
+ output_hidden_states: Optional[bool] = None,
574
+ return_dict: Optional[bool] = None,
575
+ *args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
576
+ r"""
577
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
578
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
579
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
580
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
581
+ """
582
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
583
+
584
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
585
+ outputs = self.model.decoder(
586
+ input_ids=input_ids,
587
+ attention_mask=attention_mask,
588
+ head_mask=head_mask,
589
+ past_key_values=past_key_values,
590
+ inputs_embeds=inputs_embeds,
591
+ use_cache=use_cache,
592
+ output_attentions=output_attentions,
593
+ output_hidden_states=output_hidden_states,
594
+ return_dict=return_dict,
595
+ *args, **kwargs)
596
+
597
+ hidden_states = outputs[0]
598
+ logits = self.score(hidden_states)
599
+
600
+ if input_ids is not None:
601
+ batch_size, sequence_length = input_ids.shape[:2]
602
+ else:
603
+ batch_size, sequence_length = inputs_embeds.shape[:2]
604
+
605
+ if self.config.pad_token_id is None:
606
+ sequence_lengths = -1
607
+ else:
608
+ if input_ids is not None:
609
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
610
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
611
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
612
+ sequence_lengths = sequence_lengths.to(logits.device)
613
+ else:
614
+ sequence_lengths = -1
615
+ # logger.warning(
616
+ # f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
617
+ # "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
618
+ # )
619
+
620
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
621
+
622
+ loss = None
623
+ if labels is not None:
624
+ if self.config.problem_type is None:
625
+ if self.num_labels == 1:
626
+ self.config.problem_type = "regression"
627
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
628
+ self.config.problem_type = "single_label_classification"
629
+ else:
630
+ self.config.problem_type = "multi_label_classification"
631
+
632
+ if self.config.problem_type == "regression":
633
+ loss_fct = MSELoss()
634
+ if self.num_labels == 1:
635
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
636
+ else:
637
+ loss = loss_fct(pooled_logits, labels)
638
+ elif self.config.problem_type == "single_label_classification":
639
+ loss_fct = CrossEntropyLoss()
640
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
641
+ elif self.config.problem_type == "multi_label_classification":
642
+ loss_fct = BCEWithLogitsLoss()
643
+ loss = loss_fct(pooled_logits, labels)
644
+
645
+ if not return_dict:
646
+ output = (pooled_logits,) + outputs[1:]
647
+ return ((loss,) + output) if loss is not None else output
648
+
649
+ return SequenceClassifierOutputWithPast(
650
+ loss=loss,
651
+ logits=pooled_logits,
652
+ past_key_values=outputs.past_key_values,
653
+ hidden_states=outputs.hidden_states,
654
+ attentions=outputs.attentions,
655
+ )
656
+
657
+ def get_input_embeddings(self):
658
+ return self.model.decoder.embed_tokens
659
+
660
+ def set_input_embeddings(self, value):
661
+ self.model.decoder.embed_tokens = value
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "model_input_names": [
49
+ "input_ids",
50
+ "attention_mask"
51
+ ],
52
+ "model_max_length": 512,
53
+ "pad_token": "[PAD]",
54
+ "processor_class": "GitProcessor",
55
+ "sep_token": "[SEP]",
56
+ "strip_accents": null,
57
+ "tokenize_chinese_chars": true,
58
+ "tokenizer_class": "BertTokenizer",
59
+ "unk_token": "[UNK]"
60
+ }
utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def exists(val):
4
+ return val is not None
5
+
6
+ # for controlling freezing during training of flamingo
7
+
8
+ def set_module_requires_grad_(module, requires_grad):
9
+ for param in module.parameters():
10
+ param.requires_grad = requires_grad
11
+
12
+ def freeze_all_layers_(module):
13
+ set_module_requires_grad_(module, False)
14
+
15
+ def unfreeze_all_layers_(module):
16
+ set_module_requires_grad_(module, True)
17
+
18
+ def freeze_model_and_make_eval_(model):
19
+ model.eval()
20
+ freeze_all_layers_(model)
21
+
22
+ def _make_att_wd_mask(
23
+ input_ids_shape: torch.Size,
24
+ dtype: torch.dtype, device: torch.device,
25
+ past_key_values_length: int = 0,
26
+ att_wd_size: int = 0,
27
+ ):
28
+ bsz, tgt_len = input_ids_shape
29
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
30
+ mask_cond = torch.arange(mask.size(-1), device=device)
31
+ mask.masked_fill_(
32
+ mask_cond > (mask_cond - att_wd_size).view(mask.size(-1), 1), 0)
33
+ mask = mask.to(dtype)
34
+
35
+ if past_key_values_length > 0:
36
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
37
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
vocab.txt ADDED
The diff for this file is too large to render. See raw diff