SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/custom_losses-checkpoint.py
| import torch | |
| import torch.nn.functional as F | |
| from segmentation_models_pytorch.utils import base | |
| from segmentation_models_pytorch.base.modules import Activation | |
| class FocalLossFunction(base.Loss): | |
| def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs): | |
| super().__init__(**kwargs) | |
| self.activation = Activation(activation) | |
| self.alpha = alpha | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| if inputs.shape[1] == 1: # Binary case | |
| inputs = torch.cat((inputs, 1 - inputs), dim=1) | |
| targets = torch.cat((targets, 1 - targets), dim=1) | |
| targets = torch.argmax(targets, dim=1) | |
| cross_entropy = F.cross_entropy(inputs, targets, reduction='none') | |
| probability = torch.exp(-cross_entropy) | |
| alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where( | |
| targets == 1, 1-self.alpha, self.alpha) | |
| focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy | |
| if self.reduction == 'mean': | |
| return focal_weight.mean() | |
| elif self.reduction == 'sum': | |
| return focal_weight.sum() | |
| return focal_weight | |
| class TverskyLossFunction(base.Loss): | |
| def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None, | |
| reduction='mean', **kwargs): | |
| super().__init__(**kwargs) | |
| self.activation = Activation(activation) | |
| self.alpha = alpha | |
| self.beta = beta | |
| self.ignore_channels = ignore_channels | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| if self.ignore_channels is not None: | |
| mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) | |
| mask[self.ignore_channels] = False | |
| inputs = inputs[:, mask, ...] | |
| num_classes = inputs.shape[1] | |
| inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1 | |
| else F.softmax(inputs, dim=1)) | |
| if num_classes == 1: | |
| inputs_softmax = inputs_softmax.squeeze(1) | |
| targets = targets.squeeze(1) | |
| tversky_loss = 0 | |
| for class_idx in range(num_classes): | |
| if num_classes == 1: | |
| flat_inputs = inputs_softmax.reshape(-1) | |
| flat_targets = targets.reshape(-1) | |
| else: | |
| flat_inputs = inputs_softmax[:, class_idx].reshape(-1) | |
| flat_targets = targets[:, class_idx].reshape(-1) | |
| intersection = (flat_inputs * flat_targets).sum() | |
| fps = ((1 - flat_targets) * flat_inputs).sum() | |
| fns = (flat_targets * (1 - flat_inputs)).sum() | |
| tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10 | |
| tversky_loss += 1 - intersection / tversky_index | |
| if self.reduction == 'mean': | |
| return tversky_loss / (1 if num_classes == 1 else num_classes) | |
| elif self.reduction == 'sum': | |
| return tversky_loss | |
| return tversky_loss / inputs.shape[0] | |
| class EnhancedCrossEntropy(base.Loss): | |
| def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs): | |
| super().__init__(**kwargs) | |
| self.activation = Activation(activation) | |
| self.ignore_channels = ignore_channels | |
| self.reduction = reduction | |
| def forward(self, inputs, targets): | |
| inputs = self.activation(inputs) | |
| if self.ignore_channels is not None: | |
| mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) | |
| mask[self.ignore_channels] = False | |
| inputs = inputs[:, mask, ...] | |
| if targets.dim() == 4: # Convert one-hot to class indices | |
| targets = torch.argmax(targets, dim=1) | |
| return F.cross_entropy(inputs, targets, reduction=self.reduction) |