Spaces:
Runtime error
Runtime error
# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
IMG_RES = 256 # in WLDO it is 224 | |
class Metrics(): | |
def PCK_thresh( | |
pred_keypoints, gt_keypoints, | |
gtseg, has_seg, | |
thresh, idxs, biggs=False): | |
pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg] | |
if idxs is None: | |
idxs = list(range(pred_keypoints.shape[1])) | |
idxs = np.array(idxs).astype(int) | |
pred_keypoints = pred_keypoints[:, idxs] | |
gt_keypoints = gt_keypoints[:, idxs] | |
if biggs: | |
keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES | |
dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1) | |
else: | |
keypoints_gt = gt_keypoints # (0 to IMG_SIZE) | |
dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1) | |
seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1) | |
hits = (dist / torch.sqrt(seg_area)) < thresh | |
total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1) | |
pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible | |
return pck | |
def PCK( | |
pred_keypoints, keypoints, | |
gtseg, has_seg, | |
thresh_range=[0.15], | |
idxs:list=None, | |
biggs=False): | |
"""Calc PCK with same method as in eval. | |
idxs = optional list of subset of keypoints to index from | |
""" | |
cumulative_pck = [] | |
for thresh in thresh_range: | |
pck = Metrics.PCK_thresh( | |
pred_keypoints, keypoints, | |
gtseg, has_seg, thresh, idxs, | |
biggs=biggs) | |
cumulative_pck.append(pck) | |
pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0) | |
return pck_mean | |
def IOU(synth_silhouettes, gt_seg, img_border_mask, mask): | |
for i in range(mask.shape[0]): | |
synth_silhouettes[i] *= mask[i] | |
# Do not penalize parts of the segmentation outside the img range | |
gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask) | |
intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1) | |
union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1) | |
acc_IOU_SCORE = intersection / union | |
if torch.isnan(acc_IOU_SCORE).sum() > 0: | |
import pdb; pdb.set_trace() | |
return acc_IOU_SCORE |