File size: 2,260 Bytes
5085882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import numpy as np
from scipy import stats
from sklearn import metrics
import torch
def d_prime(auc):
standard_normal = stats.norm()
d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
return d_prime
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def calculate_stats(output, target):
"""Calculate statistics including mAP, AUC, etc.
Args:
output: 2d array, (samples_num, classes_num)
target: 2d array, (samples_num, classes_num)
Returns:
stats: list of statistic of each class.
"""
classes_num = target.shape[-1]
stats = []
# Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
# Class-wise statistics
for k in range(classes_num):
# Average precision
avg_precision = metrics.average_precision_score(
target[:, k], output[:, k], average=None
)
# AUC
# auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
# Precisions, recalls
(precisions, recalls, thresholds) = metrics.precision_recall_curve(
target[:, k], output[:, k]
)
# FPR, TPR
(fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
save_every_steps = 1000 # Sample statistics to reduce size
dict = {
"precisions": precisions[0::save_every_steps],
"recalls": recalls[0::save_every_steps],
"AP": avg_precision,
"fpr": fpr[0::save_every_steps],
"fnr": 1.0 - tpr[0::save_every_steps],
# 'auc': auc,
# note acc is not class-wise, this is just to keep consistent with other metrics
"acc": acc,
}
stats.append(dict)
return stats
|