|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import xavier_init |
|
from mmcv.runner import force_fp32 |
|
|
|
from mmdet.core import (build_anchor_generator, build_assigner, |
|
build_bbox_coder, build_sampler, multi_apply) |
|
from ..builder import HEADS |
|
from ..losses import smooth_l1_loss |
|
from .anchor_head import AnchorHead |
|
|
|
|
|
|
|
@HEADS.register_module() |
|
class SSDHead(AnchorHead): |
|
"""SSD head used in https://arxiv.org/abs/1512.02325. |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
anchor_generator (dict): Config dict for anchor generator |
|
bbox_coder (dict): Config of bounding box coder. |
|
reg_decoded_bbox (bool): If true, the regression loss would be |
|
applied directly on decoded bounding boxes, converting both |
|
the predicted boxes and regression targets to absolute |
|
coordinates format. Default False. It should be `True` when |
|
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. |
|
train_cfg (dict): Training config of anchor head. |
|
test_cfg (dict): Testing config of anchor head. |
|
""" |
|
|
|
def __init__(self, |
|
num_classes=80, |
|
in_channels=(512, 1024, 512, 256, 256, 256), |
|
anchor_generator=dict( |
|
type='SSDAnchorGenerator', |
|
scale_major=False, |
|
input_size=300, |
|
strides=[8, 16, 32, 64, 100, 300], |
|
ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]), |
|
basesize_ratio_range=(0.1, 0.9)), |
|
bbox_coder=dict( |
|
type='DeltaXYWHBBoxCoder', |
|
clip_border=True, |
|
target_means=[.0, .0, .0, .0], |
|
target_stds=[1.0, 1.0, 1.0, 1.0], |
|
), |
|
reg_decoded_bbox=False, |
|
train_cfg=None, |
|
test_cfg=None): |
|
super(AnchorHead, self).__init__() |
|
self.num_classes = num_classes |
|
self.in_channels = in_channels |
|
self.cls_out_channels = num_classes + 1 |
|
self.anchor_generator = build_anchor_generator(anchor_generator) |
|
num_anchors = self.anchor_generator.num_base_anchors |
|
|
|
reg_convs = [] |
|
cls_convs = [] |
|
for i in range(len(in_channels)): |
|
reg_convs.append( |
|
nn.Conv2d( |
|
in_channels[i], |
|
num_anchors[i] * 4, |
|
kernel_size=3, |
|
padding=1)) |
|
cls_convs.append( |
|
nn.Conv2d( |
|
in_channels[i], |
|
num_anchors[i] * (num_classes + 1), |
|
kernel_size=3, |
|
padding=1)) |
|
self.reg_convs = nn.ModuleList(reg_convs) |
|
self.cls_convs = nn.ModuleList(cls_convs) |
|
|
|
self.bbox_coder = build_bbox_coder(bbox_coder) |
|
self.reg_decoded_bbox = reg_decoded_bbox |
|
self.use_sigmoid_cls = False |
|
self.cls_focal_loss = False |
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
|
|
self.sampling = False |
|
if self.train_cfg: |
|
self.assigner = build_assigner(self.train_cfg.assigner) |
|
|
|
sampler_cfg = dict(type='PseudoSampler') |
|
self.sampler = build_sampler(sampler_cfg, context=self) |
|
self.fp16_enabled = False |
|
|
|
def init_weights(self): |
|
"""Initialize weights of the head.""" |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
xavier_init(m, distribution='uniform', bias=0) |
|
|
|
def forward(self, feats): |
|
"""Forward features from the upstream network. |
|
|
|
Args: |
|
feats (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
|
|
Returns: |
|
tuple: |
|
cls_scores (list[Tensor]): Classification scores for all scale |
|
levels, each is a 4D-tensor, the channels number is |
|
num_anchors * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for all scale |
|
levels, each is a 4D-tensor, the channels number is |
|
num_anchors * 4. |
|
""" |
|
cls_scores = [] |
|
bbox_preds = [] |
|
for feat, reg_conv, cls_conv in zip(feats, self.reg_convs, |
|
self.cls_convs): |
|
cls_scores.append(cls_conv(feat)) |
|
bbox_preds.append(reg_conv(feat)) |
|
return cls_scores, bbox_preds |
|
|
|
def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights, |
|
bbox_targets, bbox_weights, num_total_samples): |
|
"""Compute loss of a single image. |
|
|
|
Args: |
|
cls_score (Tensor): Box scores for eachimage |
|
Has shape (num_total_anchors, num_classes). |
|
bbox_pred (Tensor): Box energies / deltas for each image |
|
level with shape (num_total_anchors, 4). |
|
anchors (Tensor): Box reference for each scale level with shape |
|
(num_total_anchors, 4). |
|
labels (Tensor): Labels of each anchors with shape |
|
(num_total_anchors,). |
|
label_weights (Tensor): Label weights of each anchor with shape |
|
(num_total_anchors,) |
|
bbox_targets (Tensor): BBox regression targets of each anchor wight |
|
shape (num_total_anchors, 4). |
|
bbox_weights (Tensor): BBox regression loss weights of each anchor |
|
with shape (num_total_anchors, 4). |
|
num_total_samples (int): If sampling, num total samples equal to |
|
the number of total anchors; Otherwise, it is the number of |
|
positive anchors. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
|
|
loss_cls_all = F.cross_entropy( |
|
cls_score, labels, reduction='none') * label_weights |
|
|
|
pos_inds = ((labels >= 0) & |
|
(labels < self.num_classes)).nonzero().reshape(-1) |
|
neg_inds = (labels == self.num_classes).nonzero().view(-1) |
|
|
|
num_pos_samples = pos_inds.size(0) |
|
num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples |
|
if num_neg_samples > neg_inds.size(0): |
|
num_neg_samples = neg_inds.size(0) |
|
topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples) |
|
loss_cls_pos = loss_cls_all[pos_inds].sum() |
|
loss_cls_neg = topk_loss_cls_neg.sum() |
|
loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples |
|
|
|
if self.reg_decoded_bbox: |
|
|
|
|
|
|
|
bbox_pred = self.bbox_coder.decode(anchor, bbox_pred) |
|
|
|
loss_bbox = smooth_l1_loss( |
|
bbox_pred, |
|
bbox_targets, |
|
bbox_weights, |
|
beta=self.train_cfg.smoothl1_beta, |
|
avg_factor=num_total_samples) |
|
return loss_cls[None], loss_bbox |
|
|
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
|
def loss(self, |
|
cls_scores, |
|
bbox_preds, |
|
gt_bboxes, |
|
gt_labels, |
|
img_metas, |
|
gt_bboxes_ignore=None): |
|
"""Compute losses of the head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level |
|
Has shape (N, num_anchors * num_classes, H, W) |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level with shape (N, num_anchors * 4, H, W) |
|
gt_bboxes (list[Tensor]): each item are the truth boxes for each |
|
image in [tl_x, tl_y, br_x, br_y] format. |
|
gt_labels (list[Tensor]): class indices corresponding to each box |
|
img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
gt_bboxes_ignore (None | list[Tensor]): specify which bounding |
|
boxes can be ignored when computing the loss. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
assert len(featmap_sizes) == self.anchor_generator.num_levels |
|
|
|
device = cls_scores[0].device |
|
|
|
anchor_list, valid_flag_list = self.get_anchors( |
|
featmap_sizes, img_metas, device=device) |
|
cls_reg_targets = self.get_targets( |
|
anchor_list, |
|
valid_flag_list, |
|
gt_bboxes, |
|
img_metas, |
|
gt_bboxes_ignore_list=gt_bboxes_ignore, |
|
gt_labels_list=gt_labels, |
|
label_channels=1, |
|
unmap_outputs=False) |
|
if cls_reg_targets is None: |
|
return None |
|
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, |
|
num_total_pos, num_total_neg) = cls_reg_targets |
|
|
|
num_images = len(img_metas) |
|
all_cls_scores = torch.cat([ |
|
s.permute(0, 2, 3, 1).reshape( |
|
num_images, -1, self.cls_out_channels) for s in cls_scores |
|
], 1) |
|
all_labels = torch.cat(labels_list, -1).view(num_images, -1) |
|
all_label_weights = torch.cat(label_weights_list, |
|
-1).view(num_images, -1) |
|
all_bbox_preds = torch.cat([ |
|
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4) |
|
for b in bbox_preds |
|
], -2) |
|
all_bbox_targets = torch.cat(bbox_targets_list, |
|
-2).view(num_images, -1, 4) |
|
all_bbox_weights = torch.cat(bbox_weights_list, |
|
-2).view(num_images, -1, 4) |
|
|
|
|
|
all_anchors = [] |
|
for i in range(num_images): |
|
all_anchors.append(torch.cat(anchor_list[i])) |
|
|
|
|
|
assert torch.isfinite(all_cls_scores).all().item(), \ |
|
'classification scores become infinite or NaN!' |
|
assert torch.isfinite(all_bbox_preds).all().item(), \ |
|
'bbox predications become infinite or NaN!' |
|
|
|
losses_cls, losses_bbox = multi_apply( |
|
self.loss_single, |
|
all_cls_scores, |
|
all_bbox_preds, |
|
all_anchors, |
|
all_labels, |
|
all_label_weights, |
|
all_bbox_targets, |
|
all_bbox_weights, |
|
num_total_samples=num_total_pos) |
|
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) |
|
|