Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import cv2 | |
| import numpy as np | |
| def visualize_keypoints( | |
| image: np.ndarray, # RGB uint8 H,W,3 | |
| keypoints, # list[(J,2)] | |
| keypoints_visible, # list[(J,), {0/1}] | |
| keypoint_scores, # list[(J,)] | |
| *, | |
| radius: int = 4, | |
| thickness: int = -1, | |
| color=(255, 0, 0), | |
| kpt_thr: float = 0.3, | |
| skeleton: list | None = None, # [(i,j)] | |
| kpt_color: list | tuple | np.ndarray | None = None, | |
| link_color: list | tuple | np.ndarray | None = None, | |
| show_kpt_idx: bool = False, | |
| ) -> np.ndarray: | |
| img = image.copy() | |
| H, W = img.shape[:2] | |
| # defaults | |
| if skeleton is None: | |
| skeleton = [] # points only | |
| if kpt_color is None: | |
| kpt_color = color | |
| if link_color is None: | |
| link_color = (0, 255, 0) | |
| # robust color normalization: supports tuple, list-of-tuples, np.ndarray (N,3) or (3,) | |
| def _as_color_list(c, n): | |
| # torch -> numpy | |
| if hasattr(c, "detach"): | |
| c = c.detach().cpu().numpy() | |
| # numpy -> array | |
| if isinstance(c, np.ndarray): | |
| if c.ndim == 2 and c.shape[1] == 3: # (N,3) palette | |
| return [tuple(int(v) for v in row) for row in c.tolist()] | |
| if c.size == 3: # single (3,) | |
| return [tuple(int(v) for v in c.tolist())] * max(1, n) | |
| # python containers | |
| if isinstance(c, (list, tuple)): | |
| if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): | |
| out = [] | |
| for cc in c: | |
| cc = np.asarray(cc).reshape(-1) | |
| assert cc.size == 3, "Each color must be length-3" | |
| out.append(tuple(int(v) for v in cc.tolist())) | |
| return out | |
| # single triplet | |
| c_arr = np.asarray(c).reshape(-1) | |
| if c_arr.size == 3: | |
| return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) | |
| # fallback: red | |
| return [(255, 0, 0)] * max(1, n) | |
| J = keypoints[0].shape[0] if keypoints else 0 | |
| kpt_colors = _as_color_list(kpt_color, J) | |
| link_colors = _as_color_list(link_color, len(skeleton)) | |
| def in_bounds(x, y): | |
| return 0 <= x < W and 0 <= y < H | |
| for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): | |
| kpts = np.asarray(kpts, float) | |
| vis = np.asarray(vis).reshape(-1).astype(bool) | |
| score = np.asarray(score).reshape(-1) | |
| # links (draw in RGB; NO channel flip) | |
| for lk, (i, j) in enumerate(skeleton): | |
| if i >= len(kpts) or j >= len(kpts): | |
| continue | |
| if not (vis[i] and vis[j]): | |
| continue | |
| if score[i] < kpt_thr or score[j] < kpt_thr: | |
| continue | |
| x1, y1 = map(int, np.round(kpts[i])) | |
| x2, y2 = map(int, np.round(kpts[j])) | |
| if not (in_bounds(x1, y1) and in_bounds(x2, y2)): | |
| continue | |
| cv2.line( | |
| img, | |
| (x1, y1), | |
| (x2, y2), | |
| link_colors[lk % len(link_colors)], | |
| thickness=max(1, thickness), | |
| lineType=cv2.LINE_AA, | |
| ) | |
| # points | |
| for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): | |
| if not v or s < kpt_thr: | |
| continue | |
| x, y = map(int, np.round(xy)) | |
| if not in_bounds(x, y): | |
| continue | |
| c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] | |
| cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA) | |
| if show_kpt_idx: | |
| cv2.putText( | |
| img, | |
| str(j_idx), | |
| (x + radius, y - radius), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.4, | |
| c, | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| return img | |