import logging import sys import torch from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes, merge_aug_masks, multiclass_nms) logger = logging.getLogger(__name__) if sys.version_info >= (3, 7): from mmdet.utils.contextmanagers import completed class BBoxTestMixin(object): if sys.version_info >= (3, 7): async def async_test_bboxes(self, x, img_metas, proposals, rcnn_test_cfg, rescale=False, bbox_semaphore=None, global_lock=None): """Asynchronized test for box head without augmentation.""" rois = bbox2roi(proposals) roi_feats = self.bbox_roi_extractor( x[:len(self.bbox_roi_extractor.featmap_strides)], rois) if self.with_shared_head: roi_feats = self.shared_head(roi_feats) sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017) async with completed( __name__, 'bbox_head_forward', sleep_interval=sleep_interval): cls_score, bbox_pred = self.bbox_head(roi_feats) img_shape = img_metas[0]['img_shape'] scale_factor = img_metas[0]['scale_factor'] det_bboxes, det_labels = self.bbox_head.get_bboxes( rois, cls_score, bbox_pred, img_shape, scale_factor, rescale=rescale, cfg=rcnn_test_cfg) return det_bboxes, det_labels def simple_test_bboxes(self, x, img_metas, proposals, rcnn_test_cfg, rescale=False): """Test only det bboxes without augmentation. Args: x (tuple[Tensor]): Feature maps of all scale level. img_metas (list[dict]): Image meta info. proposals (Tensor or List[Tensor]): Region proposals. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. rescale (bool): If True, return boxes in original image space. Default: False. Returns: tuple[list[Tensor], list[Tensor]]: The first list contains the boxes of the corresponding image in a batch, each tensor has the shape (num_boxes, 5) and last dimension 5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor in the second list is the labels with shape (num_boxes, ). The length of both lists should be equal to batch_size. """ # get origin input shape to support onnx dynamic input shape if torch.onnx.is_in_onnx_export(): assert len( img_metas ) == 1, 'Only support one input image while in exporting to ONNX' img_shapes = img_metas[0]['img_shape_for_onnx'] else: img_shapes = tuple(meta['img_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) # The length of proposals of different batches may be different. # In order to form a batch, a padding operation is required. if isinstance(proposals, list): # padding to form a batch max_size = max([proposal.size(0) for proposal in proposals]) for i, proposal in enumerate(proposals): supplement = proposal.new_full( (max_size - proposal.size(0), proposal.size(1)), 0) proposals[i] = torch.cat((supplement, proposal), dim=0) rois = torch.stack(proposals, dim=0) else: rois = proposals batch_index = torch.arange( rois.size(0), device=rois.device).float().view(-1, 1, 1).expand( rois.size(0), rois.size(1), 1) rois = torch.cat([batch_index, rois[..., :4]], dim=-1) batch_size = rois.shape[0] num_proposals_per_img = rois.shape[1] # Eliminate the batch dimension rois = rois.view(-1, 5) bbox_results = self._bbox_forward(x, rois) cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] # Recover the batch dimension rois = rois.reshape(batch_size, num_proposals_per_img, -1) cls_score = cls_score.reshape(batch_size, num_proposals_per_img, -1) if not torch.onnx.is_in_onnx_export(): # remove padding supplement_mask = rois[..., -1] == 0 cls_score[supplement_mask, :] = 0 # bbox_pred would be None in some detector when with_reg is False, # e.g. Grid R-CNN. if bbox_pred is not None: # the bbox prediction of some detectors like SABL is not Tensor if isinstance(bbox_pred, torch.Tensor): bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, -1) if not torch.onnx.is_in_onnx_export(): bbox_pred[supplement_mask, :] = 0 else: # TODO: Looking forward to a better way # For SABL bbox_preds = self.bbox_head.bbox_pred_split( bbox_pred, num_proposals_per_img) # apply bbox post-processing to each image individually det_bboxes = [] det_labels = [] for i in range(len(proposals)): # remove padding supplement_mask = proposals[i][..., -1] == 0 for bbox in bbox_preds[i]: bbox[supplement_mask] = 0 det_bbox, det_label = self.bbox_head.get_bboxes( rois[i], cls_score[i], bbox_preds[i], img_shapes[i], scale_factors[i], rescale=rescale, cfg=rcnn_test_cfg) det_bboxes.append(det_bbox) det_labels.append(det_label) return det_bboxes, det_labels else: bbox_pred = None return self.bbox_head.get_bboxes( rois, cls_score, bbox_pred, img_shapes, scale_factors, rescale=rescale, cfg=rcnn_test_cfg) def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg): """Test det bboxes with test time augmentation.""" aug_bboxes = [] aug_scores = [] for x, img_meta in zip(feats, img_metas): # only one image in the batch img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] flip_direction = img_meta[0]['flip_direction'] # TODO more flexible proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, scale_factor, flip, flip_direction) rois = bbox2roi([proposals]) bbox_results = self._bbox_forward(x, rois) bboxes, scores = self.bbox_head.get_bboxes( rois, bbox_results['cls_score'], bbox_results['bbox_pred'], img_shape, scale_factor, rescale=False, cfg=None) aug_bboxes.append(bboxes) aug_scores.append(scores) # after merging, bboxes will be rescaled to the original image size merged_bboxes, merged_scores = merge_aug_bboxes( aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, rcnn_test_cfg.score_thr, rcnn_test_cfg.nms, rcnn_test_cfg.max_per_img) return det_bboxes, det_labels class MaskTestMixin(object): if sys.version_info >= (3, 7): async def async_test_mask(self, x, img_metas, det_bboxes, det_labels, rescale=False, mask_test_cfg=None): """Asynchronized test for mask head without augmentation.""" # image shape of the first image in the batch (only one) ori_shape = img_metas[0]['ori_shape'] scale_factor = img_metas[0]['scale_factor'] if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head.num_classes)] else: if rescale and not isinstance(scale_factor, (float, torch.Tensor)): scale_factor = det_bboxes.new_tensor(scale_factor) _bboxes = ( det_bboxes[:, :4] * scale_factor if rescale else det_bboxes) mask_rois = bbox2roi([_bboxes]) mask_feats = self.mask_roi_extractor( x[:len(self.mask_roi_extractor.featmap_strides)], mask_rois) if self.with_shared_head: mask_feats = self.shared_head(mask_feats) if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'): sleep_interval = mask_test_cfg['async_sleep_interval'] else: sleep_interval = 0.035 async with completed( __name__, 'mask_head_forward', sleep_interval=sleep_interval): mask_pred = self.mask_head(mask_feats) segm_result = self.mask_head.get_seg_masks( mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape, scale_factor, rescale) return segm_result def simple_test_mask(self, x, img_metas, det_bboxes, det_labels, rescale=False): """Simple test for mask head without augmentation.""" # image shapes of images in the batch ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) scale_factors = tuple(meta['scale_factor'] for meta in img_metas) # The length of proposals of different batches may be different. # In order to form a batch, a padding operation is required. if isinstance(det_bboxes, list): # padding to form a batch max_size = max([bboxes.size(0) for bboxes in det_bboxes]) for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)): supplement_bbox = bbox.new_full( (max_size - bbox.size(0), bbox.size(1)), 0) supplement_label = label.new_full((max_size - label.size(0), ), 0) det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0) det_labels[i] = torch.cat((supplement_label, label), dim=0) det_bboxes = torch.stack(det_bboxes, dim=0) det_labels = torch.stack(det_labels, dim=0) batch_size = det_bboxes.size(0) num_proposals_per_img = det_bboxes.shape[1] # if det_bboxes is rescaled to the original image size, we need to # rescale it back to the testing scale to obtain RoIs. det_bboxes = det_bboxes[..., :4] if rescale: if not isinstance(scale_factors[0], float): scale_factors = det_bboxes.new_tensor(scale_factors) det_bboxes = det_bboxes * scale_factors.unsqueeze(1) batch_index = torch.arange( det_bboxes.size(0), device=det_bboxes.device).float().view( -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1) mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) mask_rois = mask_rois.view(-1, 5) mask_results = self._mask_forward(x, mask_rois) mask_pred = mask_results['mask_pred'] # Recover the batch dimension mask_preds = mask_pred.reshape(batch_size, num_proposals_per_img, *mask_pred.shape[1:]) # apply mask post-processing to each image individually segm_results = [] for i in range(batch_size): mask_pred = mask_preds[i] det_bbox = det_bboxes[i] det_label = det_labels[i] # remove padding supplement_mask = det_bbox[..., -1] != 0 mask_pred = mask_pred[supplement_mask] det_bbox = det_bbox[supplement_mask] det_label = det_label[supplement_mask] if det_label.shape[0] == 0: segm_results.append([[] for _ in range(self.mask_head.num_classes) ]) else: segm_result = self.mask_head.get_seg_masks( mask_pred, det_bbox, det_label, self.test_cfg, ori_shapes[i], scale_factors[i], rescale) segm_results.append(segm_result) return segm_results def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels): """Test for mask head with test time augmentation.""" if det_bboxes.shape[0] == 0: segm_result = [[] for _ in range(self.mask_head.num_classes)] else: aug_masks = [] for x, img_meta in zip(feats, img_metas): img_shape = img_meta[0]['img_shape'] scale_factor = img_meta[0]['scale_factor'] flip = img_meta[0]['flip'] flip_direction = img_meta[0]['flip_direction'] _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, scale_factor, flip, flip_direction) mask_rois = bbox2roi([_bboxes]) mask_results = self._mask_forward(x, mask_rois) # convert to numpy array to save memory aug_masks.append( mask_results['mask_pred'].sigmoid().cpu().numpy()) merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg) ori_shape = img_metas[0][0]['ori_shape'] segm_result = self.mask_head.get_seg_masks( merged_masks, det_bboxes, det_labels, self.test_cfg, ori_shape, scale_factor=1.0, rescale=False) return segm_result