RSPrompter / mmpl /utils /large_image.py
KyanChen's picture
Upload 159 files
1c3eb47
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple
import torch
from mmcv.ops import batched_nms
from mmdet.structures import DetDataSample, SampleList
from mmengine.structures import InstanceData
def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]):
"""Shift rotated bboxes with offset.
Args:
bboxes (Tensor): The rotated bboxes need to be translated.
With shape (n, 5), which means (x, y, w, h, a).
offset (Sequence[int]): The translation offsets with shape of (2, ).
Returns:
Tensor: Shifted rotated bboxes.
"""
offset_tensor = bboxes.new_tensor(offset)
shifted_bboxes = bboxes.clone()
shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor
return shifted_bboxes
def shift_predictions(det_data_samples: SampleList,
offsets: Sequence[Tuple[int, int]],
src_image_shape: Tuple[int, int]) -> SampleList:
"""Shift predictions to the original image.
Args:
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
Returns:
(List[:obj:`DetDataSample`]): shifted results.
"""
try:
from sahi.slicing import shift_bboxes, shift_masks
except ImportError:
raise ImportError('Please run "pip install -U sahi" '
'to install sahi first for large image inference.')
assert len(det_data_samples) == len(
offsets), 'The `results` should has the ' 'same length with `offsets`.'
shifted_predictions = []
for det_data_sample, offset in zip(det_data_samples, offsets):
pred_inst = det_data_sample.pred_instances.clone()
# Check bbox type
if pred_inst.bboxes.size(-1) == 4:
# Horizontal bboxes
shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset)
elif pred_inst.bboxes.size(-1) == 5:
# Rotated bboxes
shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset)
else:
raise NotImplementedError
# shift bboxes and masks
pred_inst.bboxes = shifted_bboxes
if 'masks' in det_data_sample:
pred_inst.masks = shift_masks(pred_inst.masks, offset,
src_image_shape)
shifted_predictions.append(pred_inst.clone())
shifted_predictions = InstanceData.cat(shifted_predictions)
return shifted_predictions
def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int,
int]],
src_image_shape: Tuple[int, int],
nms_cfg: dict) -> DetDataSample:
"""Merge patch results by nms.
Args:
results (List[:obj:`DetDataSample`]): A list of patch results.
offsets (Sequence[Tuple[int, int]]): Positions of the left top points
of patches.
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large
image's width and height.
nms_cfg (dict): it should specify nms type and other parameters
like `iou_threshold`.
Returns:
:obj:`DetDataSample`: merged results.
"""
shifted_instances = shift_predictions(results, offsets, src_image_shape)
_, keeps = batched_nms(
boxes=shifted_instances.bboxes,
scores=shifted_instances.scores,
idxs=shifted_instances.labels,
nms_cfg=nms_cfg)
merged_instances = shifted_instances[keeps]
merged_result = results[0].clone()
merged_result.pred_instances = merged_instances
return merged_result