albef-vqa / model.py
ryanramos's picture
Add source code
d1b8c9b
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder
from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity
from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder
from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
from torchmultimodal.modules.losses.albef import (
CausalLanguageModelingLoss,
ImageTextContrastiveLoss,
)
from torchmultimodal.utils.attention import get_causal_attention_mask
from torchmultimodal.utils.common import momentum_update, remove_grad
_ALBEF_PRETRAINED_URLS = {
"vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt",
"retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt",
}
class PredictionHead(nn.Module):
"""
Predict the following token autoregressively.
Args:
vocab_size (int): The number of different tokens the prediction_head can predict.
hidden_size (int): The hidden size of the prediction_head.
layer_norm_eps (float): The epsilon used by the prediction_head normalization layer.
transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head.
Inputs:
hidden_states (Tensor): The hidden states of preceding tokens.
Returns:
Tensor: Prediction scores for the following token.
"""
def __init__(
self,
vocab_size: int = 30522,
hidden_size: int = 768,
layer_norm_eps: float = 1e-12,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
) -> None:
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.decoder = nn.Linear(hidden_size, vocab_size)
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class ALBEFDecoder(nn.Module):
"""
Generate the prediction scores for answers from image and question hidden states.
Args:
text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings.
multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder.
prediction_head (PredictionHead): Instantiated PredictionHead.
Inputs:
input_ids (Tensor of shape (batch_size, seq_len)):
Input ids for input text tokens.
attention_mask (Tensor of shape (batch_size, seq_len)):
Input attention mask to avoid performing attention on padding token indices.
encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)):
The encoder hidden states.
encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)):
The attention mask for encoder hidden states.
Returns:
Tensor: Prediction scores for answers.
"""
def __init__(
self,
text_embeddings: BERTTextEmbeddings,
multimodal_encoder: ALBEFMultimodalEncoder,
prediction_head: PredictionHead,
) -> None:
super().__init__()
self.text_embeddings = text_embeddings
self.multimodal_encoder = multimodal_encoder
self.prediction_head = prediction_head
def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor:
"""
Apply a causal mask in addition to the padding mask and make the mask broadcastable,
such that future and masked tokens are ignored.
Args:
attention_mask (Tensor):
Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore.
Returns:
extended_attention_mask (Tensor):
The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``.
"""
device = attention_mask.device
batch_size, seq_length = attention_mask.shape
causal_mask = get_causal_attention_mask(seq_length).to(device)
causal_mask = causal_mask.repeat(batch_size, 1).view(
batch_size, seq_length, seq_length
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype)
return extended_attention_mask
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
encoder_hidden_states: Tensor,
encoder_attention_mask: Tensor,
) -> Tensor:
hidden_states = self.text_embeddings(input_ids)
attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask)
decoder_output = self.multimodal_encoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
prediction_scores = self.prediction_head(decoder_output)
return prediction_scores
class ALBEFModelForVQA(nn.Module):
"""
ALBEF Model for VQA finetuning and inference.
Args:
model (ALBEFModel): Instantiated ALBEFModel.
answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder.
loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answers_atts (Tensor of shape (N, M)): Answer attention mask.
ans_weights (Optional[Tensor] of shape (N)): Weights for each answer.
Required if is_train is True.
ans_lengths (Optional[List[int]] of length B): Number of answers for each question.
ans_lengths should sum to N.
Required if is_train is True.
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
Required if is_train is True.
k (Optional[int]): The number of answers to return for inference.
Required if is_train is False.
is_train (Optional[bool]): Whether the model is in training.
Returns:
is_train is True:
Tensor: The masked language modeling loss for input.
is_train is False:
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
"""
def __init__(
self,
model: ALBEFModel,
answer_decoder: ALBEFDecoder,
loss: CausalLanguageModelingLoss,
) -> None:
super().__init__()
self.model = model
self.answer_decoder = answer_decoder
self.loss = loss
self.answer_decoder_m = copy.deepcopy(self.answer_decoder)
remove_grad(
self.answer_decoder_m
) # remove gradient for the momentum decoder model
def _train_forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answers_atts: Tensor,
ans_weights: Tensor,
ans_lengths: List[int],
alpha: float,
) -> Tensor:
"""
Forward step for training. Encode the inputs with the ALBEFModel.
Generate pseudo-targets using answer_decoder_m (momentum decoder model).
Generate answer predictions using answer_decoder.
Compute masked language modeling loss of the predictions using answers as labels,
pseudo-targets as soft-labels, and alpha as their interpolation value.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answers_atts (Tensor of shape (N, M)): Answer attention mask.
ans_weights (Tensor of shape (N)): Weights for each answer.
ans_lengths (List[int] of length B): Number of answers for each question.
ans_lengths should sum to N.
alpha (float): The interpolation value between clm_loss and loss_distill.
Returns:
Tensor: The masked language modeling loss for input.
"""
# get image-question embeddings from the ALBEFModel and format it to match the ans_lengths
encoder_outputs = self.model(image, question, question_atts)
(
encoder_hidden_states,
encoder_hidden_states_m,
encoder_attention_mask,
) = self._encoder_hidden_states(
encoder_outputs.multimodal_embeddings,
encoder_outputs.multimodal_embeddings_m,
question_atts,
ans_lengths,
)
# use the momentum model to generate pseudo-targets
with torch.no_grad():
momentum_update(
self.answer_decoder, self.answer_decoder_m, self.model.momentum
)
prediction_scores_m = self.answer_decoder_m(
input_ids=answers,
attention_mask=answers_atts,
encoder_hidden_states=encoder_hidden_states_m,
encoder_attention_mask=encoder_attention_mask,
)
# generate answer predictions
prediction_scores = self.answer_decoder(
input_ids=answers,
attention_mask=answers_atts,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
# compute masked language modeling loss from the prediction scores
labels = answers.masked_fill(answers == 0, self.loss.mask_token_id)
loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha)
loss = ans_weights * loss
loss = loss.sum() / image.size(0)
return loss
def _eval_forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answer_atts: Tensor,
k: int = 128,
) -> Tuple[Tensor, Tensor]:
"""
Forward step for evaluation. Encode the inputs with the ALBEFModel.
Generate answer autoregressively using the decoder, starting with the [CLS] token.
Compute the answer ids and their perspective probabilities of the top k predictions.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answer_atts (Tensor of shape (N, M)): Answer attention mask.
k (int): The number of answers to return for inference.
Returns:
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
"""
# get multimodal embeddings from the ALBEFModel and
# feed it to the decoder as cross attention
encoder_outputs = self.model(image, question, question_atts)
# use cls token as the decoder's initial input token
num_ques = question.size(0)
start_ids = answers[0, 0].repeat(num_ques, 1)
atts = torch.ones(start_ids.shape).to(image.device)
# auto-regressively generates the answer
prediction_scores = self.answer_decoder(
input_ids=start_ids,
attention_mask=atts,
encoder_hidden_states=encoder_outputs.multimodal_embeddings,
encoder_attention_mask=question_atts,
)
logits = prediction_scores[:, 0, :]
answer_first_token = answers[:, 1]
prob_first_token = F.softmax(logits, dim=1).index_select(
dim=1, index=answer_first_token
)
topk_probs, topk_ids = prob_first_token.topk(k, dim=1)
input_ids = []
input_atts = []
for topk_id in topk_ids:
input_ids.append(answers.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids)
input_atts = torch.cat(input_atts)
targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id)
question_states = encoder_outputs.multimodal_embeddings.repeat_interleave(
k, dim=0
)
question_atts = question_atts.repeat_interleave(k, dim=0)
prediction_scores = self.answer_decoder(
input_ids=input_ids,
attention_mask=input_atts,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
)
answer_loss = self.loss(targets_ids, prediction_scores)
answer_loss = answer_loss.view(input_ids.size(0), -1)
# topk_prob: first token probability
topk_probs = topk_probs.view(-1, 1)
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)
# re-calculate log probabilities for the answer sequences using chain rule
log_probs_sum = log_probs.sum(1)
log_probs_sum = log_probs_sum.view(num_ques, k)
topk_probs = F.softmax(log_probs_sum, dim=-1)
# get top-k after re-ranking
topk_probs, rerank_id = topk_probs.topk(k, dim=1)
topk_ids = torch.gather(topk_ids, 1, rerank_id)
return topk_ids, topk_probs
def _encoder_hidden_states(
self,
multimodal_embeds: Tensor,
multimodal_embeds_m: Tensor,
question_atts: Tensor,
ans_lengths: List[int],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Repeat each image-question input, repeat its embedding and mask to match the number of answers it has.
Args:
multimodal_embeds (Tensor): Image-question embeddings.
multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model.
question_atts (Tensor): Question attention mask.
ans_lengths (List[int]): The number of answers each image-question input has.
Returns:
encoder_hidden_states (Tensor): Image-question embeddings after the repetition.
encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition.
encoder_attention_mask (Tensor): Question attention mask after the repetition.
"""
encoder_hidden_states = []
encoder_attention_mask = []
for b, n in enumerate(ans_lengths):
encoder_hidden_states += [multimodal_embeds[b]] * n
encoder_attention_mask += [question_atts[b]] * n
encoder_hidden_states = torch.stack(encoder_hidden_states)
encoder_attention_mask = torch.stack(encoder_attention_mask)
with torch.no_grad():
encoder_hidden_states_m = []
for b, n in enumerate(ans_lengths):
encoder_hidden_states_m += [multimodal_embeds_m[b]] * n
encoder_hidden_states_m = torch.stack(encoder_hidden_states_m)
return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask
def forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answers_atts: Tensor,
ans_weights: Optional[Tensor] = None,
ans_lengths: Optional[List[int]] = None,
alpha: Optional[float] = 0.0,
k: Optional[int] = 128,
is_train: Optional[bool] = True,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if is_train:
return self._train_forward(
image,
question,
question_atts,
answers,
answers_atts,
ans_weights,
ans_lengths,
alpha,
)
else:
return self._eval_forward(
image,
question,
question_atts,
answers,
answers_atts,
k,
)
class ALBEFModelForRetrieval(nn.Module):
"""
ALBEF Model for Retrieval finetuning and inference.
In training mode, the forward step computes image-text contrastive loss and
image-text matching loss.
In evaluation mode, the forward step takes 3 types of input:
image: encode image input, project and normalize the embeddings.
text: encode text input, project and normalize the embeddings.
multimodal: create multimodal embeddings from image and text
embeddings, and compute image-text matching scores.
Args:
model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity.
itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss.
hidden_size (int): Dimensionality of encoder outputs.
Inputs:
image (Optional[Tensor] of shape (B, C, H, W)): Image features.
Required if is_train is True.
Required if input_type is "image" or "multimodal".
text (Optional[Tensor] of shape (B, L)): Text features.
Required if is_train is True.
Required if input_type is "text" or "multimodal".
text_atts (Tensor of shape (B, L)): Text attention mask.
Required if is_train is True.
Required if input_type is "text" or "multimodal".
idx (Tensor of shape (B)): Identifier for each image sample.
Required if is_train is True.
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
Default is 0.
input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type.
Required if is_train is False.
is_train (Optional[bool]): Whether the model is in training.
Default is True.
Returns:
is_train is True:
Tensor: The sum of itc loss and itm loss.
is_train is False:
input_type is "image":
Tuple[Tensor, Tensor]: Image embeddings and projected image features.
input_type is "text":
Tuple[Tensor, Tensor]: Text embeddings and projected text features.
input_type is "multimodal"
Tensor: Scores for the retrieval task.
"""
def __init__(
self,
model_with_similarity: ALBEFModelWithSimilarity,
itc_loss: ImageTextContrastiveLoss,
hidden_size: int,
) -> None:
super().__init__()
self.model_with_similarity = model_with_similarity
self.itc_loss = itc_loss
self.itm_head = nn.Linear(hidden_size, 2)
def _train_forward(
self,
image: Tensor,
text: Tensor,
text_atts: Tensor,
idx: Tensor,
alpha: float,
) -> Tensor:
encoder_output = self.model_with_similarity(image, text, text_atts, idx)
# compute image-text contrastive loss
similarity_outputs = encoder_output.similarity
similarity_targets = encoder_output.sim_targets
itc_loss = self.itc_loss(
similarity_outputs.sim_i2t,
similarity_outputs.sim_t2i,
similarity_outputs.sim_i2t_m,
similarity_outputs.sim_t2i_m,
similarity_targets,
alpha,
)
# compute image-text matching loss
pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :]
neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :]
vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0)
vl_output = self.itm_head(vl_embeddings)
itm_labels = torch.cat(
[
torch.ones(pos_embeddings.size(0), dtype=torch.long),
torch.zeros(neg_embeddings.size(0), dtype=torch.long),
],
dim=0,
).to(vl_embeddings.device)
itm_loss = F.cross_entropy(vl_output, itm_labels)
loss = itc_loss + itm_loss
return loss
def _encode_image(
self,
image: Tensor,
) -> Tuple[Tensor, Tensor]:
image_embed = self.model_with_similarity.albef_model.vision_encoder(image)
image_feat = F.normalize(
self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1
)
return image_embed, image_feat
def _encode_text(
self,
text: Tensor,
text_atts: Tensor,
) -> Tuple[Tensor, Tensor]:
text_embed = self.model_with_similarity.albef_model.text_encoder(
text, text_atts
).last_hidden_state
text_feat = F.normalize(
self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1
)
return text_embed, text_feat
def _image_text_matching_score(
self,
image: Tensor,
text: Tensor,
text_atts: Tensor,
) -> Tensor:
multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder(
text,
text_atts,
image,
)
score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1]
return score
def _eval_forward(
self,
input_type: str,
image: Optional[Tensor],
text: Optional[Tensor],
text_atts: Optional[Tensor],
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if input_type == "image":
assert image is not None, "image input tensor cannot be None"
return self._encode_image(image)
elif input_type == "text":
assert (
text is not None and text_atts is not None
), "text and text attention mask cannot be None"
return self._encode_text(text, text_atts)
elif input_type == "multimodal":
assert (
image is not None and text is not None and text_atts is not None
), "image embeddings, text embeddings, and text attention mask cannot be None"
return self._image_text_matching_score(image, text, text_atts)
else:
raise ValueError("input_type must be image, text, or multimodal")
def forward(
self,
image: Optional[Tensor] = None,
text: Optional[Tensor] = None,
text_atts: Optional[Tensor] = None,
idx: Optional[Tensor] = None,
alpha: Optional[Tensor] = 0.0,
input_type: Optional[str] = None,
is_train: Optional[bool] = True,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if is_train:
return self._train_forward(
image,
text,
text_atts,
idx,
alpha,
)
else:
return self._eval_forward(
input_type,
image,
text,
text_atts,
)
def albef_model_for_vqa(
config: Dict[str, Any], pretrained: bool = False
) -> ALBEFModelForVQA:
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
text_encoder = bert_text_encoder(**config["text_encoder_args"])
question_multimodal_encoder = ALBEFMultimodalEncoder(
**config["multimodal_encoder_args"]
)
text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"])
answer_multimodal_encoder = ALBEFMultimodalEncoder(
**config["multimodal_encoder_args"]
)
prediction_head = PredictionHead(**config["prediction_head_args"])
albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder)
decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head)
loss = CausalLanguageModelingLoss()
model = ALBEFModelForVQA(albef_model, decoder, loss)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
_ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu"
)
model.load_state_dict(checkpoint)
return model
def albef_model_for_retrieval(
config: Dict[str, Any], pretrained: bool = False
) -> ALBEFModelForRetrieval:
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
text_encoder = bert_text_encoder(**config["text_encoder_args"])
multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"])
vision_proj = nn.Linear(**config["projection_args"])
text_proj = nn.Linear(**config["projection_args"])
albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder)
albef_model_with_sim = ALBEFModelWithSimilarity(
albef_model, vision_proj, text_proj, **config["similarity_args"]
)
itc_loss = ImageTextContrastiveLoss()
model = ALBEFModelForRetrieval(
albef_model_with_sim, itc_loss, config["hidden_size"]
)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
_ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu"
)
model.load_state_dict(checkpoint)
return model