|
""" |
|
Module for defining utilities for training such as the negative class sampler |
|
and focal loss function. |
|
""" |
|
|
|
import numpy as np |
|
from sklearn.metrics import ( |
|
precision_recall_fscore_support, |
|
precision_recall_curve, |
|
auc, |
|
PrecisionRecallDisplay |
|
) |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.data import Sampler |
|
|
|
|
|
def compute_metrics(targs, scores, preds, plot_pr_curve: bool = True): |
|
precision, recall, f1, _ = precision_recall_fscore_support(targs, preds, average="binary") |
|
prs, rcs, _ = precision_recall_curve(targs, scores) |
|
|
|
if plot_pr_curve: |
|
display = PrecisionRecallDisplay.from_predictions( |
|
targs, scores, plot_chance_level=True |
|
) |
|
display.ax_.set_title("Precision-Recall curve of subsample") |
|
display.figure_.show() |
|
|
|
try: |
|
pr_auc = auc(prs, rcs) |
|
except ValueError: |
|
print("Warning: curve is non-monotonic, returning None") |
|
pr_auc = None |
|
|
|
return { |
|
'precision': precision, |
|
'recall': recall, |
|
'f1': f1, |
|
'pr_auc': pr_auc |
|
} |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
def __init__(self, class_frequencies: torch.Tensor, gamma: int = 2): |
|
super(FocalLoss, self).__init__() |
|
self.alpha = (1 / class_frequencies).to( |
|
torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
) |
|
self.alpha = (self.alpha / self.alpha.sum()) |
|
self.gamma = gamma |
|
|
|
def forward(self, inputs, targets): |
|
alpha_targets = self.alpha[targets] |
|
if inputs.data.type() != targets.data.type(): |
|
targets = targets.type_as(inputs.data) |
|
if self.alpha.type() != inputs.data.type(): |
|
self.alpha = self.alpha.type_as(inputs.data) |
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
pt = torch.exp(-ce_loss) |
|
loss = (alpha_targets * (1 - pt) ** self.gamma * ce_loss).mean() |
|
return loss |
|
|
|
|
|
class NegClassRandomSampler(Sampler): |
|
""" |
|
Dataloader Sampler that subsamples the negative class after each epoch. |
|
The idea is that we want to keep the positive samples but select a random |
|
subset of negative samples each epoch for a fresh set. |
|
|
|
With the current settings, the sampling is done without replacement, and we |
|
end up with a roughly 20% data imbalance, which should hopefully be more |
|
manageable. |
|
""" |
|
|
|
def __init__(self, data_source, neg_class_ratio: float = 0.2, seed: int = 42): |
|
self._random_gen = np.random.default_rng(seed) |
|
self.data_source = data_source |
|
self._neg_class_ratio = neg_class_ratio |
|
|
|
|
|
self._pos_indices = np.argwhere(np.array(data_source['labels']) == 1).flatten() |
|
self._neg_indices = np.argwhere(np.array(data_source['labels']) == 0).flatten() |
|
self._neg_num_samples = int(len(self._neg_indices) * neg_class_ratio) |
|
self._pos_num_samples = len(self._pos_indices) |
|
|
|
@property |
|
def num_samples(self): |
|
return self._pos_num_samples + self._neg_num_samples |
|
|
|
def __iter__(self): |
|
""" |
|
Each time an iteration of this is requested, the resampling is done. |
|
""" |
|
_neg_samples = self._random_gen.choice(self._neg_indices, self._neg_num_samples, replace=False) |
|
_samples = np.concatenate((_neg_samples, self._pos_indices), axis=0) |
|
self._random_gen.shuffle(_samples) |
|
if (len(_samples) != len(self)): |
|
raise ValueError("Length of output samples (%d) does not match expected (%d)", len(_samples), len(self)) |
|
return iter(_samples.tolist()) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
|