virtex-redcaps / virtex /models /captioning.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw history blame
No virus
12.5 kB
import copy
import functools
from typing import Any, Dict
import torch
from torch import nn
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone
class CaptioningModel(nn.Module):
r"""
A model to perform image captioning (in both forward and backward directions
independently, only in forward direction). It is composed of a
:class:`~virtex.modules.visual_backbones.VisualBackbone` and a
:class:`~virtex.modules.textual_heads.TextualHead` on top of it.
During training, it maximizes the likelihood of ground truth caption
conditioned on image features. During inference, it predicts a caption for
an input image through beam search decoding.
Parameters
----------
visual: virtex.modules.visual_backbones.VisualBackbone
A :class:`~virtex.modules.visual_backbones.VisualBackbone` which
computes visual features from an input image.
textual: virtex.modules.textual_heads.TextualHead
A :class:`~virtex.modules.textual_heads.TextualHead` which
makes final predictions conditioned on visual features.
sos_index: int, optional (default = 1)
The index of the end token (``[SOS]``) in vocabulary.
eos_index: int, optional (default = 2)
The index of the end token (``[EOS]``) in vocabulary.
caption_backward: bool, optional (default = False)
Whether to *also* perform captioning in backward direction. Default is
``False`` -- only forward captioning is performed. When ``True``, a
clone of textual head is created, which does not share weights with
"forward" model except input and output embeddings.
decoder: Any, optional (default = None)
An instance of :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`
or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`
for decoding captions during inference (unused during training).
"""
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
caption_backward: bool = False,
sos_index: int = 1,
eos_index: int = 2,
label_smoothing: float = 0.0,
decoder: Any = None,
):
super().__init__()
self.visual = visual
self.textual = textual
self.padding_idx = self.textual.padding_idx
self.caption_backward = caption_backward
# Clone the textual module for backward direction if doing captioning
# in both directions (separately).
if self.caption_backward:
self.backward_textual = copy.deepcopy(self.textual)
# Share weights for visual projection, and input/output embeddings.
self.backward_textual.visual_projection = self.textual.visual_projection
self.backward_textual.embedding = self.textual.embedding
self.backward_textual.output = self.textual.output
# These boundary indices are needed for beam search.
self.sos_index = sos_index
self.eos_index = eos_index
self.decoder = decoder
self.loss = CrossEntropyLossWithLabelSmoothing(
label_smoothing, ignore_index=self.padding_idx
)
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
r"""
Given a batch of images and captions, compute log likelihood loss per
caption token during training. During inference (with images), predict
a caption through either beam search decoding or nucleus sampling.
Parameters
----------
batch: Dict[str, torch.Tensor]
Training or inference batch. During training, a batch would at least
contain keys ``{"image", "caption_tokens", "caption_lengths"}`` and
also ``"noitpac_tokens"`` for bicaptioning.
During inference, a batch would contain key ``{"image"}`` and
optionally ``"decode_prompt"`` as a partial sequence for decoding.
Returns
-------
Dict[str, Any]
A dict with the following structure, containing loss for optimization,
loss components to log directly to tensorboard, and optionally
predictions.
.. code-block::
{
"loss": torch.Tensor,
"loss_components": {
"captioning_forward": torch.Tensor,
"captioning_backward": torch.Tensor, (optional)
},
"predictions": torch.Tensor
}
"""
# shape: (batch_size, channels, height, width)
visual_features = self.visual(batch["image"])
batch_size = visual_features.size(0)
if "caption_tokens" in batch:
caption_tokens = batch["caption_tokens"]
caption_lengths = batch["caption_lengths"]
# shape: (batch_size, max_caption_length, vocab_size)
output_logits = self.textual(
visual_features, caption_tokens, caption_lengths
)
loss = self.loss(
output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
caption_tokens[:, 1:].contiguous().view(-1),
)
output_dict: Dict[str, Any] = {
"loss": loss,
# Single scalar per batch for logging in training script.
"loss_components": {"captioning_forward": loss.clone().detach()},
}
# Do captioning in backward direction if specified.
if self.caption_backward:
backward_caption_tokens = batch["noitpac_tokens"]
backward_output_logits = self.backward_textual(
visual_features, backward_caption_tokens, caption_lengths
)
backward_loss = self.loss(
backward_output_logits[:, :-1]
.contiguous()
.view(-1, self.textual.vocab_size),
backward_caption_tokens[:, 1:].contiguous().view(-1),
)
output_dict["loss"] += backward_loss
# Single scalar per batch for logging in training script.
output_dict["loss_components"].update(
captioning_backward=backward_loss.clone().detach()
)
if not self.training:
# During validation (while pretraining), get best prediction
# at every timestep.
output_dict["predictions"] = torch.argmax(output_logits, dim=-1)
else:
if self.decoder is None:
raise ValueError("Decoder for predicting captions is missing!")
# During inference, decode captions from forward transformer model.
# Check if the batch contains decoding prompt.
if "decode_prompt" in batch:
# shape: (batch_size, prompt_length)
start_predictions = torch.unsqueeze(batch["decode_prompt"], 0)
start_predictions = start_predictions.repeat(batch_size, 1)
else:
# shape: (batch_size, )
start_predictions = torch.full(
(batch_size,), self.sos_index, device=visual_features.device
).long()
# Add image features as a default argument to match callable
# signature accepted by beam search class (partial captions only).
decoding_step = functools.partial(self.decoding_step, visual_features)
predicted_caption, _ = self.decoder.search(
start_predictions, decoding_step
)
output_dict = {"predictions": predicted_caption}
return output_dict
def decoding_step(
self, visual_features: torch.Tensor, partial_captions: torch.Tensor
) -> torch.Tensor:
r"""
Given visual features and a batch of (assumed) partial captions, predict
the logits over output vocabulary tokens for next timestep. This method
is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`
and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`.
.. note::
For nucleus sampling, ``beam_size`` will always be 1 (not relevant).
Parameters
----------
projected_visual_features: torch.Tensor
A tensor of shape ``(batch_size, ..., textual_feature_size)``
with visual features already projected to ``textual_feature_size``.
partial_captions: torch.Tensor
A tensor of shape ``(batch_size * beam_size, timesteps)``
containing tokens predicted so far -- one for each beam. We need all
prior predictions because our model is auto-regressive.
Returns
-------
torch.Tensor
A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits
over output vocabulary tokens for next timestep.
"""
# Expand and repeat image features while doing beam search.
batch_size, channels, height, width = visual_features.size()
beam_size = int(partial_captions.size(0) / batch_size)
if beam_size > 1:
# shape: (batch_size * beam_size, channels, height, width)
visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1)
visual_features = visual_features.view(
batch_size * beam_size, channels, height, width
)
# Provide caption lengths as current length (irrespective of predicted
# EOS/padding tokens). shape: (batch_size, )
caption_lengths = torch.ones_like(partial_captions)
if len(caption_lengths.size()) == 2:
caption_lengths = caption_lengths.sum(1)
else:
# Add a timestep. shape: (batch_size, 1)
partial_captions = partial_captions.unsqueeze(1)
# shape: (batch_size * beam_size, partial_caption_length, vocab_size)
logits = self.textual(visual_features, partial_captions, caption_lengths)
# Return logits from the last timestep.
return logits[:, -1, :]
def log_predictions(
self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer
) -> str:
self.eval()
with torch.no_grad():
predictions = self.forward(batch)["predictions"]
self.train()
predictions_str = ""
for tokens, preds in zip(batch["caption_tokens"], predictions):
predictions_str += f"""
Caption tokens : {" ".join(tokens.tolist())}
Predictions (f): {" ".join(preds.tolist())}
"""
return predictions_str
class ForwardCaptioningModel(CaptioningModel):
r"""
Convenient extension of :class:`~virtex.models.captioning.CaptioningModel`
for better readability: this passes ``caption_backward=False`` to super class.
"""
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
sos_index: int = 1,
eos_index: int = 2,
label_smoothing: float = 0.0,
decoder: Any = None,
):
super().__init__(
visual,
textual,
sos_index=sos_index,
eos_index=eos_index,
caption_backward=False,
label_smoothing=label_smoothing,
decoder=decoder,
)
class BidirectionalCaptioningModel(CaptioningModel):
r"""
Convenient extension of :class:`~virtex.models.captioning.CaptioningModel`
for better readability: this passes ``caption_backward=True`` to super class.
"""
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
sos_index: int = 1,
eos_index: int = 2,
label_smoothing: float = 0.0,
decoder: Any = None,
):
super().__init__(
visual,
textual,
sos_index=sos_index,
eos_index=eos_index,
caption_backward=True,
label_smoothing=label_smoothing,
decoder=decoder,
)
# Convenient handle for our main model.
VirTexModel = BidirectionalCaptioningModel