Robert001's picture
first commit
b334e29
raw
history blame
No virus
4.09 kB
from inspect import signature
import torch
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
class BBoxTestMixin(object):
"""Mixin class for test time augmentation of bboxes."""
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
"""Merge augmented detection bboxes and scores.
Args:
aug_bboxes (list[Tensor]): shape (n, 4*#class)
aug_scores (list[Tensor] or None): shape (n, #class)
img_shapes (list[Tensor]): shape (3, ).
Returns:
tuple: (bboxes, scores)
"""
recovered_bboxes = []
for bboxes, img_info in zip(aug_bboxes, img_metas):
img_shape = img_info[0]['img_shape']
scale_factor = img_info[0]['scale_factor']
flip = img_info[0]['flip']
flip_direction = img_info[0]['flip_direction']
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
flip_direction)
recovered_bboxes.append(bboxes)
bboxes = torch.cat(recovered_bboxes, dim=0)
if aug_scores is None:
return bboxes
else:
scores = torch.cat(aug_scores, dim=0)
return bboxes, scores
def aug_test_bboxes(self, feats, img_metas, rescale=False):
"""Test det bboxes with test time augmentation.
Args:
feats (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains features for all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[ndarray]: bbox results of each class
"""
# check with_nms argument
gb_sig = signature(self.get_bboxes)
gb_args = [p.name for p in gb_sig.parameters.values()]
if hasattr(self, '_get_bboxes'):
gbs_sig = signature(self._get_bboxes)
else:
gbs_sig = signature(self._get_bboxes_single)
gbs_args = [p.name for p in gbs_sig.parameters.values()]
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
f'{self.__class__.__name__}' \
' does not support test-time augmentation'
aug_bboxes = []
aug_scores = []
aug_factors = [] # score_factors for NMS
for x, img_meta in zip(feats, img_metas):
# only one image in the batch
outs = self.forward(x)
bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
bbox_outputs = self.get_bboxes(*bbox_inputs)[0]
aug_bboxes.append(bbox_outputs[0])
aug_scores.append(bbox_outputs[1])
# bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3)
# contains additional element to adjust scores before NMS
if len(bbox_outputs) >= 3:
aug_factors.append(bbox_outputs[2])
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = self.merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas)
merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None
det_bboxes, det_labels = multiclass_nms(
merged_bboxes,
merged_scores,
self.test_cfg.score_thr,
self.test_cfg.nms,
self.test_cfg.max_per_img,
score_factors=merged_factors)
if rescale:
_det_bboxes = det_bboxes
else:
_det_bboxes = det_bboxes.clone()
_det_bboxes[:, :4] *= det_bboxes.new_tensor(
img_metas[0][0]['scale_factor'])
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
return bbox_results