Spaces:
Runtime error
Runtime error
import copy | |
import functools | |
from typing import Any, Dict | |
import torch | |
from torch import nn | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing | |
from virtex.modules.textual_heads import TextualHead | |
from virtex.modules.visual_backbones import VisualBackbone | |
class CaptioningModel(nn.Module): | |
r""" | |
A model to perform image captioning (in both forward and backward directions | |
independently, only in forward direction). It is composed of a | |
:class:`~virtex.modules.visual_backbones.VisualBackbone` and a | |
:class:`~virtex.modules.textual_heads.TextualHead` on top of it. | |
During training, it maximizes the likelihood of ground truth caption | |
conditioned on image features. During inference, it predicts a caption for | |
an input image through beam search decoding. | |
Parameters | |
---------- | |
visual: virtex.modules.visual_backbones.VisualBackbone | |
A :class:`~virtex.modules.visual_backbones.VisualBackbone` which | |
computes visual features from an input image. | |
textual: virtex.modules.textual_heads.TextualHead | |
A :class:`~virtex.modules.textual_heads.TextualHead` which | |
makes final predictions conditioned on visual features. | |
sos_index: int, optional (default = 1) | |
The index of the end token (``[SOS]``) in vocabulary. | |
eos_index: int, optional (default = 2) | |
The index of the end token (``[EOS]``) in vocabulary. | |
caption_backward: bool, optional (default = False) | |
Whether to *also* perform captioning in backward direction. Default is | |
``False`` -- only forward captioning is performed. When ``True``, a | |
clone of textual head is created, which does not share weights with | |
"forward" model except input and output embeddings. | |
decoder: Any, optional (default = None) | |
An instance of :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` | |
or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling` | |
for decoding captions during inference (unused during training). | |
""" | |
def __init__( | |
self, | |
visual: VisualBackbone, | |
textual: TextualHead, | |
caption_backward: bool = False, | |
sos_index: int = 1, | |
eos_index: int = 2, | |
label_smoothing: float = 0.0, | |
decoder: Any = None, | |
): | |
super().__init__() | |
self.visual = visual | |
self.textual = textual | |
self.padding_idx = self.textual.padding_idx | |
self.caption_backward = caption_backward | |
# Clone the textual module for backward direction if doing captioning | |
# in both directions (separately). | |
if self.caption_backward: | |
self.backward_textual = copy.deepcopy(self.textual) | |
# Share weights for visual projection, and input/output embeddings. | |
self.backward_textual.visual_projection = self.textual.visual_projection | |
self.backward_textual.embedding = self.textual.embedding | |
self.backward_textual.output = self.textual.output | |
# These boundary indices are needed for beam search. | |
self.sos_index = sos_index | |
self.eos_index = eos_index | |
self.decoder = decoder | |
self.loss = CrossEntropyLossWithLabelSmoothing( | |
label_smoothing, ignore_index=self.padding_idx | |
) | |
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: | |
r""" | |
Given a batch of images and captions, compute log likelihood loss per | |
caption token during training. During inference (with images), predict | |
a caption through either beam search decoding or nucleus sampling. | |
Parameters | |
---------- | |
batch: Dict[str, torch.Tensor] | |
Training or inference batch. During training, a batch would at least | |
contain keys ``{"image", "caption_tokens", "caption_lengths"}`` and | |
also ``"noitpac_tokens"`` for bicaptioning. | |
During inference, a batch would contain key ``{"image"}`` and | |
optionally ``"decode_prompt"`` as a partial sequence for decoding. | |
Returns | |
------- | |
Dict[str, Any] | |
A dict with the following structure, containing loss for optimization, | |
loss components to log directly to tensorboard, and optionally | |
predictions. | |
.. code-block:: | |
{ | |
"loss": torch.Tensor, | |
"loss_components": { | |
"captioning_forward": torch.Tensor, | |
"captioning_backward": torch.Tensor, (optional) | |
}, | |
"predictions": torch.Tensor | |
} | |
""" | |
# shape: (batch_size, channels, height, width) | |
visual_features = self.visual(batch["image"]) | |
batch_size = visual_features.size(0) | |
if "caption_tokens" in batch: | |
caption_tokens = batch["caption_tokens"] | |
caption_lengths = batch["caption_lengths"] | |
# shape: (batch_size, max_caption_length, vocab_size) | |
output_logits = self.textual( | |
visual_features, caption_tokens, caption_lengths | |
) | |
loss = self.loss( | |
output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), | |
caption_tokens[:, 1:].contiguous().view(-1), | |
) | |
output_dict: Dict[str, Any] = { | |
"loss": loss, | |
# Single scalar per batch for logging in training script. | |
"loss_components": {"captioning_forward": loss.clone().detach()}, | |
} | |
# Do captioning in backward direction if specified. | |
if self.caption_backward: | |
backward_caption_tokens = batch["noitpac_tokens"] | |
backward_output_logits = self.backward_textual( | |
visual_features, backward_caption_tokens, caption_lengths | |
) | |
backward_loss = self.loss( | |
backward_output_logits[:, :-1] | |
.contiguous() | |
.view(-1, self.textual.vocab_size), | |
backward_caption_tokens[:, 1:].contiguous().view(-1), | |
) | |
output_dict["loss"] += backward_loss | |
# Single scalar per batch for logging in training script. | |
output_dict["loss_components"].update( | |
captioning_backward=backward_loss.clone().detach() | |
) | |
if not self.training: | |
# During validation (while pretraining), get best prediction | |
# at every timestep. | |
output_dict["predictions"] = torch.argmax(output_logits, dim=-1) | |
else: | |
if self.decoder is None: | |
raise ValueError("Decoder for predicting captions is missing!") | |
# During inference, decode captions from forward transformer model. | |
# Check if the batch contains decoding prompt. | |
if "decode_prompt" in batch: | |
# shape: (batch_size, prompt_length) | |
start_predictions = torch.unsqueeze(batch["decode_prompt"], 0) | |
start_predictions = start_predictions.repeat(batch_size, 1) | |
else: | |
# shape: (batch_size, ) | |
start_predictions = torch.full( | |
(batch_size,), self.sos_index, device=visual_features.device | |
).long() | |
# Add image features as a default argument to match callable | |
# signature accepted by beam search class (partial captions only). | |
decoding_step = functools.partial(self.decoding_step, visual_features) | |
predicted_caption, _ = self.decoder.search( | |
start_predictions, decoding_step | |
) | |
output_dict = {"predictions": predicted_caption} | |
return output_dict | |
def decoding_step( | |
self, visual_features: torch.Tensor, partial_captions: torch.Tensor | |
) -> torch.Tensor: | |
r""" | |
Given visual features and a batch of (assumed) partial captions, predict | |
the logits over output vocabulary tokens for next timestep. This method | |
is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` | |
and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`. | |
.. note:: | |
For nucleus sampling, ``beam_size`` will always be 1 (not relevant). | |
Parameters | |
---------- | |
projected_visual_features: torch.Tensor | |
A tensor of shape ``(batch_size, ..., textual_feature_size)`` | |
with visual features already projected to ``textual_feature_size``. | |
partial_captions: torch.Tensor | |
A tensor of shape ``(batch_size * beam_size, timesteps)`` | |
containing tokens predicted so far -- one for each beam. We need all | |
prior predictions because our model is auto-regressive. | |
Returns | |
------- | |
torch.Tensor | |
A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits | |
over output vocabulary tokens for next timestep. | |
""" | |
# Expand and repeat image features while doing beam search. | |
batch_size, channels, height, width = visual_features.size() | |
beam_size = int(partial_captions.size(0) / batch_size) | |
if beam_size > 1: | |
# shape: (batch_size * beam_size, channels, height, width) | |
visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1) | |
visual_features = visual_features.view( | |
batch_size * beam_size, channels, height, width | |
) | |
# Provide caption lengths as current length (irrespective of predicted | |
# EOS/padding tokens). shape: (batch_size, ) | |
caption_lengths = torch.ones_like(partial_captions) | |
if len(caption_lengths.size()) == 2: | |
caption_lengths = caption_lengths.sum(1) | |
else: | |
# Add a timestep. shape: (batch_size, 1) | |
partial_captions = partial_captions.unsqueeze(1) | |
# shape: (batch_size * beam_size, partial_caption_length, vocab_size) | |
logits = self.textual(visual_features, partial_captions, caption_lengths) | |
# Return logits from the last timestep. | |
return logits[:, -1, :] | |
def log_predictions( | |
self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer | |
) -> str: | |
self.eval() | |
with torch.no_grad(): | |
predictions = self.forward(batch)["predictions"] | |
self.train() | |
predictions_str = "" | |
for tokens, preds in zip(batch["caption_tokens"], predictions): | |
predictions_str += f""" | |
Caption tokens : {" ".join(tokens.tolist())} | |
Predictions (f): {" ".join(preds.tolist())} | |
""" | |
return predictions_str | |
class ForwardCaptioningModel(CaptioningModel): | |
r""" | |
Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` | |
for better readability: this passes ``caption_backward=False`` to super class. | |
""" | |
def __init__( | |
self, | |
visual: VisualBackbone, | |
textual: TextualHead, | |
sos_index: int = 1, | |
eos_index: int = 2, | |
label_smoothing: float = 0.0, | |
decoder: Any = None, | |
): | |
super().__init__( | |
visual, | |
textual, | |
sos_index=sos_index, | |
eos_index=eos_index, | |
caption_backward=False, | |
label_smoothing=label_smoothing, | |
decoder=decoder, | |
) | |
class BidirectionalCaptioningModel(CaptioningModel): | |
r""" | |
Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` | |
for better readability: this passes ``caption_backward=True`` to super class. | |
""" | |
def __init__( | |
self, | |
visual: VisualBackbone, | |
textual: TextualHead, | |
sos_index: int = 1, | |
eos_index: int = 2, | |
label_smoothing: float = 0.0, | |
decoder: Any = None, | |
): | |
super().__init__( | |
visual, | |
textual, | |
sos_index=sos_index, | |
eos_index=eos_index, | |
caption_backward=True, | |
label_smoothing=label_smoothing, | |
decoder=decoder, | |
) | |
# Convenient handle for our main model. | |
VirTexModel = BidirectionalCaptioningModel | |