File size: 4,793 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Any, Dict

import torch
from torch import nn
import torch.distributed as dist

from virtex.modules.label_smoothing import CrossEntropyLossWithLabelSmoothing
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone


class ImageTextContrastiveModel(nn.Module):
    def __init__(
        self,
        visual: VisualBackbone,
        textual: TextualHead,
        label_smoothing: float = 0.0
    ):
        super().__init__()
        self.visual = visual
        self.textual = textual
        self.padding_idx = self.textual.padding_idx

        self.visual_projection = nn.Linear(
            self.visual.visual_feature_size,
            self.textual.textual_feature_size,
            bias=False,
        )
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07)))
        self.loss = CrossEntropyLossWithLabelSmoothing(
            label_smoothing, ignore_index=self.padding_idx
        )

    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:

        # Check if logit_scale needs to be clipped from last iteration.
        self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 3.912)
        # 50 times

        # shape: (batch_size, channels, height, width)
        visual_features = self.visual(batch["image"])
        batch_size = visual_features.size(0)

        # shape: (batch_size, channels)
        visual_features = visual_features.mean(dim=[2, 3]).view(batch_size, -1)

        # shape: (batch_size, textual_feature_size)
        visual_features = self.visual_projection(visual_features)

        caption_tokens = batch["caption_tokens"]
        caption_lengths = batch["caption_lengths"]

        # shape: (batch_size, max_caption_length, hidden_size)
        textual_features = self.textual(caption_tokens, caption_lengths)

        # Take features from the first time-step (as BERT-* models do).
        # shape: (batch_size, hidden_size)
        textual_features = textual_features[:, 0, :]

        # Normalize visual and textual features.
        # shape: (batch_size, textual_feature_size)
        visual_features = visual_features / visual_features.norm(dim=-1, keepdim=True)
        textual_features = textual_features / textual_features.norm(
            dim=-1, keepdim=True
        )
        # Gather textual features from all processes into one large tensor to
        # increase negative samples for contrastive learning.
        gathered_textual_features = [
            torch.zeros_like(textual_features) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(gathered_textual_features, textual_features)

        # Shift features of current rank to zeroth index for easy implementation.
        gathered_textual_features[0], gathered_textual_features[dist.get_rank()] = (
            gathered_textual_features[dist.get_rank()],
            gathered_textual_features[0],
        )
        # shape: (batch_size * world_size, textual_feature_size)
        gathered_textual_features = torch.cat(gathered_textual_features, dim=0)

        # Calculate pairwise cosine similarity as logits.
        logit_scale = self.logit_scale.exp()
        visual_logits = logit_scale * visual_features @ gathered_textual_features.t()

        # Targets are an identity matrix (image [i] should match with caption [i])
        visual_loss = self.loss(
            visual_logits, torch.arange(visual_logits.size(0)).to(visual_logits.device)
        )

        # Do the same thing for visual features.
        gathered_visual_features = [
            torch.zeros_like(visual_features) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(gathered_visual_features, visual_features)

        gathered_visual_features[0], gathered_visual_features[dist.get_rank()] = (
            gathered_visual_features[dist.get_rank()],
            gathered_visual_features[0],
        )
        # shape: (batch_size * world_size, textual_feature_size)
        gathered_visual_features = torch.cat(gathered_visual_features, dim=0)

        # Calculate pairwise cosine similarity as logits.
        logit_scale = self.logit_scale.exp()
        textual_logits = logit_scale * textual_features @ gathered_visual_features.t()

        # Targets are an identity matrix (image [i] should match with caption [i])
        textual_loss = self.loss(
            textual_logits,
            torch.arange(textual_logits.size(0)).to(textual_logits.device),
        )
        loss = 0.5 * (visual_loss + textual_loss)
        output_dict: Dict[str, Any] = {
            "loss": loss,
            # Single scalar per batch for logging in training script.
            "loss_components": {"contrastive": loss.clone().detach()},
        }

        return output_dict