Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import namedtuple | |
| from itertools import product | |
| from typing import Any, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from munkres import Munkres | |
| from torch import Tensor | |
| from mmpose.registry import KEYPOINT_CODECS | |
| from mmpose.utils.tensor_utils import to_numpy | |
| from .base import BaseKeypointCodec | |
| from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps, | |
| generate_udp_gaussian_heatmaps, refine_keypoints, | |
| refine_keypoints_dark_udp) | |
| def _group_keypoints_by_tags(vals: np.ndarray, | |
| tags: np.ndarray, | |
| locs: np.ndarray, | |
| keypoint_order: List[int], | |
| val_thr: float, | |
| tag_thr: float = 1.0, | |
| max_groups: Optional[int] = None) -> np.ndarray: | |
| """Group the keypoints by tags using Munkres algorithm. | |
| Note: | |
| - keypoint number: K | |
| - candidate number: M | |
| - tag dimenssion: L | |
| - coordinate dimension: D | |
| - group number: G | |
| Args: | |
| vals (np.ndarray): The heatmap response values of keypoints in shape | |
| (K, M) | |
| tags (np.ndarray): The tags of the keypoint candidates in shape | |
| (K, M, L) | |
| locs (np.ndarray): The locations of the keypoint candidates in shape | |
| (K, M, D) | |
| keypoint_order (List[int]): The grouping order of the keypoints. | |
| The groupping usually starts from a keypoints around the head and | |
| torso, and gruadually moves out to the limbs | |
| val_thr (float): The threshold of the keypoint response value | |
| tag_thr (float): The maximum allowed tag distance when matching a | |
| keypoint to a group. A keypoint with larger tag distance to any | |
| of the existing groups will initializes a new group | |
| max_groups (int, optional): The maximum group number. ``None`` means | |
| no limitation. Defaults to ``None`` | |
| Returns: | |
| np.ndarray: grouped keypoints in shape (G, K, D+1), where the last | |
| dimenssion is the concatenated keypoint coordinates and scores. | |
| """ | |
| K, M, D = locs.shape | |
| assert vals.shape == tags.shape[:2] == (K, M) | |
| assert len(keypoint_order) == K | |
| # Build Munkres instance | |
| munkres = Munkres() | |
| # Build a group pool, each group contains the keypoints of an instance | |
| groups = [] | |
| Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list']) | |
| def _init_group(): | |
| """Initialize a group, which is composed of the keypoints, keypoint | |
| scores and the tag of each keypoint.""" | |
| _group = Group( | |
| kpts=np.zeros((K, D), dtype=np.float32), | |
| scores=np.zeros(K, dtype=np.float32), | |
| tag_list=[]) | |
| return _group | |
| for i in keypoint_order: | |
| # Get all valid candidate of the i-th keypoints | |
| valid = vals[i] > val_thr | |
| if not valid.any(): | |
| continue | |
| tags_i = tags[i, valid] # (M', L) | |
| vals_i = vals[i, valid] # (M',) | |
| locs_i = locs[i, valid] # (M', D) | |
| if len(groups) == 0: # Initialize the group pool | |
| for tag, val, loc in zip(tags_i, vals_i, locs_i): | |
| group = _init_group() | |
| group.kpts[i] = loc | |
| group.scores[i] = val | |
| group.tag_list.append(tag) | |
| groups.append(group) | |
| else: # Match keypoints to existing groups | |
| groups = groups[:max_groups] | |
| group_tags = [np.mean(g.tag_list, axis=0) for g in groups] | |
| # Calculate distance matrix between group tags and tag candidates | |
| # of the i-th keypoint | |
| # Shape: (M', 1, L) , (1, G, L) -> (M', G, L) | |
| diff = tags_i[:, None] - np.array(group_tags)[None] | |
| dists = np.linalg.norm(diff, ord=2, axis=2) | |
| num_kpts, num_groups = dists.shape[:2] | |
| # Experimental cost function for keypoint-group matching | |
| costs = np.round(dists) * 100 - vals_i[..., None] | |
| if num_kpts > num_groups: | |
| padding = np.full((num_kpts, num_kpts - num_groups), | |
| 1e10, | |
| dtype=np.float32) | |
| costs = np.concatenate((costs, padding), axis=1) | |
| # Match keypoints and groups by Munkres algorithm | |
| matches = munkres.compute(costs) | |
| for kpt_idx, group_idx in matches: | |
| if group_idx < num_groups and dists[kpt_idx, | |
| group_idx] < tag_thr: | |
| # Add the keypoint to the matched group | |
| group = groups[group_idx] | |
| else: | |
| # Initialize a new group with unmatched keypoint | |
| group = _init_group() | |
| groups.append(group) | |
| group.kpts[i] = locs_i[kpt_idx] | |
| group.scores[i] = vals_i[kpt_idx] | |
| group.tag_list.append(tags_i[kpt_idx]) | |
| groups = groups[:max_groups] | |
| if groups: | |
| grouped_keypoints = np.stack( | |
| [np.r_['1', g.kpts, g.scores[:, None]] for g in groups]) | |
| else: | |
| grouped_keypoints = np.empty((0, K, D + 1)) | |
| return grouped_keypoints | |
| class AssociativeEmbedding(BaseKeypointCodec): | |
| """Encode/decode keypoints with the method introduced in "Associative | |
| Embedding". This is an asymmetric codec, where the keypoints are | |
| represented as gaussian heatmaps and position indices during encoding, and | |
| restored from predicted heatmaps and group tags. | |
| See the paper `Associative Embedding: End-to-End Learning for Joint | |
| Detection and Grouping`_ by Newell et al (2017) for details | |
| Note: | |
| - instance number: N | |
| - keypoint number: K | |
| - keypoint dimension: D | |
| - embedding tag dimension: L | |
| - image size: [w, h] | |
| - heatmap size: [W, H] | |
| Encoded: | |
| - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) | |
| where [W, H] is the `heatmap_size` | |
| - keypoint_indices (np.ndarray): The keypoint position indices in shape | |
| (N, K, 2). Each keypoint's index is [i, v], where i is the position | |
| index in the heatmap (:math:`i=y*w+x`) and v is the visibility | |
| - keypoint_weights (np.ndarray): The target weights in shape (N, K) | |
| Args: | |
| input_size (tuple): Image size in [w, h] | |
| heatmap_size (tuple): Heatmap size in [W, H] | |
| sigma (float): The sigma value of the Gaussian heatmap | |
| use_udp (bool): Whether use unbiased data processing. See | |
| `UDP (CVPR 2020)`_ for details. Defaults to ``False`` | |
| decode_keypoint_order (List[int]): The grouping order of the | |
| keypoint indices. The groupping usually starts from a keypoints | |
| around the head and torso, and gruadually moves out to the limbs | |
| decode_keypoint_thr (float): The threshold of keypoint response value | |
| in heatmaps. Defaults to 0.1 | |
| decode_tag_thr (float): The maximum allowed tag distance when matching | |
| a keypoint to a group. A keypoint with larger tag distance to any | |
| of the existing groups will initializes a new group. Defaults to | |
| 1.0 | |
| decode_nms_kernel (int): The kernel size of the NMS during decoding, | |
| which should be an odd integer. Defaults to 5 | |
| decode_gaussian_kernel (int): The kernel size of the Gaussian blur | |
| during decoding, which should be an odd integer. It is only used | |
| when ``self.use_udp==True``. Defaults to 3 | |
| decode_topk (int): The number top-k candidates of each keypoints that | |
| will be retrieved from the heatmaps during dedocding. Defaults to | |
| 20 | |
| decode_max_instances (int, optional): The maximum number of instances | |
| to decode. ``None`` means no limitation to the instance number. | |
| Defaults to ``None`` | |
| .. _`Associative Embedding: End-to-End Learning for Joint Detection and | |
| Grouping`: https://arxiv.org/abs/1611.05424 | |
| .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 | |
| """ | |
| def __init__( | |
| self, | |
| input_size: Tuple[int, int], | |
| heatmap_size: Tuple[int, int], | |
| sigma: Optional[float] = None, | |
| use_udp: bool = False, | |
| decode_keypoint_order: List[int] = [], | |
| decode_nms_kernel: int = 5, | |
| decode_gaussian_kernel: int = 3, | |
| decode_keypoint_thr: float = 0.1, | |
| decode_tag_thr: float = 1.0, | |
| decode_topk: int = 20, | |
| decode_max_instances: Optional[int] = None, | |
| ) -> None: | |
| super().__init__() | |
| self.input_size = input_size | |
| self.heatmap_size = heatmap_size | |
| self.use_udp = use_udp | |
| self.decode_nms_kernel = decode_nms_kernel | |
| self.decode_gaussian_kernel = decode_gaussian_kernel | |
| self.decode_keypoint_thr = decode_keypoint_thr | |
| self.decode_tag_thr = decode_tag_thr | |
| self.decode_topk = decode_topk | |
| self.decode_max_instances = decode_max_instances | |
| self.dedecode_keypoint_order = decode_keypoint_order.copy() | |
| if self.use_udp: | |
| self.scale_factor = ((np.array(input_size) - 1) / | |
| (np.array(heatmap_size) - 1)).astype( | |
| np.float32) | |
| else: | |
| self.scale_factor = (np.array(input_size) / | |
| heatmap_size).astype(np.float32) | |
| if sigma is None: | |
| sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64 | |
| self.sigma = sigma | |
| def encode( | |
| self, | |
| keypoints: np.ndarray, | |
| keypoints_visible: Optional[np.ndarray] = None | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Encode keypoints into heatmaps and position indices. Note that the | |
| original keypoint coordinates should be in the input image space. | |
| Args: | |
| keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) | |
| keypoints_visible (np.ndarray): Keypoint visibilities in shape | |
| (N, K) | |
| Returns: | |
| dict: | |
| - heatmaps (np.ndarray): The generated heatmap in shape | |
| (K, H, W) where [W, H] is the `heatmap_size` | |
| - keypoint_indices (np.ndarray): The keypoint position indices | |
| in shape (N, K, 2). Each keypoint's index is [i, v], where i | |
| is the position index in the heatmap (:math:`i=y*w+x`) and v | |
| is the visibility | |
| - keypoint_weights (np.ndarray): The target weights in shape | |
| (N, K) | |
| """ | |
| if keypoints_visible is None: | |
| keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) | |
| # keypoint coordinates in heatmap | |
| _keypoints = keypoints / self.scale_factor | |
| if self.use_udp: | |
| heatmaps, keypoint_weights = generate_udp_gaussian_heatmaps( | |
| heatmap_size=self.heatmap_size, | |
| keypoints=_keypoints, | |
| keypoints_visible=keypoints_visible, | |
| sigma=self.sigma) | |
| else: | |
| heatmaps, keypoint_weights = generate_gaussian_heatmaps( | |
| heatmap_size=self.heatmap_size, | |
| keypoints=_keypoints, | |
| keypoints_visible=keypoints_visible, | |
| sigma=self.sigma) | |
| keypoint_indices = self._encode_keypoint_indices( | |
| heatmap_size=self.heatmap_size, | |
| keypoints=_keypoints, | |
| keypoints_visible=keypoints_visible) | |
| encoded = dict( | |
| heatmaps=heatmaps, | |
| keypoint_indices=keypoint_indices, | |
| keypoint_weights=keypoint_weights) | |
| return encoded | |
| def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int], | |
| keypoints: np.ndarray, | |
| keypoints_visible: np.ndarray) -> np.ndarray: | |
| w, h = heatmap_size | |
| N, K, _ = keypoints.shape | |
| keypoint_indices = np.zeros((N, K, 2), dtype=np.int64) | |
| for n, k in product(range(N), range(K)): | |
| x, y = (keypoints[n, k] + 0.5).astype(np.int64) | |
| index = y * w + x | |
| vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h) | |
| keypoint_indices[n, k] = [index, vis] | |
| return keypoint_indices | |
| def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: | |
| raise NotImplementedError() | |
| def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, | |
| k: int): | |
| """Get top-k response values from the heatmaps and corresponding tag | |
| values from the tagging heatmaps. | |
| Args: | |
| batch_heatmaps (Tensor): Keypoint detection heatmaps in shape | |
| (B, K, H, W) | |
| batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where | |
| the tag dim C is 2*K when using flip testing, or K otherwise | |
| k (int): The number of top responses to get | |
| Returns: | |
| tuple: | |
| - topk_vals (Tensor): Top-k response values of each heatmap in | |
| shape (B, K, Topk) | |
| - topk_tags (Tensor): The corresponding embedding tags of the | |
| top-k responses, in shape (B, K, Topk, L) | |
| - topk_locs (Tensor): The location of the top-k responses in each | |
| heatmap, in shape (B, K, Topk, 2) where last dimension | |
| represents x and y coordinates | |
| """ | |
| B, K, H, W = batch_heatmaps.shape | |
| L = batch_tags.shape[1] // K | |
| # shape of topk_val, top_indices: (B, K, TopK) | |
| topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( | |
| k, dim=-1) | |
| topk_tags_per_kpts = [ | |
| torch.gather(_tag, dim=2, index=topk_indices) | |
| for _tag in torch.unbind(batch_tags.view(B, L, K, H * W), dim=1) | |
| ] | |
| topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L) | |
| topk_locs = torch.stack([topk_indices % W, topk_indices // W], | |
| dim=-1) # (B, K, TopK, 2) | |
| return topk_vals, topk_tags, topk_locs | |
| def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray, | |
| batch_locs: np.ndarray): | |
| """Group keypoints into groups (each represents an instance) by tags. | |
| Args: | |
| batch_vals (Tensor): Heatmap response values of keypoint | |
| candidates in shape (B, K, Topk) | |
| batch_tags (Tensor): Tags of keypoint candidates in shape | |
| (B, K, Topk, L) | |
| batch_locs (Tensor): Locations of keypoint candidates in shape | |
| (B, K, Topk, 2) | |
| Returns: | |
| List[np.ndarray]: Grouping results of a batch, each element is a | |
| np.ndarray (in shape [N, K, D+1]) that contains the groups | |
| detected in an image, including both keypoint coordinates and | |
| scores. | |
| """ | |
| def _group_func(inputs: Tuple): | |
| vals, tags, locs = inputs | |
| return _group_keypoints_by_tags( | |
| vals, | |
| tags, | |
| locs, | |
| keypoint_order=self.dedecode_keypoint_order, | |
| val_thr=self.decode_keypoint_thr, | |
| tag_thr=self.decode_tag_thr, | |
| max_groups=self.decode_max_instances) | |
| _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs)) | |
| results = list(_results) | |
| return results | |
| def _fill_missing_keypoints(self, keypoints: np.ndarray, | |
| keypoint_scores: np.ndarray, | |
| heatmaps: np.ndarray, tags: np.ndarray): | |
| """Fill the missing keypoints in the initial predictions. | |
| Args: | |
| keypoints (np.ndarray): Keypoint predictions in shape (N, K, D) | |
| keypoint_scores (np.ndarray): Keypint score predictions in shape | |
| (N, K), in which 0 means the corresponding keypoint is | |
| missing in the initial prediction | |
| heatmaps (np.ndarry): Heatmaps in shape (K, H, W) | |
| tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where | |
| C=L*K | |
| Returns: | |
| tuple: | |
| - keypoints (np.ndarray): Keypoint predictions with missing | |
| ones filled | |
| - keypoint_scores (np.ndarray): Keypoint score predictions with | |
| missing ones filled | |
| """ | |
| N, K = keypoints.shape[:2] | |
| H, W = heatmaps.shape[1:] | |
| L = tags.shape[0] // K | |
| keypoint_tags = [tags[k::K] for k in range(K)] | |
| for n in range(N): | |
| # Calculate the instance tag (mean tag of detected keypoints) | |
| _tag = [] | |
| for k in range(K): | |
| if keypoint_scores[n, k] > 0: | |
| x, y = keypoints[n, k, :2].astype(np.int64) | |
| x = np.clip(x, 0, W - 1) | |
| y = np.clip(y, 0, H - 1) | |
| _tag.append(keypoint_tags[k][:, y, x]) | |
| tag = np.mean(_tag, axis=0) | |
| tag = tag.reshape(L, 1, 1) | |
| # Search maximum response of the missing keypoints | |
| for k in range(K): | |
| if keypoint_scores[n, k] > 0: | |
| continue | |
| dist_map = np.linalg.norm( | |
| keypoint_tags[k] - tag, ord=2, axis=0) | |
| cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W | |
| y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) | |
| keypoints[n, k] = [x, y] | |
| keypoint_scores[n, k] = heatmaps[k, y, x] | |
| return keypoints, keypoint_scores | |
| def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor | |
| ) -> Tuple[List[np.ndarray], List[np.ndarray]]: | |
| """Decode the keypoint coordinates from a batch of heatmaps and tagging | |
| heatmaps. The decoded keypoint coordinates are in the input image | |
| space. | |
| Args: | |
| batch_heatmaps (Tensor): Keypoint detection heatmaps in shape | |
| (B, K, H, W) | |
| batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where | |
| :math:`C=L*K` | |
| Returns: | |
| tuple: | |
| - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates | |
| of the batch, each is in shape (N, K, D) | |
| - batch_scores (List[np.ndarray]): Decoded keypoint scores of the | |
| batch, each is in shape (N, K). It usually represents the | |
| confidience of the keypoint prediction | |
| """ | |
| B, _, H, W = batch_heatmaps.shape | |
| assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( | |
| f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and ' | |
| f'tagging map ({batch_tags.shape})') | |
| # Heatmap NMS | |
| batch_heatmaps = batch_heatmap_nms(batch_heatmaps, | |
| self.decode_nms_kernel) | |
| # Get top-k in each heatmap and and convert to numpy | |
| batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( | |
| self._get_batch_topk( | |
| batch_heatmaps, batch_tags, k=self.decode_topk)) | |
| # Group keypoint candidates into groups (instances) | |
| batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, | |
| batch_topk_locs) | |
| # Convert to numpy | |
| batch_heatmaps_np = to_numpy(batch_heatmaps) | |
| batch_tags_np = to_numpy(batch_tags) | |
| # Refine the keypoint prediction | |
| batch_keypoints = [] | |
| batch_keypoint_scores = [] | |
| for i, (groups, heatmaps, tags) in enumerate( | |
| zip(batch_groups, batch_heatmaps_np, batch_tags_np)): | |
| keypoints, scores = groups[..., :-1], groups[..., -1] | |
| if keypoints.size > 0: | |
| # identify missing keypoints | |
| keypoints, scores = self._fill_missing_keypoints( | |
| keypoints, scores, heatmaps, tags) | |
| # refine keypoint coordinates according to heatmap distribution | |
| if self.use_udp: | |
| keypoints = refine_keypoints_dark_udp( | |
| keypoints, | |
| heatmaps, | |
| blur_kernel_size=self.decode_gaussian_kernel) | |
| else: | |
| keypoints = refine_keypoints(keypoints, heatmaps) | |
| batch_keypoints.append(keypoints) | |
| batch_keypoint_scores.append(scores) | |
| # restore keypoint scale | |
| batch_keypoints = [ | |
| kpts * self.scale_factor for kpts in batch_keypoints | |
| ] | |
| return batch_keypoints, batch_keypoint_scores | |