File size: 6,310 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.dist import all_gather
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class ITCHead(BaseModule):
    """Image-text matching head for multi-modal pre-trained task. Adapted by
    BLIP, ALBEF. Normally used for retrieval task.

    Args:
        embed_dim (int): Embed channel size for queue.
        queue_size (int): Queue size for image and text. Defaults to 57600.
        temperature (float): Temperature to calculate the similarity.
            Defaults to 0.07.
        use_distill (bool): Whether to use distill to calculate loss.
            Defaults to True.
        alpha (float): Weight for momentum similarity. Defaults to 0.4.
        init_cfg (dict, optional): the config to control the initialization.
            Defaults to None.
    """

    def __init__(self,
                 embed_dim: int,
                 queue_size: int = 57600,
                 temperature: float = 0.07,
                 use_distill: bool = True,
                 alpha: float = 0.4,
                 init_cfg: Optional[dict] = None):
        super(ITCHead, self).__init__(init_cfg=init_cfg)
        self.temp = nn.Parameter(temperature * torch.ones([]))
        self.use_distill = use_distill
        if self.use_distill:
            # create the queue
            self.register_buffer('image_queue',
                                 torch.randn(embed_dim, queue_size))
            self.register_buffer('text_queue',
                                 torch.randn(embed_dim, queue_size))
            self.register_buffer('idx_queue', torch.full((1, queue_size),
                                                         -100))
            self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))

            self.image_queue = F.normalize(self.image_queue, dim=0)
            self.text_queue = F.normalize(self.text_queue, dim=0)

            self.queue_size = queue_size
            # This value will be warmup by `WarmupParamHook`
            self.alpha = alpha

    def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The forward process."""
        return feats[-1]

    def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict:
        """Calculate losses from the classification score.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
                Multiple stage inputs are acceptable but only the last stage
                will be used to classify. The shape of every item should be
                ``(num_samples, num_classes)``.
            data_samples (List[ClsDataSample]): The annotation data of
                every samples.
            **kwargs: Other keyword arguments to forward the loss module.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """

        # The part can be traced by torch.fx
        img_feats, text_feats, img_feats_m, text_feats_m = self(feats)

        img_feats_all = torch.cat(
            [img_feats_m.t(),
             self.image_queue.clone().detach()], dim=1)
        text_feats_all = torch.cat(
            [text_feats_m.t(),
             self.text_queue.clone().detach()], dim=1)

        # The part can not be traced by torch.fx
        losses = self._get_loss(img_feats, text_feats, img_feats_m,
                                text_feats_m, img_feats_all, text_feats_all,
                                data_samples, **kwargs)
        return losses

    def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m,
                  img_feats_all, text_feats_all, data_samples, **kwargs):
        """Unpack data samples and compute loss."""

        idx = torch.tensor([ds.image_id
                            for ds in data_samples]).to(img_feats.device)
        idx = idx.view(-1, 1)
        idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
        pos_idx = torch.eq(idx, idx_all).float()
        sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)

        with torch.no_grad():
            if self.use_distill:
                sim_i2t_m = img_feats_m @ text_feats_all / self.temp
                sim_t2i_m = text_feats_m @ img_feats_all / self.temp

                sim_i2t_targets = (
                    self.alpha * F.softmax(sim_i2t_m, dim=1) +
                    (1 - self.alpha) * sim_targets)
                sim_t2i_targets = (
                    self.alpha * F.softmax(sim_t2i_m, dim=1) +
                    (1 - self.alpha) * sim_targets)

        sim_i2t = img_feats @ text_feats_all / self.temp
        sim_t2i = text_feats @ img_feats_all / self.temp

        if self.use_distill:
            loss_i2t = -torch.sum(
                F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
            loss_t2i = -torch.sum(
                F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
        else:
            loss_i2t = -torch.sum(
                F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
            loss_t2i = -torch.sum(
                F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()

        # compute loss
        losses = dict()

        losses['itc_loss'] = (loss_i2t + loss_t2i) / 2
        self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx)
        return losses

    @torch.no_grad()
    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
        # gather keys before updating queue
        image_feats = torch.cat(all_gather(image_feat))
        text_feats = torch.cat(all_gather(text_feat))

        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)
        assert self.queue_size % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
        self.text_queue[:, ptr:ptr + batch_size] = text_feats.T

        if idxs is not None:
            idxs = torch.cat(all_gather(idxs))
            self.idx_queue[:, ptr:ptr + batch_size] = idxs.T

        ptr = (ptr + batch_size) % self.queue_size  # move pointer
        self.queue_ptr[0] = ptr