Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from ..builder import HEADS | |
from .ascend_anchor_head import AscendAnchorHead | |
from .retina_head import RetinaHead | |
class AscendRetinaHead(RetinaHead, AscendAnchorHead): | |
r"""An anchor-based head used in `RetinaNet | |
<https://arxiv.org/pdf/1708.02002.pdf>`_. | |
The head contains two subnetworks. The first classifies anchor boxes and | |
the second regresses deltas for the anchors. | |
Example: | |
>>> import torch | |
>>> self = RetinaHead(11, 7) | |
>>> x = torch.rand(1, 7, 32, 32) | |
>>> cls_score, bbox_pred = self.forward_single(x) | |
>>> # Each anchor predicts a score for each class except background | |
>>> cls_per_anchor = cls_score.shape[1] / self.num_anchors | |
>>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors | |
>>> assert cls_per_anchor == (self.num_classes) | |
>>> assert box_per_anchor == 4 | |
""" | |
def __init__(self, | |
num_classes, | |
in_channels, | |
stacked_convs=4, | |
conv_cfg=None, | |
norm_cfg=None, | |
anchor_generator=dict( | |
type='AnchorGenerator', | |
octave_base_scale=4, | |
scales_per_octave=3, | |
ratios=[0.5, 1.0, 2.0], | |
strides=[8, 16, 32, 64, 128]), | |
init_cfg=dict( | |
type='Normal', | |
layer='Conv2d', | |
std=0.01, | |
override=dict( | |
type='Normal', | |
name='retina_cls', | |
std=0.01, | |
bias_prob=0.01)), | |
**kwargs): | |
super(AscendRetinaHead, self).__init__( | |
num_classes=num_classes, | |
in_channels=in_channels, | |
stacked_convs=stacked_convs, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
anchor_generator=anchor_generator, | |
init_cfg=init_cfg, | |
**kwargs) | |
def get_targets(self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
img_metas, | |
gt_bboxes_ignore_list=None, | |
gt_labels_list=None, | |
label_channels=1, | |
unmap_outputs=True, | |
return_sampling_results=False, | |
return_level=True): | |
"""Compute regression and classification targets for anchors in | |
multiple images. | |
Args: | |
anchor_list (list[list[Tensor]]): Multi level anchors of each | |
image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, 4). | |
valid_flag_list (list[list[Tensor]]): Multi level valid flags of | |
each image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, ) | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. | |
img_metas (list[dict]): Meta info of each image. | |
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be | |
ignored. | |
gt_labels_list (list[Tensor]): Ground truth labels of each box. | |
label_channels (int): Channel of label. | |
unmap_outputs (bool): Whether to map outputs back to the original | |
set of anchors. | |
return_sampling_results (bool): Whether to return the result of | |
sample. | |
return_level (bool): Whether to map outputs back to the levels | |
of feature map sizes. | |
Returns: | |
tuple: Usually returns a tuple containing learning targets. | |
- labels_list (list[Tensor]): Labels of each level. | |
- label_weights_list (list[Tensor]): Label weights of each | |
level. | |
- bbox_targets_list (list[Tensor]): BBox targets of each level. | |
- bbox_weights_list (list[Tensor]): BBox weights of each level. | |
- num_total_pos (int): Number of positive samples in all | |
images. | |
- num_total_neg (int): Number of negative samples in all | |
images. | |
additional_returns: This function enables user-defined returns from | |
`self._get_targets_single`. These returns are currently refined | |
to properties at each feature map (i.e. having HxW dimension). | |
The results will be concatenated after the end | |
""" | |
return AscendAnchorHead.get_targets( | |
self, anchor_list, valid_flag_list, gt_bboxes_list, img_metas, | |
gt_bboxes_ignore_list, gt_labels_list, label_channels, | |
unmap_outputs, return_sampling_results, return_level) | |