Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Sequence, Tuple | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, is_norm | |
from mmdet.models.task_modules.samplers import PseudoSampler | |
from mmdet.structures.bbox import distance2bbox | |
from mmdet.utils import (ConfigType, InstanceList, OptConfigType, | |
OptInstanceList, OptMultiConfig, reduce_mean) | |
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, | |
normal_init) | |
from torch import Tensor | |
from mmyolo.registry import MODELS, TASK_UTILS | |
from ..utils import gt_instances_preprocess | |
from .yolov5_head import YOLOv5Head | |
class RTMDetSepBNHeadModule(BaseModule): | |
"""Detection Head of RTMDet. | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
widen_factor (float): Width multiplier, multiply number of | |
channels in each layer by this amount. Defaults to 1.0. | |
num_base_priors (int): The number of priors (points) at a point | |
on the feature grid. Defaults to 1. | |
feat_channels (int): Number of hidden channels. Used in child classes. | |
Defaults to 256 | |
stacked_convs (int): Number of stacking convs of the head. | |
Defaults to 2. | |
featmap_strides (Sequence[int]): Downsample factor of each feature map. | |
Defaults to (8, 16, 32). | |
share_conv (bool): Whether to share conv layers between stages. | |
Defaults to True. | |
pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1. | |
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for | |
convolution layer. Defaults to None. | |
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization | |
layer. Defaults to ``dict(type='BN')``. | |
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. | |
Default: dict(type='SiLU', inplace=True). | |
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__( | |
self, | |
num_classes: int, | |
in_channels: int, | |
widen_factor: float = 1.0, | |
num_base_priors: int = 1, | |
feat_channels: int = 256, | |
stacked_convs: int = 2, | |
featmap_strides: Sequence[int] = [8, 16, 32], | |
share_conv: bool = True, | |
pred_kernel_size: int = 1, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: ConfigType = dict(type='BN'), | |
act_cfg: ConfigType = dict(type='SiLU', inplace=True), | |
init_cfg: OptMultiConfig = None, | |
): | |
super().__init__(init_cfg=init_cfg) | |
self.share_conv = share_conv | |
self.num_classes = num_classes | |
self.pred_kernel_size = pred_kernel_size | |
self.feat_channels = int(feat_channels * widen_factor) | |
self.stacked_convs = stacked_convs | |
self.num_base_priors = num_base_priors | |
self.conv_cfg = conv_cfg | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.featmap_strides = featmap_strides | |
self.in_channels = int(in_channels * widen_factor) | |
self._init_layers() | |
def _init_layers(self): | |
"""Initialize layers of the head.""" | |
self.cls_convs = nn.ModuleList() | |
self.reg_convs = nn.ModuleList() | |
self.rtm_cls = nn.ModuleList() | |
self.rtm_reg = nn.ModuleList() | |
for n in range(len(self.featmap_strides)): | |
cls_convs = nn.ModuleList() | |
reg_convs = nn.ModuleList() | |
for i in range(self.stacked_convs): | |
chn = self.in_channels if i == 0 else self.feat_channels | |
cls_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
reg_convs.append( | |
ConvModule( | |
chn, | |
self.feat_channels, | |
3, | |
stride=1, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg)) | |
self.cls_convs.append(cls_convs) | |
self.reg_convs.append(reg_convs) | |
self.rtm_cls.append( | |
nn.Conv2d( | |
self.feat_channels, | |
self.num_base_priors * self.num_classes, | |
self.pred_kernel_size, | |
padding=self.pred_kernel_size // 2)) | |
self.rtm_reg.append( | |
nn.Conv2d( | |
self.feat_channels, | |
self.num_base_priors * 4, | |
self.pred_kernel_size, | |
padding=self.pred_kernel_size // 2)) | |
if self.share_conv: | |
for n in range(len(self.featmap_strides)): | |
for i in range(self.stacked_convs): | |
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv | |
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv | |
def init_weights(self) -> None: | |
"""Initialize weights of the head.""" | |
# Use prior in model initialization to improve stability | |
super().init_weights() | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
normal_init(m, mean=0, std=0.01) | |
if is_norm(m): | |
constant_init(m, 1) | |
bias_cls = bias_init_with_prob(0.01) | |
for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg): | |
normal_init(rtm_cls, std=0.01, bias=bias_cls) | |
normal_init(rtm_reg, std=0.01) | |
def forward(self, feats: Tuple[Tensor, ...]) -> tuple: | |
"""Forward features from the upstream network. | |
Args: | |
feats (tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
tuple: Usually a tuple of classification scores and bbox prediction | |
- cls_scores (list[Tensor]): Classification scores for all scale | |
levels, each is a 4D-tensor, the channels number is | |
num_base_priors * num_classes. | |
- bbox_preds (list[Tensor]): Box energies / deltas for all scale | |
levels, each is a 4D-tensor, the channels number is | |
num_base_priors * 4. | |
""" | |
cls_scores = [] | |
bbox_preds = [] | |
for idx, x in enumerate(feats): | |
cls_feat = x | |
reg_feat = x | |
for cls_layer in self.cls_convs[idx]: | |
cls_feat = cls_layer(cls_feat) | |
cls_score = self.rtm_cls[idx](cls_feat) | |
for reg_layer in self.reg_convs[idx]: | |
reg_feat = reg_layer(reg_feat) | |
reg_dist = self.rtm_reg[idx](reg_feat) | |
cls_scores.append(cls_score) | |
bbox_preds.append(reg_dist) | |
return tuple(cls_scores), tuple(bbox_preds) | |
class RTMDetHead(YOLOv5Head): | |
"""RTMDet head. | |
Args: | |
head_module(ConfigType): Base module used for RTMDetHead | |
prior_generator: Points generator feature maps in | |
2D points-based detectors. | |
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. | |
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. | |
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. | |
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of | |
anchor head. Defaults to None. | |
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of | |
anchor head. Defaults to None. | |
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or | |
list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
head_module: ConfigType, | |
prior_generator: ConfigType = dict( | |
type='mmdet.MlvlPointGenerator', | |
offset=0, | |
strides=[8, 16, 32]), | |
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), | |
loss_cls: ConfigType = dict( | |
type='mmdet.QualityFocalLoss', | |
use_sigmoid=True, | |
beta=2.0, | |
loss_weight=1.0), | |
loss_bbox: ConfigType = dict( | |
type='mmdet.GIoULoss', loss_weight=2.0), | |
train_cfg: OptConfigType = None, | |
test_cfg: OptConfigType = None, | |
init_cfg: OptMultiConfig = None): | |
super().__init__( | |
head_module=head_module, | |
prior_generator=prior_generator, | |
bbox_coder=bbox_coder, | |
loss_cls=loss_cls, | |
loss_bbox=loss_bbox, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
init_cfg=init_cfg) | |
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) | |
if self.use_sigmoid_cls: | |
self.cls_out_channels = self.num_classes | |
else: | |
self.cls_out_channels = self.num_classes + 1 | |
# rtmdet doesn't need loss_obj | |
self.loss_obj = None | |
def special_init(self): | |
"""Since YOLO series algorithms will inherit from YOLOv5Head, but | |
different algorithms have special initialization process. | |
The special_init function is designed to deal with this situation. | |
""" | |
if self.train_cfg: | |
self.assigner = TASK_UTILS.build(self.train_cfg.assigner) | |
if self.train_cfg.get('sampler', None) is not None: | |
self.sampler = TASK_UTILS.build( | |
self.train_cfg.sampler, default_args=dict(context=self)) | |
else: | |
self.sampler = PseudoSampler(context=self) | |
self.featmap_sizes_train = None | |
self.flatten_priors_train = None | |
def forward(self, x: Tuple[Tensor]) -> Tuple[List]: | |
"""Forward features from the upstream network. | |
Args: | |
x (Tuple[Tensor]): Features from the upstream network, each is | |
a 4D-tensor. | |
Returns: | |
Tuple[List]: A tuple of multi-level classification scores, bbox | |
predictions, and objectnesses. | |
""" | |
return self.head_module(x) | |
def loss_by_feat( | |
self, | |
cls_scores: List[Tensor], | |
bbox_preds: List[Tensor], | |
batch_gt_instances: InstanceList, | |
batch_img_metas: List[dict], | |
batch_gt_instances_ignore: OptInstanceList = None) -> dict: | |
"""Compute losses of the head. | |
Args: | |
cls_scores (list[Tensor]): Box scores for each scale level | |
Has shape (N, num_anchors * num_classes, H, W) | |
bbox_preds (list[Tensor]): Decoded box for each scale | |
level with shape (N, num_anchors * 4, H, W) in | |
[tl_x, tl_y, br_x, br_y] format. | |
batch_gt_instances (list[:obj:`InstanceData`]): Batch of | |
gt_instance. It usually includes ``bboxes`` and ``labels`` | |
attributes. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): | |
Batch of gt_instances_ignore. It includes ``bboxes`` attribute | |
data that is ignored during training and testing. | |
Defaults to None. | |
Returns: | |
dict[str, Tensor]: A dictionary of loss components. | |
""" | |
num_imgs = len(batch_img_metas) | |
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | |
assert len(featmap_sizes) == self.prior_generator.num_levels | |
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) | |
gt_labels = gt_info[:, :, :1] | |
gt_bboxes = gt_info[:, :, 1:] # xyxy | |
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float() | |
device = cls_scores[0].device | |
# If the shape does not equal, generate new one | |
if featmap_sizes != self.featmap_sizes_train: | |
self.featmap_sizes_train = featmap_sizes | |
mlvl_priors_with_stride = self.prior_generator.grid_priors( | |
featmap_sizes, device=device, with_stride=True) | |
self.flatten_priors_train = torch.cat( | |
mlvl_priors_with_stride, dim=0) | |
flatten_cls_scores = torch.cat([ | |
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, | |
self.cls_out_channels) | |
for cls_score in cls_scores | |
], 1).contiguous() | |
flatten_bboxes = torch.cat([ | |
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) | |
for bbox_pred in bbox_preds | |
], 1) | |
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1, | |
None] | |
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2], | |
flatten_bboxes) | |
assigned_result = self.assigner(flatten_bboxes.detach(), | |
flatten_cls_scores.detach(), | |
self.flatten_priors_train, gt_labels, | |
gt_bboxes, pad_bbox_flag) | |
labels = assigned_result['assigned_labels'].reshape(-1) | |
label_weights = assigned_result['assigned_labels_weights'].reshape(-1) | |
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4) | |
assign_metrics = assigned_result['assign_metrics'].reshape(-1) | |
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes) | |
bbox_preds = flatten_bboxes.reshape(-1, 4) | |
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes | |
bg_class_ind = self.num_classes | |
pos_inds = ((labels >= 0) | |
& (labels < bg_class_ind)).nonzero().squeeze(1) | |
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item() | |
loss_cls = self.loss_cls( | |
cls_preds, (labels, assign_metrics), | |
label_weights, | |
avg_factor=avg_factor) | |
if len(pos_inds) > 0: | |
loss_bbox = self.loss_bbox( | |
bbox_preds[pos_inds], | |
bbox_targets[pos_inds], | |
weight=assign_metrics[pos_inds], | |
avg_factor=avg_factor) | |
else: | |
loss_bbox = bbox_preds.sum() * 0 | |
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox) | |