hasibzunair's picture
added files
4a285f6
raw history blame
No virus
3.17 kB
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : soft_dice_loss.py
@Time : 8/13/19 5:09 PM
@Desc :
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import print_function, division
import torch
import torch.nn.functional as F
from torch import nn
try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse as ifilterfalse
def tversky_loss(probas, labels, alpha=0.5, beta=0.5, epsilon=1e-6):
'''
Tversky loss function.
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
Same as soft dice loss when alpha=beta=0.5.
Same as Jaccord loss when alpha=beta=1.0.
See `Tversky loss function for image segmentation using 3D fully convolutional deep networks`
https://arxiv.org/pdf/1706.05721.pdf
'''
C = probas.size(1)
losses = []
for c in list(range(C)):
fg = (labels == c).float()
if fg.sum() == 0:
continue
class_pred = probas[:, c]
p0 = class_pred
p1 = 1 - class_pred
g0 = fg
g1 = 1 - fg
numerator = torch.sum(p0 * g0)
denominator = numerator + alpha * torch.sum(p0 * g1) + beta * torch.sum(p1 * g0)
losses.append(1 - ((numerator) / (denominator + epsilon)))
return mean(losses)
def flatten_probas(probas, labels, ignore=255):
"""
Flattens predictions in the batch
"""
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
def isnan(x):
return x != x
def mean(l, ignore_nan=False, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
if ignore_nan:
l = ifilterfalse(isnan, l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
class SoftDiceLoss(nn.Module):
def __init__(self, ignore_index=255):
super(SoftDiceLoss, self).__init__()
self.ignore_index = ignore_index
def forward(self, pred, label):
pred = F.softmax(pred, dim=1)
return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=0.5, beta=0.5)
class SoftJaccordLoss(nn.Module):
def __init__(self, ignore_index=255):
super(SoftJaccordLoss, self).__init__()
self.ignore_index = ignore_index
def forward(self, pred, label):
pred = F.softmax(pred, dim=1)
return tversky_loss(*flatten_probas(pred, label, ignore=self.ignore_index), alpha=1.0, beta=1.0)