Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os.path as osp | |
| from copy import deepcopy | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from mmengine.config import Config | |
| from mmengine.dataset import pseudo_collate | |
| from mmengine.structures import InstanceData, PixelData | |
| from mmpose.structures import MultilevelPixelData, PoseDataSample | |
| from mmpose.structures.bbox import bbox_xyxy2cs | |
| def get_coco_sample( | |
| img_shape=(240, 320), | |
| img_fill: Optional[int] = None, | |
| num_instances=1, | |
| with_bbox_cs=True, | |
| with_img_mask=False, | |
| random_keypoints_visible=False, | |
| non_occlusion=False): | |
| """Create a dummy data sample in COCO style.""" | |
| rng = np.random.RandomState(0) | |
| h, w = img_shape | |
| if img_fill is None: | |
| img = np.random.randint(0, 256, (h, w, 3), dtype=np.uint8) | |
| else: | |
| img = np.full((h, w, 3), img_fill, dtype=np.uint8) | |
| if non_occlusion: | |
| bbox = _rand_bboxes(rng, num_instances, w / num_instances, h) | |
| for i in range(num_instances): | |
| bbox[i, 0::2] += w / num_instances * i | |
| else: | |
| bbox = _rand_bboxes(rng, num_instances, w, h) | |
| keypoints = _rand_keypoints(rng, bbox, 17) | |
| if random_keypoints_visible: | |
| keypoints_visible = np.random.randint(0, 2, (num_instances, | |
| 17)).astype(np.float32) | |
| else: | |
| keypoints_visible = np.full((num_instances, 17), 1, dtype=np.float32) | |
| upper_body_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | |
| lower_body_ids = [11, 12, 13, 14, 15, 16] | |
| flip_pairs = [[2, 1], [1, 2], [4, 3], [3, 4], [6, 5], [5, 6], [8, 7], | |
| [7, 8], [10, 9], [9, 10], [12, 11], [11, 12], [14, 13], | |
| [13, 14], [16, 15], [15, 16]] | |
| flip_indices = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] | |
| dataset_keypoint_weights = np.array([ | |
| 1., 1., 1., 1., 1., 1., 1., 1.2, 1.2, 1.5, 1.5, 1., 1., 1.2, 1.2, 1.5, | |
| 1.5 | |
| ]).astype(np.float32) | |
| data = { | |
| 'img': img, | |
| 'img_shape': img_shape, | |
| 'ori_shape': img_shape, | |
| 'bbox': bbox, | |
| 'keypoints': keypoints, | |
| 'keypoints_visible': keypoints_visible, | |
| 'upper_body_ids': upper_body_ids, | |
| 'lower_body_ids': lower_body_ids, | |
| 'flip_pairs': flip_pairs, | |
| 'flip_indices': flip_indices, | |
| 'dataset_keypoint_weights': dataset_keypoint_weights, | |
| 'invalid_segs': [], | |
| } | |
| if with_bbox_cs: | |
| data['bbox_center'], data['bbox_scale'] = bbox_xyxy2cs(data['bbox']) | |
| if with_img_mask: | |
| data['img_mask'] = np.random.randint(0, 2, (h, w), dtype=np.uint8) | |
| return data | |
| def get_packed_inputs(batch_size=2, | |
| num_instances=1, | |
| num_keypoints=17, | |
| num_levels=1, | |
| img_shape=(256, 192), | |
| input_size=(192, 256), | |
| heatmap_size=(48, 64), | |
| simcc_split_ratio=2.0, | |
| with_heatmap=True, | |
| with_reg_label=True, | |
| with_simcc_label=True): | |
| """Create a dummy batch of model inputs and data samples.""" | |
| rng = np.random.RandomState(0) | |
| inputs_list = [] | |
| for idx in range(batch_size): | |
| inputs = dict() | |
| # input | |
| h, w = img_shape | |
| image = rng.randint(0, 255, size=(3, h, w), dtype=np.uint8) | |
| inputs['inputs'] = torch.from_numpy(image) | |
| # meta | |
| img_meta = { | |
| 'id': idx, | |
| 'img_id': idx, | |
| 'img_path': '<demo>.png', | |
| 'img_shape': img_shape, | |
| 'input_size': input_size, | |
| 'flip': False, | |
| 'flip_direction': None, | |
| 'flip_indices': list(range(num_keypoints)) | |
| } | |
| np.random.shuffle(img_meta['flip_indices']) | |
| data_sample = PoseDataSample(metainfo=img_meta) | |
| # gt_instance | |
| gt_instances = InstanceData() | |
| gt_instance_labels = InstanceData() | |
| bboxes = _rand_bboxes(rng, num_instances, w, h) | |
| bbox_centers, bbox_scales = bbox_xyxy2cs(bboxes) | |
| keypoints = _rand_keypoints(rng, bboxes, num_keypoints) | |
| keypoints_visible = np.ones((num_instances, num_keypoints), | |
| dtype=np.float32) | |
| # [N, K] -> [N, num_levels, K] | |
| # keep the first dimension as the num_instances | |
| if num_levels > 1: | |
| keypoint_weights = np.tile(keypoints_visible[:, None], | |
| (1, num_levels, 1)) | |
| else: | |
| keypoint_weights = keypoints_visible.copy() | |
| gt_instances.bboxes = bboxes | |
| gt_instances.bbox_centers = bbox_centers | |
| gt_instances.bbox_scales = bbox_scales | |
| gt_instances.bbox_scores = np.ones((num_instances, ), dtype=np.float32) | |
| gt_instances.keypoints = keypoints | |
| gt_instances.keypoints_visible = keypoints_visible | |
| gt_instance_labels.keypoint_weights = torch.FloatTensor( | |
| keypoint_weights) | |
| if with_reg_label: | |
| gt_instance_labels.keypoint_labels = torch.FloatTensor(keypoints / | |
| input_size) | |
| if with_simcc_label: | |
| len_x = np.around(input_size[0] * simcc_split_ratio) | |
| len_y = np.around(input_size[1] * simcc_split_ratio) | |
| gt_instance_labels.keypoint_x_labels = torch.FloatTensor( | |
| _rand_simcc_label(rng, num_instances, num_keypoints, len_x)) | |
| gt_instance_labels.keypoint_y_labels = torch.FloatTensor( | |
| _rand_simcc_label(rng, num_instances, num_keypoints, len_y)) | |
| # gt_fields | |
| if with_heatmap: | |
| if num_levels == 1: | |
| gt_fields = PixelData() | |
| # generate single-level heatmaps | |
| W, H = heatmap_size | |
| heatmaps = rng.rand(num_keypoints, H, W) | |
| gt_fields.heatmaps = torch.FloatTensor(heatmaps) | |
| else: | |
| # generate multilevel heatmaps | |
| heatmaps = [] | |
| for _ in range(num_levels): | |
| W, H = heatmap_size | |
| heatmaps_ = rng.rand(num_keypoints, H, W) | |
| heatmaps.append(torch.FloatTensor(heatmaps_)) | |
| # [num_levels*K, H, W] | |
| gt_fields = MultilevelPixelData() | |
| gt_fields.heatmaps = heatmaps | |
| data_sample.gt_fields = gt_fields | |
| data_sample.gt_instances = gt_instances | |
| data_sample.gt_instance_labels = gt_instance_labels | |
| inputs['data_samples'] = data_sample | |
| inputs_list.append(inputs) | |
| packed_inputs = pseudo_collate(inputs_list) | |
| return packed_inputs | |
| def _rand_keypoints(rng, bboxes, num_keypoints): | |
| n = bboxes.shape[0] | |
| relative_pos = rng.rand(n, num_keypoints, 2) | |
| keypoints = relative_pos * bboxes[:, None, :2] + ( | |
| 1 - relative_pos) * bboxes[:, None, 2:4] | |
| return keypoints | |
| def _rand_simcc_label(rng, num_instances, num_keypoints, len_feats): | |
| simcc_label = rng.rand(num_instances, num_keypoints, int(len_feats)) | |
| return simcc_label | |
| def _rand_bboxes(rng, num_instances, img_w, img_h): | |
| cx, cy = rng.rand(num_instances, 2).T | |
| bw, bh = 0.2 + 0.8 * rng.rand(num_instances, 2).T | |
| tl_x = ((cx * img_w) - (img_w * bw / 2)).clip(0, img_w) | |
| tl_y = ((cy * img_h) - (img_h * bh / 2)).clip(0, img_h) | |
| br_x = ((cx * img_w) + (img_w * bw / 2)).clip(0, img_w) | |
| br_y = ((cy * img_h) + (img_h * bh / 2)).clip(0, img_h) | |
| bboxes = np.vstack([tl_x, tl_y, br_x, br_y]).T | |
| return bboxes | |
| def get_repo_dir(): | |
| """Return the path of the MMPose repo directory.""" | |
| try: | |
| # Assume the function in invoked is the source mmpose repo | |
| repo_dir = osp.dirname(osp.dirname(osp.dirname(__file__))) | |
| except NameError: | |
| # For IPython development when __file__ is not defined | |
| import mmpose | |
| repo_dir = osp.dirname(osp.dirname(mmpose.__file__)) | |
| return repo_dir | |
| def get_config_file(fn: str): | |
| """Return full path of a config file from the given relative path.""" | |
| repo_dir = get_repo_dir() | |
| if fn.startswith('configs'): | |
| fn_config = osp.join(repo_dir, fn) | |
| else: | |
| fn_config = osp.join(repo_dir, 'configs', fn) | |
| if not osp.isfile(fn_config): | |
| raise FileNotFoundError(f'Cannot find config file {fn_config}') | |
| return fn_config | |
| def get_pose_estimator_cfg(fn: str): | |
| """Load model config from a config file.""" | |
| fn_config = get_config_file(fn) | |
| config = Config.fromfile(fn_config) | |
| return deepcopy(config.model) | |