|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropy(nn.Module): |
|
|
def __init__(self, ignore_label=-1, weight=None): |
|
|
super(CrossEntropy, self).__init__() |
|
|
self.ignore_label = ignore_label |
|
|
self.criterion = nn.CrossEntropyLoss(weight=weight, |
|
|
ignore_index=ignore_label) |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
loss = self.criterion(score, target) |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
class DiceLoss(nn.Module): |
|
|
def __init__(self, ignore_label=-1, smooth=1, exponent=2): |
|
|
super(DiceLoss, self).__init__() |
|
|
self.ignore_index = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
|
|
|
def dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
total_loss = 0 |
|
|
num_classes = pred.shape[1] |
|
|
for i in range(num_classes): |
|
|
if i != self.ignore_index: |
|
|
dice_loss = self.binary_dice_loss( |
|
|
pred[:, i], |
|
|
target[..., i], |
|
|
valid_mask=valid_mask,) |
|
|
total_loss += dice_loss |
|
|
return total_loss / num_classes |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
|
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_index).long() |
|
|
|
|
|
loss = self.dice_loss(score, one_hot_target, valid_mask) |
|
|
return loss |
|
|
|
|
|
|
|
|
class BinaryDiceLoss(nn.Module): |
|
|
def __init__(self, smooth=1, exponent=2, ignore_label=-1): |
|
|
super(BinaryDiceLoss, self).__init__() |
|
|
self.ignore_index = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
|
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_index).long() |
|
|
|
|
|
loss = self.binary_dice_loss( |
|
|
score[:, 1], |
|
|
one_hot_target[..., 1], |
|
|
valid_mask) |
|
|
return loss |
|
|
|
|
|
|
|
|
class DiceEntropyLoss(nn.Module): |
|
|
def __init__(self, smooth=1, exponent=2, ignore_label=-1, weight=None): |
|
|
super(DiceEntropyLoss, self).__init__() |
|
|
self.ignore_label = ignore_label |
|
|
self.smooth = smooth |
|
|
self.exponent = exponent |
|
|
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, |
|
|
ignore_index=ignore_label) |
|
|
|
|
|
def binary_dice_loss(self, pred, target, valid_mask): |
|
|
print(pred.shape, target.shape, "this is second list") |
|
|
assert pred.shape[0] == target.shape[0] |
|
|
pred = pred.reshape(pred.shape[0], -1) |
|
|
target = target.reshape(target.shape[0], -1) |
|
|
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1) |
|
|
|
|
|
|
|
|
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + self.smooth |
|
|
den = torch.sum(pred.pow(self.exponent)*valid_mask + target.pow(self.exponent)*valid_mask, dim=1) + max(self.smooth, 1e-5) |
|
|
|
|
|
dice = num / den |
|
|
dice = torch.mean(dice) |
|
|
return 1 - dice |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
|
|
|
|
|
|
|
|
|
CE_loss = self.cross_entropy(score, target) |
|
|
|
|
|
|
|
|
score = F.softmax(score,dim=1) |
|
|
num_classes = score.shape[1] |
|
|
|
|
|
one_hot_target = F.one_hot( |
|
|
torch.clamp(target.long(), 0, num_classes - 1), |
|
|
num_classes=num_classes) |
|
|
valid_mask = (target != self.ignore_label).long() |
|
|
|
|
|
dice_loss = self.binary_dice_loss( |
|
|
score[:, 1], |
|
|
one_hot_target[:,1][..., 1], |
|
|
valid_mask) |
|
|
|
|
|
return 0.3*CE_loss + 0.7*dice_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
|
|
|
def __init__(self, alpha=0.25, gamma=2., ignore_label=-1): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha=alpha |
|
|
self.gamma= gamma |
|
|
self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label, reduction="none") |
|
|
|
|
|
def forward(self, score, target): |
|
|
ph, pw = score.size(2), score.size(3) |
|
|
h, w = target.size(1), target.size(2) |
|
|
if ph != h or pw != w: |
|
|
score = F.upsample( |
|
|
input=score, size=(h, w), mode='bilinear') |
|
|
|
|
|
ce_loss = self.criterion(score, target) |
|
|
pt = torch.exp(-ce_loss) |
|
|
f_loss = self.alpha * (1-pt)**self.gamma * ce_loss |
|
|
return f_loss.mean() |
|
|
|
|
|
|