YOLO-World3 / third_party /mmyolo /demo /large_image_demo.py
stevengrove
initial commit
186701e
raw
history blame
No virus
10.6 kB
# Copyright (c) OpenMMLab. All rights reserved.
"""Perform MMYOLO inference on large images (as satellite imagery) as:
```shell
wget -P checkpoint https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth # noqa: E501, E261.
python demo/large_image_demo.py \
demo/large_image.jpg \
configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py \
checkpoint/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth
```
"""
import os
import random
from argparse import ArgumentParser
from pathlib import Path
import mmcv
import numpy as np
from mmdet.apis import inference_detector, init_detector
from mmengine.config import Config, ConfigDict
from mmengine.logging import print_log
from mmengine.utils import ProgressBar
try:
from sahi.slicing import slice_image
except ImportError:
raise ImportError('Please run "pip install -U sahi" '
'to install sahi first for large image inference.')
from mmyolo.registry import VISUALIZERS
from mmyolo.utils import switch_to_deploy
from mmyolo.utils.large_image import merge_results_by_nms, shift_predictions
from mmyolo.utils.misc import get_file_list
def parse_args():
parser = ArgumentParser(
description='Perform MMYOLO inference on large images.')
parser.add_argument(
'img', help='Image path, include image file, dir and URL.')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-dir', default='./output', help='Path to output file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--show', action='store_true', help='Show the detection results')
parser.add_argument(
'--deploy',
action='store_true',
help='Switch model to deployment mode')
parser.add_argument(
'--tta',
action='store_true',
help='Whether to use test time augmentation')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='Bbox score threshold')
parser.add_argument(
'--patch-size', type=int, default=640, help='The size of patches')
parser.add_argument(
'--patch-overlap-ratio',
type=float,
default=0.25,
help='Ratio of overlap between two patches')
parser.add_argument(
'--merge-iou-thr',
type=float,
default=0.25,
help='IoU threshould for merging results')
parser.add_argument(
'--merge-nms-type',
type=str,
default='nms',
help='NMS type for merging results')
parser.add_argument(
'--batch-size',
type=int,
default=1,
help='Batch size, must greater than or equal to 1')
parser.add_argument(
'--debug',
action='store_true',
help='Export debug results before merging')
parser.add_argument(
'--save-patch',
action='store_true',
help='Save the results of each patch. '
'The `--debug` must be enabled.')
args = parser.parse_args()
return args
def main():
args = parse_args()
config = args.config
if isinstance(config, (str, Path)):
config = Config.fromfile(config)
elif not isinstance(config, Config):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
if 'init_cfg' in config.model.backbone:
config.model.backbone.init_cfg = None
if args.tta:
assert 'tta_model' in config, 'Cannot find ``tta_model`` in config.' \
" Can't use tta !"
assert 'tta_pipeline' in config, 'Cannot find ``tta_pipeline`` ' \
"in config. Can't use tta !"
config.model = ConfigDict(**config.tta_model, module=config.model)
test_data_cfg = config.test_dataloader.dataset
while 'dataset' in test_data_cfg:
test_data_cfg = test_data_cfg['dataset']
# batch_shapes_cfg will force control the size of the output image,
# it is not compatible with tta.
if 'batch_shapes_cfg' in test_data_cfg:
test_data_cfg.batch_shapes_cfg = None
test_data_cfg.pipeline = config.tta_pipeline
# TODO: TTA mode will error if cfg_options is not set.
# This is an mmdet issue and needs to be fixed later.
# build the model from a config file and a checkpoint file
model = init_detector(
config, args.checkpoint, device=args.device, cfg_options={})
if args.deploy:
switch_to_deploy(model)
if not os.path.exists(args.out_dir) and not args.show:
os.mkdir(args.out_dir)
# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
visualizer.dataset_meta = model.dataset_meta
# get file list
files, source_type = get_file_list(args.img)
# start detector inference
print(f'Performing inference on {len(files)} images.... '
'This may take a while.')
progress_bar = ProgressBar(len(files))
for file in files:
# read image
img = mmcv.imread(file)
# arrange slices
height, width = img.shape[:2]
sliced_image_object = slice_image(
img,
slice_height=args.patch_size,
slice_width=args.patch_size,
auto_slice_resolution=False,
overlap_height_ratio=args.patch_overlap_ratio,
overlap_width_ratio=args.patch_overlap_ratio,
)
# perform sliced inference
slice_results = []
start = 0
while True:
# prepare batch slices
end = min(start + args.batch_size, len(sliced_image_object))
images = []
for sliced_image in sliced_image_object.images[start:end]:
images.append(sliced_image)
# forward the model
slice_results.extend(inference_detector(model, images))
if end >= len(sliced_image_object):
break
start += args.batch_size
if source_type['is_dir']:
filename = os.path.relpath(file, args.img).replace('/', '_')
else:
filename = os.path.basename(file)
img = mmcv.imconvert(img, 'bgr', 'rgb')
out_file = None if args.show else os.path.join(args.out_dir, filename)
# export debug images
if args.debug:
# export sliced image results
name, suffix = os.path.splitext(filename)
shifted_instances = shift_predictions(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width))
merged_result = slice_results[0].clone()
merged_result.pred_instances = shifted_instances
debug_file_name = name + '_debug' + suffix
debug_out_file = None if args.show else os.path.join(
args.out_dir, debug_file_name)
visualizer.set_image(img.copy())
debug_grids = []
for starting_point in sliced_image_object.starting_pixels:
start_point_x = starting_point[0]
start_point_y = starting_point[1]
end_point_x = start_point_x + args.patch_size
end_point_y = start_point_y + args.patch_size
debug_grids.append(
[start_point_x, start_point_y, end_point_x, end_point_y])
debug_grids = np.array(debug_grids)
debug_grids[:, 0::2] = np.clip(debug_grids[:, 0::2], 1,
img.shape[1] - 1)
debug_grids[:, 1::2] = np.clip(debug_grids[:, 1::2], 1,
img.shape[0] - 1)
palette = np.random.randint(0, 256, size=(len(debug_grids), 3))
palette = [tuple(c) for c in palette]
line_styles = random.choices(['-', '-.', ':'], k=len(debug_grids))
visualizer.draw_bboxes(
debug_grids,
edge_colors=palette,
alpha=1,
line_styles=line_styles)
visualizer.draw_bboxes(
debug_grids, face_colors=palette, alpha=0.15)
visualizer.draw_texts(
list(range(len(debug_grids))),
debug_grids[:, :2] + 5,
colors='w')
visualizer.add_datasample(
debug_file_name,
visualizer.get_image(),
data_sample=merged_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=debug_out_file,
pred_score_thr=args.score_thr,
)
if args.save_patch:
debug_patch_out_dir = os.path.join(args.out_dir,
f'{name}_patch')
for i, slice_result in enumerate(slice_results):
patch_out_file = os.path.join(
debug_patch_out_dir,
f'{filename}_slice_{i}_result.jpg')
image = mmcv.imconvert(sliced_image_object.images[i],
'bgr', 'rgb')
visualizer.add_datasample(
'patch_result',
image,
data_sample=slice_result,
draw_gt=False,
show=False,
wait_time=0,
out_file=patch_out_file,
pred_score_thr=args.score_thr,
)
image_result = merge_results_by_nms(
slice_results,
sliced_image_object.starting_pixels,
src_image_shape=(height, width),
nms_cfg={
'type': args.merge_nms_type,
'iou_threshold': args.merge_iou_thr
})
visualizer.add_datasample(
filename,
img,
data_sample=image_result,
draw_gt=False,
show=args.show,
wait_time=0,
out_file=out_file,
pred_score_thr=args.score_thr,
)
progress_bar.update()
if not args.show or (args.debug and args.save_patch):
print_log(
f'\nResults have been saved at {os.path.abspath(args.out_dir)}')
if __name__ == '__main__':
main()