Spaces:
Runtime error
Runtime error
File size: 7,327 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmengine.dist import all_reduce
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@torch.no_grad()
def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int,
world_size: int, epsilon: float) -> torch.Tensor:
"""Apply the distributed sinknorn optimization on the scores matrix to find
the assignments.
This function is modified from
https://github.com/facebookresearch/swav/blob/main/main_swav.py
Args:
out (torch.Tensor): The scores matrix
sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp
algorithm.
world_size (int): The world size of the process group.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Returns:
torch.Tensor: Output of sinkhorn algorithm.
"""
eps_num_stab = 1e-12
Q = torch.exp(out / epsilon).t(
) # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
all_reduce(sum_Q)
Q /= sum_Q
for it in range(sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
u = torch.sum(Q, dim=1, keepdim=True)
if len(torch.nonzero(u == 0)) > 0:
Q += eps_num_stab
u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype)
all_reduce(u)
Q /= u
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t()
class MultiPrototypes(BaseModule):
"""Multi-prototypes for SwAV head.
Args:
output_dim (int): The output dim from SwAV neck.
num_prototypes (List[int]): The number of prototypes needed.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
output_dim: int,
num_prototypes: List[int],
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(num_prototypes, list)
self.num_heads = len(num_prototypes)
for i, k in enumerate(num_prototypes):
self.add_module('prototypes' + str(i),
nn.Linear(output_dim, k, bias=False))
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Run forward for every prototype."""
out = []
for i in range(self.num_heads):
out.append(getattr(self, 'prototypes' + str(i))(x))
return out
@MODELS.register_module()
class SwAVLoss(BaseModule):
"""The Loss for SwAV.
This Loss contains clustering and sinkhorn algorithms to compute Q codes.
Part of the code is borrowed from `script
<https://github.com/facebookresearch/swav>`_.
The queue is built in `engine/hooks/swav_hook.py`.
Args:
feat_dim (int): feature dimension of the prototypes.
sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp
algorithm. Defaults to 3.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Defaults to 0.05.
temperature (float): temperature parameter in training loss.
Defaults to 0.1.
crops_for_assign (List[int]): list of crops id used for computing
assignments. Defaults to [0, 1].
num_crops (List[int]): list of number of crops. Defaults to [2].
num_prototypes (int): number of prototypes. Defaults to 3000.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
feat_dim: int,
sinkhorn_iterations: int = 3,
epsilon: float = 0.05,
temperature: float = 0.1,
crops_for_assign: List[int] = [0, 1],
num_crops: List[int] = [2],
num_prototypes: int = 3000,
init_cfg: Optional[Union[List[dict], dict]] = None):
super().__init__(init_cfg=init_cfg)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.num_crops = num_crops
self.use_queue = False
self.queue = None
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
# prototype layer
self.prototypes = None
if isinstance(num_prototypes, list):
self.prototypes = MultiPrototypes(feat_dim, num_prototypes)
elif num_prototypes > 0:
self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False)
assert self.prototypes is not None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function of SwAV loss.
Args:
x (torch.Tensor): NxC input features.
Returns:
torch.Tensor: The returned loss.
"""
# normalize the prototypes
with torch.no_grad():
w = self.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
self.prototypes.weight.copy_(w)
embedding, output = x, self.prototypes(x)
embedding = embedding.detach()
bs = int(embedding.size(0) / sum(self.num_crops))
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id:bs * (crop_id + 1)].detach()
# time to use the queue
if self.queue is not None:
if self.use_queue or not torch.all(self.queue[i,
-1, :] == 0):
self.use_queue = True
out = torch.cat(
(torch.mm(self.queue[i],
self.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
bs]
# get assignments (batch_size * num_prototypes)
q = distributed_sinkhorn(out, self.sinkhorn_iterations,
self.world_size, self.epsilon)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
x = output[bs * v:bs * (v + 1)] / self.temperature
subloss -= torch.mean(
torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(self.num_crops) - 1)
loss /= len(self.crops_for_assign)
return loss
|