# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.distributed as dist import torch.nn.functional as F from torch import nn import logging logger = logging.getLogger("dinov2") try: from xformers.ops import cross_entropy def lossfunc(t, s, temp): s = s.float() t = t.float() if s.ndim == 2: return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) elif s.ndim == 3: return -cross_entropy(s, t, temp, bw_inplace=True) except ImportError: def lossfunc(t, s, temp): return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) class iBOTPatchLoss(nn.Module): def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): super().__init__() self.student_temp = student_temp self.center_momentum = center_momentum self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) self.updated = True self.reduce_handle = None self.len_teacher_patch_tokens = None self.async_batch_center = None @torch.no_grad() def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): self.apply_center_update() # teacher centering and sharpening # # WARNING: # as self.center is a float32, everything gets casted to float32 afterwards # # teacher_patch_tokens = teacher_patch_tokens.float() # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) # this is experimental, keep everything in float16 and let's see what happens: # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) @torch.no_grad() def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): teacher_output = teacher_output.float() # world_size = dist.get_world_size() if dist.is_initialized() else 1 Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper # B = Q.shape[1] * world_size # number of samples to assign B = n_masked_patches_tensor dist.all_reduce(B) K = Q.shape[0] # how many prototypes # make the matrix sums to 1 sum_Q = torch.sum(Q) if dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q for it in range(n_iterations): # normalize each row: total weight per prototype must be 1/K sum_of_rows = torch.sum(Q, dim=1, keepdim=True) if dist.is_initialized(): dist.all_reduce(sum_of_rows) Q /= sum_of_rows Q /= K # normalize each column: total weight per sample must be 1/B Q /= torch.sum(Q, dim=0, keepdim=True) Q /= B Q *= B # the columns must sum to 1 so that Q is an assignment return Q.t() def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): """ Cross-entropy between softmax outputs of the teacher and student networks. student_patch_tokens: (B, N, D) tensor teacher_patch_tokens: (B, N, D) tensor student_masks_flat: (B, N) tensor """ t = teacher_patch_tokens s = student_patch_tokens loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) return -loss.mean() def forward_masked( self, student_patch_tokens_masked, teacher_patch_tokens_masked, student_masks_flat, n_masked_patches=None, masks_weight=None, ): t = teacher_patch_tokens_masked s = student_patch_tokens_masked # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) loss = lossfunc(t, s, self.student_temp) if masks_weight is None: masks_weight = ( (1 / student_masks_flat.sum(-1).clamp(min=1.0)) .unsqueeze(-1) .expand_as(student_masks_flat)[student_masks_flat] ) if n_masked_patches is not None: loss = loss[:n_masked_patches] loss = loss * masks_weight return -loss.sum() / student_masks_flat.shape[0] @torch.no_grad() def update_center(self, teacher_patch_tokens): self.reduce_center_update(teacher_patch_tokens) @torch.no_grad() def reduce_center_update(self, teacher_patch_tokens): self.updated = False self.len_teacher_patch_tokens = len(teacher_patch_tokens) self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) if dist.is_initialized(): self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) @torch.no_grad() def apply_center_update(self): if self.updated is False: world_size = dist.get_world_size() if dist.is_initialized() else 1 if self.reduce_handle is not None: self.reduce_handle.wait() _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) self.updated = True