Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
"""Modified from | |
https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/dice_loss.py#L333 | |
(Apache-2.0 License)""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..builder import LOSSES | |
from .utils import get_class_weight, weighted_loss | |
def tversky_loss(pred, | |
target, | |
valid_mask, | |
alpha=0.3, | |
beta=0.7, | |
smooth=1, | |
class_weight=None, | |
ignore_index=255): | |
assert pred.shape[0] == target.shape[0] | |
total_loss = 0 | |
num_classes = pred.shape[1] | |
for i in range(num_classes): | |
if i != ignore_index: | |
tversky_loss = binary_tversky_loss( | |
pred[:, i], | |
target[..., i], | |
valid_mask=valid_mask, | |
alpha=alpha, | |
beta=beta, | |
smooth=smooth) | |
if class_weight is not None: | |
tversky_loss *= class_weight[i] | |
total_loss += tversky_loss | |
return total_loss / num_classes | |
def binary_tversky_loss(pred, | |
target, | |
valid_mask, | |
alpha=0.3, | |
beta=0.7, | |
smooth=1): | |
assert pred.shape[0] == target.shape[0] | |
pred = pred.reshape(pred.shape[0], -1) | |
target = target.reshape(target.shape[0], -1) | |
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) | |
TP = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) | |
FP = torch.sum(torch.mul(pred, 1 - target) * valid_mask, dim=1) | |
FN = torch.sum(torch.mul(1 - pred, target) * valid_mask, dim=1) | |
tversky = (TP + smooth) / (TP + alpha * FP + beta * FN + smooth) | |
return 1 - tversky | |
class TverskyLoss(nn.Module): | |
"""TverskyLoss. This loss is proposed in `Tversky loss function for image | |
segmentation using 3D fully convolutional deep networks. | |
<https://arxiv.org/abs/1706.05721>`_. | |
Args: | |
smooth (float): A float number to smooth loss, and avoid NaN error. | |
Default: 1. | |
class_weight (list[float] | str, optional): Weight of each class. If in | |
str format, read them from a file. Defaults to None. | |
loss_weight (float, optional): Weight of the loss. Default to 1.0. | |
ignore_index (int | None): The label index to be ignored. Default: 255. | |
alpha(float, in [0, 1]): | |
The coefficient of false positives. Default: 0.3. | |
beta (float, in [0, 1]): | |
The coefficient of false negatives. Default: 0.7. | |
Note: alpha + beta = 1. | |
loss_name (str, optional): Name of the loss item. If you want this loss | |
item to be included into the backward graph, `loss_` must be the | |
prefix of the name. Defaults to 'loss_tversky'. | |
""" | |
def __init__(self, | |
smooth=1, | |
class_weight=None, | |
loss_weight=1.0, | |
ignore_index=255, | |
alpha=0.3, | |
beta=0.7, | |
loss_name='loss_tversky'): | |
super().__init__() | |
self.smooth = smooth | |
self.class_weight = get_class_weight(class_weight) | |
self.loss_weight = loss_weight | |
self.ignore_index = ignore_index | |
assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' | |
self.alpha = alpha | |
self.beta = beta | |
self._loss_name = loss_name | |
def forward(self, pred, target, **kwargs): | |
if self.class_weight is not None: | |
class_weight = pred.new_tensor(self.class_weight) | |
else: | |
class_weight = None | |
pred = F.softmax(pred, dim=1) | |
num_classes = pred.shape[1] | |
one_hot_target = F.one_hot( | |
torch.clamp(target.long(), 0, num_classes - 1), | |
num_classes=num_classes) | |
valid_mask = (target != self.ignore_index).long() | |
loss = self.loss_weight * tversky_loss( | |
pred, | |
one_hot_target, | |
valid_mask=valid_mask, | |
alpha=self.alpha, | |
beta=self.beta, | |
smooth=self.smooth, | |
class_weight=class_weight, | |
ignore_index=self.ignore_index) | |
return loss | |
def loss_name(self): | |
"""Loss Name. | |
This function must be implemented and will return the name of this | |
loss function. This name will be used to combine different loss items | |
by simple sum operation. In addition, if you want this loss item to be | |
included into the backward graph, `loss_` must be the prefix of the | |
name. | |
Returns: | |
str: The name of this loss item. | |
""" | |
return self._loss_name | |