import torch import torch.nn.functional as F from .whitening import Whitening2d from .base import BaseMethod from .norm_mse import norm_mse_loss class WMSE(BaseMethod): """ implements W-MSE loss """ def __init__(self, cfg): """ init whitening transform """ super().__init__(cfg) self.whitening = Whitening2d(cfg.emb, eps=cfg.w_eps, track_running_stats=False) self.loss_f = norm_mse_loss if cfg.norm else F.mse_loss self.w_iter = cfg.w_iter self.w_size = cfg.bs if cfg.w_size is None else cfg.w_size def forward(self, samples): bs = len(samples[0]) h = [self.model(x.cuda(non_blocking=True)) for x in samples] h = self.head(torch.cat(h)) loss = 0 for _ in range(self.w_iter): z = torch.empty_like(h) perm = torch.randperm(bs).view(-1, self.w_size) for idx in perm: for i in range(len(samples)): z[idx + i * bs] = self.whitening(h[idx + i * bs]) for i in range(len(samples) - 1): for j in range(i + 1, len(samples)): x0 = z[i * bs : (i + 1) * bs] x1 = z[j * bs : (j + 1) * bs] loss += self.loss_f(x0, x1) loss /= self.w_iter * self.num_pairs return loss