Spaces:
Runtime error
Runtime error
File size: 7,932 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmengine.dist import all_gather
from mmengine.model import ExponentialMovingAverage
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import batch_shuffle_ddp, batch_unshuffle_ddp
from .base import BaseSelfSupervisor
@MODELS.register_module()
class DenseCL(BaseSelfSupervisor):
"""DenseCL.
Implementation of `Dense Contrastive Learning for Self-Supervised Visual
Pre-Training <https://arxiv.org/abs/2011.09157>`_.
Borrowed from the authors' code: `<https://github.com/WXinlong/DenseCL>`_.
The loss_lambda warmup is in `engine/hooks/densecl_hook.py`.
Args:
backbone (dict): Config dict for module of backbone.
neck (dict): Config dict for module of deep features to compact
feature vectors.
head (dict): Config dict for module of head functions.
queue_len (int): Number of negative keys maintained in the queue.
Defaults to 65536.
feat_dim (int): Dimension of compact feature vectors. Defaults to 128.
momentum (float): Momentum coefficient for the momentum-updated
encoder. Defaults to 0.999.
loss_lambda (float): Loss weight for the single and dense contrastive
loss. Defaults to 0.5.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing
input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
queue_len: int = 65536,
feat_dim: int = 128,
momentum: float = 0.001,
loss_lambda: float = 0.5,
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.encoder_k = ExponentialMovingAverage(
nn.Sequential(self.backbone, self.neck), momentum)
self.queue_len = queue_len
self.loss_lambda = loss_lambda
# create the queue
self.register_buffer('queue', torch.randn(feat_dim, queue_len))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer('queue_ptr', torch.zeros(1, dtype=torch.long))
# create the second queue for dense output
self.register_buffer('queue2', torch.randn(feat_dim, queue_len))
self.queue2 = nn.functional.normalize(self.queue2, dim=0)
self.register_buffer('queue2_ptr', torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _dequeue_and_enqueue(self, keys: torch.Tensor) -> None:
"""Update queue."""
# gather keys before updating queue
keys = torch.cat(all_gather(keys), dim=0)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.queue_len % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
ptr = (ptr + batch_size) % self.queue_len # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _dequeue_and_enqueue2(self, keys: torch.Tensor) -> None:
"""Update queue2."""
# gather keys before updating queue
keys = torch.cat(all_gather(keys), dim=0)
batch_size = keys.shape[0]
ptr = int(self.queue2_ptr)
assert self.queue_len % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue2[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
ptr = (ptr + batch_size) % self.queue_len # move pointer
self.queue2_ptr[0] = ptr
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
assert isinstance(inputs, list)
im_q = inputs[0]
im_k = inputs[1]
# compute query features
q_b = self.backbone(im_q) # backbone features
q, q_grid, q2 = self.neck(q_b) # queries: NxC; NxCxS^2
q_b = q_b[0]
q_b = q_b.view(q_b.size(0), q_b.size(1), -1)
q = nn.functional.normalize(q, dim=1)
q2 = nn.functional.normalize(q2, dim=1)
q_grid = nn.functional.normalize(q_grid, dim=1)
q_b = nn.functional.normalize(q_b, dim=1)
# compute key features
with torch.no_grad(): # no gradient to keys
# update the key encoder
self.encoder_k.update_parameters(
nn.Sequential(self.backbone, self.neck))
# shuffle for making use of BN
im_k, idx_unshuffle = batch_shuffle_ddp(im_k)
k_b = self.encoder_k.module[0](im_k) # backbone features
k, k_grid, k2 = self.encoder_k.module[1](k_b) # keys: NxC; NxCxS^2
k_b = k_b[0]
k_b = k_b.view(k_b.size(0), k_b.size(1), -1)
k = nn.functional.normalize(k, dim=1)
k2 = nn.functional.normalize(k2, dim=1)
k_grid = nn.functional.normalize(k_grid, dim=1)
k_b = nn.functional.normalize(k_b, dim=1)
# undo shuffle
k = batch_unshuffle_ddp(k, idx_unshuffle)
k2 = batch_unshuffle_ddp(k2, idx_unshuffle)
k_grid = batch_unshuffle_ddp(k_grid, idx_unshuffle)
k_b = batch_unshuffle_ddp(k_b, idx_unshuffle)
# compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# feat point set sim
backbone_sim_matrix = torch.matmul(q_b.permute(0, 2, 1), k_b)
densecl_sim_ind = backbone_sim_matrix.max(dim=2)[1] # NxS^2
indexed_k_grid = torch.gather(k_grid, 2,
densecl_sim_ind.unsqueeze(1).expand(
-1, k_grid.size(1), -1)) # NxCxS^2
densecl_sim_q = (q_grid * indexed_k_grid).sum(1) # NxS^2
# dense positive logits: NS^2X1
l_pos_dense = densecl_sim_q.view(-1).unsqueeze(-1)
q_grid = q_grid.permute(0, 2, 1)
q_grid = q_grid.reshape(-1, q_grid.size(2))
# dense negative logits: NS^2xK
l_neg_dense = torch.einsum(
'nc,ck->nk', [q_grid, self.queue2.clone().detach()])
loss_single = self.head.loss(l_pos, l_neg)
loss_dense = self.head.loss(l_pos_dense, l_neg_dense)
losses = dict()
losses['loss_single'] = loss_single * (1 - self.loss_lambda)
losses['loss_dense'] = loss_dense * self.loss_lambda
self._dequeue_and_enqueue(k)
self._dequeue_and_enqueue2(k2)
return losses
|