Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import logging | |
| from argparse import ArgumentParser | |
| from mmcv.image import imread | |
| from mmengine.logging import print_log | |
| from mmpose.apis import inference_topdown, init_model | |
| from mmpose.registry import VISUALIZERS | |
| from mmpose.structures import merge_data_samples | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument('img', help='Image file') | |
| parser.add_argument('config', help='Config file') | |
| parser.add_argument('checkpoint', help='Checkpoint file') | |
| parser.add_argument('--out-file', default=None, help='Path to output file') | |
| parser.add_argument( | |
| '--device', default='cuda:0', help='Device used for inference') | |
| parser.add_argument( | |
| '--draw-heatmap', | |
| action='store_true', | |
| help='Visualize the predicted heatmap') | |
| parser.add_argument( | |
| '--show-kpt-idx', | |
| action='store_true', | |
| default=False, | |
| help='Whether to show the index of keypoints') | |
| parser.add_argument( | |
| '--skeleton-style', | |
| default='mmpose', | |
| type=str, | |
| choices=['mmpose', 'openpose'], | |
| help='Skeleton style selection') | |
| parser.add_argument( | |
| '--kpt-thr', | |
| type=float, | |
| default=0.3, | |
| help='Visualizing keypoint thresholds') | |
| parser.add_argument( | |
| '--radius', | |
| type=int, | |
| default=3, | |
| help='Keypoint radius for visualization') | |
| parser.add_argument( | |
| '--thickness', | |
| type=int, | |
| default=1, | |
| help='Link thickness for visualization') | |
| parser.add_argument( | |
| '--alpha', type=float, default=0.8, help='The transparency of bboxes') | |
| parser.add_argument( | |
| '--show', | |
| action='store_true', | |
| default=False, | |
| help='whether to show img') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| # build the model from a config file and a checkpoint file | |
| if args.draw_heatmap: | |
| cfg_options = dict(model=dict(test_cfg=dict(output_heatmaps=True))) | |
| else: | |
| cfg_options = None | |
| model = init_model( | |
| args.config, | |
| args.checkpoint, | |
| device=args.device, | |
| cfg_options=cfg_options) | |
| # init visualizer | |
| model.cfg.visualizer.radius = args.radius | |
| model.cfg.visualizer.alpha = args.alpha | |
| model.cfg.visualizer.line_width = args.thickness | |
| visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
| visualizer.set_dataset_meta( | |
| model.dataset_meta, skeleton_style=args.skeleton_style) | |
| # inference a single image | |
| batch_results = inference_topdown(model, args.img) | |
| results = merge_data_samples(batch_results) | |
| # show the results | |
| img = imread(args.img, channel_order='rgb') | |
| visualizer.add_datasample( | |
| 'result', | |
| img, | |
| data_sample=results, | |
| draw_gt=False, | |
| draw_bbox=True, | |
| kpt_thr=args.kpt_thr, | |
| draw_heatmap=args.draw_heatmap, | |
| show_kpt_idx=args.show_kpt_idx, | |
| skeleton_style=args.skeleton_style, | |
| show=args.show, | |
| out_file=args.out_file) | |
| if args.out_file is not None: | |
| print_log( | |
| f'the output image has been saved at {args.out_file}', | |
| logger='current', | |
| level=logging.INFO) | |
| if __name__ == '__main__': | |
| main() | |