import logging import warnings import torch import torch.nn as nn from dataclasses import dataclass, field from typing import Optional, Dict, Sequence, Union, List, Tuple, Any from transformers import ( LlamaForCausalLM, Blip2PreTrainedModel, Blip2VisionModel, Blip2Config, Blip2QFormerModel, GenerationConfig, ) from transformers.utils import ModelOutput warnings.filterwarnings('ignore') logger = logging.getLogger(__name__) @dataclass class Blip2ForConditionalGenerationModelOutput(ModelOutput): """ Class defining the outputs of [`Blip2ForConditionalGeneration`]. Args: loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): Language modeling loss from the language model. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head of the language model. vision_outputs (`BaseModelOutputWithPooling`): Outputs of the vision encoder. qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): Outputs of the Q-Former (Querying Transformer). language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): Outputs of the language model. """ loss: Optional[Tuple[torch.FloatTensor]] = None logits: Optional[Tuple[torch.FloatTensor]] = None vision_outputs: Optional[torch.FloatTensor] = None qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None def to_tuple(self) -> Tuple[Any]: return tuple( self[k] if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] else getattr(self, k).to_tuple() for k in self.keys() ) class Blip2LlaMAForConditionalGeneration(Blip2PreTrainedModel): config_class = Blip2Config main_input_name = "pixel_values" def __init__(self, config: Blip2Config): super().__init__(config) self.vision_model = Blip2VisionModel(config.vision_config) self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) self.qformer = Blip2QFormerModel(config.qformer_config) language_model = LlamaForCausalLM(config.text_config) self.language_model = language_model self.language_projection = nn.Linear(config.qformer_config.hidden_size, language_model.config.hidden_size) self.config.hidden_size = config.text_config.hidden_size self.num_queries = config.num_query_tokens self.offset = 5 # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def get_output_embeddings(self) -> nn.Module: return self.language_model.get_output_embeddings() def get_encoder(self): return self.language_model.get_encoder() def get_decoder(self): return self.language_model.get_decoder() def extract_feature( self, pixel_values: torch.FloatTensor, ): image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs.last_hidden_state language_model_inputs = self.language_projection(query_output) return language_model_inputs def _tie_weights(self): if not self.config.use_decoder_only_language_model: self.language_model.encoder.embed_tokens = self.language_model.shared self.language_model.decoder.embed_tokens = self.language_model.shared def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check https://github.com/huggingface/transformers/pull/21707 for more details. """ hf_device_map = self.hf_device_map if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. logger.warning( "The `language_model` is not in the `hf_device_map` dictionary and you are running your script" " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." " Please pass a `device_map` that contains `language_model` to remove this warning." " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", " more details on creating a `device_map` for large models.", ) if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # step 1: forward the images through the vision encoder, # to get image embeddings of shape (batch_size, seq_len, hidden_size) vision_outputs = self.vision_model( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) image_embeds = vision_outputs[0] # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) query_output = query_outputs[0] # step 3: use the language model, conditioned on the query outputs and the prompt language_model_inputs = self.language_projection(query_output) assert language_model_inputs.shape[1] == self.num_queries inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # Human: . Give the describe Assistant: # position of : [offset: offset+num_queries] inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs if attention_mask is None: attention_mask = torch.ones_like(input_ids) outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits if return_dict else outputs[0] loss = None # we compute the loss here since we need to take into account the sequence length of the query embeds if labels is not None: logits = logits[:, -labels.size(1):, :] # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous().to(logits.device).to(torch.long) # Flatten the tokens loss_fct = nn.CrossEntropyLoss(reduction="mean") loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits, vision_outputs, query_outputs, outputs) return ((loss,) + output) if loss is not None else output return Blip2ForConditionalGenerationModelOutput( loss=loss, logits=logits, vision_outputs=vision_outputs, qformer_outputs=query_outputs, language_model_outputs=outputs, ) @torch.no_grad() def generate( self, pixel_values: torch.FloatTensor, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, language_model_inputs: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, **generate_kwargs, ) -> torch.LongTensor: """ Overrides `generate` function to be able to use the model as a conditional generator. Args: pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): Input images to be processed. input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): The sequence used as a prompt for the generation. attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): Mask to avoid performing attention on padding token indices generation_config (`~generation.GenerationConfig`, *optional*): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. Returns: captions (list): A list of strings of length batch_size * num_captions. """ if hasattr(self, "hf_device_map"): # preprocess for `accelerate` self._preprocess_accelerate() if language_model_inputs is None: batch_size = pixel_values.shape[0] image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_outputs = self.qformer( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_attention_mask, return_dict=True, ) query_output = query_outputs.last_hidden_state language_model_inputs = self.language_projection(query_output) assert language_model_inputs.shape[1] == self.num_queries if input_ids is None: input_ids = ( torch.LongTensor([[self.config.text_config.bos_token_id]]) .repeat(batch_size, 1) .to(image_embeds.device) ) if attention_mask is None: attention_mask = torch.ones_like(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # position of : [offset: offset+num_queries] inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs outputs = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, generation_config=generation_config, **generate_kwargs, ) return outputs