SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/custom_losses-checkpoint.py
import torch | |
import torch.nn.functional as F | |
from segmentation_models_pytorch.utils import base | |
from segmentation_models_pytorch.base.modules import Activation | |
class FocalLossFunction(base.Loss): | |
def __init__(self, activation=None, alpha=0.25, gamma=1.5, reduction='mean', **kwargs): | |
super().__init__(**kwargs) | |
self.activation = Activation(activation) | |
self.alpha = alpha | |
self.gamma = gamma | |
self.reduction = reduction | |
def forward(self, inputs, targets): | |
if inputs.shape[1] == 1: # Binary case | |
inputs = torch.cat((inputs, 1 - inputs), dim=1) | |
targets = torch.cat((targets, 1 - targets), dim=1) | |
targets = torch.argmax(targets, dim=1) | |
cross_entropy = F.cross_entropy(inputs, targets, reduction='none') | |
probability = torch.exp(-cross_entropy) | |
alpha_factor = self.alpha if inputs.shape[1] > 1 else torch.where( | |
targets == 1, 1-self.alpha, self.alpha) | |
focal_weight = alpha_factor * (1 - probability) ** self.gamma * cross_entropy | |
if self.reduction == 'mean': | |
return focal_weight.mean() | |
elif self.reduction == 'sum': | |
return focal_weight.sum() | |
return focal_weight | |
class TverskyLossFunction(base.Loss): | |
def __init__(self, activation=None, alpha=0.5, beta=0.5, ignore_channels=None, | |
reduction='mean', **kwargs): | |
super().__init__(**kwargs) | |
self.activation = Activation(activation) | |
self.alpha = alpha | |
self.beta = beta | |
self.ignore_channels = ignore_channels | |
self.reduction = reduction | |
def forward(self, inputs, targets): | |
if self.ignore_channels is not None: | |
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) | |
mask[self.ignore_channels] = False | |
inputs = inputs[:, mask, ...] | |
num_classes = inputs.shape[1] | |
inputs_softmax = (torch.sigmoid(inputs) if num_classes == 1 | |
else F.softmax(inputs, dim=1)) | |
if num_classes == 1: | |
inputs_softmax = inputs_softmax.squeeze(1) | |
targets = targets.squeeze(1) | |
tversky_loss = 0 | |
for class_idx in range(num_classes): | |
if num_classes == 1: | |
flat_inputs = inputs_softmax.reshape(-1) | |
flat_targets = targets.reshape(-1) | |
else: | |
flat_inputs = inputs_softmax[:, class_idx].reshape(-1) | |
flat_targets = targets[:, class_idx].reshape(-1) | |
intersection = (flat_inputs * flat_targets).sum() | |
fps = ((1 - flat_targets) * flat_inputs).sum() | |
fns = (flat_targets * (1 - flat_inputs)).sum() | |
tversky_index = intersection + self.alpha * fps + self.beta * fns + 1e-10 | |
tversky_loss += 1 - intersection / tversky_index | |
if self.reduction == 'mean': | |
return tversky_loss / (1 if num_classes == 1 else num_classes) | |
elif self.reduction == 'sum': | |
return tversky_loss | |
return tversky_loss / inputs.shape[0] | |
class EnhancedCrossEntropy(base.Loss): | |
def __init__(self, activation=None, ignore_channels=None, reduction='mean', **kwargs): | |
super().__init__(**kwargs) | |
self.activation = Activation(activation) | |
self.ignore_channels = ignore_channels | |
self.reduction = reduction | |
def forward(self, inputs, targets): | |
inputs = self.activation(inputs) | |
if self.ignore_channels is not None: | |
mask = torch.ones(inputs.shape[1], dtype=torch.bool, device=inputs.device) | |
mask[self.ignore_channels] = False | |
inputs = inputs[:, mask, ...] | |
if targets.dim() == 4: # Convert one-hot to class indices | |
targets = torch.argmax(targets, dim=1) | |
return F.cross_entropy(inputs, targets, reduction=self.reduction) |