# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Union import torch import torch.nn as nn from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample from ..utils import CosineEMA from .base import BaseSelfSupervisor @MODELS.register_module() class BYOL(BaseSelfSupervisor): """BYOL. Implementation of `Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning `_. Args: backbone (dict): Config dict for module of backbone. neck (dict): Config dict for module of deep features to compact feature vectors. head (dict): Config dict for module of head functions. base_momentum (float): The base momentum coefficient for the target network. Defaults to 0.004. pretrained (str, optional): The pretrained checkpoint path, support local path and remote path. Defaults to None. data_preprocessor (dict, optional): The config for preprocessing input data. If None or no specified type, it will use "SelfSupDataPreprocessor" as type. See :class:`SelfSupDataPreprocessor` for more details. Defaults to None. init_cfg (Union[List[dict], dict], optional): Config dict for weight initialization. Defaults to None. """ def __init__(self, backbone: dict, neck: dict, head: dict, base_momentum: float = 0.004, pretrained: Optional[str] = None, data_preprocessor: Optional[dict] = None, init_cfg: Optional[Union[List[dict], dict]] = None) -> None: super().__init__( backbone=backbone, neck=neck, head=head, pretrained=pretrained, data_preprocessor=data_preprocessor, init_cfg=init_cfg) # create momentum model self.target_net = CosineEMA( nn.Sequential(self.backbone, self.neck), momentum=base_momentum) def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], **kwargs) -> Dict[str, torch.Tensor]: """The forward function in training. Args: inputs (List[torch.Tensor]): The input images. data_samples (List[DataSample]): All elements required during the forward function. Returns: Dict[str, torch.Tensor]: A dictionary of loss components. """ assert isinstance(inputs, list) img_v1 = inputs[0] img_v2 = inputs[1] # compute online features proj_online_v1 = self.neck(self.backbone(img_v1))[0] proj_online_v2 = self.neck(self.backbone(img_v2))[0] # compute target features with torch.no_grad(): # update the target net self.target_net.update_parameters( nn.Sequential(self.backbone, self.neck)) proj_target_v1 = self.target_net(img_v1)[0] proj_target_v2 = self.target_net(img_v2)[0] loss_1 = self.head.loss(proj_online_v1, proj_target_v2) loss_2 = self.head.loss(proj_online_v2, proj_target_v1) losses = dict(loss=2. * (loss_1 + loss_2)) return losses