| | import torch |
| | import torch.nn as nn |
| | import utils |
| |
|
| | from utils import trunc_normal_ |
| |
|
| | class CSyncBatchNorm(nn.SyncBatchNorm): |
| | def __init__(self, |
| | *args, |
| | with_var=False, |
| | **kwargs): |
| | super(CSyncBatchNorm, self).__init__(*args, **kwargs) |
| | self.with_var = with_var |
| |
|
| | def forward(self, x): |
| | |
| | self.training = False |
| | if not self.with_var: |
| | self.running_var = torch.ones_like(self.running_var) |
| | normed_x = super(CSyncBatchNorm, self).forward(x) |
| | |
| | self.training = True |
| | _ = super(CSyncBatchNorm, self).forward(x) |
| | return normed_x |
| |
|
| | class PSyncBatchNorm(nn.SyncBatchNorm): |
| | def __init__(self, |
| | *args, |
| | bunch_size, |
| | **kwargs): |
| | procs_per_bunch = min(bunch_size, utils.get_world_size()) |
| | assert utils.get_world_size() % procs_per_bunch == 0 |
| | n_bunch = utils.get_world_size() // procs_per_bunch |
| | |
| | ranks = list(range(utils.get_world_size())) |
| | print('---ALL RANKS----\n{}'.format(ranks)) |
| | rank_groups = [ranks[i*procs_per_bunch: (i+1)*procs_per_bunch] for i in range(n_bunch)] |
| | print('---RANK GROUPS----\n{}'.format(rank_groups)) |
| | process_groups = [torch.distributed.new_group(pids) for pids in rank_groups] |
| | bunch_id = utils.get_rank() // procs_per_bunch |
| | process_group = process_groups[bunch_id] |
| | print('---CURRENT GROUP----\n{}'.format(process_group)) |
| | super(PSyncBatchNorm, self).__init__(*args, process_group=process_group, **kwargs) |
| |
|
| | class CustomSequential(nn.Sequential): |
| | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
| |
|
| | def forward(self, input): |
| | for module in self: |
| | dim = len(input.shape) |
| | if isinstance(module, self.bn_types) and dim > 2: |
| | perm = list(range(dim - 1)); perm.insert(1, dim - 1) |
| | inv_perm = list(range(dim)) + [1]; inv_perm.pop(1) |
| | input = module(input.permute(*perm)).permute(*inv_perm) |
| | else: |
| | input = module(input) |
| | return input |
| |
|
| | class DINOHead(nn.Module): |
| | def __init__(self, in_dim, out_dim, norm=None, act='gelu', last_norm=None, |
| | nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, **kwargs): |
| | super().__init__() |
| | norm = self._build_norm(norm, hidden_dim) |
| | last_norm = self._build_norm(last_norm, out_dim, affine=False, **kwargs) |
| | act = self._build_act(act) |
| |
|
| | nlayers = max(nlayers, 1) |
| | if nlayers == 1: |
| | if bottleneck_dim > 0: |
| | self.mlp = nn.Linear(in_dim, bottleneck_dim) |
| | else: |
| | self.mlp = nn.Linear(in_dim, out_dim) |
| | else: |
| | layers = [nn.Linear(in_dim, hidden_dim)] |
| | if norm is not None: |
| | layers.append(norm) |
| | layers.append(act) |
| | for _ in range(nlayers - 2): |
| | layers.append(nn.Linear(hidden_dim, hidden_dim)) |
| | if norm is not None: |
| | layers.append(norm) |
| | layers.append(act) |
| | if bottleneck_dim > 0: |
| | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) |
| | else: |
| | layers.append(nn.Linear(hidden_dim, out_dim)) |
| | self.mlp = CustomSequential(*layers) |
| | self.apply(self._init_weights) |
| | |
| | if bottleneck_dim > 0: |
| | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) |
| | self.last_layer.weight_g.data.fill_(1) |
| | if norm_last_layer: |
| | self.last_layer.weight_g.requires_grad = False |
| | else: |
| | self.last_layer = None |
| |
|
| | self.last_norm = last_norm |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | trunc_normal_(m.weight, std=.02) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def forward(self, x): |
| | x = self.mlp(x) |
| | if self.last_layer is not None: |
| | x = nn.functional.normalize(x, dim=-1, p=2) |
| | x = self.last_layer(x) |
| | if self.last_norm is not None: |
| | x = self.last_norm(x) |
| | return x |
| |
|
| | def _build_norm(self, norm, hidden_dim, **kwargs): |
| | if norm == 'bn': |
| | norm = nn.BatchNorm1d(hidden_dim, **kwargs) |
| | elif norm == 'syncbn': |
| | norm = nn.SyncBatchNorm(hidden_dim, **kwargs) |
| | elif norm == 'csyncbn': |
| | norm = CSyncBatchNorm(hidden_dim, **kwargs) |
| | elif norm == 'psyncbn': |
| | norm = PSyncBatchNorm(hidden_dim, **kwargs) |
| | elif norm == 'ln': |
| | norm = nn.LayerNorm(hidden_dim, **kwargs) |
| | else: |
| | assert norm is None, "unknown norm type {}".format(norm) |
| | return norm |
| |
|
| | def _build_act(self, act): |
| | if act == 'relu': |
| | act = nn.ReLU() |
| | elif act == 'gelu': |
| | act = nn.GELU() |
| | else: |
| | assert False, "unknown act type {}".format(act) |
| | return act |
| |
|
| | class iBOTHead(DINOHead): |
| |
|
| | def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, |
| | nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, |
| | shared_head=False, **kwargs): |
| | |
| | super(iBOTHead, self).__init__(*args, |
| | norm=norm, |
| | act=act, |
| | last_norm=last_norm, |
| | nlayers=nlayers, |
| | hidden_dim=hidden_dim, |
| | bottleneck_dim=bottleneck_dim, |
| | norm_last_layer=norm_last_layer, |
| | **kwargs) |
| |
|
| | if not shared_head: |
| | if bottleneck_dim > 0: |
| | self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False)) |
| | self.last_layer2.weight_g.data.fill_(1) |
| | if norm_last_layer: |
| | self.last_layer2.weight_g.requires_grad = False |
| | else: |
| | self.mlp2 = nn.Linear(hidden_dim, patch_out_dim) |
| | self.last_layer2 = None |
| |
|
| | self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs) |
| | else: |
| | if bottleneck_dim > 0: |
| | self.last_layer2 = self.last_layer |
| | else: |
| | self.mlp2 = self.mlp[-1] |
| | self.last_layer2 = None |
| |
|
| | self.last_norm2 = self.last_norm |
| |
|
| | def forward(self, x): |
| | if len(x.shape) == 2: |
| | return super(iBOTHead, self).forward(x) |
| |
|
| | if self.last_layer is not None: |
| | x = self.mlp(x) |
| | x = nn.functional.normalize(x, dim=-1, p=2) |
| | x1 = self.last_layer(x[:, 0]) |
| | x2 = self.last_layer2(x[:, 1:]) |
| | else: |
| | x = self.mlp[:-1](x) |
| | x1 = self.mlp[-1](x[:, 0]) |
| | x2 = self.mlp2(x[:, 1:]) |
| | |
| | if self.last_norm is not None: |
| | x1 = self.last_norm(x1) |
| | x2 = self.last_norm2(x2) |
| | |
| | return x1, x2 |
| |
|
| |
|
| |
|
| | class TemporalSideContext(nn.Module): |
| | def __init__(self, D, max_len=64, n_layers=6, n_head=8, dropout=0.1): |
| | super().__init__() |
| | |
| | layer = nn.TransformerEncoderLayer(D, n_head, 4*D, |
| | dropout=dropout, batch_first=True) |
| | self.enc = nn.TransformerEncoder(layer, n_layers) |
| |
|
| | def forward(self, x): |
| | B,T,D = x.shape |
| | device = x.device |
| | |
| | |
| | |
| | |
| | return self.enc(x) |
| |
|
| |
|
| |
|
| | class TemporalHead(nn.Module): |
| | """ |
| | Converts backbone features [B,T,D] → logits [B,T,1] for Plackett–Luce. |
| | """ |
| | def __init__(self, backbone_dim: int, hidden_mul: float = 0.5, max_len: int = 64): |
| | super().__init__() |
| | hidden_dim = int(backbone_dim * hidden_mul) |
| |
|
| | self.reduce = nn.Sequential( |
| | nn.Linear(backbone_dim, hidden_dim), |
| | nn.GELU() |
| | ) |
| | self.temporal = TemporalSideContext(hidden_dim, max_len=max_len) |
| | self.scorer = nn.Sequential( |
| | nn.Linear(hidden_dim, hidden_dim // 2), |
| | nn.GELU(), |
| | nn.Linear(hidden_dim // 2, 1) |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | x = self.reduce(x) |
| | x = self.temporal(x) |
| | return self.scorer(x) |
| |
|
| |
|
| |
|
| |
|