virtex-redcaps / virtex /models /zero_shot_classification_eval.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw history blame
No virus
4.51 kB
import copy
import functools
from typing import Any, Dict
import json
import torch
from torch import nn
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone
class ZeroShotClassifier(nn.Module):
def __init__(
self,
visual: VisualBackbone,
textual: TextualHead,
):
super().__init__()
self.visual = visual
self.textual = textual
self.padding_idx = self.textual.padding_idx
# Clone the textual module for backward direction if doing captioning
# in both directions (separately).
self.backward_textual = copy.deepcopy(self.textual)
# Share weights for visual projection, and input/output embeddings.
self.backward_textual.visual_projection = self.textual.visual_projection
self.backward_textual.embedding = self.textual.embedding
self.backward_textual.output = self.textual.output
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx,reduction='none')
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
# shape: (batch_size, channels, height, width)
visual_features = self.visual(batch["image"])
batch_size = visual_features.size(0)
classification_losses = []
#catagories shape: (1000, 20)
caption_tokens = batch["caption_tokens"]
backward_caption_tokens = batch["noitpac_tokens"]
caption_lengths = batch["caption_lengths"]
print
for i in range(caption_tokens.shape[0]):
# shape : (batch size, 20)
catagory_caption_tokens = caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
# shape : (batch size, 20)
catagory_backward_caption_tokens = backward_caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
# shape : (batch size)
catagory_caption_lengths = caption_lengths[i].unsqueeze(0).repeat(batch_size)
#print("caption_tokens.shape:",caption_tokens.shape)
#print("backward_caption_tokens.shape:",backward_caption_tokens.shape)
#print("caption_lengths.shape:",caption_lengths.shape)
#print("catagory_caption_tokens.shape:",catagory_caption_tokens.shape)
#print("catagory_backward_caption_tokens.shape:",catagory_backward_caption_tokens.shape)
#print("catagory_caption_lengths.shape:",catagory_caption_lengths.shape)
output_logits = self.textual(
visual_features, catagory_caption_tokens, catagory_caption_lengths
)
loss = self.loss(
output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
catagory_caption_tokens[:, 1:].contiguous().view(-1)
)
# Do captioning in backward direction if specified.
backward_output_logits = self.backward_textual(
visual_features, catagory_backward_caption_tokens, catagory_caption_lengths
)
backward_loss = self.loss(
backward_output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
catagory_backward_caption_tokens[:, 1:].contiguous().view(-1),
)
loss = loss.view(batch_size,-1).sum(dim=1)
backward_loss = backward_loss.view(batch_size,-1).sum(dim=1)
total_scores = (-loss - backward_loss)/catagory_caption_lengths
#print("loss.shape:",loss.shape)
#print("backward_loss.shape:",backward_loss.shape)
#print("loss.shape:",loss.shape)
#scores_caption = [torch.sum(x) for x in torch.chunk(loss, batch_size)]
#scores_noipac = [torch.sum(x) for x in torch.chunk(backward_loss, batch_size)]
#total_scores = [(scores_caption[j]+scores_noipac[j]).item() for j in range(batch_size)]
classification_losses.append(total_scores)
#classification_losses = torch.tensor(classification_losses)
classification_losses = torch.stack(classification_losses).t()
return classification_losses