geopavlakos's picture
Initial commit
d7a991a
raw
history blame
13.6 kB
# ------------------------------------------------------------------------------
# Adapted from https://github.com/princeton-vl/pose-ae-train/
# Original licence: Copyright (c) 2017, umich-vl, under BSD 3-Clause License.
# ------------------------------------------------------------------------------
import numpy as np
import torch
from munkres import Munkres
from mmpose.core.evaluation import post_dark_udp
def _py_max_match(scores):
"""Apply munkres algorithm to get the best match.
Args:
scores(np.ndarray): cost matrix.
Returns:
np.ndarray: best match.
"""
m = Munkres()
tmp = m.compute(scores)
tmp = np.array(tmp).astype(int)
return tmp
def _match_by_tag(inp, params):
"""Match joints by tags. Use Munkres algorithm to calculate the best match
for keypoints grouping.
Note:
number of keypoints: K
max number of people in an image: M (M=30 by default)
dim of tags: L
If use flip testing, L=2; else L=1.
Args:
inp(tuple):
tag_k (np.ndarray[KxMxL]): tag corresponding to the
top k values of feature map per keypoint.
loc_k (np.ndarray[KxMx2]): top k locations of the
feature maps for keypoint.
val_k (np.ndarray[KxM]): top k value of the
feature maps per keypoint.
params(Params): class Params().
Returns:
np.ndarray: result of pose groups.
"""
assert isinstance(params, _Params), 'params should be class _Params()'
tag_k, loc_k, val_k = inp
default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2]),
dtype=np.float32)
joint_dict = {}
tag_dict = {}
for i in range(params.num_joints):
idx = params.joint_order[i]
tags = tag_k[idx]
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
mask = joints[:, 2] > params.detection_threshold
tags = tags[mask]
joints = joints[mask]
if joints.shape[0] == 0:
continue
if i == 0 or len(joint_dict) == 0:
for tag, joint in zip(tags, joints):
key = tag[0]
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
tag_dict[key] = [tag]
else:
grouped_keys = list(joint_dict.keys())[:params.max_num_people]
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]
if (params.ignore_too_much
and len(grouped_keys) == params.max_num_people):
continue
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
diff_saved = np.copy(diff_normed)
if params.use_detection_val:
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]
num_added = diff.shape[0]
num_grouped = diff.shape[1]
if num_added > num_grouped:
diff_normed = np.concatenate(
(diff_normed,
np.zeros((num_added, num_added - num_grouped),
dtype=np.float32) + 1e10),
axis=1)
pairs = _py_max_match(diff_normed)
for row, col in pairs:
if (row < num_added and col < num_grouped
and diff_saved[row][col] < params.tag_threshold):
key = grouped_keys[col]
joint_dict[key][idx] = joints[row]
tag_dict[key].append(tags[row])
else:
key = tags[row][0]
joint_dict.setdefault(key, np.copy(default_))[idx] = \
joints[row]
tag_dict[key] = [tags[row]]
results = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32)
return results
class _Params:
"""A class of parameter.
Args:
cfg(Config): config.
"""
def __init__(self, cfg):
self.num_joints = cfg['num_joints']
self.max_num_people = cfg['max_num_people']
self.detection_threshold = cfg['detection_threshold']
self.tag_threshold = cfg['tag_threshold']
self.use_detection_val = cfg['use_detection_val']
self.ignore_too_much = cfg['ignore_too_much']
if self.num_joints == 17:
self.joint_order = [
i - 1 for i in
[1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17]
]
else:
self.joint_order = list(np.arange(self.num_joints))
class HeatmapParser:
"""The heatmap parser for post processing."""
def __init__(self, cfg):
self.params = _Params(cfg)
self.tag_per_joint = cfg['tag_per_joint']
self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1,
cfg['nms_padding'])
self.use_udp = cfg.get('use_udp', False)
self.score_per_joint = cfg.get('score_per_joint', False)
def nms(self, heatmaps):
"""Non-Maximum Suppression for heatmaps.
Args:
heatmap(torch.Tensor): Heatmaps before nms.
Returns:
torch.Tensor: Heatmaps after nms.
"""
maxm = self.pool(heatmaps)
maxm = torch.eq(maxm, heatmaps).float()
heatmaps = heatmaps * maxm
return heatmaps
def match(self, tag_k, loc_k, val_k):
"""Group keypoints to human poses in a batch.
Args:
tag_k (np.ndarray[NxKxMxL]): tag corresponding to the
top k values of feature map per keypoint.
loc_k (np.ndarray[NxKxMx2]): top k locations of the
feature maps for keypoint.
val_k (np.ndarray[NxKxM]): top k value of the
feature maps per keypoint.
Returns:
list
"""
def _match(x):
return _match_by_tag(x, self.params)
return list(map(_match, zip(tag_k, loc_k, val_k)))
def top_k(self, heatmaps, tags):
"""Find top_k values in an image.
Note:
batch size: N
number of keypoints: K
heatmap height: H
heatmap width: W
max number of people: M
dim of tags: L
If use flip testing, L=2; else L=1.
Args:
heatmaps (torch.Tensor[NxKxHxW])
tags (torch.Tensor[NxKxHxWxL])
Returns:
dict: A dict containing top_k values.
- tag_k (np.ndarray[NxKxMxL]):
tag corresponding to the top k values of
feature map per keypoint.
- loc_k (np.ndarray[NxKxMx2]):
top k location of feature map per keypoint.
- val_k (np.ndarray[NxKxM]):
top k value of feature map per keypoint.
"""
heatmaps = self.nms(heatmaps)
N, K, H, W = heatmaps.size()
heatmaps = heatmaps.view(N, K, -1)
val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2)
tags = tags.view(tags.size(0), tags.size(1), W * H, -1)
if not self.tag_per_joint:
tags = tags.expand(-1, self.params.num_joints, -1, -1)
tag_k = torch.stack(
[torch.gather(tags[..., i], 2, ind) for i in range(tags.size(3))],
dim=3)
x = ind % W
y = ind // W
ind_k = torch.stack((x, y), dim=3)
results = {
'tag_k': tag_k.cpu().numpy(),
'loc_k': ind_k.cpu().numpy(),
'val_k': val_k.cpu().numpy()
}
return results
@staticmethod
def adjust(results, heatmaps):
"""Adjust the coordinates for better accuracy.
Note:
batch size: N
number of keypoints: K
heatmap height: H
heatmap width: W
Args:
results (list(np.ndarray)): Keypoint predictions.
heatmaps (torch.Tensor[NxKxHxW]): Heatmaps.
"""
_, _, H, W = heatmaps.shape
for batch_id, people in enumerate(results):
for people_id, people_i in enumerate(people):
for joint_id, joint in enumerate(people_i):
if joint[2] > 0:
x, y = joint[0:2]
xx, yy = int(x), int(y)
tmp = heatmaps[batch_id][joint_id]
if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1),
xx]:
y += 0.25
else:
y -= 0.25
if tmp[yy, min(W - 1, xx + 1)] > tmp[yy,
max(0, xx - 1)]:
x += 0.25
else:
x -= 0.25
results[batch_id][people_id, joint_id,
0:2] = (x + 0.5, y + 0.5)
return results
@staticmethod
def refine(heatmap, tag, keypoints, use_udp=False):
"""Given initial keypoint predictions, we identify missing joints.
Note:
number of keypoints: K
heatmap height: H
heatmap width: W
dim of tags: L
If use flip testing, L=2; else L=1.
Args:
heatmap: np.ndarray(K, H, W).
tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L)
keypoints: np.ndarray of size (K, 3 + L)
last dim is (x, y, score, tag).
use_udp: bool-unbiased data processing
Returns:
np.ndarray: The refined keypoints.
"""
K, H, W = heatmap.shape
if len(tag.shape) == 3:
tag = tag[..., None]
tags = []
for i in range(K):
if keypoints[i, 2] > 0:
# save tag value of detected keypoint
x, y = keypoints[i][:2].astype(int)
x = np.clip(x, 0, W - 1)
y = np.clip(y, 0, H - 1)
tags.append(tag[i, y, x])
# mean tag of current detected people
prev_tag = np.mean(tags, axis=0)
results = []
for _heatmap, _tag in zip(heatmap, tag):
# distance of all tag values with mean tag of
# current detected people
distance_tag = (((_tag -
prev_tag[None, None, :])**2).sum(axis=2)**0.5)
norm_heatmap = _heatmap - np.round(distance_tag)
# find maximum position
y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape)
xx = x.copy()
yy = y.copy()
# detection score at maximum position
val = _heatmap[y, x]
if not use_udp:
# offset by 0.5
x += 0.5
y += 0.5
# add a quarter offset
if _heatmap[yy, min(W - 1, xx + 1)] > _heatmap[yy, max(0, xx - 1)]:
x += 0.25
else:
x -= 0.25
if _heatmap[min(H - 1, yy + 1), xx] > _heatmap[max(0, yy - 1), xx]:
y += 0.25
else:
y -= 0.25
results.append((x, y, val))
results = np.array(results)
if results is not None:
for i in range(K):
# add keypoint if it is not detected
if results[i, 2] > 0 and keypoints[i, 2] == 0:
keypoints[i, :3] = results[i, :3]
return keypoints
def parse(self, heatmaps, tags, adjust=True, refine=True):
"""Group keypoints into poses given heatmap and tag.
Note:
batch size: N
number of keypoints: K
heatmap height: H
heatmap width: W
dim of tags: L
If use flip testing, L=2; else L=1.
Args:
heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps.
tags (torch.Tensor[NxKxHxWxL]): model output tagmaps.
Returns:
tuple: A tuple containing keypoint grouping results.
- results (list(np.ndarray)): Pose results.
- scores (list/list(np.ndarray)): Score of people.
"""
results = self.match(**self.top_k(heatmaps, tags))
if adjust:
if self.use_udp:
for i in range(len(results)):
if results[i].shape[0] > 0:
results[i][..., :2] = post_dark_udp(
results[i][..., :2].copy(), heatmaps[i:i + 1, :])
else:
results = self.adjust(results, heatmaps)
if self.score_per_joint:
scores = [i[:, 2] for i in results[0]]
else:
scores = [i[:, 2].mean() for i in results[0]]
if refine:
results = results[0]
# for every detected person
for i in range(len(results)):
heatmap_numpy = heatmaps[0].cpu().numpy()
tag_numpy = tags[0].cpu().numpy()
if not self.tag_per_joint:
tag_numpy = np.tile(tag_numpy,
(self.params.num_joints, 1, 1, 1))
results[i] = self.refine(
heatmap_numpy, tag_numpy, results[i], use_udp=self.use_udp)
results = [results]
return results, scores