|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class DINOLoss(nn.Module): |
|
def __init__( |
|
self, |
|
out_dim, |
|
student_temp=0.1, |
|
center_momentum=0.9, |
|
): |
|
super().__init__() |
|
self.student_temp = student_temp |
|
self.center_momentum = center_momentum |
|
self.register_buffer("center", torch.zeros(1, out_dim)) |
|
self.updated = True |
|
self.reduce_handle = None |
|
self.len_teacher_output = None |
|
self.async_batch_center = None |
|
|
|
@torch.no_grad() |
|
def softmax_center_teacher(self, teacher_output, teacher_temp): |
|
self.apply_center_update() |
|
|
|
return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) |
|
|
|
@torch.no_grad() |
|
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): |
|
teacher_output = teacher_output.float() |
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
Q = torch.exp(teacher_output / teacher_temp).t() |
|
B = Q.shape[1] * world_size |
|
K = Q.shape[0] |
|
|
|
|
|
sum_Q = torch.sum(Q) |
|
if dist.is_initialized(): |
|
dist.all_reduce(sum_Q) |
|
Q /= sum_Q |
|
|
|
for it in range(n_iterations): |
|
|
|
sum_of_rows = torch.sum(Q, dim=1, keepdim=True) |
|
if dist.is_initialized(): |
|
dist.all_reduce(sum_of_rows) |
|
Q /= sum_of_rows |
|
Q /= K |
|
|
|
|
|
Q /= torch.sum(Q, dim=0, keepdim=True) |
|
Q /= B |
|
|
|
Q *= B |
|
return Q.t() |
|
|
|
def forward(self, student_output_list, teacher_out_softmaxed_centered_list): |
|
""" |
|
Cross-entropy between softmax outputs of the teacher and student networks. |
|
""" |
|
|
|
total_loss = 0 |
|
for s in student_output_list: |
|
lsm = F.log_softmax(s / self.student_temp, dim=-1) |
|
for t in teacher_out_softmaxed_centered_list: |
|
loss = torch.sum(t * lsm, dim=-1) |
|
total_loss -= loss.mean() |
|
return total_loss |
|
|
|
@torch.no_grad() |
|
def update_center(self, teacher_output): |
|
self.reduce_center_update(teacher_output) |
|
|
|
@torch.no_grad() |
|
def reduce_center_update(self, teacher_output): |
|
self.updated = False |
|
self.len_teacher_output = len(teacher_output) |
|
self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) |
|
if dist.is_initialized(): |
|
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) |
|
|
|
@torch.no_grad() |
|
def apply_center_update(self): |
|
if self.updated is False: |
|
world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
|
if self.reduce_handle is not None: |
|
self.reduce_handle.wait() |
|
_t = self.async_batch_center / (self.len_teacher_output * world_size) |
|
|
|
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) |
|
|
|
self.updated = True |
|
|