File size: 3,700 Bytes
1212df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Module for defining utilities for training such as the negative class sampler
and focal loss function.
"""

import numpy as np
from sklearn.metrics import (
    precision_recall_fscore_support,
    precision_recall_curve,
    auc,
    PrecisionRecallDisplay
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Sampler


def compute_metrics(targs, scores, preds, plot_pr_curve: bool = True):
    precision, recall, f1, _ = precision_recall_fscore_support(targs, preds, average="binary")
    prs, rcs, _ = precision_recall_curve(targs, scores)

    if plot_pr_curve:
        display = PrecisionRecallDisplay.from_predictions(
            targs, scores, plot_chance_level=True
        )
        display.ax_.set_title("Precision-Recall curve of subsample")
        display.figure_.show()

    try:
        pr_auc = auc(prs, rcs)
    except ValueError:
        print("Warning: curve is non-monotonic, returning None")
        pr_auc = None

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'pr_auc': pr_auc
    }


class FocalLoss(nn.Module):
    def __init__(self, class_frequencies: torch.Tensor, gamma: int = 2):
        super(FocalLoss, self).__init__()
        self.alpha = (1 / class_frequencies).to(
            torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.alpha = (self.alpha / self.alpha.sum())
        self.gamma = gamma

    def forward(self, inputs, targets):
        alpha_targets = self.alpha[targets]
        if inputs.data.type() != targets.data.type():
            targets = targets.type_as(inputs.data)
        if self.alpha.type() != inputs.data.type():
            self.alpha = self.alpha.type_as(inputs.data)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = (alpha_targets * (1 - pt) ** self.gamma * ce_loss).mean()
        return loss


class NegClassRandomSampler(Sampler):
    """
    Dataloader Sampler that subsamples the negative class after each epoch.
    The idea is that we want to keep the positive samples but select a random
    subset of negative samples each epoch for a fresh set.

    With the current settings, the sampling is done without replacement, and we
    end up with a roughly 20% data imbalance, which should hopefully be more
    manageable.
    """

    def __init__(self, data_source, neg_class_ratio: float = 0.2, seed: int = 42):
        self._random_gen = np.random.default_rng(seed)
        self.data_source = data_source
        self._neg_class_ratio = neg_class_ratio

        # Get indices of the positive and negative cases
        self._pos_indices = np.argwhere(np.array(data_source['labels']) == 1).flatten()
        self._neg_indices = np.argwhere(np.array(data_source['labels']) == 0).flatten()
        self._neg_num_samples = int(len(self._neg_indices) * neg_class_ratio)
        self._pos_num_samples = len(self._pos_indices)

    @property
    def num_samples(self):
        return self._pos_num_samples + self._neg_num_samples

    def __iter__(self):
        """
        Each time an iteration of this is requested, the resampling is done.
        """
        _neg_samples = self._random_gen.choice(self._neg_indices, self._neg_num_samples, replace=False)
        _samples = np.concatenate((_neg_samples, self._pos_indices), axis=0)
        self._random_gen.shuffle(_samples)
        if (len(_samples) != len(self)):
            raise ValueError("Length of output samples (%d) does not match expected (%d)", len(_samples), len(self))
        return iter(_samples.tolist())

    def __len__(self):
        return self.num_samples