barc_gradio / src /metrics /metrics.py
Nadine Rueegg
initial commit for barc
7629b39
# code from: https://github.com/benjiebob/WLDO/blob/master/wldo_regressor/metrics.py
import torch
import torch.nn.functional as F
import numpy as np
IMG_RES = 256 # in WLDO it is 224
class Metrics():
@staticmethod
def PCK_thresh(
pred_keypoints, gt_keypoints,
gtseg, has_seg,
thresh, idxs, biggs=False):
pred_keypoints, gt_keypoints, gtseg = pred_keypoints[has_seg], gt_keypoints[has_seg], gtseg[has_seg]
if idxs is None:
idxs = list(range(pred_keypoints.shape[1]))
idxs = np.array(idxs).astype(int)
pred_keypoints = pred_keypoints[:, idxs]
gt_keypoints = gt_keypoints[:, idxs]
if biggs:
keypoints_gt = ((gt_keypoints + 1.0) * 0.5) * IMG_RES
dist = torch.norm(pred_keypoints - keypoints_gt[:, :, [1, 0]], dim = -1)
else:
keypoints_gt = gt_keypoints # (0 to IMG_SIZE)
dist = torch.norm(pred_keypoints - keypoints_gt[:, :, :2], dim = -1)
seg_area = torch.sum(gtseg.reshape(gtseg.shape[0], -1), dim = -1).unsqueeze(-1)
hits = (dist / torch.sqrt(seg_area)) < thresh
total_visible = torch.sum(gt_keypoints[:, :, -1], dim = -1)
pck = torch.sum(hits.float() * gt_keypoints[:, :, -1], dim = -1) / total_visible
return pck
@staticmethod
def PCK(
pred_keypoints, keypoints,
gtseg, has_seg,
thresh_range=[0.15],
idxs:list=None,
biggs=False):
"""Calc PCK with same method as in eval.
idxs = optional list of subset of keypoints to index from
"""
cumulative_pck = []
for thresh in thresh_range:
pck = Metrics.PCK_thresh(
pred_keypoints, keypoints,
gtseg, has_seg, thresh, idxs,
biggs=biggs)
cumulative_pck.append(pck)
pck_mean = torch.stack(cumulative_pck, dim = 0).mean(dim=0)
return pck_mean
@staticmethod
def IOU(synth_silhouettes, gt_seg, img_border_mask, mask):
for i in range(mask.shape[0]):
synth_silhouettes[i] *= mask[i]
# Do not penalize parts of the segmentation outside the img range
gt_seg = (gt_seg * img_border_mask) + synth_silhouettes * (1.0 - img_border_mask)
intersection = torch.sum((synth_silhouettes * gt_seg).reshape(synth_silhouettes.shape[0], -1), dim = -1)
union = torch.sum(((synth_silhouettes + gt_seg).reshape(synth_silhouettes.shape[0], -1) > 0.0).float(), dim = -1)
acc_IOU_SCORE = intersection / union
if torch.isnan(acc_IOU_SCORE).sum() > 0:
import pdb; pdb.set_trace()
return acc_IOU_SCORE