Robert001's picture
first commit
b334e29
raw
history blame
No virus
2.05 kB
from ..builder import DETECTORS
from .two_stage import TwoStageDetector
@DETECTORS.register_module()
class FastRCNN(TwoStageDetector):
"""Implementation of `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_"""
def __init__(self,
backbone,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None):
super(FastRCNN, self).__init__(
backbone=backbone,
neck=neck,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained)
def forward_test(self, imgs, img_metas, proposals, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
proposals (List[List[Tensor]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. The Tensor should have a shape Px4, where
P is the number of proposals.
"""
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(imgs)}) '
f'!= num of image meta ({len(img_metas)})')
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], proposals[0],
**kwargs)
else:
# TODO: support test-time augmentation
assert NotImplementedError