File size: 5,513 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC, F1Score, MatthewsCorrCoef
from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef, BinaryF1Score
from torchmetrics.regression import SpearmanCorrCoef, MeanSquaredError
from torchmetrics.classification import MultilabelAveragePrecision


def count_f1_max(pred, target):
    """

    F1 score with the optimal threshold, Copied from TorchDrug.



    This function first enumerates all possible thresholds for deciding positive and negative

    samples, and then pick the threshold with the maximal F1 score.



    Parameters:

        pred (Tensor): predictions of shape :math:`(B, N)`

        target (Tensor): binary targets of shape :math:`(B, N)`

    """

    order = pred.argsort(descending=True, dim=1)
    target = target.gather(1, order)
    precision = target.cumsum(1) / torch.ones_like(target).cumsum(1)
    recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10)
    is_start = torch.zeros_like(target).bool()
    is_start[:, 0] = 1
    is_start = torch.scatter(is_start, 1, order, is_start)

    all_order = pred.flatten().argsort(descending=True)
    order = (
        order
        + torch.arange(order.shape[0], device=order.device).unsqueeze(1)
        * order.shape[1]
    )
    order = order.flatten()
    inv_order = torch.zeros_like(order)
    inv_order[order] = torch.arange(order.shape[0], device=order.device)
    is_start = is_start.flatten()[all_order]
    all_order = inv_order[all_order]
    precision = precision.flatten()
    recall = recall.flatten()
    all_precision = precision[all_order] - torch.where(
        is_start, torch.zeros_like(precision), precision[all_order - 1]
    )
    all_precision = all_precision.cumsum(0) / is_start.cumsum(0)
    all_recall = recall[all_order] - torch.where(
        is_start, torch.zeros_like(recall), recall[all_order - 1]
    )
    all_recall = all_recall.cumsum(0) / pred.shape[0]
    all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10)
    return all_f1.max()


class MultilabelF1Max(MultilabelAveragePrecision):

    def compute(self):
        return count_f1_max(torch.cat(self.preds), torch.cat(self.target))

def setup_metrics(args):
    """Setup metrics based on problem type and specified metrics list."""
    metrics_dict = {}
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    for metric_name in args.metrics:
        if args.problem_type == 'regression':
            metric_config = _setup_regression_metrics(metric_name, device)
        elif args.problem_type == 'single_label_classification':
            if args.num_labels == 2:
                metric_config = _setup_binary_metrics(metric_name, device)
            else:
                metric_config = _setup_multiclass_metrics(metric_name, args.num_labels, device)            
        elif args.problem_type == 'multi_label_classification':
            metric_config = _setup_multilabel_metrics(metric_name, args.num_labels, device)
            
        if metric_config:
            metrics_dict[metric_name] = metric_config['metric']
    
    # Add loss to metrics if it's the monitor metric
    if args.monitor == 'loss':
        metrics_dict['loss'] = 'loss'
        
    return metrics_dict

def _setup_regression_metrics(metric_name, device):
    metrics_config = {
        'spearman_corr': {
            'metric': SpearmanCorrCoef().to(device),
        },
        'mse': {
            'metric': MeanSquaredError().to(device),
        }
    }
    return metrics_config.get(metric_name)

def _setup_multiclass_metrics(metric_name, num_classes, device):
    metrics_config = {
        'accuracy': {
            'metric': Accuracy(task='multiclass', num_classes=num_classes).to(device),
        },
        'recall': {
            'metric': Recall(task='multiclass', num_classes=num_classes).to(device),
        },
        'precision': {
            'metric': Precision(task='multiclass', num_classes=num_classes).to(device),
        },
        'f1': {
            'metric': F1Score(task='multiclass', num_classes=num_classes).to(device),
        },
        'mcc': {
            'metric': MatthewsCorrCoef(task='multiclass', num_classes=num_classes).to(device),
        },
        'auroc': {
            'metric': AUROC(task='multiclass', num_classes=num_classes).to(device),
        }
    }
    return metrics_config.get(metric_name)

def _setup_binary_metrics(metric_name, device):
    metrics_config = {
        'accuracy': {
            'metric': BinaryAccuracy().to(device),
        },
        'recall': {
            'metric': BinaryRecall().to(device),
        },
        'precision': {
            'metric': BinaryPrecision().to(device),
        },
        'f1': {
            'metric': BinaryF1Score().to(device),
        },
        'mcc': {
            'metric': BinaryMatthewsCorrCoef().to(device),
        },
        'auroc': {
            'metric': BinaryAUROC().to(device),
        }
    }
    return metrics_config.get(metric_name)

def _setup_multilabel_metrics(metric_name, num_labels, device):
    metrics_config = {
        'f1_max': {
            'metric': MultilabelF1Max(num_labels=num_labels).to(device),
        }
    }
    return metrics_config.get(metric_name)