Spaces:
Runtime error
Runtime error
from typing import Any, Dict, List | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.modules.textual_heads import TextualHead | |
from virtex.modules.visual_backbones import VisualBackbone | |
class ClassificationModel(nn.Module): | |
r""" | |
A model to perform classification (generally, with multiple targets). It is | |
composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a | |
:class:`~virtex.modules.textual_heads.TextualHead` on top of it. | |
.. note:: | |
As with currently available textual heads, only one textual head is | |
supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`. | |
During training, it minimizes the KL-divergence loss with a K-hot vector, | |
with values ``1/K``, where K are the number of unique labels to classify. | |
Parameters | |
---------- | |
visual: virtex.modules.visual_backbones.VisualBackbone | |
A :class:`~virtex.modules.visual_backbones.VisualBackbone` which | |
computes visual features from an input image. | |
textual: virtex.modules.textual_heads.TextualHead | |
A :class:`~virtex.modules.textual_heads.TextualHead` which | |
makes final predictions conditioned on visual features. | |
ignore_indices: List[int] | |
Ignore a set of token indices while computing KL-divergence loss. These | |
are usually the special tokens such as ``[SOS]``, ``[EOS]`` etc. | |
""" | |
def __init__( | |
self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int] | |
): | |
super().__init__() | |
self.visual = visual | |
self.textual = textual | |
self.ignore_indices = ignore_indices | |
def forward(self, batch: Dict[str, torch.Tensor]): | |
r""" | |
Given a batch of images and set of labels, perform classification with | |
multiple targets by minimizing a KL-divergence loss. | |
Parameters | |
---------- | |
batch: Dict[str, torch.Tensor] | |
A batch of images and labels. Possible set of keys: | |
``{"image_id", "image", "labels"}`` | |
Returns | |
------- | |
Dict[str, Any] | |
A dict with the following structure, containing loss for optimization, | |
loss components to log directly to tensorboard, and optionally | |
predictions. | |
.. code-block:: | |
{ | |
"loss": torch.Tensor, | |
"loss_components": { | |
"classification": torch.Tensor, | |
}, | |
"predictions": torch.Tensor | |
} | |
""" | |
# shape: (batch_size, visual_feature_size, ...) | |
visual_features = self.visual(batch["image"]) | |
batch_size = visual_features.size(0) | |
# Get logits and further log-probabilities. | |
# shape: (batch_size, vocab_size) | |
logits = self.textual(visual_features) | |
logprobs = F.log_softmax(logits, dim=1) | |
# Average log-probs per unique token in associated caption to compute | |
# loss. This is simply cross-entropy with target-vector as a K-hot | |
# vector. Do in a for-loop, there isn't a straightforward vectorized way. | |
loss = torch.tensor(0.0, device=logprobs.device) | |
for index in range(batch_size): | |
# Get unique labels for particular instance. | |
unique_labels = batch["labels"][index].unique() | |
# Ignore indices of special tokens such as [SOS], [EOS] etc. and | |
# any other token specified. | |
unique_labels = [l for l in unique_labels if l not in self.ignore_indices] | |
# Get log-probabilities corresponding to these tokens. | |
instance_logprobs = logprobs[index, unique_labels].mean() | |
# Accumulate negative log-probability for this instance in loss. | |
loss = loss - instance_logprobs | |
# Average loss across instances. | |
output_dict: Dict[str, Any] = {"loss": loss / batch_size} | |
# Single scalar per batch for logging to tensorboard in training script. | |
output_dict["loss_components"] = { | |
"classification": loss.clone().detach() / batch_size | |
} | |
# Return top-10 tokens according to log-probabilities during validation. | |
# Useful for logging. | |
if not self.training: | |
top_logprobs, top_tokens = logprobs.topk(k=10, dim=1) | |
output_dict["predictions"] = top_tokens | |
return output_dict | |
class TokenClassificationModel(ClassificationModel): | |
r""" | |
Convenient extension of :class:`~virtex.models.classification.ClassificationModel` | |
for better readability (this only modifies the tensorboard logging logic). | |
Ground truth targets here are a set of unique caption tokens (ignoring the | |
special tokens like ``[SOS]``, ``[EOS]`` etc.). | |
""" | |
def log_predictions( | |
self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer | |
) -> str: | |
self.eval() | |
with torch.no_grad(): | |
predictions = self.forward(batch)["predictions"] | |
self.train() | |
predictions_str = "" | |
for tokens, preds in zip(batch["caption_tokens"], predictions): | |
# Predictions here are individual tokens, and do not have any order | |
# like captions, so decode them separately so we don't strip off | |
# metaspace character and special tokens if any. | |
preds = [tokenizer.id_to_token(p) for p in preds.tolist()] | |
predictions_str += f""" | |
Caption tokens : {tokenizer.decode(tokens.tolist())} | |
Predictions (f): {" ".join(preds)} | |
""" | |
return predictions_str | |
class MultiLabelClassificationModel(ClassificationModel): | |
r""" | |
Convenient extension of :class:`~virtex.models.classification.ClassificationModel` | |
for better readability (this only modifies the tensorboard logging logic). | |
Ground truth targets here are a set of unique instances in images (ignoring | |
the special background token, category id = 0 in COCO). | |
""" | |
def log_predictions( | |
self, | |
batch: Dict[str, torch.Tensor], | |
tokenizer: SentencePieceBPETokenizer = None, | |
) -> str: | |
# We accept `tokenizer` for having consistent API but don't use it here. | |
self.eval() | |
with torch.no_grad(): | |
predictions = self.forward(batch)["predictions"] | |
self.train() | |
predictions_str = "" | |
for tokens, preds in zip(batch["caption_tokens"], predictions): | |
# Predictions here are COCO category IDs, let them be as is. | |
# Sorted ground truth, remove background tokens. | |
tokens = sorted([t for t in tokens.tolist() if t != 0]) | |
preds = sorted(preds.tolist()[: len(tokens)]) | |
predictions_str += f""" | |
COCO Instance IDs (GT) : {tokens} | |
COCO Instance IDs (Pred) : {preds} | |
""" | |
return predictions_str | |