Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class PixelReconstructionLoss(BaseModule): | |
"""Loss for the reconstruction of pixel in Masked Image Modeling. | |
This module measures the distance between the target image and the | |
reconstructed image and compute the loss to optimize the model. Currently, | |
This module only provides L1 and L2 loss to penalize the reconstructed | |
error. In addition, a mask can be passed in the ``forward`` function to | |
only apply loss on visible region, like that in MAE. | |
Args: | |
criterion (str): The loss the penalize the reconstructed error. | |
Currently, only supports L1 and L2 loss | |
channel (int, optional): The number of channels to average the | |
reconstruction loss. If not None, the reconstruction loss | |
will be divided by the channel. Defaults to None. | |
""" | |
def __init__(self, criterion: str, channel: Optional[int] = None) -> None: | |
super().__init__() | |
if criterion == 'L1': | |
self.penalty = torch.nn.L1Loss(reduction='none') | |
elif criterion == 'L2': | |
self.penalty = torch.nn.MSELoss(reduction='none') | |
else: | |
raise NotImplementedError(f'Currently, PixelReconstructionLoss \ | |
only supports L1 and L2 loss, but get {criterion}') | |
self.channel = channel if channel is not None else 1 | |
def forward(self, | |
pred: torch.Tensor, | |
target: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Forward function to compute the reconstrction loss. | |
Args: | |
pred (torch.Tensor): The reconstructed image. | |
target (torch.Tensor): The target image. | |
mask (torch.Tensor): The mask of the target image. | |
Returns: | |
torch.Tensor: The reconstruction loss. | |
""" | |
loss = self.penalty(pred, target) | |
# if the dim of the loss is 3, take the average of the loss | |
# along the last dim | |
if len(loss.shape) == 3: | |
loss = loss.mean(dim=-1) | |
if mask is None: | |
loss = loss.mean() | |
else: | |
loss = (loss * mask).sum() / mask.sum() / self.channel | |
return loss | |