KyanChen's picture
Upload 89 files
3094730
raw
history blame
15.1 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, is_norm
from mmdet.models.task_modules.samplers import PseudoSampler
from mmdet.structures.bbox import distance2bbox
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, OptMultiConfig, reduce_mean)
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
normal_init)
from torch import Tensor
from mmyolo.registry import MODELS, TASK_UTILS
from ..utils import gt_instances_preprocess
from .yolov5_head import YOLOv5Head
@MODELS.register_module()
class RTMDetSepBNHeadModule(BaseModule):
"""Detection Head of RTMDet.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
widen_factor (float): Width multiplier, multiply number of
channels in each layer by this amount. Defaults to 1.0.
num_base_priors (int): The number of priors (points) at a point
on the feature grid. Defaults to 1.
feat_channels (int): Number of hidden channels. Used in child classes.
Defaults to 256
stacked_convs (int): Number of stacking convs of the head.
Defaults to 2.
featmap_strides (Sequence[int]): Downsample factor of each feature map.
Defaults to (8, 16, 32).
share_conv (bool): Whether to share conv layers between stages.
Defaults to True.
pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to ``dict(type='BN')``.
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
Default: dict(type='SiLU', inplace=True).
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
num_classes: int,
in_channels: int,
widen_factor: float = 1.0,
num_base_priors: int = 1,
feat_channels: int = 256,
stacked_convs: int = 2,
featmap_strides: Sequence[int] = [8, 16, 32],
share_conv: bool = True,
pred_kernel_size: int = 1,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
init_cfg: OptMultiConfig = None,
):
super().__init__(init_cfg=init_cfg)
self.share_conv = share_conv
self.num_classes = num_classes
self.pred_kernel_size = pred_kernel_size
self.feat_channels = int(feat_channels * widen_factor)
self.stacked_convs = stacked_convs
self.num_base_priors = num_base_priors
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.featmap_strides = featmap_strides
self.in_channels = int(in_channels * widen_factor)
self._init_layers()
def _init_layers(self):
"""Initialize layers of the head."""
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.rtm_cls = nn.ModuleList()
self.rtm_reg = nn.ModuleList()
for n in range(len(self.featmap_strides)):
cls_convs = nn.ModuleList()
reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.cls_convs.append(cls_convs)
self.reg_convs.append(reg_convs)
self.rtm_cls.append(
nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.num_classes,
self.pred_kernel_size,
padding=self.pred_kernel_size // 2))
self.rtm_reg.append(
nn.Conv2d(
self.feat_channels,
self.num_base_priors * 4,
self.pred_kernel_size,
padding=self.pred_kernel_size // 2))
if self.share_conv:
for n in range(len(self.featmap_strides)):
for i in range(self.stacked_convs):
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
def init_weights(self) -> None:
"""Initialize weights of the head."""
# Use prior in model initialization to improve stability
super().init_weights()
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, mean=0, std=0.01)
if is_norm(m):
constant_init(m, 1)
bias_cls = bias_init_with_prob(0.01)
for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg):
normal_init(rtm_cls, std=0.01, bias=bias_cls)
normal_init(rtm_reg, std=0.01)
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
- cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * num_classes.
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
levels, each is a 4D-tensor, the channels number is
num_base_priors * 4.
"""
cls_scores = []
bbox_preds = []
for idx, x in enumerate(feats):
cls_feat = x
reg_feat = x
for cls_layer in self.cls_convs[idx]:
cls_feat = cls_layer(cls_feat)
cls_score = self.rtm_cls[idx](cls_feat)
for reg_layer in self.reg_convs[idx]:
reg_feat = reg_layer(reg_feat)
reg_dist = self.rtm_reg[idx](reg_feat)
cls_scores.append(cls_score)
bbox_preds.append(reg_dist)
return tuple(cls_scores), tuple(bbox_preds)
@MODELS.register_module()
class RTMDetHead(YOLOv5Head):
"""RTMDet head.
Args:
head_module(ConfigType): Base module used for RTMDetHead
prior_generator: Points generator feature maps in
2D points-based detectors.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
anchor head. Defaults to None.
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
anchor head. Defaults to None.
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
head_module: ConfigType,
prior_generator: ConfigType = dict(
type='mmdet.MlvlPointGenerator',
offset=0,
strides=[8, 16, 32]),
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
loss_cls: ConfigType = dict(
type='mmdet.QualityFocalLoss',
use_sigmoid=True,
beta=2.0,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='mmdet.GIoULoss', loss_weight=2.0),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None):
super().__init__(
head_module=head_module,
prior_generator=prior_generator,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg)
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
if self.use_sigmoid_cls:
self.cls_out_channels = self.num_classes
else:
self.cls_out_channels = self.num_classes + 1
# rtmdet doesn't need loss_obj
self.loss_obj = None
def special_init(self):
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
different algorithms have special initialization process.
The special_init function is designed to deal with this situation.
"""
if self.train_cfg:
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
if self.train_cfg.get('sampler', None) is not None:
self.sampler = TASK_UTILS.build(
self.train_cfg.sampler, default_args=dict(context=self))
else:
self.sampler = PseudoSampler(context=self)
self.featmap_sizes_train = None
self.flatten_priors_train = None
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
"""Forward features from the upstream network.
Args:
x (Tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
Tuple[List]: A tuple of multi-level classification scores, bbox
predictions, and objectnesses.
"""
return self.head_module(x)
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
gt_labels = gt_info[:, :, :1]
gt_bboxes = gt_info[:, :, 1:] # xyxy
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
device = cls_scores[0].device
# If the shape does not equal, generate new one
if featmap_sizes != self.featmap_sizes_train:
self.featmap_sizes_train = featmap_sizes
mlvl_priors_with_stride = self.prior_generator.grid_priors(
featmap_sizes, device=device, with_stride=True)
self.flatten_priors_train = torch.cat(
mlvl_priors_with_stride, dim=0)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1).contiguous()
flatten_bboxes = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
], 1)
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
None]
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
flatten_bboxes)
assigned_result = self.assigner(flatten_bboxes.detach(),
flatten_cls_scores.detach(),
self.flatten_priors_train, gt_labels,
gt_bboxes, pad_bbox_flag)
labels = assigned_result['assigned_labels'].reshape(-1)
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
bbox_preds = flatten_bboxes.reshape(-1, 4)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
loss_cls = self.loss_cls(
cls_preds, (labels, assign_metrics),
label_weights,
avg_factor=avg_factor)
if len(pos_inds) > 0:
loss_bbox = self.loss_bbox(
bbox_preds[pos_inds],
bbox_targets[pos_inds],
weight=assign_metrics[pos_inds],
avg_factor=avg_factor)
else:
loss_bbox = bbox_preds.sum() * 0
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)