from inspect import signature import torch from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms class BBoxTestMixin(object): """Mixin class for test time augmentation of bboxes.""" def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): """Merge augmented detection bboxes and scores. Args: aug_bboxes (list[Tensor]): shape (n, 4*#class) aug_scores (list[Tensor] or None): shape (n, #class) img_shapes (list[Tensor]): shape (3, ). Returns: tuple: (bboxes, scores) """ recovered_bboxes = [] for bboxes, img_info in zip(aug_bboxes, img_metas): img_shape = img_info[0]['img_shape'] scale_factor = img_info[0]['scale_factor'] flip = img_info[0]['flip'] flip_direction = img_info[0]['flip_direction'] bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, flip_direction) recovered_bboxes.append(bboxes) bboxes = torch.cat(recovered_bboxes, dim=0) if aug_scores is None: return bboxes else: scores = torch.cat(aug_scores, dim=0) return bboxes, scores def aug_test_bboxes(self, feats, img_metas, rescale=False): """Test det bboxes with test time augmentation. Args: feats (list[Tensor]): the outer list indicates test-time augmentations and inner Tensor should have a shape NxCxHxW, which contains features for all images in the batch. img_metas (list[list[dict]]): the outer list indicates test-time augs (multiscale, flip, etc.) and the inner list indicates images in a batch. each dict has image information. rescale (bool, optional): Whether to rescale the results. Defaults to False. Returns: list[ndarray]: bbox results of each class """ # check with_nms argument gb_sig = signature(self.get_bboxes) gb_args = [p.name for p in gb_sig.parameters.values()] if hasattr(self, '_get_bboxes'): gbs_sig = signature(self._get_bboxes) else: gbs_sig = signature(self._get_bboxes_single) gbs_args = [p.name for p in gbs_sig.parameters.values()] assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ f'{self.__class__.__name__}' \ ' does not support test-time augmentation' aug_bboxes = [] aug_scores = [] aug_factors = [] # score_factors for NMS for x, img_meta in zip(feats, img_metas): # only one image in the batch outs = self.forward(x) bbox_inputs = outs + (img_meta, self.test_cfg, False, False) bbox_outputs = self.get_bboxes(*bbox_inputs)[0] aug_bboxes.append(bbox_outputs[0]) aug_scores.append(bbox_outputs[1]) # bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3) # contains additional element to adjust scores before NMS if len(bbox_outputs) >= 3: aug_factors.append(bbox_outputs[2]) # after merging, bboxes will be rescaled to the original image size merged_bboxes, merged_scores = self.merge_aug_bboxes( aug_bboxes, aug_scores, img_metas) merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None det_bboxes, det_labels = multiclass_nms( merged_bboxes, merged_scores, self.test_cfg.score_thr, self.test_cfg.nms, self.test_cfg.max_per_img, score_factors=merged_factors) if rescale: _det_bboxes = det_bboxes else: _det_bboxes = det_bboxes.clone() _det_bboxes[:, :4] *= det_bboxes.new_tensor( img_metas[0][0]['scale_factor']) bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) return bbox_results