"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ segmentron/solver/loss.py (Apache-2.0 License)""" import torch import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES from .utils import get_class_weight, weighted_loss @weighted_loss def dice_loss(pred, target, valid_mask, smooth=1, exponent=2, class_weight=None, ignore_index=255): assert pred.shape[0] == target.shape[0] total_loss = 0 num_classes = pred.shape[1] for i in range(num_classes): if i != ignore_index: dice_loss = binary_dice_loss( pred[:, i], target[..., i], valid_mask=valid_mask, smooth=smooth, exponent=exponent) if class_weight is not None: dice_loss *= class_weight[i] total_loss += dice_loss return total_loss / num_classes @weighted_loss def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): assert pred.shape[0] == target.shape[0] pred = pred.reshape(pred.shape[0], -1) target = target.reshape(target.shape[0], -1) valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth return 1 - num / den @LOSSES.register_module() class DiceLoss(nn.Module): """DiceLoss. This loss is proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. Args: loss_type (str, optional): Binary or multi-class loss. Default: 'multi_class'. Options are "binary" and "multi_class". smooth (float): A float number to smooth loss, and avoid NaN error. Default: 1 exponent (float): An float number to calculate denominator value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". This parameter only works when per_image is True. Default: 'mean'. class_weight (list[float] | str, optional): Weight of each class. If in str format, read them from a file. Defaults to None. loss_weight (float, optional): Weight of the loss. Default to 1.0. ignore_index (int | None): The label index to be ignored. Default: 255. """ def __init__(self, smooth=1, exponent=2, reduction='mean', class_weight=None, loss_weight=1.0, ignore_index=255, **kwards): super(DiceLoss, self).__init__() self.smooth = smooth self.exponent = exponent self.reduction = reduction self.class_weight = get_class_weight(class_weight) self.loss_weight = loss_weight self.ignore_index = ignore_index def forward(self, pred, target, avg_factor=None, reduction_override=None, **kwards): assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.class_weight is not None: class_weight = pred.new_tensor(self.class_weight) else: class_weight = None pred = F.softmax(pred, dim=1) num_classes = pred.shape[1] one_hot_target = F.one_hot( torch.clamp(target.long(), 0, num_classes - 1), num_classes=num_classes) valid_mask = (target != self.ignore_index).long() loss = self.loss_weight * dice_loss( pred, one_hot_target, valid_mask=valid_mask, reduction=reduction, avg_factor=avg_factor, smooth=self.smooth, exponent=self.exponent, class_weight=class_weight, ignore_index=self.ignore_index) return loss