bvishnu123's picture
setup
1212df0 verified
raw
history blame
No virus
3.7 kB
"""
Module for defining utilities for training such as the negative class sampler
and focal loss function.
"""
import numpy as np
from sklearn.metrics import (
precision_recall_fscore_support,
precision_recall_curve,
auc,
PrecisionRecallDisplay
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Sampler
def compute_metrics(targs, scores, preds, plot_pr_curve: bool = True):
precision, recall, f1, _ = precision_recall_fscore_support(targs, preds, average="binary")
prs, rcs, _ = precision_recall_curve(targs, scores)
if plot_pr_curve:
display = PrecisionRecallDisplay.from_predictions(
targs, scores, plot_chance_level=True
)
display.ax_.set_title("Precision-Recall curve of subsample")
display.figure_.show()
try:
pr_auc = auc(prs, rcs)
except ValueError:
print("Warning: curve is non-monotonic, returning None")
pr_auc = None
return {
'precision': precision,
'recall': recall,
'f1': f1,
'pr_auc': pr_auc
}
class FocalLoss(nn.Module):
def __init__(self, class_frequencies: torch.Tensor, gamma: int = 2):
super(FocalLoss, self).__init__()
self.alpha = (1 / class_frequencies).to(
torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.alpha = (self.alpha / self.alpha.sum())
self.gamma = gamma
def forward(self, inputs, targets):
alpha_targets = self.alpha[targets]
if inputs.data.type() != targets.data.type():
targets = targets.type_as(inputs.data)
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
loss = (alpha_targets * (1 - pt) ** self.gamma * ce_loss).mean()
return loss
class NegClassRandomSampler(Sampler):
"""
Dataloader Sampler that subsamples the negative class after each epoch.
The idea is that we want to keep the positive samples but select a random
subset of negative samples each epoch for a fresh set.
With the current settings, the sampling is done without replacement, and we
end up with a roughly 20% data imbalance, which should hopefully be more
manageable.
"""
def __init__(self, data_source, neg_class_ratio: float = 0.2, seed: int = 42):
self._random_gen = np.random.default_rng(seed)
self.data_source = data_source
self._neg_class_ratio = neg_class_ratio
# Get indices of the positive and negative cases
self._pos_indices = np.argwhere(np.array(data_source['labels']) == 1).flatten()
self._neg_indices = np.argwhere(np.array(data_source['labels']) == 0).flatten()
self._neg_num_samples = int(len(self._neg_indices) * neg_class_ratio)
self._pos_num_samples = len(self._pos_indices)
@property
def num_samples(self):
return self._pos_num_samples + self._neg_num_samples
def __iter__(self):
"""
Each time an iteration of this is requested, the resampling is done.
"""
_neg_samples = self._random_gen.choice(self._neg_indices, self._neg_num_samples, replace=False)
_samples = np.concatenate((_neg_samples, self._pos_indices), axis=0)
self._random_gen.shuffle(_samples)
if (len(_samples) != len(self)):
raise ValueError("Length of output samples (%d) does not match expected (%d)", len(_samples), len(self))
return iter(_samples.tolist())
def __len__(self):
return self.num_samples