Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmengine.dist import all_gather | |
from mmengine.model import ExponentialMovingAverage | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp | |
from .base import BaseSelfSupervisor | |
class MoCo(BaseSelfSupervisor): | |
"""MoCo. | |
Implementation of `Momentum Contrast for Unsupervised Visual | |
Representation Learning <https://arxiv.org/abs/1911.05722>`_. | |
Part of the code is borrowed from: | |
`<https://github.com/facebookresearch/moco/blob/master/moco/builder.py>`_. | |
Args: | |
backbone (dict): Config dict for module of backbone. | |
neck (dict): Config dict for module of deep features to compact feature | |
vectors. | |
head (dict): Config dict for module of head functions. | |
queue_len (int): Number of negative keys maintained in the | |
queue. Defaults to 65536. | |
feat_dim (int): Dimension of compact feature vectors. | |
Defaults to 128. | |
momentum (float): Momentum coefficient for the momentum-updated | |
encoder. Defaults to 0.001. | |
pretrained (str, optional): The pretrained checkpoint path, support | |
local path and remote path. Defaults to None. | |
data_preprocessor (dict, optional): The config for preprocessing | |
input data. If None or no specified type, it will use | |
"SelfSupDataPreprocessor" as type. | |
See :class:`SelfSupDataPreprocessor` for more details. | |
Defaults to None. | |
init_cfg (Union[List[dict], dict], optional): Config dict for weight | |
initialization. Defaults to None. | |
""" | |
def __init__(self, | |
backbone: dict, | |
neck: dict, | |
head: dict, | |
queue_len: int = 65536, | |
feat_dim: int = 128, | |
momentum: float = 0.001, | |
pretrained: Optional[str] = None, | |
data_preprocessor: Optional[dict] = None, | |
init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
super().__init__( | |
backbone=backbone, | |
neck=neck, | |
head=head, | |
pretrained=pretrained, | |
data_preprocessor=data_preprocessor, | |
init_cfg=init_cfg) | |
# create momentum model | |
self.encoder_k = ExponentialMovingAverage( | |
nn.Sequential(self.backbone, self.neck), momentum) | |
# create the queue | |
self.queue_len = queue_len | |
self.register_buffer('queue', torch.randn(feat_dim, queue_len)) | |
self.queue = nn.functional.normalize(self.queue, dim=0) | |
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long)) | |
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None: | |
"""Update queue.""" | |
# gather keys before updating queue | |
keys = torch.cat(all_gather(keys), dim=0) | |
batch_size = keys.shape[0] | |
ptr = int(self.queue_ptr) | |
assert self.queue_len % batch_size == 0 # for simplicity | |
# replace the keys at ptr (dequeue and enqueue) | |
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) | |
ptr = (ptr + batch_size) % self.queue_len # move pointer | |
self.queue_ptr[0] = ptr | |
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], | |
**kwargs) -> Dict[str, torch.Tensor]: | |
"""The forward function in training. | |
Args: | |
inputs (List[torch.Tensor]): The input images. | |
data_samples (List[DataSample]): All elements required | |
during the forward function. | |
Returns: | |
Dict[str, torch.Tensor]: A dictionary of loss components. | |
""" | |
assert isinstance(inputs, list) | |
im_q = inputs[0] | |
im_k = inputs[1] | |
# compute query features from encoder_q | |
q = self.neck(self.backbone(im_q))[0] # queries: NxC | |
q = nn.functional.normalize(q, dim=1) | |
# compute key features | |
with torch.no_grad(): # no gradient to keys | |
# update the key encoder | |
self.encoder_k.update_parameters( | |
nn.Sequential(self.backbone, self.neck)) | |
# shuffle for making use of BN | |
im_k, idx_unshuffle = batch_shuffle_ddp(im_k) | |
k = self.encoder_k(im_k)[0] # keys: NxC | |
k = nn.functional.normalize(k, dim=1) | |
# undo shuffle | |
k = batch_unshuffle_ddp(k, idx_unshuffle) | |
# compute logits | |
# Einstein sum is more intuitive | |
# positive logits: Nx1 | |
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) | |
# negative logits: NxK | |
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) | |
loss = self.head.loss(l_pos, l_neg) | |
# update the queue | |
self._dequeue_and_enqueue(k) | |
losses = dict(loss=loss) | |
return losses | |