# coding=utf-8 from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import ModelOutput from modeling_phi import PhiForCausalLM from configuration_llava import LlavaConfig from open_clip import create_model @dataclass class LlavaCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None image_features: Optional[torch.FloatTensor] = None class LlavaMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): super().__init__() self.linear_1 = nn.Linear( config.vision_embed_dim, config.text_config.n_embd * config.projector_tokens_num, bias=True, ) self.act = nn.GELU() self.linear_2 = nn.Linear( config.text_config.n_embd * config.projector_tokens_num, config.text_config.n_embd * config.projector_tokens_num, bias=True, ) self.projector_tokens_num = config.projector_tokens_num def forward(self, image_features): hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = hidden_states.reshape( hidden_states.shape[0], self.projector_tokens_num, int(hidden_states.shape[1] / self.projector_tokens_num), ) return hidden_states class LlavaPreTrainedModel(PreTrainedModel): config_class = LlavaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True def __init__(self, config): super().__init__(config) def _init_weights(self, module): return @property def _supports_sdpa(self): """ Retrieve language_model's attribute to check whether the model supports SDPA or not. """ return self.language_model._supports_sdpa class LlavaForConditionalGeneration(LlavaPreTrainedModel): def __init__(self, config: LlavaConfig): super().__init__(config) clip_model = create_model(config.vision_tower_name) self.vision_model = clip_model.visual self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.vocab_size self.language_model = PhiForCausalLM(config.text_config) self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 ) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.transformer = decoder def get_decoder(self): return self.language_model.transformer def tie_weights(self): return self.language_model.tie_weights() def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None ) -> nn.Embedding: model_embeds = self.language_model.resize_token_embeddings( new_num_tokens, pad_to_multiple_of ) # update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings self.config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings return model_embeds def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids, attention_mask, position_ids ): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum( input_ids[:, -1] == torch.tensor(self.pad_token_id) ) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = ( num_special_image_tokens.max() * (num_image_patches - 1) ) + sequence_length batch_indices, non_image_indices = torch.where( input_ids != self.config.image_token_index ) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = ( torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 ) nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_image_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device, ) final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device, ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_image_indices, text_to_overwrite = ( batch_indices.to(target_device), non_image_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ batch_indices, non_image_indices ] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ batch_indices, non_image_indices ] # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[ :, None ].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) final_embedding[image_to_overwrite] = ( image_features.contiguous().reshape(-1, embed_dim).to(target_device) ) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_( (final_attention_mask == 0), 1 ) return final_embedding, final_attention_mask, position_ids def forward( self, input_ids: torch.LongTensor = None, image_features: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if image_features is not None and input_ids.shape[1] != 1: ( inputs_embeds, attention_mask, position_ids, ) = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, position_ids, ) outputs = self.language_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs[0] if not return_dict: output = (logits,) + outputs[1:] return output return LlavaCausalLMOutputWithPast( logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_features=image_features, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, image_features=None, **kwargs, ): res = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, attention_mask, **kwargs) input_ids = res["input_ids"] past_key_values = res["past_key_values"] attention_mask = res["attention_mask"] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "image_features": image_features, } ) return model_inputs def _reorder_cache(self, *args, **kwargs): return self.language_model._reorder_cache(*args, **kwargs)