Spaces:
Runtime error
Runtime error
from abc import abstractmethod | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init | |
from mmcv.runner import force_fp32 | |
from mmdet.core import multi_apply | |
from ..builder import HEADS, build_loss | |
from .base_dense_head import BaseDenseHead | |
from .dense_test_mixins import BBoxTestMixin | |
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): | |
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.). | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels. Used in child classes. | |
stacked_convs (int): Number of stacking convs of the head. | |
strides (tuple): Downsample factor of each feature map. | |
dcn_on_last_conv (bool): If true, use dcn in the last layer of | |
towers. Default: False. | |
conv_bias (bool | str): If specified as `auto`, it will be decided by | |
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is | |
None, otherwise False. Default: "auto". | |
loss_cls (dict): Config of classification loss. | |
loss_bbox (dict): Config of localization loss. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
train_cfg (dict): Training config of anchor head. | |
test_cfg (dict): Testing config of anchor head. | |
""" # noqa: W605 | |
_version = 1 | |
def __init__(self, | |
num_classes, | |
in_channels, | |
feat_channels=256, | |
stacked_convs=4, | |
strides=(4, 8, 16, 32, 64), | |
dcn_on_last_conv=False, | |
conv_bias='auto', | |
loss_cls=dict( | |
type='FocalLoss', | |
use_sigmoid=True, | |
gamma=2.0, | |
alpha=0.25, | |
loss_weight=1.0), | |
loss_bbox=dict(type='IoULoss', loss_weight=1.0), | |
conv_cfg=None, | |
norm_cfg=None, | |
train_cfg=None, | |
test_cfg=None): | |
super(AnchorFreeHead, self).__init__() | |
self.num_classes = num_classes | |
self.cls_out_channels = num_classes | |
self.in_channels = in_channels | |
self.feat_channels = feat_channels | |
self.stacked_convs = stacked_convs | |
self.strides = strides | |
self.dcn_on_last_conv = dcn_on_last_conv | |
assert conv_bias == 'auto' or isinstance(conv_bias, bool) | |
self.conv_bias = conv_bias | |
self.loss_cls = build_loss(loss_cls) | |
self.loss_bbox = build_loss(loss_bbox) | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.fp16_enabled = False | |
self._init_layers() | |
def _init_layers(self): | |
"""Initialize layers of the head.""" | |
self._init_cls_convs() | |
self._init_reg_convs() | |
self._init_predictor() | |
def _init_cls_convs(self): | |
"""Initialize classification conv layers of the head.""" | |
self.cls_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.cls_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_reg_convs(self): | |
"""Initialize bbox regression conv layers of the head.""" | |
self.reg_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
if self.dcn_on_last_conv and i == self.stacked_convs - 1: | |
conv_cfg = dict(type='DCNv2') | |
else: | |
conv_cfg = self.conv_cfg | |
self.reg_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=self.norm_cfg, | |
bias=self.conv_bias)) | |
def _init_predictor(self): | |
"""Initialize predictor layers of the head.""" | |
self.conv_cls = nn.Conv2d( | |
self.feat_channels, self.cls_out_channels, 3, padding=1) | |
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) | |
def init_weights(self): | |
"""Initialize weights of the head.""" | |
for m in self.cls_convs: | |
if isinstance(m.conv, nn.Conv2d): | |
normal_init(m.conv, std=0.01) | |
for m in self.reg_convs: | |
if isinstance(m.conv, nn.Conv2d): | |
normal_init(m.conv, std=0.01) | |
bias_cls = bias_init_with_prob(0.01) | |
normal_init(self.conv_cls, std=0.01, bias=bias_cls) | |
normal_init(self.conv_reg, std=0.01) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
"""Hack some keys of the model state dict so that can load checkpoints | |
of previous version.""" | |
version = local_metadata.get('version', None) | |
if version is None: | |
# the key is different in early versions | |
# for example, 'fcos_cls' become 'conv_cls' now | |
bbox_head_keys = [ | |
k for k in state_dict.keys() if k.startswith(prefix) | |
] | |
ori_predictor_keys = [] | |
new_predictor_keys = [] | |
# e.g. 'fcos_cls' or 'fcos_reg' | |
for key in bbox_head_keys: | |
ori_predictor_keys.append(key) | |
key = key.split('.') | |
conv_name = None | |
if key[1].endswith('cls'): | |
conv_name = 'conv_cls' | |
elif key[1].endswith('reg'): | |
conv_name = 'conv_reg' | |
elif key[1].endswith('centerness'): | |
conv_name = 'conv_centerness' | |
else: | |
assert NotImplementedError | |
if conv_name is not None: | |
key[1] = conv_name | |
new_predictor_keys.append('.'.join(key)) | |
else: | |
ori_predictor_keys.pop(-1) | |
for i in range(len(new_predictor_keys)): | |
state_dict[new_predictor_keys[i]] = state_dict.pop( | |
ori_predictor_keys[i]) | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
strict, missing_keys, unexpected_keys, | |
error_msgs) | |
def forward(self, feats): | |
"""Forward features from the upstream network. | |
Args: | |
feats (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
tuple: Usually contain classification scores and bbox predictions. | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
""" | |
return multi_apply(self.forward_single, feats)[:2] | |
def forward_single(self, x): | |
"""Forward features of a single scale level. | |
Args: | |
x (Tensor): FPN feature maps of the specified stride. | |
Returns: | |
tuple: Scores for each class, bbox predictions, features | |
after classification and regression conv layers, some | |
models needs these features like FCOS. | |
""" | |
cls_feat = x | |
reg_feat = x | |
for cls_layer in self.cls_convs: | |
cls_feat = cls_layer(cls_feat) | |
cls_score = self.conv_cls(cls_feat) | |
for reg_layer in self.reg_convs: | |
reg_feat = reg_layer(reg_feat) | |
bbox_pred = self.conv_reg(reg_feat) | |
return cls_score, bbox_pred, cls_feat, reg_feat | |
def loss(self, | |
cls_scores, | |
bbox_preds, | |
gt_bboxes, | |
gt_labels, | |
img_metas, | |
gt_bboxes_ignore=None): | |
"""Compute loss of the head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level, | |
each is a 4D-tensor, the channel number is | |
num_points * num_classes. | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level, each is a 4D-tensor, the channel number is | |
num_points * 4. | |
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with | |
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. | |
gt_labels (list[Tensor]): class indices corresponding to each box | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
gt_bboxes_ignore (None | list[Tensor]): specify which bounding | |
boxes can be ignored when computing the loss. | |
""" | |
raise NotImplementedError | |
def get_bboxes(self, | |
cls_scores, | |
bbox_preds, | |
img_metas, | |
cfg=None, | |
rescale=None): | |
"""Transform network output for a batch into bbox predictions. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level | |
Has shape (N, num_points * num_classes, H, W) | |
bbox_preds (list[Tensor]): Box energies / deltas for each scale | |
level with shape (N, num_points * 4, H, W) | |
img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
cfg (mmcv.Config): Test / postprocessing configuration, | |
if None, test_cfg would be used | |
rescale (bool): If True, return boxes in original image space | |
""" | |
raise NotImplementedError | |
def get_targets(self, points, gt_bboxes_list, gt_labels_list): | |
"""Compute regression, classification and centerness targets for points | |
in multiple images. | |
Args: | |
points (list[Tensor]): Points of each fpn level, each has shape | |
(num_points, 2). | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, | |
each has shape (num_gt, 4). | |
gt_labels_list (list[Tensor]): Ground truth labels of each box, | |
each has shape (num_gt,). | |
""" | |
raise NotImplementedError | |
def _get_points_single(self, | |
featmap_size, | |
stride, | |
dtype, | |
device, | |
flatten=False): | |
"""Get points of a single scale level.""" | |
h, w = featmap_size | |
x_range = torch.arange(w, dtype=dtype, device=device) | |
y_range = torch.arange(h, dtype=dtype, device=device) | |
y, x = torch.meshgrid(y_range, x_range) | |
if flatten: | |
y = y.flatten() | |
x = x.flatten() | |
return y, x | |
def get_points(self, featmap_sizes, dtype, device, flatten=False): | |
"""Get points according to feature map sizes. | |
Args: | |
featmap_sizes (list[tuple]): Multi-level feature map sizes. | |
dtype (torch.dtype): Type of points. | |
device (torch.device): Device of points. | |
Returns: | |
tuple: points of each image. | |
""" | |
mlvl_points = [] | |
for i in range(len(featmap_sizes)): | |
mlvl_points.append( | |
self._get_points_single(featmap_sizes[i], self.strides[i], | |
dtype, device, flatten)) | |
return mlvl_points | |
def aug_test(self, feats, img_metas, rescale=False): | |
"""Test function with test time augmentation. | |
Args: | |
feats (list[Tensor]): the outer list indicates test-time | |
augmentations and inner Tensor should have a shape NxCxHxW, | |
which contains features for all images in the batch. | |
img_metas (list[list[dict]]): the outer list indicates test-time | |
augs (multiscale, flip, etc.) and the inner list indicates | |
images in a batch. each dict has image information. | |
rescale (bool, optional): Whether to rescale the results. | |
Defaults to False. | |
Returns: | |
list[ndarray]: bbox results of each class | |
""" | |
return self.aug_test_bboxes(feats, img_metas, rescale=rescale) | |