Spaces:
Runtime error
Runtime error
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/ | |
segmentron/solver/loss.py (Apache-2.0 License)""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..builder import LOSSES | |
from .utils import get_class_weight, weighted_loss | |
def dice_loss(pred, | |
target, | |
valid_mask, | |
smooth=1, | |
exponent=2, | |
class_weight=None, | |
ignore_index=255): | |
assert pred.shape[0] == target.shape[0] | |
total_loss = 0 | |
num_classes = pred.shape[1] | |
for i in range(num_classes): | |
if i != ignore_index: | |
dice_loss = binary_dice_loss( | |
pred[:, i], | |
target[..., i], | |
valid_mask=valid_mask, | |
smooth=smooth, | |
exponent=exponent) | |
if class_weight is not None: | |
dice_loss *= class_weight[i] | |
total_loss += dice_loss | |
return total_loss / num_classes | |
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): | |
assert pred.shape[0] == target.shape[0] | |
pred = pred.reshape(pred.shape[0], -1) | |
target = target.reshape(target.shape[0], -1) | |
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) | |
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth | |
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth | |
return 1 - num / den | |
class DiceLoss(nn.Module): | |
"""DiceLoss. | |
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for | |
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_. | |
Args: | |
loss_type (str, optional): Binary or multi-class loss. | |
Default: 'multi_class'. Options are "binary" and "multi_class". | |
smooth (float): A float number to smooth loss, and avoid NaN error. | |
Default: 1 | |
exponent (float): An float number to calculate denominator | |
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2. | |
reduction (str, optional): The method used to reduce the loss. Options | |
are "none", "mean" and "sum". This parameter only works when | |
per_image is True. Default: 'mean'. | |
class_weight (list[float] | str, optional): Weight of each class. If in | |
str format, read them from a file. Defaults to None. | |
loss_weight (float, optional): Weight of the loss. Default to 1.0. | |
ignore_index (int | None): The label index to be ignored. Default: 255. | |
""" | |
def __init__(self, | |
smooth=1, | |
exponent=2, | |
reduction='mean', | |
class_weight=None, | |
loss_weight=1.0, | |
ignore_index=255, | |
**kwards): | |
super(DiceLoss, self).__init__() | |
self.smooth = smooth | |
self.exponent = exponent | |
self.reduction = reduction | |
self.class_weight = get_class_weight(class_weight) | |
self.loss_weight = loss_weight | |
self.ignore_index = ignore_index | |
def forward(self, | |
pred, | |
target, | |
avg_factor=None, | |
reduction_override=None, | |
**kwards): | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
if self.class_weight is not None: | |
class_weight = pred.new_tensor(self.class_weight) | |
else: | |
class_weight = None | |
pred = F.softmax(pred, dim=1) | |
num_classes = pred.shape[1] | |
one_hot_target = F.one_hot( | |
torch.clamp(target.long(), 0, num_classes - 1), | |
num_classes=num_classes) | |
valid_mask = (target != self.ignore_index).long() | |
loss = self.loss_weight * dice_loss( | |
pred, | |
one_hot_target, | |
valid_mask=valid_mask, | |
reduction=reduction, | |
avg_factor=avg_factor, | |
smooth=self.smooth, | |
exponent=self.exponent, | |
class_weight=class_weight, | |
ignore_index=self.ignore_index) | |
return loss | |