# Modified from: # https://github.com/anibali/pytorch-stacked-hourglass # https://github.com/bearpaw/pytorch-pose import math import torch from kornia.geometry.subpix import dsnt # kornia 0.4.0 import torch.nn.functional as F from .transforms import transform_preds __all__ = ['get_preds', 'get_preds_soft', 'calc_dists', 'dist_acc', 'accuracy', 'final_preds_untransformed', 'final_preds', 'AverageMeter'] def get_preds(scores, return_maxval=False): ''' get predictions from score maps in torch Tensor return type: torch.LongTensor ''' assert scores.dim() == 4, 'Score maps should be 4-dim' maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) maxval = maxval.view(scores.size(0), scores.size(1), 1) idx = idx.view(scores.size(0), scores.size(1), 1) + 1 preds = idx.repeat(1, 1, 2).float() preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1 pred_mask = maxval.gt(0).repeat(1, 1, 2).float() # values > 0 preds *= pred_mask if return_maxval: return preds, maxval else: return preds def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False): ''' get predictions from score maps in torch Tensor predictions are made assuming a logit output map return type: torch.LongTensor ''' # New: work on logit predictions scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1)) # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2) # from unnormalized to normalized see: # from -1to1 to 0to64 # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40) # xs = (xs / (width - 1) - 0.5) * 2 # ys = (ys / (height - 1) - 0.5) * 2 device = scores.device if return_maxval: preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) # grid_sample(input, grid, mode='bilinear', padding_mode='zeros') gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3])) # (120, 1, 64, 64) half_pad = 2 gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0) gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device) ind_tot = 0 for ind0 in [-1, 0, 1]: for ind1 in [-1, 0, 1]: gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1] ind_tot +=1 gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :] # (120, 1, 1, 2) gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1)) gs_output = gs_output_all.sum(axis=1) # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0] # gs_output[0, :, 0] gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1)) if norm_and_unnorm_coords: preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 return preds, preds_normalized, gs_output_resh elif norm_coords: return preds_normalized, gs_output_resh else: preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 return preds, gs_output_resh else: if norm_coords: preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) return preds_normalized else: preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1 return preds def calc_dists(preds, target, normalize): preds = preds.float() target = target.float() dists = torch.zeros(preds.size(1), preds.size(0)) for n in range(preds.size(0)): for c in range(preds.size(1)): if target[n,c,0] > 1 and target[n, c, 1] > 1: dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] else: dists[c, n] = -1 return dists def dist_acc(dist, thr=0.5): ''' Return percentage below threshold while ignoring values with a -1 ''' dist = dist[dist != -1] if len(dist) > 0: return 1.0 * (dist < thr).sum().item() / len(dist) else: return -1 def accuracy(output, target, idxs=None, thr=0.5): ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations First value to be returned is average accuracy across 'idxs', followed by individual accuracies ''' if idxs is None: idxs = list(range(target.shape[-3])) preds = get_preds_soft(output) # get_preds(output) gts = get_preds(target) norm = torch.ones(preds.size(0))*output.size(3)/10 dists = calc_dists(preds, gts, norm) acc = torch.zeros(len(idxs)+1) avg_acc = 0 cnt = 0 for i in range(len(idxs)): acc[i+1] = dist_acc(dists[idxs[i]], thr=thr) if acc[i+1] >= 0: avg_acc = avg_acc + acc[i+1] cnt += 1 if cnt != 0: acc[0] = avg_acc / cnt return acc def final_preds_untransformed(output, res): coords = get_preds_soft(output) # get_preds(output) # float type # pose-processing for n in range(coords.size(0)): for p in range(coords.size(1)): hm = output[n][p] px = int(math.floor(coords[n][p][0])) py = int(math.floor(coords[n][p][1])) if px > 1 and px < res[0] and py > 1 and py < res[1]: diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) coords[n][p] += diff.sign() * .25 coords += 0.5 if coords.dim() < 3: coords = coords.unsqueeze(0) coords -= 1 # Convert from 1-based to 0-based coordinates return coords def final_preds(output, center, scale, res): coords = final_preds_untransformed(output, res) preds = coords.clone() # Transform back for i in range(coords.size(0)): preds[i] = transform_preds(coords[i], center[i], scale[i], res) if preds.dim() < 3: preds = preds.unsqueeze(0) return preds class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count