Spaces:
Build error
Build error
# ------------------------------------------------------------------------------ | |
# Adapted from https://github.com/princeton-vl/pose-ae-train/ | |
# Original licence: Copyright (c) 2017, umich-vl, under BSD 3-Clause License. | |
# ------------------------------------------------------------------------------ | |
import numpy as np | |
import torch | |
from munkres import Munkres | |
from mmpose.core.evaluation import post_dark_udp | |
def _py_max_match(scores): | |
"""Apply munkres algorithm to get the best match. | |
Args: | |
scores(np.ndarray): cost matrix. | |
Returns: | |
np.ndarray: best match. | |
""" | |
m = Munkres() | |
tmp = m.compute(scores) | |
tmp = np.array(tmp).astype(int) | |
return tmp | |
def _match_by_tag(inp, params): | |
"""Match joints by tags. Use Munkres algorithm to calculate the best match | |
for keypoints grouping. | |
Note: | |
number of keypoints: K | |
max number of people in an image: M (M=30 by default) | |
dim of tags: L | |
If use flip testing, L=2; else L=1. | |
Args: | |
inp(tuple): | |
tag_k (np.ndarray[KxMxL]): tag corresponding to the | |
top k values of feature map per keypoint. | |
loc_k (np.ndarray[KxMx2]): top k locations of the | |
feature maps for keypoint. | |
val_k (np.ndarray[KxM]): top k value of the | |
feature maps per keypoint. | |
params(Params): class Params(). | |
Returns: | |
np.ndarray: result of pose groups. | |
""" | |
assert isinstance(params, _Params), 'params should be class _Params()' | |
tag_k, loc_k, val_k = inp | |
default_ = np.zeros((params.num_joints, 3 + tag_k.shape[2]), | |
dtype=np.float32) | |
joint_dict = {} | |
tag_dict = {} | |
for i in range(params.num_joints): | |
idx = params.joint_order[i] | |
tags = tag_k[idx] | |
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1) | |
mask = joints[:, 2] > params.detection_threshold | |
tags = tags[mask] | |
joints = joints[mask] | |
if joints.shape[0] == 0: | |
continue | |
if i == 0 or len(joint_dict) == 0: | |
for tag, joint in zip(tags, joints): | |
key = tag[0] | |
joint_dict.setdefault(key, np.copy(default_))[idx] = joint | |
tag_dict[key] = [tag] | |
else: | |
grouped_keys = list(joint_dict.keys())[:params.max_num_people] | |
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys] | |
if (params.ignore_too_much | |
and len(grouped_keys) == params.max_num_people): | |
continue | |
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :] | |
diff_normed = np.linalg.norm(diff, ord=2, axis=2) | |
diff_saved = np.copy(diff_normed) | |
if params.use_detection_val: | |
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3] | |
num_added = diff.shape[0] | |
num_grouped = diff.shape[1] | |
if num_added > num_grouped: | |
diff_normed = np.concatenate( | |
(diff_normed, | |
np.zeros((num_added, num_added - num_grouped), | |
dtype=np.float32) + 1e10), | |
axis=1) | |
pairs = _py_max_match(diff_normed) | |
for row, col in pairs: | |
if (row < num_added and col < num_grouped | |
and diff_saved[row][col] < params.tag_threshold): | |
key = grouped_keys[col] | |
joint_dict[key][idx] = joints[row] | |
tag_dict[key].append(tags[row]) | |
else: | |
key = tags[row][0] | |
joint_dict.setdefault(key, np.copy(default_))[idx] = \ | |
joints[row] | |
tag_dict[key] = [tags[row]] | |
results = np.array([joint_dict[i] for i in joint_dict]).astype(np.float32) | |
return results | |
class _Params: | |
"""A class of parameter. | |
Args: | |
cfg(Config): config. | |
""" | |
def __init__(self, cfg): | |
self.num_joints = cfg['num_joints'] | |
self.max_num_people = cfg['max_num_people'] | |
self.detection_threshold = cfg['detection_threshold'] | |
self.tag_threshold = cfg['tag_threshold'] | |
self.use_detection_val = cfg['use_detection_val'] | |
self.ignore_too_much = cfg['ignore_too_much'] | |
if self.num_joints == 17: | |
self.joint_order = [ | |
i - 1 for i in | |
[1, 2, 3, 4, 5, 6, 7, 12, 13, 8, 9, 10, 11, 14, 15, 16, 17] | |
] | |
else: | |
self.joint_order = list(np.arange(self.num_joints)) | |
class HeatmapParser: | |
"""The heatmap parser for post processing.""" | |
def __init__(self, cfg): | |
self.params = _Params(cfg) | |
self.tag_per_joint = cfg['tag_per_joint'] | |
self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1, | |
cfg['nms_padding']) | |
self.use_udp = cfg.get('use_udp', False) | |
self.score_per_joint = cfg.get('score_per_joint', False) | |
def nms(self, heatmaps): | |
"""Non-Maximum Suppression for heatmaps. | |
Args: | |
heatmap(torch.Tensor): Heatmaps before nms. | |
Returns: | |
torch.Tensor: Heatmaps after nms. | |
""" | |
maxm = self.pool(heatmaps) | |
maxm = torch.eq(maxm, heatmaps).float() | |
heatmaps = heatmaps * maxm | |
return heatmaps | |
def match(self, tag_k, loc_k, val_k): | |
"""Group keypoints to human poses in a batch. | |
Args: | |
tag_k (np.ndarray[NxKxMxL]): tag corresponding to the | |
top k values of feature map per keypoint. | |
loc_k (np.ndarray[NxKxMx2]): top k locations of the | |
feature maps for keypoint. | |
val_k (np.ndarray[NxKxM]): top k value of the | |
feature maps per keypoint. | |
Returns: | |
list | |
""" | |
def _match(x): | |
return _match_by_tag(x, self.params) | |
return list(map(_match, zip(tag_k, loc_k, val_k))) | |
def top_k(self, heatmaps, tags): | |
"""Find top_k values in an image. | |
Note: | |
batch size: N | |
number of keypoints: K | |
heatmap height: H | |
heatmap width: W | |
max number of people: M | |
dim of tags: L | |
If use flip testing, L=2; else L=1. | |
Args: | |
heatmaps (torch.Tensor[NxKxHxW]) | |
tags (torch.Tensor[NxKxHxWxL]) | |
Returns: | |
dict: A dict containing top_k values. | |
- tag_k (np.ndarray[NxKxMxL]): | |
tag corresponding to the top k values of | |
feature map per keypoint. | |
- loc_k (np.ndarray[NxKxMx2]): | |
top k location of feature map per keypoint. | |
- val_k (np.ndarray[NxKxM]): | |
top k value of feature map per keypoint. | |
""" | |
heatmaps = self.nms(heatmaps) | |
N, K, H, W = heatmaps.size() | |
heatmaps = heatmaps.view(N, K, -1) | |
val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2) | |
tags = tags.view(tags.size(0), tags.size(1), W * H, -1) | |
if not self.tag_per_joint: | |
tags = tags.expand(-1, self.params.num_joints, -1, -1) | |
tag_k = torch.stack( | |
[torch.gather(tags[..., i], 2, ind) for i in range(tags.size(3))], | |
dim=3) | |
x = ind % W | |
y = ind // W | |
ind_k = torch.stack((x, y), dim=3) | |
results = { | |
'tag_k': tag_k.cpu().numpy(), | |
'loc_k': ind_k.cpu().numpy(), | |
'val_k': val_k.cpu().numpy() | |
} | |
return results | |
def adjust(results, heatmaps): | |
"""Adjust the coordinates for better accuracy. | |
Note: | |
batch size: N | |
number of keypoints: K | |
heatmap height: H | |
heatmap width: W | |
Args: | |
results (list(np.ndarray)): Keypoint predictions. | |
heatmaps (torch.Tensor[NxKxHxW]): Heatmaps. | |
""" | |
_, _, H, W = heatmaps.shape | |
for batch_id, people in enumerate(results): | |
for people_id, people_i in enumerate(people): | |
for joint_id, joint in enumerate(people_i): | |
if joint[2] > 0: | |
x, y = joint[0:2] | |
xx, yy = int(x), int(y) | |
tmp = heatmaps[batch_id][joint_id] | |
if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1), | |
xx]: | |
y += 0.25 | |
else: | |
y -= 0.25 | |
if tmp[yy, min(W - 1, xx + 1)] > tmp[yy, | |
max(0, xx - 1)]: | |
x += 0.25 | |
else: | |
x -= 0.25 | |
results[batch_id][people_id, joint_id, | |
0:2] = (x + 0.5, y + 0.5) | |
return results | |
def refine(heatmap, tag, keypoints, use_udp=False): | |
"""Given initial keypoint predictions, we identify missing joints. | |
Note: | |
number of keypoints: K | |
heatmap height: H | |
heatmap width: W | |
dim of tags: L | |
If use flip testing, L=2; else L=1. | |
Args: | |
heatmap: np.ndarray(K, H, W). | |
tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L) | |
keypoints: np.ndarray of size (K, 3 + L) | |
last dim is (x, y, score, tag). | |
use_udp: bool-unbiased data processing | |
Returns: | |
np.ndarray: The refined keypoints. | |
""" | |
K, H, W = heatmap.shape | |
if len(tag.shape) == 3: | |
tag = tag[..., None] | |
tags = [] | |
for i in range(K): | |
if keypoints[i, 2] > 0: | |
# save tag value of detected keypoint | |
x, y = keypoints[i][:2].astype(int) | |
x = np.clip(x, 0, W - 1) | |
y = np.clip(y, 0, H - 1) | |
tags.append(tag[i, y, x]) | |
# mean tag of current detected people | |
prev_tag = np.mean(tags, axis=0) | |
results = [] | |
for _heatmap, _tag in zip(heatmap, tag): | |
# distance of all tag values with mean tag of | |
# current detected people | |
distance_tag = (((_tag - | |
prev_tag[None, None, :])**2).sum(axis=2)**0.5) | |
norm_heatmap = _heatmap - np.round(distance_tag) | |
# find maximum position | |
y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape) | |
xx = x.copy() | |
yy = y.copy() | |
# detection score at maximum position | |
val = _heatmap[y, x] | |
if not use_udp: | |
# offset by 0.5 | |
x += 0.5 | |
y += 0.5 | |
# add a quarter offset | |
if _heatmap[yy, min(W - 1, xx + 1)] > _heatmap[yy, max(0, xx - 1)]: | |
x += 0.25 | |
else: | |
x -= 0.25 | |
if _heatmap[min(H - 1, yy + 1), xx] > _heatmap[max(0, yy - 1), xx]: | |
y += 0.25 | |
else: | |
y -= 0.25 | |
results.append((x, y, val)) | |
results = np.array(results) | |
if results is not None: | |
for i in range(K): | |
# add keypoint if it is not detected | |
if results[i, 2] > 0 and keypoints[i, 2] == 0: | |
keypoints[i, :3] = results[i, :3] | |
return keypoints | |
def parse(self, heatmaps, tags, adjust=True, refine=True): | |
"""Group keypoints into poses given heatmap and tag. | |
Note: | |
batch size: N | |
number of keypoints: K | |
heatmap height: H | |
heatmap width: W | |
dim of tags: L | |
If use flip testing, L=2; else L=1. | |
Args: | |
heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps. | |
tags (torch.Tensor[NxKxHxWxL]): model output tagmaps. | |
Returns: | |
tuple: A tuple containing keypoint grouping results. | |
- results (list(np.ndarray)): Pose results. | |
- scores (list/list(np.ndarray)): Score of people. | |
""" | |
results = self.match(**self.top_k(heatmaps, tags)) | |
if adjust: | |
if self.use_udp: | |
for i in range(len(results)): | |
if results[i].shape[0] > 0: | |
results[i][..., :2] = post_dark_udp( | |
results[i][..., :2].copy(), heatmaps[i:i + 1, :]) | |
else: | |
results = self.adjust(results, heatmaps) | |
if self.score_per_joint: | |
scores = [i[:, 2] for i in results[0]] | |
else: | |
scores = [i[:, 2].mean() for i in results[0]] | |
if refine: | |
results = results[0] | |
# for every detected person | |
for i in range(len(results)): | |
heatmap_numpy = heatmaps[0].cpu().numpy() | |
tag_numpy = tags[0].cpu().numpy() | |
if not self.tag_per_joint: | |
tag_numpy = np.tile(tag_numpy, | |
(self.params.num_joints, 1, 1, 1)) | |
results[i] = self.refine( | |
heatmap_numpy, tag_numpy, results[i], use_udp=self.use_udp) | |
results = [results] | |
return results, scores | |