Spaces:
Runtime error
Runtime error
from typing import Any, Dict | |
import torch | |
from torch import nn | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.modules.textual_heads import TextualHead | |
from virtex.modules.visual_backbones import VisualBackbone | |
class MaskedLMModel(nn.Module): | |
r""" | |
A model to perform BERT-like masked language modeling. It is composed of a | |
:class:`~virtex.modules.visual_backbones.VisualBackbone` and a | |
:class:`~virtex.modules.textual_heads.TextualHead` on top of it. | |
During training, the model received caption tokens with certain tokens | |
replaced by ``[MASK]`` token, and it predicts these masked tokens based on | |
surrounding context. | |
Parameters | |
---------- | |
visual: virtex.modules.visual_backbones.VisualBackbone | |
A :class:`~virtex.modules.visual_backbones.VisualBackbone` which | |
computes visual features from an input image. | |
textual: virtex.modules.textual_heads.TextualHead | |
A :class:`~virtex.modules.textual_heads.TextualHead` which | |
makes final predictions conditioned on visual features. | |
""" | |
def __init__(self, visual: VisualBackbone, textual: TextualHead): | |
super().__init__() | |
self.visual = visual | |
self.textual = textual | |
self.padding_idx = self.textual.padding_idx | |
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) | |
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: | |
r""" | |
Given a batch of images and captions with certain masked tokens, | |
predict the tokens at masked positions. | |
Parameters | |
---------- | |
batch: Dict[str, torch.Tensor] | |
A batch of images, ground truth caption tokens and masked labels. | |
Possible set of keys: ``{"image_id", "image", "caption_tokens", | |
"masked_labels", "caption_lengths"}``. | |
Returns | |
------- | |
Dict[str, Any] | |
A dict with the following structure, containing loss for optimization, | |
loss components to log directly to tensorboard, and optionally | |
predictions. | |
.. code-block:: | |
{ | |
"loss": torch.Tensor, | |
"loss_components": {"masked_lm": torch.Tensor}, | |
"predictions": torch.Tensor | |
} | |
""" | |
# shape: (batch_size, channels, height, width) | |
visual_features = self.visual(batch["image"]) | |
caption_tokens = batch["caption_tokens"] | |
caption_lengths = batch["caption_lengths"] | |
masked_labels = batch["masked_labels"] | |
# shape: (batch_size, num_caption_tokens, vocab_size) | |
output_logits = self.textual(visual_features, caption_tokens, caption_lengths) | |
output_dict: Dict[str, Any] = { | |
"loss": self.loss( | |
output_logits.view(-1, output_logits.size(-1)), masked_labels.view(-1) | |
) | |
} | |
# Single scalar per batch for logging in training script. | |
output_dict["loss_components"] = { | |
"masked_lm": output_dict["loss"].clone().detach() | |
} | |
# During evaluation, get predictions from logits. Useful for logging. | |
# Only the predictions at [MASK]ed positions are relevant. | |
if not self.training: | |
predictions = torch.argmax(output_logits, dim=-1) | |
redundant_positions = masked_labels == self.padding_idx | |
predictions[redundant_positions] = self.padding_idx | |
output_dict["predictions"] = predictions | |
return output_dict | |
def log_predictions( | |
self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer | |
) -> str: | |
self.eval() | |
with torch.no_grad(): | |
predictions = self.forward(batch)["predictions"] | |
self.train() | |
predictions_str = "" | |
for tokens, labels, preds in zip( | |
batch["caption_tokens"], batch["masked_labels"], predictions | |
): | |
predictions_str += f""" | |
Caption tokens : {tokenizer.decode(tokens.tolist())} | |
Masked Labels : {tokenizer.decode(labels.tolist())} | |
Predictions : {tokenizer.decode(preds.tolist())} | |
""" | |
return predictions_str | |