# Copyright (c) OpenMMLab. All rights reserved. """Perform MMYOLO inference on large images (as satellite imagery) as: ```shell wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth # noqa: E501, E261. python demo/large_image_demo.py \ demo/large_image.jpg \ configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \ checkpoint/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth ``` """ import os import random from argparse import ArgumentParser from pathlib import Path import mmcv import numpy as np from mmdet.apis import inference_detector, init_detector from mmengine.config import Config, ConfigDict from mmengine.logging import print_log from mmengine.utils import ProgressBar try: from sahi.slicing import slice_image except ImportError: raise ImportError('Please run "pip install -U sahi" ' 'to install sahi first for large image inference.') from mmyolo.registry import VISUALIZERS from mmyolo.utils import switch_to_deploy from mmyolo.utils.large_image import merge_results_by_nms, shift_predictions from mmyolo.utils.misc import get_file_list def parse_args(): parser = ArgumentParser( description='Perform MMYOLO inference on large images.') parser.add_argument( 'img', help='Image path, include image file, dir and URL.') parser.add_argument('config', help='Config file') parser.add_argument('checkpoint', help='Checkpoint file') parser.add_argument( '--out-dir', default='./output', help='Path to output file') parser.add_argument( '--device', default='cuda:0', help='Device used for inference') parser.add_argument( '--show', action='store_true', help='Show the detection results') parser.add_argument( '--deploy', action='store_true', help='Switch model to deployment mode') parser.add_argument( '--tta', action='store_true', help='Whether to use test time augmentation') parser.add_argument( '--score-thr', type=float, default=0.3, help='Bbox score threshold') parser.add_argument( '--patch-size', type=int, default=640, help='The size of patches') parser.add_argument( '--patch-overlap-ratio', type=float, default=0.25, help='Ratio of overlap between two patches') parser.add_argument( '--merge-iou-thr', type=float, default=0.25, help='IoU threshould for merging results') parser.add_argument( '--merge-nms-type', type=str, default='nms', help='NMS type for merging results') parser.add_argument( '--batch-size', type=int, default=1, help='Batch size, must greater than or equal to 1') parser.add_argument( '--debug', action='store_true', help='Export debug results before merging') parser.add_argument( '--save-patch', action='store_true', help='Save the results of each patch. ' 'The `--debug` must be enabled.') args = parser.parse_args() return args def main(): args = parse_args() config = args.config if isinstance(config, (str, Path)): config = Config.fromfile(config) elif not isinstance(config, Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') if 'init_cfg' in config.model.backbone: config.model.backbone.init_cfg = None if args.tta: assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \ " Can't use tta !" assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \ "in config. Can't use tta !" config.model = ConfigDict(**config.tta_model, module=config.model) test_data_cfg = config.test_dataloader.dataset while 'dataset' in test_data_cfg: test_data_cfg = test_data_cfg['dataset'] # batch_shapes_cfg will force control the size of the output image, # it is not compatible with tta. if 'batch_shapes_cfg' in test_data_cfg: test_data_cfg.batch_shapes_cfg = None test_data_cfg.pipeline = config.tta_pipeline # TODO: TTA mode will error if cfg_options is not set. # This is an mmdet issue and needs to be fixed later. # build the model from a config file and a checkpoint file model = init_detector( config, args.checkpoint, device=args.device, cfg_options={}) if args.deploy: switch_to_deploy(model) if not os.path.exists(args.out_dir) and not args.show: os.mkdir(args.out_dir) # init visualizer visualizer = VISUALIZERS.build(model.cfg.visualizer) visualizer.dataset_meta = model.dataset_meta # get file list files, source_type = get_file_list(args.img) # start detector inference print(f'Performing inference on {len(files)} images.... ' 'This may take a while.') progress_bar = ProgressBar(len(files)) for file in files: # read image img = mmcv.imread(file) # arrange slices height, width = img.shape[:2] sliced_image_object = slice_image( img, slice_height=args.patch_size, slice_width=args.patch_size, auto_slice_resolution=False, overlap_height_ratio=args.patch_overlap_ratio, overlap_width_ratio=args.patch_overlap_ratio, ) # perform sliced inference slice_results = [] start = 0 while True: # prepare batch slices end = min(start + args.batch_size, len(sliced_image_object)) images = [] for sliced_image in sliced_image_object.images[start:end]: images.append(sliced_image) # forward the model slice_results.extend(inference_detector(model, images)) if end >= len(sliced_image_object): break start += args.batch_size if source_type['is_dir']: filename = os.path.relpath(file, args.img).replace('/', '_') else: filename = os.path.basename(file) img = mmcv.imconvert(img, 'bgr', 'rgb') out_file = None if args.show else os.path.join(args.out_dir, filename) # export debug images if args.debug: # export sliced image results name, suffix = os.path.splitext(filename) shifted_instances = shift_predictions( slice_results, sliced_image_object.starting_pixels, src_image_shape=(height, width)) merged_result = slice_results[0].clone() merged_result.pred_instances = shifted_instances debug_file_name = name + '_debug' + suffix debug_out_file = None if args.show else os.path.join( args.out_dir, debug_file_name) visualizer.set_image(img.copy()) debug_grids = [] for starting_point in sliced_image_object.starting_pixels: start_point_x = starting_point[0] start_point_y = starting_point[1] end_point_x = start_point_x + args.patch_size end_point_y = start_point_y + args.patch_size debug_grids.append( [start_point_x, start_point_y, end_point_x, end_point_y]) debug_grids = np.array(debug_grids) debug_grids[:, 0::2] = np.clip(debug_grids[:, 0::2], 1, img.shape[1] - 1) debug_grids[:, 1::2] = np.clip(debug_grids[:, 1::2], 1, img.shape[0] - 1) palette = np.random.randint(0, 256, size=(len(debug_grids), 3)) palette = [tuple(c) for c in palette] line_styles = random.choices(['-', '-.', ':'], k=len(debug_grids)) visualizer.draw_bboxes( debug_grids, edge_colors=palette, alpha=1, line_styles=line_styles) visualizer.draw_bboxes( debug_grids, face_colors=palette, alpha=0.15) visualizer.draw_texts( list(range(len(debug_grids))), debug_grids[:, :2] + 5, colors='w') visualizer.add_datasample( debug_file_name, visualizer.get_image(), data_sample=merged_result, draw_gt=False, show=args.show, wait_time=0, out_file=debug_out_file, pred_score_thr=args.score_thr, ) if args.save_patch: debug_patch_out_dir = os.path.join(args.out_dir, f'{name}_patch') for i, slice_result in enumerate(slice_results): patch_out_file = os.path.join( debug_patch_out_dir, f'{filename}_slice_{i}_result.jpg') image = mmcv.imconvert(sliced_image_object.images[i], 'bgr', 'rgb') visualizer.add_datasample( 'patch_result', image, data_sample=slice_result, draw_gt=False, show=False, wait_time=0, out_file=patch_out_file, pred_score_thr=args.score_thr, ) image_result = merge_results_by_nms( slice_results, sliced_image_object.starting_pixels, src_image_shape=(height, width), nms_cfg={ 'type': args.merge_nms_type, 'iou_threshold': args.merge_iou_thr }) visualizer.add_datasample( filename, img, data_sample=image_result, draw_gt=False, show=args.show, wait_time=0, out_file=out_file, pred_score_thr=args.score_thr, ) progress_bar.update() if not args.show or (args.debug and args.save_patch): print_log( f'\nResults have been saved at {os.path.abspath(args.out_dir)}') if __name__ == '__main__': main()