|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
replace_return_docstrings, |
|
) |
|
from transformers.models.llava.modeling_llava import (_CONFIG_FOR_DOC, |
|
LLAVA_START_DOCSTRING, LLAVA_INPUTS_DOCSTRING, |
|
LlavaForConditionalGeneration) |
|
|
|
|
|
@dataclass |
|
|
|
class LlavaCausalLMOutputWithPast(ModelOutput): |
|
""" |
|
Base class for Llava causal language model (or autoregressive) outputs. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): |
|
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, |
|
sequence_length, hidden_size)`. |
|
|
|
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[List[torch.FloatTensor]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
image_to_overwrite: Optional[Tuple[torch.BoolTensor]] = None |
|
mask_ids: Optional[Tuple[torch.LongTensor]] = None |
|
labels: Optional[Tuple[torch.LongTensor]] = None |
|
|
|
|
|
@add_start_docstrings( |
|
"""The LLAVA model which consists of a vision backbone and a language model.""", |
|
LLAVA_START_DOCSTRING, |
|
) |
|
class CustomLlavaForConditionalGeneration(LlavaForConditionalGeneration): |
|
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels, |
|
mask_ids=None): |
|
num_images, num_image_patches, embed_dim = image_features.shape |
|
batch_size, sequence_length = input_ids.shape |
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) |
|
|
|
special_image_token_mask = input_ids == self.config.image_token_index |
|
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
|
|
|
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length |
|
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) |
|
|
|
|
|
|
|
|
|
|
|
|
|
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 |
|
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
|
if left_padding: |
|
new_token_positions += nb_image_pad[:, None] |
|
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
|
|
|
|
|
final_embedding = torch.zeros( |
|
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
|
) |
|
final_attention_mask = torch.zeros( |
|
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
|
) |
|
if labels is not None: |
|
final_labels = torch.full( |
|
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
|
) |
|
|
|
if mask_ids is not None: |
|
final_mask_ids = torch.full( |
|
(batch_size, max_embed_dim), -1, dtype=input_ids.dtype, device=input_ids.device |
|
) |
|
|
|
|
|
|
|
target_device = inputs_embeds.device |
|
batch_indices, non_image_indices, text_to_overwrite = ( |
|
batch_indices.to(target_device), |
|
non_image_indices.to(target_device), |
|
text_to_overwrite.to(target_device), |
|
) |
|
attention_mask = attention_mask.to(target_device) |
|
|
|
|
|
|
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
|
if labels is not None: |
|
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
|
if mask_ids is not None: |
|
final_mask_ids[batch_indices, text_to_overwrite] = mask_ids[batch_indices, non_image_indices] |
|
|
|
|
|
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) |
|
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
|
|
|
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): |
|
raise ValueError( |
|
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" |
|
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
|
) |
|
|
|
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
|
final_attention_mask |= image_to_overwrite |
|
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
|
|
|
|
|
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) |
|
indices_to_mask = new_token_positions[batch_indices, pad_indices] |
|
|
|
final_embedding[batch_indices, indices_to_mask] = 0 |
|
|
|
if labels is None: |
|
final_labels = None |
|
if mask_ids is None: |
|
final_mask_ids = None |
|
|
|
return final_embedding, final_attention_mask, final_labels, position_ids, final_mask_ids, image_to_overwrite |
|
|
|
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) |
|
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
vision_feature_layer: Optional[int] = None, |
|
vision_feature_select_strategy: Optional[str] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
mask_ids: Optional[torch.LongTensor] = None, |
|
image_to_overwrite: Optional[torch.BoolTensor] = None, |
|
) -> Union[Tuple, LlavaCausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
|
|
|
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
|
|
|
>>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:" |
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(**inputs, max_length=30) |
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner" |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
vision_feature_layer = ( |
|
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
|
) |
|
vision_feature_select_strategy = ( |
|
vision_feature_select_strategy |
|
if vision_feature_select_strategy is not None |
|
else self.config.vision_feature_select_strategy |
|
) |
|
|
|
if inputs_embeds is None: |
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
if pixel_values is not None and input_ids.shape[1] != 1: |
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) |
|
|
|
selected_image_feature = image_outputs.hidden_states[vision_feature_layer] |
|
|
|
if vision_feature_select_strategy == "default": |
|
selected_image_feature = selected_image_feature[:, 1:] |
|
elif vision_feature_select_strategy == "full": |
|
selected_image_feature = selected_image_feature |
|
else: |
|
raise ValueError( |
|
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" |
|
) |
|
|
|
image_features = self.multi_modal_projector(selected_image_feature) |
|
inputs_embeds, attention_mask, labels, position_ids, mask_ids, image_to_overwrite \ |
|
= self._merge_input_ids_with_image_features(image_features, |
|
inputs_embeds, input_ids, attention_mask, labels, |
|
mask_ids=mask_ids) |
|
if labels is None: |
|
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) |
|
|
|
|
|
|
|
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: |
|
|
|
|
|
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
|
|
|
|
|
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
|
|
|
|
|
target_length = input_ids.shape[1] |
|
past_length = first_layer_past_key_value.shape[-1] |
|
|
|
extended_attention_mask = torch.ones( |
|
(attention_mask.shape[0], past_length), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
|
|
|
|
|
|
|
|
valid_indices = non_attended_tokens < extended_attention_mask.size(-1) |
|
new_batch_index = batch_index[valid_indices] |
|
new_non_attended_tokens = non_attended_tokens[valid_indices] |
|
|
|
|
|
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 |
|
|
|
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) |
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
|
|
|
outputs = self.language_model( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
logits = outputs[0] |
|
|
|
loss = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert return_dict, "Use dict in our implementation" |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return LlavaCausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
image_to_overwrite=image_to_overwrite, |
|
mask_ids=mask_ids, |
|
labels=labels, |
|
) |