|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def confusion_matrix( |
|
input: torch.Tensor, target: torch.Tensor, num_classes: int, normalized: bool = False |
|
) -> torch.Tensor: |
|
r"""Compute confusion matrix to evaluate the accuracy of a classification. |
|
|
|
Args: |
|
input: tensor with estimated targets returned by a |
|
classifier. The shape can be :math:`(B, *)` and must contain integer |
|
values between 0 and K-1. |
|
target: tensor with ground truth (correct) target |
|
values. The shape can be :math:`(B, *)` and must contain integer |
|
values between 0 and K-1, where targets are assumed to be provided as |
|
one-hot vectors. |
|
num_classes: total possible number of classes in target. |
|
normalized: whether to return the confusion matrix normalized. |
|
|
|
Returns: |
|
a tensor containing the confusion matrix with shape |
|
:math:`(B, K, K)` where K is the number of classes. |
|
|
|
Example: |
|
>>> logits = torch.tensor([[0, 1, 0]]) |
|
>>> target = torch.tensor([[0, 1, 0]]) |
|
>>> confusion_matrix(logits, target, num_classes=3) |
|
tensor([[[2., 0., 0.], |
|
[0., 1., 0.], |
|
[0., 0., 0.]]]) |
|
""" |
|
if not torch.is_tensor(input) and input.dtype is not torch.int64: |
|
raise TypeError("Input input type is not a torch.Tensor with " "torch.int64 dtype. Got {}".format(type(input))) |
|
|
|
if not torch.is_tensor(target) and target.dtype is not torch.int64: |
|
raise TypeError( |
|
"Input target type is not a torch.Tensor with " "torch.int64 dtype. Got {}".format(type(target)) |
|
) |
|
if not input.shape == target.shape: |
|
raise ValueError( |
|
"Inputs input and target must have the same shape. " "Got: {} and {}".format(input.shape, target.shape) |
|
) |
|
if not input.device == target.device: |
|
raise ValueError("Inputs must be in the same device. " "Got: {} - {}".format(input.device, target.device)) |
|
|
|
if not isinstance(num_classes, int) or num_classes < 2: |
|
raise ValueError("The number of classes must be an integer bigger " "than two. Got: {}".format(num_classes)) |
|
|
|
batch_size: int = input.shape[0] |
|
|
|
|
|
|
|
pre_bincount: torch.Tensor = input + target * num_classes |
|
pre_bincount_vec: torch.Tensor = pre_bincount.view(batch_size, -1) |
|
|
|
confusion_list = [] |
|
for iter_id in range(batch_size): |
|
pb: torch.Tensor = pre_bincount_vec[iter_id] |
|
bin_count: torch.Tensor = torch.bincount(pb, minlength=num_classes ** 2) |
|
confusion_list.append(bin_count) |
|
|
|
confusion_vec: torch.Tensor = torch.stack(confusion_list) |
|
confusion_mat: torch.Tensor = confusion_vec.view(batch_size, num_classes, num_classes).to(torch.float32) |
|
|
|
if normalized: |
|
norm_val: torch.Tensor = torch.sum(confusion_mat, dim=1, keepdim=True) |
|
confusion_mat = confusion_mat / (norm_val + 1e-6) |
|
|
|
return confusion_mat |
|
|