Spaces:
Runtime error
Runtime error
File size: 7,034 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from typing import Any, Dict, List
import torch
from torch import nn
from torch.nn import functional as F
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone
class ClassificationModel(nn.Module):
r"""
A model to perform classification (generally, with multiple targets). It is
composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a
:class:`~virtex.modules.textual_heads.TextualHead` on top of it.
.. note::
As with currently available textual heads, only one textual head is
supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`.
During training, it minimizes the KL-divergence loss with a K-hot vector,
with values ``1/K``, where K are the number of unique labels to classify.
Parameters
----------
visual: virtex.modules.visual_backbones.VisualBackbone
A :class:`~virtex.modules.visual_backbones.VisualBackbone` which
computes visual features from an input image.
textual: virtex.modules.textual_heads.TextualHead
A :class:`~virtex.modules.textual_heads.TextualHead` which
makes final predictions conditioned on visual features.
ignore_indices: List[int]
Ignore a set of token indices while computing KL-divergence loss. These
are usually the special tokens such as ``[SOS]``, ``[EOS]`` etc.
"""
def __init__(
self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int]
):
super().__init__()
self.visual = visual
self.textual = textual
self.ignore_indices = ignore_indices
def forward(self, batch: Dict[str, torch.Tensor]):
r"""
Given a batch of images and set of labels, perform classification with
multiple targets by minimizing a KL-divergence loss.
Parameters
----------
batch: Dict[str, torch.Tensor]
A batch of images and labels. Possible set of keys:
``{"image_id", "image", "labels"}``
Returns
-------
Dict[str, Any]
A dict with the following structure, containing loss for optimization,
loss components to log directly to tensorboard, and optionally
predictions.
.. code-block::
{
"loss": torch.Tensor,
"loss_components": {
"classification": torch.Tensor,
},
"predictions": torch.Tensor
}
"""
# shape: (batch_size, visual_feature_size, ...)
visual_features = self.visual(batch["image"])
batch_size = visual_features.size(0)
# Get logits and further log-probabilities.
# shape: (batch_size, vocab_size)
logits = self.textual(visual_features)
logprobs = F.log_softmax(logits, dim=1)
# Average log-probs per unique token in associated caption to compute
# loss. This is simply cross-entropy with target-vector as a K-hot
# vector. Do in a for-loop, there isn't a straightforward vectorized way.
loss = torch.tensor(0.0, device=logprobs.device)
for index in range(batch_size):
# Get unique labels for particular instance.
unique_labels = batch["labels"][index].unique()
# Ignore indices of special tokens such as [SOS], [EOS] etc. and
# any other token specified.
unique_labels = [l for l in unique_labels if l not in self.ignore_indices]
# Get log-probabilities corresponding to these tokens.
instance_logprobs = logprobs[index, unique_labels].mean()
# Accumulate negative log-probability for this instance in loss.
loss = loss - instance_logprobs
# Average loss across instances.
output_dict: Dict[str, Any] = {"loss": loss / batch_size}
# Single scalar per batch for logging to tensorboard in training script.
output_dict["loss_components"] = {
"classification": loss.clone().detach() / batch_size
}
# Return top-10 tokens according to log-probabilities during validation.
# Useful for logging.
if not self.training:
top_logprobs, top_tokens = logprobs.topk(k=10, dim=1)
output_dict["predictions"] = top_tokens
return output_dict
class TokenClassificationModel(ClassificationModel):
r"""
Convenient extension of :class:`~virtex.models.classification.ClassificationModel`
for better readability (this only modifies the tensorboard logging logic).
Ground truth targets here are a set of unique caption tokens (ignoring the
special tokens like ``[SOS]``, ``[EOS]`` etc.).
"""
def log_predictions(
self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer
) -> str:
self.eval()
with torch.no_grad():
predictions = self.forward(batch)["predictions"]
self.train()
predictions_str = ""
for tokens, preds in zip(batch["caption_tokens"], predictions):
# Predictions here are individual tokens, and do not have any order
# like captions, so decode them separately so we don't strip off
# metaspace character and special tokens if any.
preds = [tokenizer.id_to_token(p) for p in preds.tolist()]
predictions_str += f"""
Caption tokens : {tokenizer.decode(tokens.tolist())}
Predictions (f): {" ".join(preds)}
"""
return predictions_str
class MultiLabelClassificationModel(ClassificationModel):
r"""
Convenient extension of :class:`~virtex.models.classification.ClassificationModel`
for better readability (this only modifies the tensorboard logging logic).
Ground truth targets here are a set of unique instances in images (ignoring
the special background token, category id = 0 in COCO).
"""
def log_predictions(
self,
batch: Dict[str, torch.Tensor],
tokenizer: SentencePieceBPETokenizer = None,
) -> str:
# We accept `tokenizer` for having consistent API but don't use it here.
self.eval()
with torch.no_grad():
predictions = self.forward(batch)["predictions"]
self.train()
predictions_str = ""
for tokens, preds in zip(batch["caption_tokens"], predictions):
# Predictions here are COCO category IDs, let them be as is.
# Sorted ground truth, remove background tokens.
tokens = sorted([t for t in tokens.tolist() if t != 0])
preds = sorted(preds.tolist()[: len(tokens)])
predictions_str += f"""
COCO Instance IDs (GT) : {tokens}
COCO Instance IDs (Pred) : {preds}
"""
return predictions_str
|