obichimav's picture
Upload 42 files
8e5d8c7 verified
import torch
import torch.nn.functional as F
from segmentation_models_pytorch.utils import base
from segmentation_models_pytorch.base.modules import Activation
class FocalLossFunction(base.Loss):
def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs):
super().__init__(**kwargs)
self.activation = Activation(activation)
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
if inputs.shape[1] == 1: # Binary case
inputs = torch.cat((inputs, 1 - inputs), dim=1)
targets = torch.cat((targets, 1 - targets), dim=1)
targets = torch.argmax(targets, dim=1)
cross_entropy = F.cross_entropy(inputs, targets, reduction='none')
probability = torch.exp(-cross_entropy)
alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where(
targets == 1, 1-self.alpha, self.alpha)
focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy
if self.reduction == 'mean':
return focal_weight.mean()
elif self.reduction == 'sum':
return focal_weight.sum()
return focal_weight
class TverskyLossFunction(base.Loss):
def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None,
reduction='mean', **kwargs):
super().__init__(**kwargs)
self.activation = Activation(activation)
self.alpha = alpha
self.beta = beta
self.ignore_channels = ignore_channels
self.reduction = reduction
def forward(self, inputs, targets):
if self.ignore_channels is not None:
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
mask[self.ignore_channels] = False
inputs = inputs[:, mask, ...]
num_classes = inputs.shape[1]
inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1
else F.softmax(inputs, dim=1))
if num_classes == 1:
inputs_softmax = inputs_softmax.squeeze(1)
targets = targets.squeeze(1)
tversky_loss = 0
for class_idx in range(num_classes):
if num_classes == 1:
flat_inputs = inputs_softmax.reshape(-1)
flat_targets = targets.reshape(-1)
else:
flat_inputs = inputs_softmax[:, class_idx].reshape(-1)
flat_targets = targets[:, class_idx].reshape(-1)
intersection = (flat_inputs * flat_targets).sum()
fps = ((1 - flat_targets) * flat_inputs).sum()
fns = (flat_targets * (1 - flat_inputs)).sum()
tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10
tversky_loss += 1 - intersection / tversky_index
if self.reduction == 'mean':
return tversky_loss / (1 if num_classes == 1 else num_classes)
elif self.reduction == 'sum':
return tversky_loss
return tversky_loss / inputs.shape[0]
class EnhancedCrossEntropy(base.Loss):
def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs):
super().__init__(**kwargs)
self.activation = Activation(activation)
self.ignore_channels = ignore_channels
self.reduction = reduction
def forward(self, inputs, targets):
inputs = self.activation(inputs)
if self.ignore_channels is not None:
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device)
mask[self.ignore_channels] = False
inputs = inputs[:, mask, ...]
if targets.dim() == 4: # Convert one-hot to class indices
targets = torch.argmax(targets, dim=1)
return F.cross_entropy(inputs, targets, reduction=self.reduction)