from functools import partial from typing import Union import numpy as np import torch import torch.nn as nn from torchvision.ops import sigmoid_focal_loss class FocalLoss(nn.Module): """ Focal Loss implementation. This class defines the Focal Loss, which is a variant of the Binary Cross Entropy (BCE) loss that is designed to address the problem of class imbalance in binary classification tasks. The Focal Loss introduces two hyperparameters, alpha and gamma, to control the balance between easy and hard examples during training. :param alpha: The balancing parameter between positive and negative examples. A float value between 0 and 1. If set to -1, no balancing is applied. Default is 0.25. :type alpha: float :param gamma: The focusing parameter to control the emphasis on hard examples. A positive integer. Default is 2. :type gamma: int """ def __init__(self, alpha: float = 0.25, gamma: int = 2): super().__init__() self.loss_fn = partial(sigmoid_focal_loss, alpha=alpha, gamma=gamma, reduction="mean") def forward(self, inputs, targets): """ Compute the Focal Loss. :param inputs: The predicted inputs from the model. :type inputs: torch.Tensor :param targets: The ground truth targets. :type targets: torch.Tensor :return: The computed Focal Loss. :rtype: torch.Tensor :raises ValueError: If the inputs and targets have different shapes. """ return self.loss_fn(inputs=inputs, targets=targets) class HardDistillationLoss(nn.Module): """Hard Distillation Loss implementation. This class defines the Hard Distillation Loss, which is used for model distillation, a technique used to transfer knowledge from a large, complex teacher model to a smaller, simpler student model. The Hard Distillation Loss computes the loss by comparing the outputs of the student model and the teacher model using a provided loss function. It also introduces a threshold parameter to convert the teacher model outputs to binary labels for the distillation process. :param teacher: The teacher model used for distillation. :type teacher: torch.nn.Module :param loss_fn: The loss function used for computing the distillation loss. :type loss_fn: torch.nn.Module :param threshold: The threshold value used to convert teacher model outputs to binary labels. Can be a list or numpy array of threshold values. :type threshold: Union[list, np.array] :param device: The device to be used for computation. Default is "cuda". :type device: str """ def __init__(self, teacher: nn.Module, loss_fn: nn.Module, threshold: Union[list, np.array], device: str = "cuda"): super().__init__() self.teacher = teacher self.loss_fn = loss_fn self.threshold = torch.tensor(threshold).to(device) def forward(self, inputs, student_outputs, targets): """ Compute the Hard Distillation Loss. :param inputs: The input data fed to the student model. :type inputs: torch.Tensor :param student_outputs: The output predictions from the student model, which consists of both classification and distillation outputs. :type student_outputs: tuple :param targets: The ground truth targets. :type targets: torch.Tensor :return: The computed Hard Distillation Loss. :rtype: torch.Tensor :raises ValueError: If the inputs and targets have different shapes. """ outputs_cls, outputs_dist = student_outputs teacher_outputs = torch.sigmoid(self.teacher(inputs)) teacher_labels = (teacher_outputs > self.threshold).float() base_loss = self.loss_fn(outputs_cls, targets) teacher_loss = self.loss_fn(outputs_dist, teacher_labels) return (base_loss + teacher_loss) / 2