# Copyright (c) OpenMMLab. All rights reserved. import copy import warnings from typing import List, Optional, Sequence, Tuple import torch import torch.nn as nn from mmdet.models.utils import filter_scores_and_topk from mmdet.structures.bbox import HorizontalBoxes, distance2bbox from mmdet.structures.bbox.transforms import bbox_cxcywh_to_xyxy, scale_boxes from mmdet.utils import (ConfigType, InstanceList, OptConfigType, OptInstanceList, OptMultiConfig, reduce_mean) from mmengine.config import ConfigDict from mmengine.model import normal_init from mmengine.structures import InstanceData from torch import Tensor from mmyolo.registry import MODELS, TASK_UTILS from ..utils import gt_instances_preprocess from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule try: from mmrotate.structures.bbox import RotatedBoxes, distance2obb MMROTATE_AVAILABLE = True except ImportError: RotatedBoxes = None distance2obb = None MMROTATE_AVAILABLE = False @MODELS.register_module() class RTMDetRotatedSepBNHeadModule(RTMDetSepBNHeadModule): """Detection Head Module of RTMDet-R. Compared with RTMDet Detection Head Module, RTMDet-R adds a conv for angle prediction. An `angle_out_dim` arg is added, which is generated by the angle_coder module and controls the angle pred dim. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. widen_factor (float): Width multiplier, multiply number of channels in each layer by this amount. Defaults to 1.0. num_base_priors (int): The number of priors (points) at a point on the feature grid. Defaults to 1. feat_channels (int): Number of hidden channels. Used in child classes. Defaults to 256 stacked_convs (int): Number of stacking convs of the head. Defaults to 2. featmap_strides (Sequence[int]): Downsample factor of each feature map. Defaults to (8, 16, 32). share_conv (bool): Whether to share conv layers between stages. Defaults to True. pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1. angle_out_dim (int): Encoded length of angle, will passed by head. Defaults to 1. conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for convolution layer. Defaults to None. norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization layer. Defaults to ``dict(type='BN')``. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. Default: dict(type='SiLU', inplace=True). init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__( self, num_classes: int, in_channels: int, widen_factor: float = 1.0, num_base_priors: int = 1, feat_channels: int = 256, stacked_convs: int = 2, featmap_strides: Sequence[int] = [8, 16, 32], share_conv: bool = True, pred_kernel_size: int = 1, angle_out_dim: int = 1, conv_cfg: OptConfigType = None, norm_cfg: ConfigType = dict(type='BN'), act_cfg: ConfigType = dict(type='SiLU', inplace=True), init_cfg: OptMultiConfig = None, ): self.angle_out_dim = angle_out_dim super().__init__( num_classes=num_classes, in_channels=in_channels, widen_factor=widen_factor, num_base_priors=num_base_priors, feat_channels=feat_channels, stacked_convs=stacked_convs, featmap_strides=featmap_strides, share_conv=share_conv, pred_kernel_size=pred_kernel_size, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, init_cfg=init_cfg) def _init_layers(self): """Initialize layers of the head.""" super()._init_layers() self.rtm_ang = nn.ModuleList() for _ in range(len(self.featmap_strides)): self.rtm_ang.append( nn.Conv2d( self.feat_channels, self.num_base_priors * self.angle_out_dim, self.pred_kernel_size, padding=self.pred_kernel_size // 2)) def init_weights(self) -> None: """Initialize weights of the head.""" # Use prior in model initialization to improve stability super().init_weights() for rtm_ang in self.rtm_ang: normal_init(rtm_ang, std=0.01) def forward(self, feats: Tuple[Tensor, ...]) -> tuple: """Forward features from the upstream network. Args: feats (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. Returns: tuple: Usually a tuple of classification scores and bbox prediction - cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * num_classes. - bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * 4. - angle_preds (list[Tensor]): Angle prediction for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * angle_out_dim. """ cls_scores = [] bbox_preds = [] angle_preds = [] for idx, x in enumerate(feats): cls_feat = x reg_feat = x for cls_layer in self.cls_convs[idx]: cls_feat = cls_layer(cls_feat) cls_score = self.rtm_cls[idx](cls_feat) for reg_layer in self.reg_convs[idx]: reg_feat = reg_layer(reg_feat) reg_dist = self.rtm_reg[idx](reg_feat) angle_pred = self.rtm_ang[idx](reg_feat) cls_scores.append(cls_score) bbox_preds.append(reg_dist) angle_preds.append(angle_pred) return tuple(cls_scores), tuple(bbox_preds), tuple(angle_preds) @MODELS.register_module() class RTMDetRotatedHead(RTMDetHead): """RTMDet-R head. Compared with RTMDetHead, RTMDetRotatedHead add some args to support rotated object detection. - `angle_version` used to limit angle_range during training. - `angle_coder` used to encode and decode angle, which is similar to bbox_coder. - `use_hbbox_loss` and `loss_angle` allow custom regression loss calculation for rotated box. There are three combination options for regression: 1. `use_hbbox_loss=False` and loss_angle is None. .. code:: text bbox_pred────(tblr)───┐ ▼ angle_pred decode──►rbox_pred──(xywha)─►loss_bbox │ ▲ └────►decode──(a)─┘ 2. `use_hbbox_loss=False` and loss_angle is specified. A angle loss is added on angle_pred. .. code:: text bbox_pred────(tblr)───┐ ▼ angle_pred decode──►rbox_pred──(xywha)─►loss_bbox │ ▲ ├────►decode──(a)─┘ │ └───────────────────────────────────────────►loss_angle 3. `use_hbbox_loss=True` and loss_angle is specified. In this case the loss_angle must be set. .. code:: text bbox_pred──(tblr)──►decode──►hbox_pred──(xyxy)──►loss_bbox angle_pred──────────────────────────────────────►loss_angle - There's a `decoded_with_angle` flag in test_cfg, which is similar to training process. When `decoded_with_angle=True`: .. code:: text bbox_pred────(tblr)───┐ ▼ angle_pred decode──(xywha)──►rbox_pred │ ▲ └────►decode──(a)─┘ When `decoded_with_angle=False`: .. code:: text bbox_pred──(tblr)─►decode │ (xyxy) ▼ format───(xywh)──►concat──(xywha)──►rbox_pred ▲ angle_pred────────►decode────(a)───────┘ Args: head_module(ConfigType): Base module used for RTMDetRotatedHead. prior_generator: Points generator feature maps in 2D points-based detectors. bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. angle_version (str): Angle representations. Defaults to 'le90'. use_hbbox_loss (bool): If true, use horizontal bbox loss and loss_angle should not be None. Default to False. angle_coder (:obj:`ConfigDict` or dict): Config of angle coder. loss_angle (:obj:`ConfigDict` or dict, optional): Config of angle loss. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of anchor head. Defaults to None. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of anchor head. Defaults to None. init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or list[dict], optional): Initialization config dict. Defaults to None. """ def __init__( self, head_module: ConfigType, prior_generator: ConfigType = dict( type='mmdet.MlvlPointGenerator', strides=[8, 16, 32], offset=0), bbox_coder: ConfigType = dict(type='DistanceAnglePointCoder'), loss_cls: ConfigType = dict( type='mmdet.QualityFocalLoss', use_sigmoid=True, beta=2.0, loss_weight=1.0), loss_bbox: ConfigType = dict( type='mmrotate.RotatedIoULoss', mode='linear', loss_weight=2.0), angle_version: str = 'le90', use_hbbox_loss: bool = False, angle_coder: ConfigType = dict(type='mmrotate.PseudoAngleCoder'), loss_angle: OptConfigType = None, train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, init_cfg: OptMultiConfig = None): if not MMROTATE_AVAILABLE: raise ImportError( 'Please run "mim install -r requirements/mmrotate.txt" ' 'to install mmrotate first for rotated detection.') self.angle_version = angle_version self.use_hbbox_loss = use_hbbox_loss if self.use_hbbox_loss: assert loss_angle is not None, \ ('When use hbbox loss, loss_angle needs to be specified') self.angle_coder = TASK_UTILS.build(angle_coder) self.angle_out_dim = self.angle_coder.encode_size if head_module.get('angle_out_dim') is not None: warnings.warn('angle_out_dim will be overridden by angle_coder ' 'and does not need to be set manually') head_module['angle_out_dim'] = self.angle_out_dim super().__init__( head_module=head_module, prior_generator=prior_generator, bbox_coder=bbox_coder, loss_cls=loss_cls, loss_bbox=loss_bbox, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg) if loss_angle is not None: self.loss_angle = MODELS.build(loss_angle) else: self.loss_angle = None def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], angle_preds: List[Tensor], objectnesses: Optional[List[Tensor]] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = True, with_nms: bool = True) -> List[InstanceData]: """Transform a batch of output features extracted by the head into bbox results. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). angle_preds (list[Tensor]): Box angle for each scale level with shape (N, num_points * angle_dim, H, W) objectnesses (list[Tensor], Optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, 1, H, W). batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 5), the last dimension 4 arrange as (x, y, w, h, angle). """ assert len(cls_scores) == len(bbox_preds) if objectnesses is None: with_objectnesses = False else: with_objectnesses = True assert len(cls_scores) == len(objectnesses) cfg = self.test_cfg if cfg is None else cfg cfg = copy.deepcopy(cfg) multi_label = cfg.multi_label multi_label &= self.num_classes > 1 cfg.multi_label = multi_label # Whether to decode rbox with angle. # different setting lead to different final results. # Defaults to True. decode_with_angle = cfg.get('decode_with_angle', True) num_imgs = len(batch_img_metas) featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] # If the shape does not change, use the previous mlvl_priors if featmap_sizes != self.featmap_sizes: self.mlvl_priors = self.prior_generator.grid_priors( featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device) self.featmap_sizes = featmap_sizes flatten_priors = torch.cat(self.mlvl_priors) mlvl_strides = [ flatten_priors.new_full( (featmap_size.numel() * self.num_base_priors, ), stride) for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) ] flatten_stride = torch.cat(mlvl_strides) # flatten cls_scores, bbox_preds and objectness flatten_cls_scores = [ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_score in cls_scores ] flatten_bbox_preds = [ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ] flatten_angle_preds = [ angle_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.angle_out_dim) for angle_pred in angle_preds ] flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1) flatten_angle_preds = self.angle_coder.decode( flatten_angle_preds, keepdim=True) if decode_with_angle: flatten_rbbox_preds = torch.cat( [flatten_bbox_preds, flatten_angle_preds], dim=-1) flatten_decoded_bboxes = self.bbox_coder.decode( flatten_priors[None], flatten_rbbox_preds, flatten_stride) else: flatten_decoded_hbboxes = self.bbox_coder.decode( flatten_priors[None], flatten_bbox_preds, flatten_stride) flatten_decoded_hbboxes = HorizontalBoxes.xyxy_to_cxcywh( flatten_decoded_hbboxes) flatten_decoded_bboxes = torch.cat( [flatten_decoded_hbboxes, flatten_angle_preds], dim=-1) if with_objectnesses: flatten_objectness = [ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) for objectness in objectnesses ] flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() else: flatten_objectness = [None for _ in range(num_imgs)] results_list = [] for (bboxes, scores, objectness, img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, flatten_objectness, batch_img_metas): scale_factor = img_meta['scale_factor'] if 'pad_param' in img_meta: pad_param = img_meta['pad_param'] else: pad_param = None score_thr = cfg.get('score_thr', -1) # yolox_style does not require the following operations if objectness is not None and score_thr > 0 and not cfg.get( 'yolox_style', False): conf_inds = objectness > score_thr bboxes = bboxes[conf_inds, :] scores = scores[conf_inds, :] objectness = objectness[conf_inds] if objectness is not None: # conf = obj_conf * cls_conf scores *= objectness[:, None] if scores.shape[0] == 0: empty_results = InstanceData() empty_results.bboxes = RotatedBoxes(bboxes) empty_results.scores = scores[:, 0] empty_results.labels = scores[:, 0].int() results_list.append(empty_results) continue nms_pre = cfg.get('nms_pre', 100000) if cfg.multi_label is False: scores, labels = scores.max(1, keepdim=True) scores, _, keep_idxs, results = filter_scores_and_topk( scores, score_thr, nms_pre, results=dict(labels=labels[:, 0])) labels = results['labels'] else: scores, labels, keep_idxs, _ = filter_scores_and_topk( scores, score_thr, nms_pre) results = InstanceData( scores=scores, labels=labels, bboxes=RotatedBoxes(bboxes[keep_idxs])) if rescale: if pad_param is not None: results.bboxes.translate_([-pad_param[2], -pad_param[0]]) scale_factor = [1 / s for s in img_meta['scale_factor']] results.bboxes = scale_boxes(results.bboxes, scale_factor) if cfg.get('yolox_style', False): # do not need max_per_img cfg.max_per_img = len(results) results = self._bbox_post_process( results=results, cfg=cfg, rescale=False, with_nms=with_nms, img_meta=img_meta) results_list.append(results) return results_list def loss_by_feat( self, cls_scores: List[Tensor], bbox_preds: List[Tensor], angle_preds: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None) -> dict: """Compute losses of the head. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Decoded box for each scale level with shape (N, num_anchors * 4, H, W) in [tl_x, tl_y, br_x, br_y] format. angle_preds (list[Tensor]): Angle prediction for each scale level with shape (N, num_anchors * angle_out_dim, H, W). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ num_imgs = len(batch_img_metas) featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] assert len(featmap_sizes) == self.prior_generator.num_levels gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs) gt_labels = gt_info[:, :, :1] gt_bboxes = gt_info[:, :, 1:] # xywha pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float() device = cls_scores[0].device # If the shape does not equal, generate new one if featmap_sizes != self.featmap_sizes_train: self.featmap_sizes_train = featmap_sizes mlvl_priors_with_stride = self.prior_generator.grid_priors( featmap_sizes, device=device, with_stride=True) self.flatten_priors_train = torch.cat( mlvl_priors_with_stride, dim=0) flatten_cls_scores = torch.cat([ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.cls_out_channels) for cls_score in cls_scores ], 1).contiguous() flatten_tblrs = torch.cat([ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) for bbox_pred in bbox_preds ], 1) flatten_tblrs = flatten_tblrs * self.flatten_priors_train[..., -1, None] flatten_angles = torch.cat([ angle_pred.permute(0, 2, 3, 1).reshape( num_imgs, -1, self.angle_out_dim) for angle_pred in angle_preds ], 1) flatten_decoded_angle = self.angle_coder.decode( flatten_angles, keepdim=True) flatten_tblra = torch.cat([flatten_tblrs, flatten_decoded_angle], dim=-1) flatten_rbboxes = distance2obb( self.flatten_priors_train[..., :2], flatten_tblra, angle_version=self.angle_version) if self.use_hbbox_loss: flatten_hbboxes = distance2bbox(self.flatten_priors_train[..., :2], flatten_tblrs) assigned_result = self.assigner(flatten_rbboxes.detach(), flatten_cls_scores.detach(), self.flatten_priors_train, gt_labels, gt_bboxes, pad_bbox_flag) labels = assigned_result['assigned_labels'].reshape(-1) label_weights = assigned_result['assigned_labels_weights'].reshape(-1) bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 5) assign_metrics = assigned_result['assign_metrics'].reshape(-1) cls_preds = flatten_cls_scores.reshape(-1, self.num_classes) # FG cat_id: [0, num_classes -1], BG cat_id: num_classes bg_class_ind = self.num_classes pos_inds = ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1) avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item() loss_cls = self.loss_cls( cls_preds, (labels, assign_metrics), label_weights, avg_factor=avg_factor) pos_bbox_targets = bbox_targets[pos_inds] if self.use_hbbox_loss: bbox_preds = flatten_hbboxes.reshape(-1, 4) pos_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets[:, :4]) else: bbox_preds = flatten_rbboxes.reshape(-1, 5) angle_preds = flatten_angles.reshape(-1, self.angle_out_dim) if len(pos_inds) > 0: loss_bbox = self.loss_bbox( bbox_preds[pos_inds], pos_bbox_targets, weight=assign_metrics[pos_inds], avg_factor=avg_factor) loss_angle = angle_preds.sum() * 0 if self.loss_angle is not None: pos_angle_targets = bbox_targets[pos_inds][:, 4:5] pos_angle_targets = self.angle_coder.encode(pos_angle_targets) loss_angle = self.loss_angle( angle_preds[pos_inds], pos_angle_targets, weight=assign_metrics[pos_inds], avg_factor=avg_factor) else: loss_bbox = bbox_preds.sum() * 0 loss_angle = angle_preds.sum() * 0 losses = dict() losses['loss_cls'] = loss_cls losses['loss_bbox'] = loss_bbox if self.loss_angle is not None: losses['loss_angle'] = loss_angle return losses