mshukor
init
3eb682b
raw
history blame
3.73 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Functions for computing metrics."""
import torch
import numpy as np
def topks_correct(preds, labels, ks):
"""
Given the predictions, labels, and a list of top-k values, compute the
number of correct predictions for each top-k value.
Args:
preds (array): array of predictions. Dimension is batchsize
N x ClassNum.
labels (array): array of labels. Dimension is batchsize N.
ks (list): list of top-k values. For example, ks = [1, 5] correspods
to top-1 and top-5.
Returns:
topks_correct (list): list of numbers, where the `i`-th entry
corresponds to the number of top-`ks[i]` correct predictions.
"""
assert preds.size(0) == labels.size(
0
), "Batch dim of predictions and labels must match"
# Find the top max_k predictions for each sample
_top_max_k_vals, top_max_k_inds = torch.topk(
preds, max(ks), dim=1, largest=True, sorted=True
)
# (batch_size, max_k) -> (max_k, batch_size).
top_max_k_inds = top_max_k_inds.t()
# (batch_size, ) -> (max_k, batch_size).
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
# (i, j) = 1 if top i-th prediction for the j-th sample is correct.
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
# Compute the number of topk correct predictions for each k.
topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks]
return topks_correct
def topk_errors(preds, labels, ks):
"""
Computes the top-k error for each k.
Args:
preds (array): array of predictions. Dimension is N.
labels (array): array of labels. Dimension is N.
ks (list): list of ks to calculate the top accuracies.
"""
num_topks_correct = topks_correct(preds, labels, ks)
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
def topk_accuracies(preds, labels, ks):
"""
Computes the top-k accuracy for each k.
Args:
preds (array): array of predictions. Dimension is N.
labels (array): array of labels. Dimension is N.
ks (list): list of ks to calculate the top accuracies.
"""
num_topks_correct = topks_correct(preds, labels, ks)
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
def multitask_topks_correct(preds, labels, ks=(1,)):
"""
Args:
preds: tuple(torch.FloatTensor), each tensor should be of shape
[batch_size, class_count], class_count can vary on a per task basis, i.e.
outputs[i].shape[1] can be different to outputs[j].shape[j].
labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size]
ks: tuple(int), compute accuracy at top-k for the values of k specified
in this parameter.
Returns:
tuple(float), same length at topk with the corresponding accuracy@k in.
"""
max_k = int(np.max(ks))
task_count = len(preds)
batch_size = labels[0].size(0)
all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor)
if torch.cuda.is_available():
all_correct = all_correct.cuda()
for output, label in zip(preds, labels):
_, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True)
# Flip batch_size, class_count as .view doesn't work on non-contiguous
max_k_idx = max_k_idx.t()
correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx))
all_correct.add_(correct_for_task)
multitask_topks_correct = [
torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks
]
return multitask_topks_correct