|
import logging
|
|
import warnings
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, Dict, Sequence, Union, List, Tuple, Any
|
|
|
|
from transformers import (
|
|
LlamaForCausalLM,
|
|
Blip2PreTrainedModel,
|
|
Blip2VisionModel,
|
|
Blip2Config,
|
|
Blip2QFormerModel,
|
|
GenerationConfig,
|
|
)
|
|
|
|
from transformers.utils import ModelOutput
|
|
|
|
warnings.filterwarnings('ignore')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@dataclass
|
|
class Blip2ForConditionalGenerationModelOutput(ModelOutput):
|
|
"""
|
|
Class defining the outputs of [`Blip2ForConditionalGeneration`].
|
|
|
|
Args:
|
|
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
|
Language modeling loss from the language model.
|
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
|
Prediction scores of the language modeling head of the language model.
|
|
vision_outputs (`BaseModelOutputWithPooling`):
|
|
Outputs of the vision encoder.
|
|
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
|
Outputs of the Q-Former (Querying Transformer).
|
|
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
|
Outputs of the language model.
|
|
"""
|
|
|
|
loss: Optional[Tuple[torch.FloatTensor]] = None
|
|
logits: Optional[Tuple[torch.FloatTensor]] = None
|
|
vision_outputs: Optional[torch.FloatTensor] = None
|
|
qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
|
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
|
|
|
def to_tuple(self) -> Tuple[Any]:
|
|
return tuple(
|
|
self[k]
|
|
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
|
else getattr(self, k).to_tuple()
|
|
for k in self.keys()
|
|
)
|
|
|
|
class Blip2LlaMAForConditionalGeneration(Blip2PreTrainedModel):
|
|
config_class = Blip2Config
|
|
main_input_name = "pixel_values"
|
|
|
|
def __init__(self, config: Blip2Config):
|
|
super().__init__(config)
|
|
|
|
self.vision_model = Blip2VisionModel(config.vision_config)
|
|
|
|
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
|
self.qformer = Blip2QFormerModel(config.qformer_config)
|
|
|
|
language_model = LlamaForCausalLM(config.text_config)
|
|
self.language_model = language_model
|
|
|
|
self.language_projection = nn.Linear(config.qformer_config.hidden_size, language_model.config.hidden_size)
|
|
|
|
self.config.hidden_size = config.text_config.hidden_size
|
|
self.num_queries = config.num_query_tokens
|
|
self.offset = 5
|
|
|
|
|
|
self.post_init()
|
|
|
|
def get_input_embeddings(self):
|
|
return self.language_model.get_input_embeddings()
|
|
|
|
def set_input_embeddings(self, value):
|
|
self.language_model.set_input_embeddings(value)
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
self.language_model.set_output_embeddings(new_embeddings)
|
|
|
|
def get_output_embeddings(self) -> nn.Module:
|
|
return self.language_model.get_output_embeddings()
|
|
|
|
def get_encoder(self):
|
|
return self.language_model.get_encoder()
|
|
|
|
def get_decoder(self):
|
|
return self.language_model.get_decoder()
|
|
|
|
def extract_feature(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
):
|
|
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
|
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_outputs = self.qformer(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_attention_mask,
|
|
return_dict=True,
|
|
)
|
|
query_output = query_outputs.last_hidden_state
|
|
|
|
language_model_inputs = self.language_projection(query_output)
|
|
return language_model_inputs
|
|
|
|
def _tie_weights(self):
|
|
if not self.config.use_decoder_only_language_model:
|
|
self.language_model.encoder.embed_tokens = self.language_model.shared
|
|
self.language_model.decoder.embed_tokens = self.language_model.shared
|
|
|
|
def _preprocess_accelerate(self):
|
|
r"""
|
|
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
|
https://github.com/huggingface/transformers/pull/21707 for more details.
|
|
"""
|
|
hf_device_map = self.hf_device_map
|
|
|
|
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
|
|
|
logger.warning(
|
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
|
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
|
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for",
|
|
" more details on creating a `device_map` for large models.",
|
|
)
|
|
|
|
if hasattr(self.language_model, "_hf_hook"):
|
|
self.language_model._hf_hook.io_same_device = True
|
|
|
|
def forward(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
input_ids: torch.FloatTensor,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
return_dict: Optional[bool] = None,
|
|
|
|
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
vision_outputs = self.vision_model(
|
|
pixel_values=pixel_values,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
image_embeds = vision_outputs[0]
|
|
|
|
|
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_outputs = self.qformer(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
query_output = query_outputs[0]
|
|
|
|
|
|
language_model_inputs = self.language_projection(query_output)
|
|
assert language_model_inputs.shape[1] == self.num_queries
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
|
|
|
|
|
|
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
outputs = self.language_model(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
logits = outputs.logits if return_dict else outputs[0]
|
|
loss = None
|
|
|
|
|
|
if labels is not None:
|
|
logits = logits[:, -labels.size(1):, :]
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous().to(logits.device).to(torch.long)
|
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction="mean")
|
|
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
|
|
|
if not return_dict:
|
|
output = (logits, vision_outputs, query_outputs, outputs)
|
|
return ((loss,) + output) if loss is not None else output
|
|
|
|
return Blip2ForConditionalGenerationModelOutput(
|
|
loss=loss,
|
|
logits=logits,
|
|
vision_outputs=vision_outputs,
|
|
qformer_outputs=query_outputs,
|
|
language_model_outputs=outputs,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
pixel_values: torch.FloatTensor,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None,
|
|
language_model_inputs: Optional[torch.FloatTensor] = None,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
**generate_kwargs,
|
|
) -> torch.LongTensor:
|
|
"""
|
|
Overrides `generate` function to be able to use the model as a conditional generator.
|
|
|
|
Args:
|
|
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
|
Input images to be processed.
|
|
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
The sequence used as a prompt for the generation.
|
|
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
|
Mask to avoid performing attention on padding token indices
|
|
generation_config (`~generation.GenerationConfig`, *optional*):
|
|
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
|
passed to generate matching the attributes of `generation_config` will override them. If
|
|
`generation_config` is not provided, the default will be used, which had the following loading
|
|
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
|
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
|
default values, whose documentation should be checked to parameterize generation.
|
|
|
|
Returns:
|
|
captions (list): A list of strings of length batch_size * num_captions.
|
|
"""
|
|
if hasattr(self, "hf_device_map"):
|
|
|
|
self._preprocess_accelerate()
|
|
if language_model_inputs is None:
|
|
batch_size = pixel_values.shape[0]
|
|
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
|
|
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_outputs = self.qformer(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_attention_mask,
|
|
return_dict=True,
|
|
)
|
|
query_output = query_outputs.last_hidden_state
|
|
|
|
language_model_inputs = self.language_projection(query_output)
|
|
assert language_model_inputs.shape[1] == self.num_queries
|
|
|
|
if input_ids is None:
|
|
input_ids = (
|
|
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
|
.repeat(batch_size, 1)
|
|
.to(image_embeds.device)
|
|
)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
|
|
|
|
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
|
|
|
outputs = self.language_model.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=attention_mask,
|
|
generation_config=generation_config,
|
|
**generate_kwargs,
|
|
)
|
|
|
|
return outputs
|
|
|