File size: 3,108 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch

# Inspired by:
# https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py#L68-L73


def confusion_matrix(
    input: torch.Tensor, target: torch.Tensor, num_classes: int, normalized: bool = False
) -> torch.Tensor:
    r"""Compute confusion matrix to evaluate the accuracy of a classification.

    Args:
        input: tensor with estimated targets returned by a
          classifier. The shape can be :math:`(B, *)` and must contain integer
          values between 0 and K-1.
        target: tensor with ground truth (correct) target
          values. The shape can be :math:`(B, *)` and must contain integer
          values between 0 and K-1, where targets are assumed to be provided as
          one-hot vectors.
        num_classes: total possible number of classes in target.
        normalized: whether to return the confusion matrix normalized.

    Returns:
        a tensor containing the confusion matrix with shape
        :math:`(B, K, K)` where K is the number of classes.

    Example:
        >>> logits = torch.tensor([[0, 1, 0]])
        >>> target = torch.tensor([[0, 1, 0]])
        >>> confusion_matrix(logits, target, num_classes=3)
        tensor([[[2., 0., 0.],
                 [0., 1., 0.],
                 [0., 0., 0.]]])
    """
    if not torch.is_tensor(input) and input.dtype is not torch.int64:
        raise TypeError("Input input type is not a torch.Tensor with " "torch.int64 dtype. Got {}".format(type(input)))

    if not torch.is_tensor(target) and target.dtype is not torch.int64:
        raise TypeError(
            "Input target type is not a torch.Tensor with " "torch.int64 dtype. Got {}".format(type(target))
        )
    if not input.shape == target.shape:
        raise ValueError(
            "Inputs input and target must have the same shape. " "Got: {} and {}".format(input.shape, target.shape)
        )
    if not input.device == target.device:
        raise ValueError("Inputs must be in the same device. " "Got: {} - {}".format(input.device, target.device))

    if not isinstance(num_classes, int) or num_classes < 2:
        raise ValueError("The number of classes must be an integer bigger " "than two. Got: {}".format(num_classes))

    batch_size: int = input.shape[0]

    # hack for bitcounting 2 arrays together
    # NOTE: torch.bincount does not implement batched version
    pre_bincount: torch.Tensor = input + target * num_classes
    pre_bincount_vec: torch.Tensor = pre_bincount.view(batch_size, -1)

    confusion_list = []
    for iter_id in range(batch_size):
        pb: torch.Tensor = pre_bincount_vec[iter_id]
        bin_count: torch.Tensor = torch.bincount(pb, minlength=num_classes ** 2)
        confusion_list.append(bin_count)

    confusion_vec: torch.Tensor = torch.stack(confusion_list)
    confusion_mat: torch.Tensor = confusion_vec.view(batch_size, num_classes, num_classes).to(torch.float32)  # BxKxK

    if normalized:
        norm_val: torch.Tensor = torch.sum(confusion_mat, dim=1, keepdim=True)
        confusion_mat = confusion_mat / (norm_val + 1e-6)

    return confusion_mat