|
import numpy as np |
|
from scipy import stats |
|
from sklearn import metrics |
|
import torch |
|
|
|
|
|
def d_prime(auc): |
|
standard_normal = stats.norm() |
|
d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) |
|
return d_prime |
|
|
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
tensors_gather = [ |
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|
|
|
|
def calculate_stats(output, target): |
|
"""Calculate statistics including mAP, AUC, etc. |
|
|
|
Args: |
|
output: 2d array, (samples_num, classes_num) |
|
target: 2d array, (samples_num, classes_num) |
|
|
|
Returns: |
|
stats: list of statistic of each class. |
|
""" |
|
|
|
classes_num = target.shape[-1] |
|
stats = [] |
|
|
|
|
|
acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1)) |
|
|
|
|
|
for k in range(classes_num): |
|
|
|
|
|
avg_precision = metrics.average_precision_score( |
|
target[:, k], output[:, k], average=None |
|
) |
|
|
|
|
|
|
|
|
|
|
|
(precisions, recalls, thresholds) = metrics.precision_recall_curve( |
|
target[:, k], output[:, k] |
|
) |
|
|
|
|
|
(fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k]) |
|
|
|
save_every_steps = 1000 |
|
dict = { |
|
"precisions": precisions[0::save_every_steps], |
|
"recalls": recalls[0::save_every_steps], |
|
"AP": avg_precision, |
|
"fpr": fpr[0::save_every_steps], |
|
"fnr": 1.0 - tpr[0::save_every_steps], |
|
|
|
|
|
"acc": acc, |
|
} |
|
stats.append(dict) |
|
|
|
return stats |
|
|