Spaces:
Running
Running
#!/usr/bin/env python | |
# -*- encoding: utf-8 -*- | |
""" | |
@Author : Peike Li | |
@Contact : peike.li@yahoo.com | |
@File : lovasz_softmax.py | |
@Time : 8/30/19 7:12 PM | |
@Desc : Lovasz-Softmax and Jaccard hinge loss in PyTorch | |
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) | |
@License : This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from __future__ import print_function, division | |
import torch | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
import numpy as np | |
from torch import nn | |
try: | |
from itertools import ifilterfalse | |
except ImportError: # py3k | |
from itertools import filterfalse as ifilterfalse | |
def lovasz_grad(gt_sorted): | |
""" | |
Computes gradient of the Lovasz extension w.r.t sorted errors | |
See Alg. 1 in paper | |
""" | |
p = len(gt_sorted) | |
gts = gt_sorted.sum() | |
intersection = gts - gt_sorted.float().cumsum(0) | |
union = gts + (1 - gt_sorted).float().cumsum(0) | |
jaccard = 1. - intersection / union | |
if p > 1: # cover 1-pixel case | |
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] | |
return jaccard | |
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): | |
""" | |
IoU for foreground class | |
binary: 1 foreground, 0 background | |
""" | |
if not per_image: | |
preds, labels = (preds,), (labels,) | |
ious = [] | |
for pred, label in zip(preds, labels): | |
intersection = ((label == 1) & (pred == 1)).sum() | |
union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() | |
if not union: | |
iou = EMPTY | |
else: | |
iou = float(intersection) / float(union) | |
ious.append(iou) | |
iou = mean(ious) # mean accross images if per_image | |
return 100 * iou | |
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): | |
""" | |
Array of IoU for each (non ignored) class | |
""" | |
if not per_image: | |
preds, labels = (preds,), (labels,) | |
ious = [] | |
for pred, label in zip(preds, labels): | |
iou = [] | |
for i in range(C): | |
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) | |
intersection = ((label == i) & (pred == i)).sum() | |
union = ((label == i) | ((pred == i) & (label != ignore))).sum() | |
if not union: | |
iou.append(EMPTY) | |
else: | |
iou.append(float(intersection) / float(union)) | |
ious.append(iou) | |
ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image | |
return 100 * np.array(ious) | |
# --------------------------- BINARY LOSSES --------------------------- | |
def lovasz_hinge(logits, labels, per_image=True, ignore=None): | |
""" | |
Binary Lovasz hinge loss | |
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) | |
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) | |
per_image: compute the loss per image instead of per batch | |
ignore: void class id | |
""" | |
if per_image: | |
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) | |
for log, lab in zip(logits, labels)) | |
else: | |
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) | |
return loss | |
def lovasz_hinge_flat(logits, labels): | |
""" | |
Binary Lovasz hinge loss | |
logits: [P] Variable, logits at each prediction (between -\infty and +\infty) | |
labels: [P] Tensor, binary ground truth labels (0 or 1) | |
ignore: label to ignore | |
""" | |
if len(labels) == 0: | |
# only void pixels, the gradients should be 0 | |
return logits.sum() * 0. | |
signs = 2. * labels.float() - 1. | |
errors = (1. - logits * Variable(signs)) | |
errors_sorted, perm = torch.sort(errors, dim=0, descending=True) | |
perm = perm.data | |
gt_sorted = labels[perm] | |
grad = lovasz_grad(gt_sorted) | |
loss = torch.dot(F.relu(errors_sorted), Variable(grad)) | |
return loss | |
def flatten_binary_scores(scores, labels, ignore=None): | |
""" | |
Flattens predictions in the batch (binary case) | |
Remove labels equal to 'ignore' | |
""" | |
scores = scores.view(-1) | |
labels = labels.view(-1) | |
if ignore is None: | |
return scores, labels | |
valid = (labels != ignore) | |
vscores = scores[valid] | |
vlabels = labels[valid] | |
return vscores, vlabels | |
class StableBCELoss(torch.nn.modules.Module): | |
def __init__(self): | |
super(StableBCELoss, self).__init__() | |
def forward(self, input, target): | |
neg_abs = - input.abs() | |
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() | |
return loss.mean() | |
def binary_xloss(logits, labels, ignore=None): | |
""" | |
Binary Cross entropy loss | |
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) | |
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) | |
ignore: void class id | |
""" | |
logits, labels = flatten_binary_scores(logits, labels, ignore) | |
loss = StableBCELoss()(logits, Variable(labels.float())) | |
return loss | |
# --------------------------- MULTICLASS LOSSES --------------------------- | |
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255, weighted=None): | |
""" | |
Multi-class Lovasz-Softmax loss | |
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). | |
Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. | |
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) | |
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. | |
per_image: compute the loss per image instead of per batch | |
ignore: void class labels | |
""" | |
if per_image: | |
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes, weighted=weighted) | |
for prob, lab in zip(probas, labels)) | |
else: | |
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes, weighted=weighted ) | |
return loss | |
def lovasz_softmax_flat(probas, labels, classes='present', weighted=None): | |
""" | |
Multi-class Lovasz-Softmax loss | |
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) | |
labels: [P] Tensor, ground truth labels (between 0 and C - 1) | |
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. | |
""" | |
if probas.numel() == 0: | |
# only void pixels, the gradients should be 0 | |
return probas * 0. | |
C = probas.size(1) | |
losses = [] | |
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes | |
for c in class_to_sum: | |
fg = (labels == c).float() # foreground for class c | |
if (classes is 'present' and fg.sum() == 0): | |
continue | |
if C == 1: | |
if len(classes) > 1: | |
raise ValueError('Sigmoid output possible only with 1 class') | |
class_pred = probas[:, 0] | |
else: | |
class_pred = probas[:, c] | |
errors = (Variable(fg) - class_pred).abs() | |
errors_sorted, perm = torch.sort(errors, 0, descending=True) | |
perm = perm.data | |
fg_sorted = fg[perm] | |
if weighted is not None: | |
losses.append(weighted[c]*torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) | |
else: | |
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) | |
return mean(losses) | |
def flatten_probas(probas, labels, ignore=None): | |
""" | |
Flattens predictions in the batch | |
""" | |
if probas.dim() == 3: | |
# assumes output of a sigmoid layer | |
B, H, W = probas.size() | |
probas = probas.view(B, 1, H, W) | |
B, C, H, W = probas.size() | |
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C | |
labels = labels.view(-1) | |
if ignore is None: | |
return probas, labels | |
valid = (labels != ignore) | |
vprobas = probas[valid.nonzero().squeeze()] | |
vlabels = labels[valid] | |
return vprobas, vlabels | |
def xloss(logits, labels, ignore=None): | |
""" | |
Cross entropy loss | |
""" | |
return F.cross_entropy(logits, Variable(labels), ignore_index=255) | |
# --------------------------- HELPER FUNCTIONS --------------------------- | |
def isnan(x): | |
return x != x | |
def mean(l, ignore_nan=False, empty=0): | |
""" | |
nanmean compatible with generators. | |
""" | |
l = iter(l) | |
if ignore_nan: | |
l = ifilterfalse(isnan, l) | |
try: | |
n = 1 | |
acc = next(l) | |
except StopIteration: | |
if empty == 'raise': | |
raise ValueError('Empty mean') | |
return empty | |
for n, v in enumerate(l, 2): | |
acc += v | |
if n == 1: | |
return acc | |
return acc / n | |
# --------------------------- Class --------------------------- | |
class LovaszSoftmax(nn.Module): | |
def __init__(self, per_image=False, ignore_index=255, weighted=None): | |
super(LovaszSoftmax, self).__init__() | |
self.lovasz_softmax = lovasz_softmax | |
self.per_image = per_image | |
self.ignore_index=ignore_index | |
self.weighted = weighted | |
def forward(self, pred, label): | |
pred = F.softmax(pred, dim=1) | |
return self.lovasz_softmax(pred, label, per_image=self.per_image, ignore=self.ignore_index, weighted=self.weighted) |