File size: 4,506 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import copy
import functools
from typing import Any, Dict

import json 

import torch
from torch import nn

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone


class ZeroShotClassifier(nn.Module):
    def __init__(
        self,
        visual: VisualBackbone,
        textual: TextualHead,
    ):
        super().__init__()
        self.visual = visual
        self.textual = textual
        self.padding_idx = self.textual.padding_idx

        # Clone the textual module for backward direction if doing captioning
        # in both directions (separately).
        self.backward_textual = copy.deepcopy(self.textual)

        # Share weights for visual projection, and input/output embeddings.
        self.backward_textual.visual_projection = self.textual.visual_projection
        self.backward_textual.embedding = self.textual.embedding
        self.backward_textual.output = self.textual.output
   
        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx,reduction='none')

    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        
        # shape: (batch_size, channels, height, width)
        visual_features = self.visual(batch["image"])
        batch_size = visual_features.size(0)
        
        classification_losses = []
          
        #catagories shape: (1000, 20)
        
        caption_tokens = batch["caption_tokens"]
        backward_caption_tokens = batch["noitpac_tokens"]
        caption_lengths = batch["caption_lengths"]
        print

        for i in range(caption_tokens.shape[0]):
            # shape : (batch size, 20)
            catagory_caption_tokens = caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
            # shape : (batch size, 20)
            catagory_backward_caption_tokens = backward_caption_tokens[i,:].unsqueeze(0).repeat(batch_size,1)
            # shape : (batch size)
            catagory_caption_lengths = caption_lengths[i].unsqueeze(0).repeat(batch_size)
            
            #print("caption_tokens.shape:",caption_tokens.shape)
            #print("backward_caption_tokens.shape:",backward_caption_tokens.shape)
            #print("caption_lengths.shape:",caption_lengths.shape)
            
            #print("catagory_caption_tokens.shape:",catagory_caption_tokens.shape)
            #print("catagory_backward_caption_tokens.shape:",catagory_backward_caption_tokens.shape)
            #print("catagory_caption_lengths.shape:",catagory_caption_lengths.shape)
           
            output_logits = self.textual(
                visual_features, catagory_caption_tokens, catagory_caption_lengths
            )
            

            loss = self.loss(
                output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
                catagory_caption_tokens[:, 1:].contiguous().view(-1)
            )
            
            # Do captioning in backward direction if specified.
            backward_output_logits = self.backward_textual(
                visual_features, catagory_backward_caption_tokens, catagory_caption_lengths
            )
            
            
            backward_loss = self.loss(
                backward_output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
                catagory_backward_caption_tokens[:, 1:].contiguous().view(-1),
            )
            loss = loss.view(batch_size,-1).sum(dim=1)
            backward_loss = backward_loss.view(batch_size,-1).sum(dim=1)
            
            total_scores = (-loss - backward_loss)/catagory_caption_lengths
            
            
            #print("loss.shape:",loss.shape)
            #print("backward_loss.shape:",backward_loss.shape)
            #print("loss.shape:",loss.shape)
            
            #scores_caption = [torch.sum(x) for x in torch.chunk(loss, batch_size)]
            #scores_noipac = [torch.sum(x) for x in torch.chunk(backward_loss, batch_size)]
            
            #total_scores = [(scores_caption[j]+scores_noipac[j]).item() for j in range(batch_size)]
            
            classification_losses.append(total_scores)
            
            
        #classification_losses = torch.tensor(classification_losses)
        classification_losses = torch.stack(classification_losses).t()

        return classification_losses