# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import torch from mmengine.model import BaseModule from mmpretrain.registry import MODELS @MODELS.register_module() class PixelReconstructionLoss(BaseModule): """Loss for the reconstruction of pixel in Masked Image Modeling. This module measures the distance between the target image and the reconstructed image and compute the loss to optimize the model. Currently, This module only provides L1 and L2 loss to penalize the reconstructed error. In addition, a mask can be passed in the ``forward`` function to only apply loss on visible region, like that in MAE. Args: criterion (str): The loss the penalize the reconstructed error. Currently, only supports L1 and L2 loss channel (int, optional): The number of channels to average the reconstruction loss. If not None, the reconstruction loss will be divided by the channel. Defaults to None. """ def __init__(self, criterion: str, channel: Optional[int] = None) -> None: super().__init__() if criterion == 'L1': self.penalty = torch.nn.L1Loss(reduction='none') elif criterion == 'L2': self.penalty = torch.nn.MSELoss(reduction='none') else: raise NotImplementedError(f'Currently, PixelReconstructionLoss \ only supports L1 and L2 loss, but get {criterion}') self.channel = channel if channel is not None else 1 def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function to compute the reconstrction loss. Args: pred (torch.Tensor): The reconstructed image. target (torch.Tensor): The target image. mask (torch.Tensor): The mask of the target image. Returns: torch.Tensor: The reconstruction loss. """ loss = self.penalty(pred, target) # if the dim of the loss is 3, take the average of the loss # along the last dim if len(loss.shape) == 3: loss = loss.mean(dim=-1) if mask is None: loss = loss.mean() else: loss = (loss * mask).sum() / mask.sum() / self.channel return loss