File size: 2,698 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py


import torch
import torch.nn.functional as F
import numpy as np

IMG_RES = 256   # in WLDO it is 224

class Metrics():

    @staticmethod
    def PCK_thresh(
        pred_keypoints, gt_keypoints,
        gtseg, has_seg,
        thresh, idxs, biggs=False):

        pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg]

        if idxs is None:
            idxs = list(range(pred_keypoints.shape[1]))

        idxs = np.array(idxs).astype(int)

        pred_keypoints = pred_keypoints[:, idxs]
        gt_keypoints = gt_keypoints[:, idxs]

        if biggs: 
            keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES  
            dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1)
        else:
            keypoints_gt = gt_keypoints     # (0 to IMG_SIZE)
            dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1)

        seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1)

        hits = (dist / torch.sqrt(seg_area)) < thresh
        total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1)
        pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible

        return pck

    @staticmethod
    def PCK(
        pred_keypoints, keypoints, 
        gtseg, has_seg, 
        thresh_range=[0.15],
        idxs:list=None,
        biggs=False):
        """Calc PCK with same method as in eval.
        idxs = optional list of subset of keypoints to index from
        """
        cumulative_pck = []
        for thresh in thresh_range:
            pck = Metrics.PCK_thresh(
                pred_keypoints, keypoints, 
                gtseg, has_seg, thresh, idxs, 
                biggs=biggs)
            cumulative_pck.append(pck)
        pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0)
        return pck_mean

    @staticmethod
    def IOU(synth_silhouettes, gt_seg, img_border_mask, mask):
        for i in range(mask.shape[0]):
            synth_silhouettes[i] *= mask[i]
        # Do not penalize parts of the segmentation outside the img range
        gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask)
        intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1)
        union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1)
        acc_IOU_SCORE = intersection / union
        if torch.isnan(acc_IOU_SCORE).sum() > 0:
            import pdb; pdb.set_trace()
        return acc_IOU_SCORE