File size: 1,033 Bytes
2366e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def compute_f1_score(preds, gts, ignores=[]):
    """Compute the F1-score of prediction.

    Args:
        preds (Tensor): The predicted probability NxC map
            with N and C being the sample number and class
            number respectively.
        gts (Tensor): The ground truth vector of size N.
        ignores (list): The index set of classes that are ignored when
            reporting results.
            Note: all samples are participated in computing.

     Returns:
        The numpy list of f1-scores of valid classes.
    """
    C = preds.size(1)
    classes = torch.LongTensor(sorted(set(range(C)) - set(ignores)))
    hist = torch.bincount(
        gts * C + preds.argmax(1), minlength=C**2).view(C, C).float()
    diag = torch.diag(hist)
    recalls = diag / hist.sum(1).clamp(min=1)
    precisions = diag / hist.sum(0).clamp(min=1)
    f1 = 2 * recalls * precisions / (recalls + precisions).clamp(min=1e-8)
    return f1[classes].cpu().numpy()