Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
from ..utils import CosineEMA | |
from .base import BaseSelfSupervisor | |
class BYOL(BaseSelfSupervisor): | |
"""BYOL. | |
Implementation of `Bootstrap Your Own Latent: A New Approach to | |
Self-Supervised Learning <https://arxiv.org/abs/2006.07733>`_. | |
Args: | |
backbone (dict): Config dict for module of backbone. | |
neck (dict): Config dict for module of deep features | |
to compact feature vectors. | |
head (dict): Config dict for module of head functions. | |
base_momentum (float): The base momentum coefficient for the target | |
network. Defaults to 0.004. | |
pretrained (str, optional): The pretrained checkpoint path, support | |
local path and remote path. Defaults to None. | |
data_preprocessor (dict, optional): The config for preprocessing | |
input data. If None or no specified type, it will use | |
"SelfSupDataPreprocessor" as type. | |
See :class:`SelfSupDataPreprocessor` for more details. | |
Defaults to None. | |
init_cfg (Union[List[dict], dict], optional): Config dict for weight | |
initialization. Defaults to None. | |
""" | |
def __init__(self, | |
backbone: dict, | |
neck: dict, | |
head: dict, | |
base_momentum: float = 0.004, | |
pretrained: Optional[str] = None, | |
data_preprocessor: Optional[dict] = None, | |
init_cfg: Optional[Union[List[dict], dict]] = None) -> None: | |
super().__init__( | |
backbone=backbone, | |
neck=neck, | |
head=head, | |
pretrained=pretrained, | |
data_preprocessor=data_preprocessor, | |
init_cfg=init_cfg) | |
# create momentum model | |
self.target_net = CosineEMA( | |
nn.Sequential(self.backbone, self.neck), momentum=base_momentum) | |
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], | |
**kwargs) -> Dict[str, torch.Tensor]: | |
"""The forward function in training. | |
Args: | |
inputs (List[torch.Tensor]): The input images. | |
data_samples (List[DataSample]): All elements required | |
during the forward function. | |
Returns: | |
Dict[str, torch.Tensor]: A dictionary of loss components. | |
""" | |
assert isinstance(inputs, list) | |
img_v1 = inputs[0] | |
img_v2 = inputs[1] | |
# compute online features | |
proj_online_v1 = self.neck(self.backbone(img_v1))[0] | |
proj_online_v2 = self.neck(self.backbone(img_v2))[0] | |
# compute target features | |
with torch.no_grad(): | |
# update the target net | |
self.target_net.update_parameters( | |
nn.Sequential(self.backbone, self.neck)) | |
proj_target_v1 = self.target_net(img_v1)[0] | |
proj_target_v2 = self.target_net(img_v2)[0] | |
loss_1 = self.head.loss(proj_online_v1, proj_target_v2) | |
loss_2 = self.head.loss(proj_online_v2, proj_target_v1) | |
losses = dict(loss=2. * (loss_1 + loss_2)) | |
return losses | |