#!/usr/bin/env python # -*- encoding: utf-8 -*- """ @Author : Peike Li @Contact : peike.li@yahoo.com @File : soft_dice_loss.py @Time : 8/13/19 5:09 PM @Desc : @License : This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ from __future__ import print_function, division import torch import torch.nn.functional as F from torch import nn try: from itertools import ifilterfalse except ImportError: # py3k from itertools import filterfalse as ifilterfalse def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6): ''' Tversky loss function. probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) Same as soft dice loss when alpha=beta=0.5. Same as Jaccord loss when alpha=beta=1.0. See `Tversky loss function for image segmentation using 3D fully convolutional deep networks` https://arxiv.org/pdf/1706.05721.pdf ''' C = probas.size(1) losses = [] for c in list(range(C)): fg = (labels == c).float() if fg.sum() == 0: continue class_pred = probas[:, c] p0 = class_pred p1 = 1 - class_pred g0 = fg g1 = 1 - fg numerator = torch.sum(p0 * g0) denominator = numerator + alpha * torch.sum(p0 * g1) + beta * torch.sum(p1 * g0) losses.append(1 - ((numerator) / (denominator + epsilon))) return mean(losses) def flatten_probas(probas, labels, ignore=255): """ Flattens predictions in the batch """ B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels def isnan(x): return x != x def mean(l, ignore_nan=False, empty=0): """ nanmean compatible with generators. """ l = iter(l) if ignore_nan: l = ifilterfalse(isnan, l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n class SoftDiceLoss(nn.Module): def __init__(self, ignore_index=255): super(SoftDiceLoss, self).__init__() self.ignore_index = ignore_index def forward(self, pred, label): pred = F.softmax(pred, dim=1) return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=0.5, beta=0.5) class SoftJaccordLoss(nn.Module): def __init__(self, ignore_index=255): super(SoftJaccordLoss, self).__init__() self.ignore_index = ignore_index def forward(self, pred, label): pred = F.softmax(pred, dim=1) return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=1.0, beta=1.0)