Spaces:
Runtime error
Runtime error
File size: 36,487 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, Scale
from mmcv.ops import deform_conv2d
from mmengine import MessageHub
from mmengine.config import ConfigDict
from mmengine.model import bias_init_with_prob, normal_init
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS, TASK_UTILS
from mmdet.structures.bbox import distance2bbox
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, reduce_mean)
from ..task_modules.prior_generators import anchor_inside_flags
from ..utils import (filter_scores_and_topk, images_to_levels, multi_apply,
sigmoid_geometric_mean, unmap)
from .atss_head import ATSSHead
class TaskDecomposition(nn.Module):
"""Task decomposition module in task-aligned predictor of TOOD.
Args:
feat_channels (int): Number of feature channels in TOOD head.
stacked_convs (int): Number of conv layers in TOOD head.
la_down_rate (int): Downsample rate of layer attention.
Defaults to 8.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
normalization layer. Defaults to None.
"""
def __init__(self,
feat_channels: int,
stacked_convs: int,
la_down_rate: int = 8,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None) -> None:
super().__init__()
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.in_channels = self.feat_channels * self.stacked_convs
self.norm_cfg = norm_cfg
self.layer_attention = nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1),
nn.ReLU(inplace=True),
nn.Conv2d(
self.in_channels // la_down_rate,
self.stacked_convs,
1,
padding=0), nn.Sigmoid())
self.reduction_conv = ConvModule(
self.in_channels,
self.feat_channels,
1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=norm_cfg is None)
def init_weights(self) -> None:
"""Initialize the parameters."""
for m in self.layer_attention.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
normal_init(self.reduction_conv.conv, std=0.01)
def forward(self,
feat: Tensor,
avg_feat: Optional[Tensor] = None) -> Tensor:
"""Forward function of task decomposition module."""
b, c, h, w = feat.shape
if avg_feat is None:
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
weight = self.layer_attention(avg_feat)
# here we first compute the product between layer attention weight and
# conv weight, and then compute the convolution between new conv weight
# and feature map, in order to save memory and FLOPs.
conv_weight = weight.reshape(
b, 1, self.stacked_convs,
1) * self.reduction_conv.conv.weight.reshape(
1, self.feat_channels, self.stacked_convs, self.feat_channels)
conv_weight = conv_weight.reshape(b, self.feat_channels,
self.in_channels)
feat = feat.reshape(b, self.in_channels, h * w)
feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h,
w)
if self.norm_cfg is not None:
feat = self.reduction_conv.norm(feat)
feat = self.reduction_conv.activate(feat)
return feat
@MODELS.register_module()
class TOODHead(ATSSHead):
"""TOODHead used in `TOOD: Task-aligned One-stage Object Detection.
<https://arxiv.org/abs/2108.07755>`_.
TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment
Learning (TAL).
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
num_dcn (int): Number of deformable convolution in the head.
Defaults to 0.
anchor_type (str): If set to ``anchor_free``, the head will use centers
to regress bboxes. If set to ``anchor_based``, the head will
regress bboxes based on anchors. Defaults to ``anchor_free``.
initial_loss_cls (:obj:`ConfigDict` or dict): Config of initial loss.
Example:
>>> self = TOODHead(11, 7)
>>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
>>> cls_score, bbox_pred = self.forward(feats)
>>> assert len(cls_score) == len(self.scales)
"""
def __init__(self,
num_classes: int,
in_channels: int,
num_dcn: int = 0,
anchor_type: str = 'anchor_free',
initial_loss_cls: ConfigType = dict(
type='FocalLoss',
use_sigmoid=True,
activated=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
**kwargs) -> None:
assert anchor_type in ['anchor_free', 'anchor_based']
self.num_dcn = num_dcn
self.anchor_type = anchor_type
super().__init__(
num_classes=num_classes, in_channels=in_channels, **kwargs)
if self.train_cfg:
self.initial_epoch = self.train_cfg['initial_epoch']
self.initial_assigner = TASK_UTILS.build(
self.train_cfg['initial_assigner'])
self.initial_loss_cls = MODELS.build(initial_loss_cls)
self.assigner = self.initial_assigner
self.alignment_assigner = TASK_UTILS.build(
self.train_cfg['assigner'])
self.alpha = self.train_cfg['alpha']
self.beta = self.train_cfg['beta']
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.inter_convs = nn.ModuleList()
for i in range(self.stacked_convs):
if i < self.num_dcn:
conv_cfg = dict(type='DCNv2', deform_groups=4)
else:
conv_cfg = self.conv_cfg
chn = self.in_channels if i == 0 else self.feat_channels
self.inter_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
self.cls_decomp = TaskDecomposition(self.feat_channels,
self.stacked_convs,
self.stacked_convs * 8,
self.conv_cfg, self.norm_cfg)
self.reg_decomp = TaskDecomposition(self.feat_channels,
self.stacked_convs,
self.stacked_convs * 8,
self.conv_cfg, self.norm_cfg)
self.tood_cls = nn.Conv2d(
self.feat_channels,
self.num_base_priors * self.cls_out_channels,
3,
padding=1)
self.tood_reg = nn.Conv2d(
self.feat_channels, self.num_base_priors * 4, 3, padding=1)
self.cls_prob_module = nn.Sequential(
nn.Conv2d(self.feat_channels * self.stacked_convs,
self.feat_channels // 4, 1), nn.ReLU(inplace=True),
nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1))
self.reg_offset_module = nn.Sequential(
nn.Conv2d(self.feat_channels * self.stacked_convs,
self.feat_channels // 4, 1), nn.ReLU(inplace=True),
nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1))
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])
def init_weights(self) -> None:
"""Initialize weights of the head."""
bias_cls = bias_init_with_prob(0.01)
for m in self.inter_convs:
normal_init(m.conv, std=0.01)
for m in self.cls_prob_module:
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.01)
for m in self.reg_offset_module:
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls)
self.cls_decomp.init_weights()
self.reg_decomp.init_weights()
normal_init(self.tood_cls, std=0.01, bias=bias_cls)
normal_init(self.tood_reg, std=0.01)
def forward(self, feats: Tuple[Tensor]) -> Tuple[List[Tensor]]:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: Usually a tuple of classification scores and bbox prediction
cls_scores (list[Tensor]): Classification scores for all scale
levels, each is a 4D-tensor, the channels number is
num_anchors * num_classes.
bbox_preds (list[Tensor]): Decoded box for all scale levels,
each is a 4D-tensor, the channels number is
num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format.
"""
cls_scores = []
bbox_preds = []
for idx, (x, scale, stride) in enumerate(
zip(feats, self.scales, self.prior_generator.strides)):
b, c, h, w = x.shape
anchor = self.prior_generator.single_level_grid_priors(
(h, w), idx, device=x.device)
anchor = torch.cat([anchor for _ in range(b)])
# extract task interactive features
inter_feats = []
for inter_conv in self.inter_convs:
x = inter_conv(x)
inter_feats.append(x)
feat = torch.cat(inter_feats, 1)
# task decomposition
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_feat = self.cls_decomp(feat, avg_feat)
reg_feat = self.reg_decomp(feat, avg_feat)
# cls prediction and alignment
cls_logits = self.tood_cls(cls_feat)
cls_prob = self.cls_prob_module(feat)
cls_score = sigmoid_geometric_mean(cls_logits, cls_prob)
# reg prediction and alignment
if self.anchor_type == 'anchor_free':
reg_dist = scale(self.tood_reg(reg_feat).exp()).float()
reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
reg_bbox = distance2bbox(
self.anchor_center(anchor) / stride[0],
reg_dist).reshape(b, h, w, 4).permute(0, 3, 1,
2) # (b, c, h, w)
elif self.anchor_type == 'anchor_based':
reg_dist = scale(self.tood_reg(reg_feat)).float()
reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4)
reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape(
b, h, w, 4).permute(0, 3, 1, 2) / stride[0]
else:
raise NotImplementedError(
f'Unknown anchor type: {self.anchor_type}.'
f'Please use `anchor_free` or `anchor_based`.')
reg_offset = self.reg_offset_module(feat)
bbox_pred = self.deform_sampling(reg_bbox.contiguous(),
reg_offset.contiguous())
# After deform_sampling, some boxes will become invalid (The
# left-top point is at the right or bottom of the right-bottom
# point), which will make the GIoULoss negative.
invalid_bbox_idx = (bbox_pred[:, [0]] > bbox_pred[:, [2]]) | \
(bbox_pred[:, [1]] > bbox_pred[:, [3]])
invalid_bbox_idx = invalid_bbox_idx.expand_as(bbox_pred)
bbox_pred = torch.where(invalid_bbox_idx, reg_bbox, bbox_pred)
cls_scores.append(cls_score)
bbox_preds.append(bbox_pred)
return tuple(cls_scores), tuple(bbox_preds)
def deform_sampling(self, feat: Tensor, offset: Tensor) -> Tensor:
"""Sampling the feature x according to offset.
Args:
feat (Tensor): Feature
offset (Tensor): Spatial offset for feature sampling
"""
# it is an equivalent implementation of bilinear interpolation
b, c, h, w = feat.shape
weight = feat.new_ones(c, 1, 1, 1)
y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c)
return y
def anchor_center(self, anchors: Tensor) -> Tensor:
"""Get anchor centers from anchors.
Args:
anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
Returns:
Tensor: Anchor centers with shape (N, 2), "xy" format.
"""
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
return torch.stack([anchors_cx, anchors_cy], dim=-1)
def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
bbox_pred: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
alignment_metrics: Tensor,
stride: Tuple[int, int]) -> dict:
"""Calculate the loss of a single scale level based on the features
extracted by the detection head.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Decoded bboxes for each scale
level with shape (N, num_anchors * 4, H, W).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors).
bbox_targets (Tensor): BBox regression targets of each anchor with
shape (N, num_total_anchors, 4).
alignment_metrics (Tensor): Alignment metrics with shape
(N, num_total_anchors).
stride (Tuple[int, int]): Downsample stride of the feature map.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels).contiguous()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
alignment_metrics = alignment_metrics.reshape(-1)
label_weights = label_weights.reshape(-1)
targets = labels if self.epoch < self.initial_epoch else (
labels, alignment_metrics)
cls_loss_func = self.initial_loss_cls \
if self.epoch < self.initial_epoch else self.loss_cls
loss_cls = cls_loss_func(
cls_score, targets, label_weights, avg_factor=1.0)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_decode_bbox_pred = pos_bbox_pred
pos_decode_bbox_targets = pos_bbox_targets / stride[0]
# regression loss
pos_bbox_weight = self.centerness_target(
pos_anchors, pos_bbox_targets
) if self.epoch < self.initial_epoch else alignment_metrics[
pos_inds]
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=pos_bbox_weight,
avg_factor=1.0)
else:
loss_bbox = bbox_pred.sum() * 0
pos_bbox_weight = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, alignment_metrics.sum(
), pos_bbox_weight.sum()
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Decoded box for each scale
level with shape (N, num_anchors * 4, H, W) in
[tl_x, tl_y, br_x, br_y] format.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_imgs = len(batch_img_metas)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
flatten_cls_scores = torch.cat([
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
self.cls_out_channels)
for cls_score in cls_scores
], 1)
flatten_bbox_preds = torch.cat([
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0]
for bbox_pred, stride in zip(bbox_preds,
self.prior_generator.strides)
], 1)
cls_reg_targets = self.get_targets(
flatten_cls_scores,
flatten_bbox_preds,
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
alignment_metrics_list) = cls_reg_targets
losses_cls, losses_bbox, \
cls_avg_factors, bbox_avg_factors = multi_apply(
self.loss_by_feat_single,
anchor_list,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
alignment_metrics_list,
self.prior_generator.strides)
cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item()
losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls))
bbox_avg_factor = reduce_mean(
sum(bbox_avg_factors)).clamp_(min=1).item()
losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def _predict_by_feat_single(self,
cls_score_list: List[Tensor],
bbox_pred_list: List[Tensor],
score_factor_list: List[Tensor],
mlvl_priors: List[Tensor],
img_meta: dict,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (:obj:`ConfigDict`, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
tuple[Tensor]: Results of detected bboxes and labels. If with_nms
is False and mlvl_score_factor is None, return mlvl_bboxes and
mlvl_scores, else return mlvl_bboxes, mlvl_scores and
mlvl_score_factor. Usually with_nms is False is used for aug
test. If with_nms is True, then return the following format
- det_bboxes (Tensor): Predicted bboxes with shape \
[num_bboxes, 5], where the first 4 columns are bounding \
box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
column are scores between 0 and 1.
- det_labels (Tensor): Predicted labels of the corresponding \
box with shape [num_bboxes].
"""
cfg = self.test_cfg if cfg is None else cfg
nms_pre = cfg.get('nms_pre', -1)
mlvl_bboxes = []
mlvl_scores = []
mlvl_labels = []
for cls_score, bbox_pred, priors, stride in zip(
cls_score_list, bbox_pred_list, mlvl_priors,
self.prior_generator.strides):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0]
scores = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
results = filter_scores_and_topk(
scores, cfg.score_thr, nms_pre,
dict(bbox_pred=bbox_pred, priors=priors))
scores, labels, keep_idxs, filtered_results = results
bboxes = filtered_results['bbox_pred']
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
results = InstanceData()
results.bboxes = torch.cat(mlvl_bboxes)
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
return self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
def get_targets(self,
cls_scores: List[List[Tensor]],
bbox_preds: List[List[Tensor]],
anchor_list: List[List[Tensor]],
valid_flag_list: List[List[Tensor]],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
unmap_outputs: bool = True) -> tuple:
"""Compute regression and classification targets for anchors in
multiple images.
Args:
cls_scores (list[list[Tensor]]): Classification predictions of
images, a 3D-Tensor with shape [num_imgs, num_priors,
num_classes].
bbox_preds (list[list[Tensor]]): Decoded bboxes predictions of one
image, a 3D-Tensor with shape [num_imgs, num_priors, 4] in
[tl_x, tl_y, br_x, br_y] format.
anchor_list (list[list[Tensor]]): Multi level anchors of each
image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, 4).
valid_flag_list (list[list[Tensor]]): Multi level valid flags of
each image. The outer list indicates images, and the inner list
corresponds to feature levels of the image. Each element of
the inner list is a tensor of shape (num_anchors, )
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: a tuple containing learning targets.
- anchors_list (list[list[Tensor]]): Anchors of each level.
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each
level.
- bbox_targets_list (list[Tensor]): BBox targets of each level.
- norm_alignment_metrics_list (list[Tensor]): Normalized
alignment metrics of each level.
"""
num_imgs = len(batch_img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = torch.cat(anchor_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
# anchor_list: list(b * [-1, 4])
# get epoch information from message hub
message_hub = MessageHub.get_current_instance()
self.epoch = message_hub.get_info('epoch')
if self.epoch < self.initial_epoch:
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list,
sampling_result) = multi_apply(
super()._get_targets_single,
anchor_list,
valid_flag_list,
num_level_anchors_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore,
unmap_outputs=unmap_outputs)
all_assign_metrics = [
weight[..., 0] for weight in all_bbox_weights
]
else:
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_assign_metrics) = multi_apply(
self._get_targets_single,
cls_scores,
bbox_preds,
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore,
unmap_outputs=unmap_outputs)
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
norm_alignment_metrics_list = images_to_levels(all_assign_metrics,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, norm_alignment_metrics_list)
def _get_targets_single(self,
cls_scores: Tensor,
bbox_preds: Tensor,
flat_anchors: Tensor,
valid_flags: Tensor,
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None,
unmap_outputs: bool = True) -> tuple:
"""Compute regression, classification targets for anchors in a single
image.
Args:
cls_scores (Tensor): Box scores for each image.
bbox_preds (Tensor): Box energies / deltas for each image.
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
anchors (Tensor): All anchors in the image with shape (N, 4).
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
norm_alignment_metrics (Tensor): Normalized alignment metrics
of all priors in the image with shape (N,).
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg['allowed_border'])
if not inside_flags.any():
raise ValueError(
'There is no valid anchor inside the image boundary. Please '
'check the image size and anchor sizes, or set '
'``allowed_border`` to -1 to skip the condition.')
# assign gt and sample anchors
anchors = flat_anchors[inside_flags, :]
pred_instances = InstanceData(
priors=anchors,
scores=cls_scores[inside_flags, :],
bboxes=bbox_preds[inside_flags, :])
assign_result = self.alignment_assigner.assign(pred_instances,
gt_instances,
gt_instances_ignore,
self.alpha, self.beta)
assign_ious = assign_result.max_overlaps
assign_metrics = assign_result.assign_metrics
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = anchors.shape[0]
bbox_targets = torch.zeros_like(anchors)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
norm_alignment_metrics = anchors.new_zeros(
num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
# point-based
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
labels[pos_inds] = sampling_result.pos_gt_labels
if self.train_cfg['pos_weight'] <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg['pos_weight']
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
class_assigned_gt_inds = torch.unique(
sampling_result.pos_assigned_gt_inds)
for gt_inds in class_assigned_gt_inds:
gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds ==
gt_inds]
pos_alignment_metrics = assign_metrics[gt_class_inds]
pos_ious = assign_ious[gt_class_inds]
pos_norm_alignment_metrics = pos_alignment_metrics / (
pos_alignment_metrics.max() + 10e-8) * pos_ious.max()
norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
norm_alignment_metrics = unmap(norm_alignment_metrics,
num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets,
norm_alignment_metrics)
|