# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Union import torch import torch.nn as nn from mmengine.dist import all_gather from mmengine.model import ExponentialMovingAverage from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp from .base import BaseSelfSupervisor @MODELS.register_module() class MoCo(BaseSelfSupervisor): """MoCo. Implementation of `Momentum Contrast for Unsupervised Visual Representation Learning `_. Part of the code is borrowed from: ``_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of deep features to compact feature vectors. head (dict): Config dict for module of head functions. queue_len (int): Number of negative keys maintained in the queue. Defaults to 65536. feat_dim (int): Dimension of compact feature vectors. Defaults to 128. momentum (float): Momentum coefficient for the momentum-updated encoder. Defaults to 0.001. pretrained (str, optional): The pretrained checkpoint path, support local path and remote path. Defaults to None. data_preprocessor (dict, optional): The config for preprocessing input data. If None or no specified type, it will use "SelfSupDataPreprocessor" as type. See :class:`SelfSupDataPreprocessor` for more details. Defaults to None. init_cfg (Union[List[dict], dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, queue_len: int = 65536, feat_dim: int = 128, momentum: float = 0.001, pretrained: Optional[str] = None, data_preprocessor: Optional[dict] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, data_preprocessor=data_preprocessor, init_cfg=init_cfg) # create momentum model self.encoder_k = ExponentialMovingAverage( nn.Sequential(self.backbone, self.neck), momentum) # create the queue self.queue_len = queue_len self.register_buffer('queue', torch.randn(feat_dim, queue_len)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) @torch.no_grad() def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: """Update queue.""" # gather keys before updating queue keys = torch.cat(all_gather(keys), dim=0) batch_size = keys.shape[0] ptr = int(self.queue_ptr) assert self.queue_len % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) ptr = (ptr + batch_size) % self.queue_len # move pointer self.queue_ptr[0] = ptr def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ assert isinstance(inputs, list) im_q = inputs[0] im_k = inputs[1] # compute query features from encoder_q q = self.neck(self.backbone(im_q))[0] # queries: NxC q = nn.functional.normalize(q, dim=1) # compute key features with torch.no_grad(): # no gradient to keys # update the key encoder self.encoder_k.update_parameters( nn.Sequential(self.backbone, self.neck)) # shuffle for making use of BN im_k, idx_unshuffle = batch_shuffle_ddp(im_k) k = self.encoder_k(im_k)[0] # keys: NxC k = nn.functional.normalize(k, dim=1) # undo shuffle k = batch_unshuffle_ddp(k, idx_unshuffle) # compute logits # Einstein sum is more intuitive # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) loss = self.head.loss(l_pos, l_neg) # update the queue self._dequeue_and_enqueue(k) losses = dict(loss=loss) return losses