Spaces:
Runtime error
Runtime error
import copy | |
import functools | |
from typing import Any, Dict | |
import json | |
import torch | |
from torch import nn | |
from virtex.data.tokenizers import SentencePieceBPETokenizer | |
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing | |
from virtex.modules.textual_heads import TextualHead | |
from virtex.modules.visual_backbones import VisualBackbone | |
class ZeroShotClassifier(nn.Module): | |
def __init__( | |
self, | |
visual: VisualBackbone, | |
textual: TextualHead, | |
): | |
super().__init__() | |
self.visual = visual | |
self.textual = textual | |
self.padding_idx = self.textual.padding_idx | |
# Clone the textual module for backward direction if doing captioning | |
# in both directions (separately). | |
self.backward_textual = copy.deepcopy(self.textual) | |
# Share weights for visual projection, and input/output embeddings. | |
self.backward_textual.visual_projection = self.textual.visual_projection | |
self.backward_textual.embedding = self.textual.embedding | |
self.backward_textual.output = self.textual.output | |
self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx,reduction='none') | |
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: | |
# shape: (batch_size, channels, height, width) | |
visual_features = self.visual(batch["image"]) | |
batch_size = visual_features.size(0) | |
classification_losses = [] | |
#catagories shape: (1000, 20) | |
caption_tokens = batch["caption_tokens"] | |
backward_caption_tokens = batch["noitpac_tokens"] | |
caption_lengths = batch["caption_lengths"] | |
for i in range(caption_tokens.shape[0]): | |
# shape : (batch size, 20) | |
catagory_caption_tokens = caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1) | |
# shape : (batch size, 20) | |
catagory_backward_caption_tokens = backward_caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1) | |
# shape : (batch size) | |
catagory_caption_lengths = caption_lengths[i].unsqueeze(0).repeat(batch_size) | |
#print("caption_tokens.shape:",caption_tokens.shape) | |
#print("backward_caption_tokens.shape:",backward_caption_tokens.shape) | |
#print("caption_lengths.shape:",caption_lengths.shape) | |
#print("catagory_caption_tokens.shape:",catagory_caption_tokens.shape) | |
#print("catagory_backward_caption_tokens.shape:",catagory_backward_caption_tokens.shape) | |
#print("catagory_caption_lengths.shape:",catagory_caption_lengths.shape) | |
output_logits = self.textual( | |
visual_features, catagory_caption_tokens, catagory_caption_lengths | |
) | |
loss = self.loss( | |
output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), | |
catagory_caption_tokens[:, 1:].contiguous().view(-1) | |
) | |
# Do captioning in backward direction if specified. | |
backward_output_logits = self.backward_textual( | |
visual_features, catagory_backward_caption_tokens, catagory_caption_lengths | |
) | |
backward_loss = self.loss( | |
backward_output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), | |
catagory_backward_caption_tokens[:, 1:].contiguous().view(-1), | |
) | |
loss = loss.view(batch_size,-1).sum(dim=1) | |
backward_loss = backward_loss.view(batch_size,-1).sum(dim=1) | |
total_scores = (-loss - backward_loss)/catagory_caption_lengths | |
#print("loss.shape:",loss.shape) | |
#print("backward_loss.shape:",backward_loss.shape) | |
#print("loss.shape:",loss.shape) | |
#scores_caption = [torch.sum(x) for x in torch.chunk(loss, batch_size)] | |
#scores_noipac = [torch.sum(x) for x in torch.chunk(backward_loss, batch_size)] | |
#total_scores = [(scores_caption[j]+scores_noipac[j]).item() for j in range(batch_size)] | |
classification_losses.append(total_scores) | |
#classification_losses = torch.tensor(classification_losses) | |
classification_losses = torch.stack(classification_losses).t() | |
return classification_losses | |