mix-bt / ssl-sota /methods /norm_mse.py
wgcban's picture
Upload 98 files
803ef9e
raw
history blame
159 Bytes
import torch.nn.functional as F
def norm_mse_loss(x0, x1):
x0 = F.normalize(x0)
x1 = F.normalize(x1)
return 2 - 2 * (x0 * x1).sum(dim=-1).mean()