import torch import torch.nn as nn from opencd.registry import MODELS def bcl_loss( pred, target, margin=2.0, eps=1e-4, ignore_index=255, **kwargs): pred = pred.squeeze() target = target.squeeze() assert pred.size() == target.size() and target.numel() > 0 mask = (target != ignore_index).float() target = target * mask utarget = 1 - target n_u = utarget.sum() + eps n_c = target.sum() + eps loss = torch.sum(utarget * torch.pow(pred, 2) * mask) / n_u + \ torch.sum(target * torch.pow(torch.clamp(margin - pred, min=0.), 2)) / n_c return loss @MODELS.register_module() class BCLLoss(nn.Module): """Batch-balanced Contrastive Loss""" def __init__( self, margin=2.0, loss_weight=1.0, ignore_index=255, loss_name='bcl_loss', **kwargs): super().__init__() self.margin = margin self.loss_weight = loss_weight self.ignore_index = ignore_index self._loss_name = loss_name def forward(self, pred, target, **kwargs): loss = self.loss_weight * bcl_loss( pred, target, self.margin, self.ignore_index) return loss @property def loss_name(self): """Loss Name. This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Returns: str: The name of this loss item. """ return self._loss_name