# Copyright (c) OpenMMLab. All rights reserved. import argparse import os from typing import Sequence import mmcv from mmdet.apis import inference_detector, init_detector from mmengine import Config, DictAction from mmengine.registry import init_default_scope from mmengine.utils import ProgressBar from mmyolo.registry import VISUALIZERS from mmyolo.utils.misc import auto_arrange_images, get_file_list def parse_args(): parser = argparse.ArgumentParser(description='Visualize feature map') parser.add_argument( 'img', help='Image path, include image file, dir and URL.') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument( '--out-dir', default='./output', help='Path to output file') parser.add_argument( '--target-layers', default=['backbone'], nargs='+', type=str, help='The target layers to get feature map, if not set, the tool will ' 'specify the backbone') parser.add_argument( '--preview-model', default=False, action='store_true', help='To preview all the model layers') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( '--score-thr', type=float, default=0.3, help='Bbox score threshold') parser.add_argument( '--show', action='store_true', help='Show the featmap results') parser.add_argument( '--channel-reduction', default='select_max', help='Reduce multiple channels to a single channel') parser.add_argument( '--topk', type=int, default=4, help='Select topk channel to show by the sum of each channel') parser.add_argument( '--arrangement', nargs='+', type=int, default=[2, 2], help='The arrangement of featmap when channel_reduction is ' 'not None and topk > 0') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') args = parser.parse_args() return args class ActivationsWrapper: def __init__(self, model, target_layers): self.model = model self.activations = [] self.handles = [] self.image = None for target_layer in target_layers: self.handles.append( target_layer.register_forward_hook(self.save_activation)) def save_activation(self, module, input, output): self.activations.append(output) def __call__(self, img_path): self.activations = [] results = inference_detector(self.model, img_path) return results, self.activations def release(self): for handle in self.handles: handle.remove() def main(): args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) init_default_scope(cfg.get('default_scope', 'mmyolo')) channel_reduction = args.channel_reduction if channel_reduction == 'None': channel_reduction = None assert len(args.arrangement) == 2 model = init_detector(args.config, args.checkpoint, device=args.device) if not os.path.exists(args.out_dir) and not args.show: os.mkdir(args.out_dir) if args.preview_model: print(model) print('\n This flag is only show model, if you want to continue, ' 'please remove `--preview-model` to get the feature map.') return target_layers = [] for target_layer in args.target_layers: try: target_layers.append(eval(f'model.{target_layer}')) except Exception as e: print(model) raise RuntimeError('layer does not exist', e) activations_wrapper = ActivationsWrapper(model, target_layers) # init visualizer visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer.dataset_meta = model.dataset_meta # get file list image_list, source_type = get_file_list(args.img) progress_bar = ProgressBar(len(image_list)) for image_path in image_list: result, featmaps = activations_wrapper(image_path) if not isinstance(featmaps, Sequence): featmaps = [featmaps] flatten_featmaps = [] for featmap in featmaps: if isinstance(featmap, Sequence): flatten_featmaps.extend(featmap) else: flatten_featmaps.append(featmap) img = mmcv.imread(image_path) img = mmcv.imconvert(img, 'bgr', 'rgb') if source_type['is_dir']: filename = os.path.relpath(image_path, args.img).replace('/', '_') else: filename = os.path.basename(image_path) out_file = None if args.show else os.path.join(args.out_dir, filename) # show the results shown_imgs = [] visualizer.add_datasample( 'result', img, data_sample=result, draw_gt=False, show=False, wait_time=0, out_file=None, pred_score_thr=args.score_thr) drawn_img = visualizer.get_image() for featmap in flatten_featmaps: shown_img = visualizer.draw_featmap( featmap[0], drawn_img, channel_reduction=channel_reduction, topk=args.topk, arrangement=args.arrangement) shown_imgs.append(shown_img) shown_imgs = auto_arrange_images(shown_imgs) progress_bar.update() if out_file: mmcv.imwrite(shown_imgs[..., ::-1], out_file) if args.show: visualizer.show(shown_imgs) if not args.show: print(f'All done!' f'\nResults have been saved at {os.path.abspath(args.out_dir)}') # Please refer to the usage tutorial: # https://github.com/open-mmlab/mmyolo/blob/main/docs/zh_cn/user_guides/visualization.md # noqa if __name__ == '__main__': main()