Spaces:
Runtime error
Runtime error
File size: 6,310 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.dist import all_gather
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class ITCHead(BaseModule):
"""Image-text matching head for multi-modal pre-trained task. Adapted by
BLIP, ALBEF. Normally used for retrieval task.
Args:
embed_dim (int): Embed channel size for queue.
queue_size (int): Queue size for image and text. Defaults to 57600.
temperature (float): Temperature to calculate the similarity.
Defaults to 0.07.
use_distill (bool): Whether to use distill to calculate loss.
Defaults to True.
alpha (float): Weight for momentum similarity. Defaults to 0.4.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
embed_dim: int,
queue_size: int = 57600,
temperature: float = 0.07,
use_distill: bool = True,
alpha: float = 0.4,
init_cfg: Optional[dict] = None):
super(ITCHead, self).__init__(init_cfg=init_cfg)
self.temp = nn.Parameter(temperature * torch.ones([]))
self.use_distill = use_distill
if self.use_distill:
# create the queue
self.register_buffer('image_queue',
torch.randn(embed_dim, queue_size))
self.register_buffer('text_queue',
torch.randn(embed_dim, queue_size))
self.register_buffer('idx_queue', torch.full((1, queue_size),
-100))
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
self.image_queue = F.normalize(self.image_queue, dim=0)
self.text_queue = F.normalize(self.text_queue, dim=0)
self.queue_size = queue_size
# This value will be warmup by `WarmupParamHook`
self.alpha = alpha
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
return feats[-1]
def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# The part can be traced by torch.fx
img_feats, text_feats, img_feats_m, text_feats_m = self(feats)
img_feats_all = torch.cat(
[img_feats_m.t(),
self.image_queue.clone().detach()], dim=1)
text_feats_all = torch.cat(
[text_feats_m.t(),
self.text_queue.clone().detach()], dim=1)
# The part can not be traced by torch.fx
losses = self._get_loss(img_feats, text_feats, img_feats_m,
text_feats_m, img_feats_all, text_feats_all,
data_samples, **kwargs)
return losses
def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m,
img_feats_all, text_feats_all, data_samples, **kwargs):
"""Unpack data samples and compute loss."""
idx = torch.tensor([ds.image_id
for ds in data_samples]).to(img_feats.device)
idx = idx.view(-1, 1)
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1)
pos_idx = torch.eq(idx, idx_all).float()
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
with torch.no_grad():
if self.use_distill:
sim_i2t_m = img_feats_m @ text_feats_all / self.temp
sim_t2i_m = text_feats_m @ img_feats_all / self.temp
sim_i2t_targets = (
self.alpha * F.softmax(sim_i2t_m, dim=1) +
(1 - self.alpha) * sim_targets)
sim_t2i_targets = (
self.alpha * F.softmax(sim_t2i_m, dim=1) +
(1 - self.alpha) * sim_targets)
sim_i2t = img_feats @ text_feats_all / self.temp
sim_t2i = text_feats @ img_feats_all / self.temp
if self.use_distill:
loss_i2t = -torch.sum(
F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
loss_t2i = -torch.sum(
F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
else:
loss_i2t = -torch.sum(
F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean()
loss_t2i = -torch.sum(
F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean()
# compute loss
losses = dict()
losses['itc_loss'] = (loss_i2t + loss_t2i) / 2
self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx)
return losses
@torch.no_grad()
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
# gather keys before updating queue
image_feats = torch.cat(all_gather(image_feat))
text_feats = torch.cat(all_gather(text_feat))
batch_size = image_feats.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_size % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
if idxs is not None:
idxs = torch.cat(all_gather(idxs))
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
ptr = (ptr + batch_size) % self.queue_size # move pointer
self.queue_ptr[0] = ptr
|