|
|
|
|
|
"""Functions for computing metrics.""" |
|
|
|
import torch |
|
import numpy as np |
|
|
|
def topks_correct(preds, labels, ks): |
|
""" |
|
Given the predictions, labels, and a list of top-k values, compute the |
|
number of correct predictions for each top-k value. |
|
|
|
Args: |
|
preds (array): array of predictions. Dimension is batchsize |
|
N x ClassNum. |
|
labels (array): array of labels. Dimension is batchsize N. |
|
ks (list): list of top-k values. For example, ks = [1, 5] correspods |
|
to top-1 and top-5. |
|
|
|
Returns: |
|
topks_correct (list): list of numbers, where the `i`-th entry |
|
corresponds to the number of top-`ks[i]` correct predictions. |
|
""" |
|
assert preds.size(0) == labels.size( |
|
0 |
|
), "Batch dim of predictions and labels must match" |
|
|
|
_top_max_k_vals, top_max_k_inds = torch.topk( |
|
preds, max(ks), dim=1, largest=True, sorted=True |
|
) |
|
|
|
top_max_k_inds = top_max_k_inds.t() |
|
|
|
rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) |
|
|
|
top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) |
|
|
|
topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] |
|
return topks_correct |
|
|
|
|
|
def topk_errors(preds, labels, ks): |
|
""" |
|
Computes the top-k error for each k. |
|
Args: |
|
preds (array): array of predictions. Dimension is N. |
|
labels (array): array of labels. Dimension is N. |
|
ks (list): list of ks to calculate the top accuracies. |
|
""" |
|
num_topks_correct = topks_correct(preds, labels, ks) |
|
return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] |
|
|
|
|
|
def topk_accuracies(preds, labels, ks): |
|
""" |
|
Computes the top-k accuracy for each k. |
|
Args: |
|
preds (array): array of predictions. Dimension is N. |
|
labels (array): array of labels. Dimension is N. |
|
ks (list): list of ks to calculate the top accuracies. |
|
""" |
|
num_topks_correct = topks_correct(preds, labels, ks) |
|
return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] |
|
|
|
def multitask_topks_correct(preds, labels, ks=(1,)): |
|
""" |
|
Args: |
|
preds: tuple(torch.FloatTensor), each tensor should be of shape |
|
[batch_size, class_count], class_count can vary on a per task basis, i.e. |
|
outputs[i].shape[1] can be different to outputs[j].shape[j]. |
|
labels: tuple(torch.LongTensor), each tensor should be of shape [batch_size] |
|
ks: tuple(int), compute accuracy at top-k for the values of k specified |
|
in this parameter. |
|
Returns: |
|
tuple(float), same length at topk with the corresponding accuracy@k in. |
|
""" |
|
max_k = int(np.max(ks)) |
|
task_count = len(preds) |
|
batch_size = labels[0].size(0) |
|
all_correct = torch.zeros(max_k, batch_size).type(torch.ByteTensor) |
|
if torch.cuda.is_available(): |
|
all_correct = all_correct.cuda() |
|
for output, label in zip(preds, labels): |
|
_, max_k_idx = output.topk(max_k, dim=1, largest=True, sorted=True) |
|
|
|
max_k_idx = max_k_idx.t() |
|
correct_for_task = max_k_idx.eq(label.view(1, -1).expand_as(max_k_idx)) |
|
all_correct.add_(correct_for_task) |
|
|
|
multitask_topks_correct = [ |
|
torch.ge(all_correct[:k].float().sum(0), task_count).float().sum(0) for k in ks |
|
] |
|
|
|
return multitask_topks_correct |
|
|