salma-remyx commited on
Commit
aea84eb
·
verified ·
1 Parent(s): ae0cb85

Upload mllava/modeling_llava.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mllava/modeling_llava.py +770 -0
mllava/modeling_llava.py ADDED
@@ -0,0 +1,770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Llava model."""
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+
23
+ # from ... import PreTrainedModel
24
+ # from ...activations import ACT2FN
25
+ # from ...cache_utils import Cache
26
+ # from ...modeling_outputs import ModelOutput
27
+ # from ...utils import (
28
+ # add_start_docstrings,
29
+ # add_start_docstrings_to_model_forward,
30
+ # logging,
31
+ # replace_return_docstrings,
32
+ # )
33
+ # from ..auto import AutoModel, AutoModelForCausalLM
34
+
35
+ from .configuration_llava import LlavaConfig
36
+
37
+ from transformers import PreTrainedModel
38
+ from transformers.activations import ACT2FN
39
+ from transformers.cache_utils import Cache
40
+ from transformers.modeling_outputs import ModelOutput
41
+ from transformers.utils import (
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
48
+ from .configuration_llava import LlavaConfig
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+ _CONFIG_FOR_DOC = "LlavaConfig"
54
+
55
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
56
+ "llava-hf/llava-1.5-7b-hf",
57
+ "llava-hf/llava-1.5-13b-hf",
58
+ "llava-hf/bakLlava-v1-hf",
59
+ # See all Llava models at https://huggingface.co/models?filter=llava
60
+ ]
61
+
62
+
63
+ @dataclass
64
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
65
+ class LlavaCausalLMOutputWithPast(ModelOutput):
66
+ """
67
+ Base class for Llava causal language model (or autoregressive) outputs.
68
+
69
+ Args:
70
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
71
+ Language modeling loss (for next-token prediction).
72
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
73
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
74
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
75
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
76
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
77
+
78
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
79
+ `past_key_values` input) to speed up sequential decoding.
80
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
81
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
82
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
83
+
84
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
85
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
86
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
87
+ sequence_length)`.
88
+
89
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
90
+ heads.
91
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
92
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
93
+ sequence_length, hidden_size)`.
94
+
95
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
96
+ """
97
+
98
+ loss: Optional[torch.FloatTensor] = None
99
+ logits: torch.FloatTensor = None
100
+ past_key_values: Optional[List[torch.FloatTensor]] = None
101
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
102
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
103
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
104
+
105
+
106
+ class LlavaMultiModalProjector(nn.Module):
107
+ def __init__(self, config: LlavaConfig):
108
+ super().__init__()
109
+
110
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
111
+ self.act = ACT2FN[config.projector_hidden_act]
112
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
113
+
114
+ def forward(self, image_features):
115
+ hidden_states = self.linear_1(image_features)
116
+ hidden_states = self.act(hidden_states)
117
+ hidden_states = self.linear_2(hidden_states)
118
+ return hidden_states
119
+
120
+
121
+ LLAVA_START_DOCSTRING = r"""
122
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
123
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
124
+ etc.)
125
+
126
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
127
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
128
+ and behavior.
129
+
130
+ Parameters:
131
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
132
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
133
+ load the weights associated with the model, only the configuration. Check out the
134
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
135
+ """
136
+
137
+
138
+ @add_start_docstrings(
139
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
140
+ LLAVA_START_DOCSTRING,
141
+ )
142
+ class LlavaPreTrainedModel(PreTrainedModel):
143
+ config_class = LlavaConfig
144
+ base_model_prefix = "model"
145
+ supports_gradient_checkpointing = True
146
+ _no_split_modules = ["LlavaVisionAttention"]
147
+ _skip_keys_device_placement = "past_key_values"
148
+ _supports_flash_attn_2 = True
149
+
150
+ def _init_weights(self, module):
151
+ # important: this ported version of Llava isn't meant for training from scratch - only
152
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
153
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
154
+ std = (
155
+ self.config.initializer_range
156
+ if hasattr(self.config, "initializer_range")
157
+ else self.config.text_config.initializer_range
158
+ )
159
+
160
+ if hasattr(module, "class_embedding"):
161
+ module.class_embedding.data.normal_(mean=0.0, std=std)
162
+
163
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
164
+ module.weight.data.normal_(mean=0.0, std=std)
165
+ if module.bias is not None:
166
+ module.bias.data.zero_()
167
+ elif isinstance(module, nn.Embedding):
168
+ module.weight.data.normal_(mean=0.0, std=std)
169
+ if module.padding_idx is not None:
170
+ module.weight.data[module.padding_idx].zero_()
171
+
172
+ @property
173
+ def _supports_sdpa(self):
174
+ """
175
+ Retrieve language_model's attribute to check whether the model supports
176
+ SDPA or not.
177
+ """
178
+ return self.language_model._supports_sdpa
179
+
180
+
181
+ LLAVA_INPUTS_DOCSTRING = r"""
182
+ Args:
183
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
184
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
185
+ it.
186
+
187
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
188
+ [`PreTrainedTokenizer.__call__`] for details.
189
+
190
+ [What are input IDs?](../glossary#input-ids)
191
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
192
+ The tensors corresponding to the input images. Pixel values can be obtained using
193
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
194
+ [`CLIPImageProcessor`] for processing images).
195
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
196
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
197
+
198
+ - 1 for tokens that are **not masked**,
199
+ - 0 for tokens that are **masked**.
200
+
201
+ [What are attention masks?](../glossary#attention-mask)
202
+
203
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
204
+ [`PreTrainedTokenizer.__call__`] for details.
205
+
206
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
207
+ `past_key_values`).
208
+
209
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
210
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
211
+ information on the default strategy.
212
+
213
+ - 1 indicates the head is **not masked**,
214
+ - 0 indicates the head is **masked**.
215
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
216
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
217
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
218
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
219
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
220
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
221
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
222
+
223
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
224
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
225
+
226
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
227
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
228
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
229
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
230
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
231
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
232
+ model's internal embedding lookup matrix.
233
+ use_cache (`bool`, *optional*):
234
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
235
+ `past_key_values`).
236
+ output_attentions (`bool`, *optional*):
237
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
238
+ tensors for more detail.
239
+ output_hidden_states (`bool`, *optional*):
240
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
241
+ more detail.
242
+ return_dict (`bool`, *optional*):
243
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
244
+ """
245
+
246
+
247
+ @add_start_docstrings(
248
+ """The LLAVA model which consists of a vision backbone and a language model.""",
249
+ LLAVA_START_DOCSTRING,
250
+ )
251
+ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
252
+ def __init__(self, config: LlavaConfig, vision_tower=None, language_model=None):
253
+ super().__init__(config)
254
+ self.vision_tower = AutoModel.from_config(config.vision_config) if vision_tower is None else vision_tower
255
+
256
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
257
+ self.vocab_size = config.vocab_size
258
+ self.language_model = AutoModelForCausalLM.from_config(
259
+ config.text_config, attn_implementation=config._attn_implementation
260
+ ) if language_model is None else language_model
261
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
262
+ self.post_init()
263
+
264
+ def get_input_embeddings(self):
265
+ return self.language_model.get_input_embeddings()
266
+
267
+ def set_input_embeddings(self, value):
268
+ self.language_model.set_input_embeddings(value)
269
+
270
+ def get_output_embeddings(self):
271
+ return self.language_model.get_output_embeddings()
272
+
273
+ def set_output_embeddings(self, new_embeddings):
274
+ self.language_model.set_output_embeddings(new_embeddings)
275
+
276
+ def set_decoder(self, decoder):
277
+ self.language_model.set_decoder(decoder)
278
+
279
+ def get_decoder(self):
280
+ return self.language_model.get_decoder()
281
+
282
+ def tie_weights(self):
283
+ return self.language_model.tie_weights()
284
+
285
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
286
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
287
+ # update vocab size
288
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
289
+ self.config.vocab_size = model_embeds.num_embeddings
290
+ self.vocab_size = model_embeds.num_embeddings
291
+ return model_embeds
292
+
293
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
294
+ num_images, num_image_patches, embed_dim = image_features.shape
295
+ batch_size, sequence_length = input_ids.shape
296
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
297
+ # 1. Create a mask to know where special image tokens are
298
+ special_image_token_mask = input_ids == self.config.image_token_index
299
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
300
+ # Compute the maximum embed dimension
301
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
302
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
303
+
304
+ # 2. Compute the positions where text should be written
305
+ # Calculate new positions for text tokens in merged image-text sequence.
306
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
307
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
308
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
309
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
310
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
311
+ if left_padding:
312
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
313
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
314
+
315
+ # 3. Create the full embedding, already padded to the maximum position
316
+ final_embedding = torch.zeros(
317
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
318
+ )
319
+ final_attention_mask = torch.zeros(
320
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
321
+ )
322
+ if labels is not None:
323
+ final_labels = torch.full(
324
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
325
+ )
326
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
327
+ # set the corresponding tensors into their correct target device.
328
+ target_device = inputs_embeds.device
329
+ batch_indices, non_image_indices, text_to_overwrite = (
330
+ batch_indices.to(target_device),
331
+ non_image_indices.to(target_device),
332
+ text_to_overwrite.to(target_device),
333
+ )
334
+ attention_mask = attention_mask.to(target_device)
335
+
336
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
337
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
338
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
339
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
340
+ if labels is not None:
341
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
342
+
343
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
344
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
345
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
346
+
347
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
348
+ raise ValueError(
349
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
350
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
351
+ )
352
+
353
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
354
+ final_attention_mask |= image_to_overwrite
355
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
356
+
357
+ if labels is None:
358
+ final_labels = None
359
+
360
+ return final_embedding, final_attention_mask, final_labels, position_ids
361
+
362
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
363
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
364
+ def forward(
365
+ self,
366
+ input_ids: torch.LongTensor = None,
367
+ pixel_values: torch.FloatTensor = None,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ position_ids: Optional[torch.LongTensor] = None,
370
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
371
+ inputs_embeds: Optional[torch.FloatTensor] = None,
372
+ vision_feature_layer: Optional[int] = None,
373
+ vision_feature_select_strategy: Optional[str] = None,
374
+ labels: Optional[torch.LongTensor] = None,
375
+ use_cache: Optional[bool] = None,
376
+ output_attentions: Optional[bool] = None,
377
+ output_hidden_states: Optional[bool] = None,
378
+ return_dict: Optional[bool] = None,
379
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
380
+ r"""
381
+ Args:
382
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
383
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
384
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
385
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
386
+
387
+ Returns:
388
+
389
+ Example:
390
+
391
+ ```python
392
+ >>> from PIL import Image
393
+ >>> import requests
394
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
395
+
396
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
397
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
398
+
399
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
400
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
401
+ >>> image = Image.open(requests.get(url, stream=True).raw)
402
+
403
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
404
+
405
+ >>> # Generate
406
+ >>> generate_ids = model.generate(**inputs, max_length=30)
407
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
408
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
409
+ ```"""
410
+
411
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
412
+ output_hidden_states = (
413
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
414
+ )
415
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
416
+ vision_feature_layer = (
417
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
418
+ )
419
+ vision_feature_select_strategy = (
420
+ vision_feature_select_strategy
421
+ if vision_feature_select_strategy is not None
422
+ else self.config.vision_feature_select_strategy
423
+ )
424
+
425
+ if inputs_embeds is None:
426
+ # 1. Extra the input embeddings
427
+ inputs_embeds = self.get_input_embeddings()(input_ids)
428
+
429
+ # 2. Merge text and images
430
+ if pixel_values is not None and input_ids.shape[1] != 1:
431
+ if isinstance(pixel_values, list):
432
+ pixel_values = torch.cat([x for x in pixel_values if x is not None], dim=0)
433
+ # for siglip, need to transform the pixel_values to the right data type
434
+ if pixel_values.dtype != self.vision_tower.dtype:
435
+ pixel_values = pixel_values.type(self.vision_tower.dtype)
436
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
437
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
438
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
439
+
440
+ if vision_feature_select_strategy == "default":
441
+ selected_image_feature = selected_image_feature[:, 1:]
442
+ elif vision_feature_select_strategy == "full":
443
+ selected_image_feature = selected_image_feature
444
+ else:
445
+ raise ValueError(
446
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
447
+ )
448
+
449
+ image_features = self.multi_modal_projector(selected_image_feature)
450
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
451
+ image_features, inputs_embeds, input_ids, attention_mask, labels
452
+ )
453
+ if labels is None:
454
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
455
+ else:
456
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
457
+ # generation with cache
458
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
459
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
460
+ # that are set to 0
461
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
462
+
463
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
464
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
465
+
466
+ # Get the target length
467
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
468
+
469
+ extended_attention_mask = torch.ones(
470
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
471
+ dtype=attention_mask.dtype,
472
+ device=attention_mask.device,
473
+ )
474
+
475
+ # Filter out only the tokens that can be un-attended, this can happen
476
+ # if one uses Llava + Fused modules where the cache on the
477
+ # first iteration is already big enough, or if one passes custom cache
478
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
479
+ new_batch_index = batch_index[valid_indices]
480
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
481
+
482
+ # Zero-out the places where we don't need to attend
483
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
484
+
485
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
486
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
487
+
488
+ outputs = self.language_model(
489
+ attention_mask=attention_mask,
490
+ position_ids=position_ids,
491
+ past_key_values=past_key_values,
492
+ inputs_embeds=inputs_embeds,
493
+ use_cache=use_cache,
494
+ output_attentions=output_attentions,
495
+ output_hidden_states=output_hidden_states,
496
+ return_dict=return_dict,
497
+ )
498
+
499
+ logits = outputs[0]
500
+
501
+ loss = None
502
+ if labels is not None:
503
+ # Shift so that tokens < n predict n
504
+ if attention_mask is not None:
505
+ shift_attention_mask = attention_mask[..., 1:]
506
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
507
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
508
+ else:
509
+ shift_logits = logits[..., :-1, :].contiguous()
510
+ shift_labels = labels[..., 1:].contiguous()
511
+ # Flatten the tokens
512
+ loss_fct = nn.CrossEntropyLoss()
513
+ loss = loss_fct(
514
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
515
+ )
516
+
517
+ if not return_dict:
518
+ output = (logits,) + outputs[1:]
519
+ return (loss,) + output if loss is not None else output
520
+
521
+ return LlavaCausalLMOutputWithPast(
522
+ loss=loss,
523
+ logits=logits,
524
+ past_key_values=outputs.past_key_values,
525
+ hidden_states=outputs.hidden_states,
526
+ attentions=outputs.attentions,
527
+ )
528
+
529
+ def prepare_inputs_for_generation(
530
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
531
+ ):
532
+ if past_key_values is not None:
533
+ if isinstance(past_key_values, Cache):
534
+ cache_length = past_key_values.get_seq_length()
535
+ past_length = past_key_values.seen_tokens
536
+ else:
537
+ cache_length = past_length = past_key_values[0][0].shape[2]
538
+
539
+ # Keep only the unprocessed tokens:
540
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
541
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
542
+ # input)
543
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
544
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
545
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
546
+ # input_ids based on the past_length.
547
+ elif past_length < input_ids.shape[1]:
548
+ input_ids = input_ids[:, past_length:]
549
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
550
+ elif self.config.image_token_index in input_ids:
551
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
552
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
553
+ # older attention values, as their corresponding values are not part of the input.
554
+ if cache_length < past_length and attention_mask is not None:
555
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
556
+
557
+ position_ids = kwargs.get("position_ids", None)
558
+ if attention_mask is not None and position_ids is None:
559
+ # create position_ids on the fly for batch generation
560
+ position_ids = attention_mask.long().cumsum(-1) - 1
561
+ position_ids.masked_fill_(attention_mask == 0, 1)
562
+ if past_key_values:
563
+ position_ids = position_ids[:, -input_ids.shape[1] :]
564
+
565
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
566
+ if inputs_embeds is not None and past_key_values is None:
567
+ model_inputs = {"inputs_embeds": inputs_embeds}
568
+ else:
569
+ model_inputs = {"input_ids": input_ids}
570
+
571
+ model_inputs.update(
572
+ {
573
+ "position_ids": position_ids,
574
+ "past_key_values": past_key_values,
575
+ "use_cache": kwargs.get("use_cache"),
576
+ "attention_mask": attention_mask,
577
+ "pixel_values": pixel_values,
578
+ }
579
+ )
580
+ return model_inputs
581
+
582
+ def _reorder_cache(self, *args, **kwargs):
583
+ return self.language_model._reorder_cache(*args, **kwargs)
584
+
585
+
586
+
587
+
588
+ from transformers.models.clip.modeling_clip import CLIPEncoderLayer, CLIPEncoder
589
+ @add_start_docstrings(
590
+ """The MLLAVA model which consists of a vision backbone and a language model.""",
591
+ LLAVA_START_DOCSTRING,
592
+ )
593
+ class MLlavaForConditionalGeneration(LlavaForConditionalGeneration):
594
+ def __init__(self, config: LlavaConfig):
595
+ super().__init__(config)
596
+ config.vision_config.type_vocab_size = 144
597
+ self.image_type_embeddings = nn.Embedding(config.vision_config.type_vocab_size, config.vision_config.hidden_size)
598
+ # self.vision_xatten_layers = nn.ModuleList([CLIPEncoderLayer(config.vision_config) for _ in range(config.vision_config.num_hidden_layers)])
599
+ self.vision_xatten_layers = CLIPEncoder(config.vision_config)
600
+
601
+
602
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
603
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
604
+ def forward(
605
+ self,
606
+ input_ids: torch.LongTensor = None,
607
+ pixel_values: torch.FloatTensor = None,
608
+ attention_mask: Optional[torch.Tensor] = None,
609
+ position_ids: Optional[torch.LongTensor] = None,
610
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
611
+ inputs_embeds: Optional[torch.FloatTensor] = None,
612
+ vision_feature_layer: Optional[int] = None,
613
+ vision_feature_select_strategy: Optional[str] = None,
614
+ labels: Optional[torch.LongTensor] = None,
615
+ use_cache: Optional[bool] = None,
616
+ output_attentions: Optional[bool] = None,
617
+ output_hidden_states: Optional[bool] = None,
618
+ return_dict: Optional[bool] = None,
619
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
620
+ r"""
621
+ Args:
622
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
623
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
624
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
625
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
626
+
627
+ Returns:
628
+
629
+ Example:
630
+
631
+ ```python
632
+ >>> from PIL import Image
633
+ >>> import requests
634
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
635
+
636
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
637
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
638
+
639
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
640
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
641
+ >>> image = Image.open(requests.get(url, stream=True).raw)
642
+
643
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
644
+
645
+ >>> # Generate
646
+ >>> generate_ids = model.generate(**inputs, max_length=30)
647
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
648
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
649
+ ```"""
650
+
651
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
652
+ output_hidden_states = (
653
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
654
+ )
655
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
656
+ vision_feature_layer = (
657
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
658
+ )
659
+ vision_feature_select_strategy = (
660
+ vision_feature_select_strategy
661
+ if vision_feature_select_strategy is not None
662
+ else self.config.vision_feature_select_strategy
663
+ )
664
+
665
+ if inputs_embeds is None:
666
+ # 1. Extra the input embeddings
667
+ inputs_embeds = self.get_input_embeddings()(input_ids)
668
+
669
+ # 2. Merge text and images
670
+ if pixel_values is not None and input_ids.shape[1] != 1:
671
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
672
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
673
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
674
+
675
+ if vision_feature_select_strategy == "default":
676
+ selected_image_feature = selected_image_feature[:, 1:]
677
+ elif vision_feature_select_strategy == "full":
678
+ selected_image_feature = selected_image_feature
679
+ else:
680
+ raise ValueError(
681
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
682
+ )
683
+
684
+ # added by Dongfu
685
+ num_images, num_image_patches, embed_dim = selected_image_feature.shape
686
+ image_type_embeddings = self.image_type_embeddings(torch.arange(num_images, device=selected_image_feature.device))
687
+ selected_image_feature += image_type_embeddings.unsqueeze(1)
688
+ xatten_output = self.vision_xatten_layers(selected_image_feature, attention_mask=None, causal_attention_mask=None)
689
+ selected_image_feature = xatten_output[0]
690
+ # end of added by Dongfu
691
+
692
+ image_features = self.multi_modal_projector(selected_image_feature)
693
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
694
+ image_features, inputs_embeds, input_ids, attention_mask, labels
695
+ )
696
+ if labels is None:
697
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
698
+ else:
699
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
700
+ # generation with cache
701
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
702
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
703
+ # that are set to 0
704
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
705
+
706
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
707
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
708
+
709
+ # Get the target length
710
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
711
+
712
+ extended_attention_mask = torch.ones(
713
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
714
+ dtype=attention_mask.dtype,
715
+ device=attention_mask.device,
716
+ )
717
+
718
+ # Filter out only the tokens that can be un-attended, this can happen
719
+ # if one uses Llava + Fused modules where the cache on the
720
+ # first iteration is already big enough, or if one passes custom cache
721
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
722
+ new_batch_index = batch_index[valid_indices]
723
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
724
+
725
+ # Zero-out the places where we don't need to attend
726
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
727
+
728
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
729
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
730
+
731
+ outputs = self.language_model(
732
+ attention_mask=attention_mask,
733
+ position_ids=position_ids,
734
+ past_key_values=past_key_values,
735
+ inputs_embeds=inputs_embeds,
736
+ use_cache=use_cache,
737
+ output_attentions=output_attentions,
738
+ output_hidden_states=output_hidden_states,
739
+ return_dict=return_dict,
740
+ )
741
+
742
+ logits = outputs[0]
743
+
744
+ loss = None
745
+ if labels is not None:
746
+ # Shift so that tokens < n predict n
747
+ if attention_mask is not None:
748
+ shift_attention_mask = attention_mask[..., 1:]
749
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
750
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
751
+ else:
752
+ shift_logits = logits[..., :-1, :].contiguous()
753
+ shift_labels = labels[..., 1:].contiguous()
754
+ # Flatten the tokens
755
+ loss_fct = nn.CrossEntropyLoss()
756
+ loss = loss_fct(
757
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
758
+ )
759
+
760
+ if not return_dict:
761
+ output = (logits,) + outputs[1:]
762
+ return (loss,) + output if loss is not None else output
763
+
764
+ return LlavaCausalLMOutputWithPast(
765
+ loss=loss,
766
+ logits=logits,
767
+ past_key_values=outputs.past_key_values,
768
+ hidden_states=outputs.hidden_states,
769
+ attentions=outputs.attentions,
770
+ )