bvishnu123's picture
setup
1212df0 verified
"""
Module for defining utilities for training such as the negative class sampler
and focal loss function.
"""
import numpy as np
from sklearn.metrics import (
precision_recall_fscore_support,
precision_recall_curve,
auc,
PrecisionRecallDisplay
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Sampler
def compute_metrics(targs, scores, preds, plot_pr_curve: bool = True):
precision, recall, f1, _ = precision_recall_fscore_support(targs, preds, average="binary")
prs, rcs, _ = precision_recall_curve(targs, scores)
if plot_pr_curve:
display = PrecisionRecallDisplay.from_predictions(
targs, scores, plot_chance_level=True
)
display.ax_.set_title("Precision-Recall curve of subsample")
display.figure_.show()
try:
pr_auc = auc(prs, rcs)
except ValueError:
print("Warning: curve is non-monotonic, returning None")
pr_auc = None
return {
'precision': precision,
'recall': recall,
'f1': f1,
'pr_auc': pr_auc
}
class FocalLoss(nn.Module):
def __init__(self, class_frequencies: torch.Tensor, gamma: int = 2):
super(FocalLoss, self).__init__()
self.alpha = (1 / class_frequencies).to(
torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.alpha = (self.alpha / self.alpha.sum())
self.gamma = gamma
def forward(self, inputs, targets):
alpha_targets = self.alpha[targets]
if inputs.data.type() != targets.data.type():
targets = targets.type_as(inputs.data)
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
loss = (alpha_targets * (1 - pt) ** self.gamma * ce_loss).mean()
return loss
class NegClassRandomSampler(Sampler):
"""
Dataloader Sampler that subsamples the negative class after each epoch.
The idea is that we want to keep the positive samples but select a random
subset of negative samples each epoch for a fresh set.
With the current settings, the sampling is done without replacement, and we
end up with a roughly 20% data imbalance, which should hopefully be more
manageable.
"""
def __init__(self, data_source, neg_class_ratio: float = 0.2, seed: int = 42):
self._random_gen = np.random.default_rng(seed)
self.data_source = data_source
self._neg_class_ratio = neg_class_ratio
# Get indices of the positive and negative cases
self._pos_indices = np.argwhere(np.array(data_source['labels']) == 1).flatten()
self._neg_indices = np.argwhere(np.array(data_source['labels']) == 0).flatten()
self._neg_num_samples = int(len(self._neg_indices) * neg_class_ratio)
self._pos_num_samples = len(self._pos_indices)
@property
def num_samples(self):
return self._pos_num_samples + self._neg_num_samples
def __iter__(self):
"""
Each time an iteration of this is requested, the resampling is done.
"""
_neg_samples = self._random_gen.choice(self._neg_indices, self._neg_num_samples, replace=False)
_samples = np.concatenate((_neg_samples, self._pos_indices), axis=0)
self._random_gen.shuffle(_samples)
if (len(_samples) != len(self)):
raise ValueError("Length of output samples (%d) does not match expected (%d)", len(_samples), len(self))
return iter(_samples.tolist())
def __len__(self):
return self.num_samples