MultiMAE / multimae /criterion.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame contribute delete
No virus
6.41 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on timm, DeiT, DINO, MoCo-v3, BEiT, MAE-priv and MAE code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/facebookresearch/moco-v3
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/BUPT-PRIV/MAE-priv
# https://github.com/facebookresearch/mae
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class MaskedCrossEntropyLoss(nn.Module):
"""Cross-entropy loss with masking
:param patch_size: Patch size
:param stride: Stride of task / modality
:param label_smoothing: Amount of smoothing in the loss (default is 0.0)
"""
def __init__(self, patch_size: int = 16, stride: int = 1, label_smoothing : float = 0.0):
super().__init__()
self.patch_size = patch_size
self.stride = stride
self.scale_factor = patch_size // stride
self.label_smoothing = label_smoothing
def forward(self, input, target, mask=None):
loss = F.cross_entropy(input, target, reduction='none', label_smoothing=self.label_smoothing)
if mask is not None:
if mask.sum() == 0:
return torch.tensor(0).to(loss.device)
H, W = input.shape[-2:]
nh, nw = H // self.scale_factor, W // self.scale_factor
# Resize mask and upsample
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw)
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1)
loss = loss * mask
# Compute mean per sample
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1)
loss = loss.nanmean() # Account for zero masks
else:
loss = loss.mean() # If this is ever nan, we want it to stop training
return loss
class MaskedMSELoss(nn.Module):
"""L1 loss with masking
:param patch_size: Patch size
:param stride: Stride of task / modality
:param norm_pix: Normalized pixel loss
"""
def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False):
super().__init__()
self.patch_size = patch_size
self.stride = stride
self.scale_factor = patch_size // stride
self.norm_pix = norm_pix
def patchify(self, imgs, nh, nw):
p = self.scale_factor
x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p)
return x
def unpatchify(self, x, nh, nw):
p = self.scale_factor
imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p)
return imgs
def forward(self, input, target, mask=None):
H, W = input.shape[-2:]
nh, nw = H // self.scale_factor, W // self.scale_factor
if self.norm_pix:
target = self.patchify(target, nh, nw)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
eps = 1e-6
target = (target - mean) / torch.sqrt(var + eps)
target = self.unpatchify(target, nh, nw)
loss = F.mse_loss(input, target, reduction='none')
if mask is not None:
if mask.sum() == 0:
return torch.tensor(0).to(loss.device)
# Resize mask and upsample
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw)
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1)
loss = loss.mean(dim=1) # B, C, H, W -> B, H, W
loss = loss * mask
# Compute mean per sample
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1)
loss = loss.nanmean() # Account for zero masks
else:
loss = loss.mean() # If this is ever nan, we want it to stop training
return loss
class MaskedL1Loss(nn.Module):
"""L1 loss with masking
:param patch_size: Patch size
:param stride: Stride of task / modality
:param norm_pix: Normalized pixel loss
"""
def __init__(self, patch_size: int = 16, stride: int = 1, norm_pix=False):
super().__init__()
self.patch_size = patch_size
self.stride = stride
self.scale_factor = patch_size // stride
self.norm_pix = norm_pix
def patchify(self, imgs, nh, nw):
p = self.scale_factor
x = rearrange(imgs, "b c (nh p1) (nw p2) -> b (nh nw) (p1 p2 c)", nh=nh, nw=nw, p1=p, p2=p)
return x
def unpatchify(self, x, nh, nw):
p = self.scale_factor
imgs = rearrange(x, "b (nh nw) (p1 p2 c) -> b c (nh p1) (nw p2)", nh=nh, nw=nw, p1=p, p2=p)
return imgs
def forward(self, input, target, mask=None):
H, W = input.shape[-2:]
nh, nw = H // self.scale_factor, W // self.scale_factor
if self.norm_pix:
target = self.patchify(target, nh, nw)
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
eps = 1e-6
target = (target - mean) / torch.sqrt(var + eps)
target = self.unpatchify(target, nh, nw)
loss = F.l1_loss(input, target, reduction='none')
if mask is not None:
if mask.sum() == 0:
return torch.tensor(0).to(loss.device)
# Resize mask and upsample
mask = rearrange(mask, "b (nh nw) -> b nh nw", nh=nh, nw=nw)
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H, W), mode='nearest').squeeze(1)
loss = loss.mean(dim=1) # B, C, H, W -> B, H, W
loss = loss * mask
# Compute mean per sample
loss = loss.flatten(start_dim=1).sum(dim=1) / mask.flatten(start_dim=1).sum(dim=1)
loss = loss.nanmean() # Account for zero masks
else:
loss = loss.mean() # If this is ever nan, we want it to stop training
return loss