Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Sequence, Tuple | |
import torch | |
from mmcv.ops import batched_nms | |
from mmdet.structures import DetDataSample, SampleList | |
from mmengine.structures import InstanceData | |
def shift_rbboxes(bboxes: torch.Tensor, offset: Sequence[int]): | |
"""Shift rotated bboxes with offset. | |
Args: | |
bboxes (Tensor): The rotated bboxes need to be translated. | |
With shape (n, 5), which means (x, y, w, h, a). | |
offset (Sequence[int]): The translation offsets with shape of (2, ). | |
Returns: | |
Tensor: Shifted rotated bboxes. | |
""" | |
offset_tensor = bboxes.new_tensor(offset) | |
shifted_bboxes = bboxes.clone() | |
shifted_bboxes[:, 0:2] = shifted_bboxes[:, 0:2] + offset_tensor | |
return shifted_bboxes | |
def shift_predictions(det_data_samples: SampleList, | |
offsets: Sequence[Tuple[int, int]], | |
src_image_shape: Tuple[int, int]) -> SampleList: | |
"""Shift predictions to the original image. | |
Args: | |
det_data_samples (List[:obj:`DetDataSample`]): A list of patch results. | |
offsets (Sequence[Tuple[int, int]]): Positions of the left top points | |
of patches. | |
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large | |
image's width and height. | |
Returns: | |
(List[:obj:`DetDataSample`]): shifted results. | |
""" | |
try: | |
from sahi.slicing import shift_bboxes, shift_masks | |
except ImportError: | |
raise ImportError('Please run "pip install -U sahi" ' | |
'to install sahi first for large image inference.') | |
assert len(det_data_samples) == len( | |
offsets), 'The `results` should has the ' 'same length with `offsets`.' | |
shifted_predictions = [] | |
for det_data_sample, offset in zip(det_data_samples, offsets): | |
pred_inst = det_data_sample.pred_instances.clone() | |
# Check bbox type | |
if pred_inst.bboxes.size(-1) == 4: | |
# Horizontal bboxes | |
shifted_bboxes = shift_bboxes(pred_inst.bboxes, offset) | |
elif pred_inst.bboxes.size(-1) == 5: | |
# Rotated bboxes | |
shifted_bboxes = shift_rbboxes(pred_inst.bboxes, offset) | |
else: | |
raise NotImplementedError | |
# shift bboxes and masks | |
pred_inst.bboxes = shifted_bboxes | |
if 'masks' in det_data_sample: | |
pred_inst.masks = shift_masks(pred_inst.masks, offset, | |
src_image_shape) | |
shifted_predictions.append(pred_inst.clone()) | |
shifted_predictions = InstanceData.cat(shifted_predictions) | |
return shifted_predictions | |
def merge_results_by_nms(results: SampleList, offsets: Sequence[Tuple[int, | |
int]], | |
src_image_shape: Tuple[int, int], | |
nms_cfg: dict) -> DetDataSample: | |
"""Merge patch results by nms. | |
Args: | |
results (List[:obj:`DetDataSample`]): A list of patch results. | |
offsets (Sequence[Tuple[int, int]]): Positions of the left top points | |
of patches. | |
src_image_shape (Tuple[int, int]): A (height, width) tuple of the large | |
image's width and height. | |
nms_cfg (dict): it should specify nms type and other parameters | |
like `iou_threshold`. | |
Returns: | |
:obj:`DetDataSample`: merged results. | |
""" | |
shifted_instances = shift_predictions(results, offsets, src_image_shape) | |
_, keeps = batched_nms( | |
boxes=shifted_instances.bboxes, | |
scores=shifted_instances.scores, | |
idxs=shifted_instances.labels, | |
nms_cfg=nms_cfg) | |
merged_instances = shifted_instances[keeps] | |
merged_result = results[0].clone() | |
merged_result.pred_instances = merged_instances | |
return merged_result | |