Robert001's picture
first commit
b334e29
raw
history blame
No virus
1.59 kB
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class CascadeRCNN(TwoStageDetector):
r"""Implementation of `Cascade R-CNN: Delving into High Quality Object
Detection <https://arxiv.org/abs/1906.09756>`_"""
def __init__(self,
backbone,
neck=None,
rpn_head=None,
roi_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(CascadeRCNN, self).__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def show_result(self, data, result, **kwargs):
"""Show prediction results of the detector.
Args:
data (str or np.ndarray): Image filename or loaded image.
result (Tensor or tuple): The results to draw over `img`
bbox_result or (bbox_result, segm_result).
Returns:
np.ndarray: The image with bboxes drawn on it.
"""
if self.with_mask:
ms_bbox_result, ms_segm_result = result
if isinstance(ms_bbox_result, dict):
result = (ms_bbox_result['ensemble'],
ms_segm_result['ensemble'])
else:
if isinstance(result, dict):
result = result['ensemble']
return super(CascadeRCNN, self).show_result(data, result, **kwargs)