""" Module for defining utilities for training such as the negative class sampler and focal loss function. """ import numpy as np from sklearn.metrics import ( precision_recall_fscore_support, precision_recall_curve, auc, PrecisionRecallDisplay ) import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Sampler def compute_metrics(targs, scores, preds, plot_pr_curve: bool = True): precision, recall, f1, _ = precision_recall_fscore_support(targs, preds, average="binary") prs, rcs, _ = precision_recall_curve(targs, scores) if plot_pr_curve: display = PrecisionRecallDisplay.from_predictions( targs, scores, plot_chance_level=True ) display.ax_.set_title("Precision-Recall curve of subsample") display.figure_.show() try: pr_auc = auc(prs, rcs) except ValueError: print("Warning: curve is non-monotonic, returning None") pr_auc = None return { 'precision': precision, 'recall': recall, 'f1': f1, 'pr_auc': pr_auc } class FocalLoss(nn.Module): def __init__(self, class_frequencies: torch.Tensor, gamma: int = 2): super(FocalLoss, self).__init__() self.alpha = (1 / class_frequencies).to( torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.alpha = (self.alpha / self.alpha.sum()) self.gamma = gamma def forward(self, inputs, targets): alpha_targets = self.alpha[targets] if inputs.data.type() != targets.data.type(): targets = targets.type_as(inputs.data) if self.alpha.type() != inputs.data.type(): self.alpha = self.alpha.type_as(inputs.data) ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) loss = (alpha_targets * (1 - pt) ** self.gamma * ce_loss).mean() return loss class NegClassRandomSampler(Sampler): """ Dataloader Sampler that subsamples the negative class after each epoch. The idea is that we want to keep the positive samples but select a random subset of negative samples each epoch for a fresh set. With the current settings, the sampling is done without replacement, and we end up with a roughly 20% data imbalance, which should hopefully be more manageable. """ def __init__(self, data_source, neg_class_ratio: float = 0.2, seed: int = 42): self._random_gen = np.random.default_rng(seed) self.data_source = data_source self._neg_class_ratio = neg_class_ratio # Get indices of the positive and negative cases self._pos_indices = np.argwhere(np.array(data_source['labels']) == 1).flatten() self._neg_indices = np.argwhere(np.array(data_source['labels']) == 0).flatten() self._neg_num_samples = int(len(self._neg_indices) * neg_class_ratio) self._pos_num_samples = len(self._pos_indices) @property def num_samples(self): return self._pos_num_samples + self._neg_num_samples def __iter__(self): """ Each time an iteration of this is requested, the resampling is done. """ _neg_samples = self._random_gen.choice(self._neg_indices, self._neg_num_samples, replace=False) _samples = np.concatenate((_neg_samples, self._pos_indices), axis=0) self._random_gen.shuffle(_samples) if (len(_samples) != len(self)): raise ValueError("Length of output samples (%d) does not match expected (%d)", len(_samples), len(self)) return iter(_samples.tolist()) def __len__(self): return self.num_samples