ynhe
init
16dc4f2
raw
history blame
499 Bytes
import torch
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k
output: (#items, #classes)
target: int,
"""
maxk = max(topk)
num_items = output.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target)
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / num_items))
return res