Spaces:
Runtime error
Runtime error
from inspect import signature | |
import torch | |
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms | |
class BBoxTestMixin(object): | |
"""Mixin class for test time augmentation of bboxes.""" | |
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): | |
"""Merge augmented detection bboxes and scores. | |
Args: | |
aug_bboxes (list[Tensor]): shape (n, 4*#class) | |
aug_scores (list[Tensor] or None): shape (n, #class) | |
img_shapes (list[Tensor]): shape (3, ). | |
Returns: | |
tuple: (bboxes, scores) | |
""" | |
recovered_bboxes = [] | |
for bboxes, img_info in zip(aug_bboxes, img_metas): | |
img_shape = img_info[0]['img_shape'] | |
scale_factor = img_info[0]['scale_factor'] | |
flip = img_info[0]['flip'] | |
flip_direction = img_info[0]['flip_direction'] | |
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, | |
flip_direction) | |
recovered_bboxes.append(bboxes) | |
bboxes = torch.cat(recovered_bboxes, dim=0) | |
if aug_scores is None: | |
return bboxes | |
else: | |
scores = torch.cat(aug_scores, dim=0) | |
return bboxes, scores | |
def aug_test_bboxes(self, feats, img_metas, rescale=False): | |
"""Test det bboxes with test time augmentation. | |
Args: | |
feats (list[Tensor]): the outer list indicates test-time | |
augmentations and inner Tensor should have a shape NxCxHxW, | |
which contains features for all images in the batch. | |
img_metas (list[list[dict]]): the outer list indicates test-time | |
augs (multiscale, flip, etc.) and the inner list indicates | |
images in a batch. each dict has image information. | |
rescale (bool, optional): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[ndarray]: bbox results of each class | |
""" | |
# check with_nms argument | |
gb_sig = signature(self.get_bboxes) | |
gb_args = [p.name for p in gb_sig.parameters.values()] | |
if hasattr(self, '_get_bboxes'): | |
gbs_sig = signature(self._get_bboxes) | |
else: | |
gbs_sig = signature(self._get_bboxes_single) | |
gbs_args = [p.name for p in gbs_sig.parameters.values()] | |
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ | |
f'{self.__class__.__name__}' \ | |
' does not support test-time augmentation' | |
aug_bboxes = [] | |
aug_scores = [] | |
aug_factors = [] # score_factors for NMS | |
for x, img_meta in zip(feats, img_metas): | |
# only one image in the batch | |
outs = self.forward(x) | |
bbox_inputs = outs + (img_meta, self.test_cfg, False, False) | |
bbox_outputs = self.get_bboxes(*bbox_inputs)[0] | |
aug_bboxes.append(bbox_outputs[0]) | |
aug_scores.append(bbox_outputs[1]) | |
# bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3) | |
# contains additional element to adjust scores before NMS | |
if len(bbox_outputs) >= 3: | |
aug_factors.append(bbox_outputs[2]) | |
# after merging, bboxes will be rescaled to the original image size | |
merged_bboxes, merged_scores = self.merge_aug_bboxes( | |
aug_bboxes, aug_scores, img_metas) | |
merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None | |
det_bboxes, det_labels = multiclass_nms( | |
merged_bboxes, | |
merged_scores, | |
self.test_cfg.score_thr, | |
self.test_cfg.nms, | |
self.test_cfg.max_per_img, | |
score_factors=merged_factors) | |
if rescale: | |
_det_bboxes = det_bboxes | |
else: | |
_det_bboxes = det_bboxes.clone() | |
_det_bboxes[:, :4] *= det_bboxes.new_tensor( | |
img_metas[0][0]['scale_factor']) | |
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) | |
return bbox_results | |