import torch import torch.nn.functional as F from segmentation_models_pytorch.utils import base from segmentation_models_pytorch.base.modules import Activation class FocalLossFunction(base.Loss): def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs): super().__init__(**kwargs) self.activation = Activation(activation) self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): if inputs.shape[1] == 1: # Binary case inputs = torch.cat((inputs, 1 - inputs), dim=1) targets = torch.cat((targets, 1 - targets), dim=1) targets = torch.argmax(targets, dim=1) cross_entropy = F.cross_entropy(inputs, targets, reduction='none') probability = torch.exp(-cross_entropy) alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where( targets == 1, 1-self.alpha, self.alpha) focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy if self.reduction == 'mean': return focal_weight.mean() elif self.reduction == 'sum': return focal_weight.sum() return focal_weight class TverskyLossFunction(base.Loss): def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None, reduction='mean', **kwargs): super().__init__(**kwargs) self.activation = Activation(activation) self.alpha = alpha self.beta = beta self.ignore_channels = ignore_channels self.reduction = reduction def forward(self, inputs, targets): if self.ignore_channels is not None: mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) mask[self.ignore_channels] = False inputs = inputs[:, mask, ...] num_classes = inputs.shape[1] inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1 else F.softmax(inputs, dim=1)) if num_classes == 1: inputs_softmax = inputs_softmax.squeeze(1) targets = targets.squeeze(1) tversky_loss = 0 for class_idx in range(num_classes): if num_classes == 1: flat_inputs = inputs_softmax.reshape(-1) flat_targets = targets.reshape(-1) else: flat_inputs = inputs_softmax[:, class_idx].reshape(-1) flat_targets = targets[:, class_idx].reshape(-1) intersection = (flat_inputs * flat_targets).sum() fps = ((1 - flat_targets) * flat_inputs).sum() fns = (flat_targets * (1 - flat_inputs)).sum() tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10 tversky_loss += 1 - intersection / tversky_index if self.reduction == 'mean': return tversky_loss / (1 if num_classes == 1 else num_classes) elif self.reduction == 'sum': return tversky_loss return tversky_loss / inputs.shape[0] class EnhancedCrossEntropy(base.Loss): def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs): super().__init__(**kwargs) self.activation = Activation(activation) self.ignore_channels = ignore_channels self.reduction = reduction def forward(self, inputs, targets): inputs = self.activation(inputs) if self.ignore_channels is not None: mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) mask[self.ignore_channels] = False inputs = inputs[:, mask, ...] if targets.dim() == 4: # Convert one-hot to class indices targets = torch.argmax(targets, dim=1) return F.cross_entropy(inputs, targets, reduction=self.reduction)