Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.dist import all_gather | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class ITCHead(BaseModule): | |
"""Image-text matching head for multi-modal pre-trained task. Adapted by | |
BLIP, ALBEF. Normally used for retrieval task. | |
Args: | |
embed_dim (int): Embed channel size for queue. | |
queue_size (int): Queue size for image and text. Defaults to 57600. | |
temperature (float): Temperature to calculate the similarity. | |
Defaults to 0.07. | |
use_distill (bool): Whether to use distill to calculate loss. | |
Defaults to True. | |
alpha (float): Weight for momentum similarity. Defaults to 0.4. | |
init_cfg (dict, optional): the config to control the initialization. | |
Defaults to None. | |
""" | |
def __init__(self, | |
embed_dim: int, | |
queue_size: int = 57600, | |
temperature: float = 0.07, | |
use_distill: bool = True, | |
alpha: float = 0.4, | |
init_cfg: Optional[dict] = None): | |
super(ITCHead, self).__init__(init_cfg=init_cfg) | |
self.temp = nn.Parameter(temperature * torch.ones([])) | |
self.use_distill = use_distill | |
if self.use_distill: | |
# create the queue | |
self.register_buffer('image_queue', | |
torch.randn(embed_dim, queue_size)) | |
self.register_buffer('text_queue', | |
torch.randn(embed_dim, queue_size)) | |
self.register_buffer('idx_queue', torch.full((1, queue_size), | |
-100)) | |
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) | |
self.image_queue = F.normalize(self.image_queue, dim=0) | |
self.text_queue = F.normalize(self.text_queue, dim=0) | |
self.queue_size = queue_size | |
# This value will be warmup by `WarmupParamHook` | |
self.alpha = alpha | |
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: | |
"""The forward process.""" | |
return feats[-1] | |
def loss(self, feats: Tuple[torch.Tensor], data_samples, **kwargs) -> dict: | |
"""Calculate losses from the classification score. | |
Args: | |
feats (tuple[Tensor]): The features extracted from the backbone. | |
Multiple stage inputs are acceptable but only the last stage | |
will be used to classify. The shape of every item should be | |
``(num_samples, num_classes)``. | |
data_samples (List[ClsDataSample]): The annotation data of | |
every samples. | |
**kwargs: Other keyword arguments to forward the loss module. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
# The part can be traced by torch.fx | |
img_feats, text_feats, img_feats_m, text_feats_m = self(feats) | |
img_feats_all = torch.cat( | |
[img_feats_m.t(), | |
self.image_queue.clone().detach()], dim=1) | |
text_feats_all = torch.cat( | |
[text_feats_m.t(), | |
self.text_queue.clone().detach()], dim=1) | |
# The part can not be traced by torch.fx | |
losses = self._get_loss(img_feats, text_feats, img_feats_m, | |
text_feats_m, img_feats_all, text_feats_all, | |
data_samples, **kwargs) | |
return losses | |
def _get_loss(self, img_feats, text_feats, img_feats_m, text_feats_m, | |
img_feats_all, text_feats_all, data_samples, **kwargs): | |
"""Unpack data samples and compute loss.""" | |
idx = torch.tensor([ds.image_id | |
for ds in data_samples]).to(img_feats.device) | |
idx = idx.view(-1, 1) | |
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) | |
pos_idx = torch.eq(idx, idx_all).float() | |
sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) | |
with torch.no_grad(): | |
if self.use_distill: | |
sim_i2t_m = img_feats_m @ text_feats_all / self.temp | |
sim_t2i_m = text_feats_m @ img_feats_all / self.temp | |
sim_i2t_targets = ( | |
self.alpha * F.softmax(sim_i2t_m, dim=1) + | |
(1 - self.alpha) * sim_targets) | |
sim_t2i_targets = ( | |
self.alpha * F.softmax(sim_t2i_m, dim=1) + | |
(1 - self.alpha) * sim_targets) | |
sim_i2t = img_feats @ text_feats_all / self.temp | |
sim_t2i = text_feats @ img_feats_all / self.temp | |
if self.use_distill: | |
loss_i2t = -torch.sum( | |
F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean() | |
loss_t2i = -torch.sum( | |
F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean() | |
else: | |
loss_i2t = -torch.sum( | |
F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1).mean() | |
loss_t2i = -torch.sum( | |
F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1).mean() | |
# compute loss | |
losses = dict() | |
losses['itc_loss'] = (loss_i2t + loss_t2i) / 2 | |
self._dequeue_and_enqueue(img_feats_m, text_feats_m, idx) | |
return losses | |
def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): | |
# gather keys before updating queue | |
image_feats = torch.cat(all_gather(image_feat)) | |
text_feats = torch.cat(all_gather(text_feat)) | |
batch_size = image_feats.shape[0] | |
ptr = int(self.queue_ptr) | |
assert self.queue_size % batch_size == 0 # for simplicity | |
# replace the keys at ptr (dequeue and enqueue) | |
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T | |
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T | |
if idxs is not None: | |
idxs = torch.cat(all_gather(idxs)) | |
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T | |
ptr = (ptr + batch_size) % self.queue_size # move pointer | |
self.queue_ptr[0] = ptr | |