Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.runner import BaseModule | |
from mmdet.models.builder import HEADS | |
from ...core import bbox_cxcywh_to_xyxy | |
class EmbeddingRPNHead(BaseModule): | |
"""RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ . | |
Unlike traditional RPNHead, this module does not need FPN input, but just | |
decode `init_proposal_bboxes` and expand the first dimension of | |
`init_proposal_bboxes` and `init_proposal_features` to the batch_size. | |
Args: | |
num_proposals (int): Number of init_proposals. Default 100. | |
proposal_feature_channel (int): Channel number of | |
init_proposal_feature. Defaults to 256. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None | |
""" | |
def __init__(self, | |
num_proposals=100, | |
proposal_feature_channel=256, | |
init_cfg=None, | |
**kwargs): | |
assert init_cfg is None, 'To prevent abnormal initialization ' \ | |
'behavior, init_cfg is not allowed to be set' | |
super(EmbeddingRPNHead, self).__init__(init_cfg) | |
self.num_proposals = num_proposals | |
self.proposal_feature_channel = proposal_feature_channel | |
self._init_layers() | |
def _init_layers(self): | |
"""Initialize a sparse set of proposal boxes and proposal features.""" | |
self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4) | |
self.init_proposal_features = nn.Embedding( | |
self.num_proposals, self.proposal_feature_channel) | |
def init_weights(self): | |
"""Initialize the init_proposal_bboxes as normalized. | |
[c_x, c_y, w, h], and we initialize it to the size of the entire | |
image. | |
""" | |
super(EmbeddingRPNHead, self).init_weights() | |
nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5) | |
nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1) | |
def _decode_init_proposals(self, imgs, img_metas): | |
"""Decode init_proposal_bboxes according to the size of images and | |
expand dimension of init_proposal_features to batch_size. | |
Args: | |
imgs (list[Tensor]): List of FPN features. | |
img_metas (list[dict]): List of meta-information of | |
images. Need the img_shape to decode the init_proposals. | |
Returns: | |
Tuple(Tensor): | |
- proposals (Tensor): Decoded proposal bboxes, | |
has shape (batch_size, num_proposals, 4). | |
- init_proposal_features (Tensor): Expanded proposal | |
features, has shape | |
(batch_size, num_proposals, proposal_feature_channel). | |
- imgs_whwh (Tensor): Tensor with shape | |
(batch_size, 4), the dimension means | |
[img_width, img_height, img_width, img_height]. | |
""" | |
proposals = self.init_proposal_bboxes.weight.clone() | |
proposals = bbox_cxcywh_to_xyxy(proposals) | |
num_imgs = len(imgs[0]) | |
imgs_whwh = [] | |
for meta in img_metas: | |
h, w, _ = meta['img_shape'] | |
imgs_whwh.append(imgs[0].new_tensor([[w, h, w, h]])) | |
imgs_whwh = torch.cat(imgs_whwh, dim=0) | |
imgs_whwh = imgs_whwh[:, None, :] | |
# imgs_whwh has shape (batch_size, 1, 4) | |
# The shape of proposals change from (num_proposals, 4) | |
# to (batch_size ,num_proposals, 4) | |
proposals = proposals * imgs_whwh | |
init_proposal_features = self.init_proposal_features.weight.clone() | |
init_proposal_features = init_proposal_features[None].expand( | |
num_imgs, *init_proposal_features.size()) | |
return proposals, init_proposal_features, imgs_whwh | |
def forward_dummy(self, img, img_metas): | |
"""Dummy forward function. | |
Used in flops calculation. | |
""" | |
return self._decode_init_proposals(img, img_metas) | |
def forward_train(self, img, img_metas): | |
"""Forward function in training stage.""" | |
return self._decode_init_proposals(img, img_metas) | |
def simple_test_rpn(self, img, img_metas): | |
"""Forward function in testing stage.""" | |
return self._decode_init_proposals(img, img_metas) | |
def simple_test(self, img, img_metas): | |
"""Forward function in testing stage.""" | |
raise NotImplementedError | |
def aug_test_rpn(self, feats, img_metas): | |
raise NotImplementedError( | |
'EmbeddingRPNHead does not support test-time augmentation') | |