Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from ...core.bbox.assigners import AscendMaxIoUAssigner | |
from ...core.bbox.samplers import PseudoSampler | |
from ...utils import (batch_images_to_levels, get_max_num_gt_division_factor, | |
masked_fill) | |
from ..builder import HEADS | |
from .anchor_head import AnchorHead | |
class AscendAnchorHead(AnchorHead): | |
"""Ascend Anchor-based head (RetinaNet, SSD, etc.). | |
Args: | |
num_classes (int): Number of categories excluding the background | |
category. | |
in_channels (int): Number of channels in the input feature map. | |
feat_channels (int): Number of hidden channels. Used in child classes. | |
anchor_generator (dict): Config dict for anchor generator | |
bbox_coder (dict): Config of bounding box coder. | |
reg_decoded_bbox (bool): If true, the regression loss would be | |
applied directly on decoded bounding boxes, converting both | |
the predicted boxes and regression targets to absolute | |
coordinates format. Default False. It should be `True` when | |
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head. | |
loss_cls (dict): Config of classification loss. | |
loss_bbox (dict): Config of localization loss. | |
train_cfg (dict): Training config of anchor head. | |
test_cfg (dict): Testing config of anchor head. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" # noqa: W605 | |
def __init__(self, | |
num_classes, | |
in_channels, | |
feat_channels=256, | |
anchor_generator=dict( | |
type='AnchorGenerator', | |
scales=[8, 16, 32], | |
ratios=[0.5, 1.0, 2.0], | |
strides=[4, 8, 16, 32, 64]), | |
bbox_coder=dict( | |
type='DeltaXYWHBBoxCoder', | |
clip_border=True, | |
target_means=(.0, .0, .0, .0), | |
target_stds=(1.0, 1.0, 1.0, 1.0)), | |
reg_decoded_bbox=False, | |
loss_cls=dict( | |
type='CrossEntropyLoss', | |
use_sigmoid=True, | |
loss_weight=1.0), | |
loss_bbox=dict( | |
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), | |
train_cfg=None, | |
test_cfg=None, | |
init_cfg=dict(type='Normal', layer='Conv2d', std=0.01)): | |
super(AscendAnchorHead, self).__init__( | |
num_classes=num_classes, | |
in_channels=in_channels, | |
feat_channels=feat_channels, | |
anchor_generator=anchor_generator, | |
bbox_coder=bbox_coder, | |
reg_decoded_bbox=reg_decoded_bbox, | |
loss_cls=loss_cls, | |
loss_bbox=loss_bbox, | |
train_cfg=train_cfg, | |
test_cfg=test_cfg, | |
init_cfg=init_cfg) | |
def get_batch_gt_bboxes(self, gt_bboxes_list, num_images, gt_nums, device, | |
max_gt_labels): | |
"""Get ground truth bboxes of all image. | |
Args: | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. | |
num_images (int): The num of images. | |
gt_nums(list[int]): The ground truth bboxes num of each image. | |
device (torch.device | str): Device for returned tensors | |
max_gt_labels(int): The max ground truth bboxes num of all image. | |
Returns: | |
batch_gt_bboxes: (Tensor): Ground truth bboxes of all image. | |
""" | |
# a static ground truth boxes. | |
# Save static gt. Related to Ascend. Helps improve performance | |
if not hasattr(self, 'batch_gt_bboxes'): | |
self.batch_gt_bboxes = {} | |
# a min anchor filled the excess anchor | |
if not hasattr(self, 'min_anchor'): | |
self.min_anchor = (-1354, -1344) | |
if gt_bboxes_list is None: | |
batch_gt_bboxes = None | |
else: | |
if self.batch_gt_bboxes.get(max_gt_labels) is None: | |
batch_gt_bboxes = torch.zeros((num_images, max_gt_labels, 4), | |
dtype=gt_bboxes_list[0].dtype, | |
device=device) | |
batch_gt_bboxes[:, :, :2] = self.min_anchor[0] | |
batch_gt_bboxes[:, :, 2:] = self.min_anchor[1] | |
self.batch_gt_bboxes[max_gt_labels] = batch_gt_bboxes.clone() | |
else: | |
batch_gt_bboxes = self.batch_gt_bboxes.get( | |
max_gt_labels).clone() | |
for index_imgs, gt_bboxes in enumerate(gt_bboxes_list): | |
batch_gt_bboxes[index_imgs, :gt_nums[index_imgs]] = gt_bboxes | |
return batch_gt_bboxes | |
def get_batch_gt_bboxes_ignore(self, gt_bboxes_ignore_list, num_images, | |
gt_nums, device): | |
"""Ground truth bboxes to be ignored of all image. | |
Args: | |
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be | |
ignored. | |
num_images (int): The num of images. | |
gt_nums(list[int]): The ground truth bboxes num of each image. | |
device (torch.device | str): Device for returned tensors | |
Returns: | |
batch_gt_bboxes_ignore: (Tensor): Ground truth bboxes to be | |
ignored of all image. | |
""" | |
# TODO: support gt_bboxes_ignore_list | |
if gt_bboxes_ignore_list is None: | |
batch_gt_bboxes_ignore = None | |
else: | |
raise RuntimeError('gt_bboxes_ignore not support yet') | |
return batch_gt_bboxes_ignore | |
def get_batch_gt_labels(self, gt_labels_list, num_images, gt_nums, device, | |
max_gt_labels): | |
"""Ground truth bboxes to be ignored of all image. | |
Args: | |
gt_labels_list (list[Tensor]): Ground truth labels. | |
num_images (int): The num of images. | |
gt_nums(list[int]): The ground truth bboxes num of each image. | |
device (torch.device | str): Device for returned tensors | |
Returns: | |
batch_gt_labels: (Tensor): Ground truth labels of all image. | |
""" | |
if gt_labels_list is None: | |
batch_gt_labels = None | |
else: | |
batch_gt_labels = torch.zeros((num_images, max_gt_labels), | |
dtype=gt_labels_list[0].dtype, | |
device=device) | |
for index_imgs, gt_labels in enumerate(gt_labels_list): | |
batch_gt_labels[index_imgs, :gt_nums[index_imgs]] = gt_labels | |
return batch_gt_labels | |
def _get_targets_concat(self, | |
batch_anchors, | |
batch_valid_flags, | |
batch_gt_bboxes, | |
batch_gt_bboxes_ignore, | |
batch_gt_labels, | |
img_metas, | |
label_channels=1, | |
unmap_outputs=True): | |
"""Compute regression and classification targets for anchors in all | |
images. | |
Args: | |
batch_anchors (Tensor): anchors of all image, which are | |
concatenated into a single tensor of | |
shape (num_imgs, num_anchors ,4). | |
batch_valid_flags (Tensor): valid flags of all image, | |
which are concatenated into a single tensor of | |
shape (num_imgs, num_anchors,). | |
batch_gt_bboxes (Tensor): Ground truth bboxes of all image, | |
shape (num_imgs, max_gt_nums, 4). | |
batch_gt_bboxes_ignore (Tensor): Ground truth bboxes to be | |
ignored, shape (num_imgs, num_ignored_gts, 4). | |
batch_gt_labels (Tensor): Ground truth labels of each box, | |
shape (num_imgs, max_gt_nums,). | |
img_metas (list[dict]): Meta info of each image. | |
label_channels (int): Channel of label. | |
unmap_outputs (bool): Whether to map outputs back to the original | |
set of anchors. | |
Returns: | |
tuple: | |
batch_labels (Tensor): Labels of all level | |
batch_label_weights (Tensor): Label weights of all level | |
batch_bbox_targets (Tensor): BBox targets of all level | |
batch_bbox_weights (Tensor): BBox weights of all level | |
batch_pos_mask (Tensor): Positive samples mask in all images | |
batch_neg_mask (Tensor): Negative samples mask in all images | |
sampling_result (Sampling): The result of sampling, | |
default: None. | |
""" | |
num_imgs, num_anchors, _ = batch_anchors.size() | |
# assign gt and sample batch_anchors | |
assign_result = self.assigner.assign( | |
batch_anchors, | |
batch_gt_bboxes, | |
batch_gt_bboxes_ignore, | |
None if self.sampling else batch_gt_labels, | |
batch_bboxes_ignore_mask=batch_valid_flags) | |
# TODO: support sampling_result | |
sampling_result = None | |
batch_pos_mask = assign_result.batch_pos_mask | |
batch_neg_mask = assign_result.batch_neg_mask | |
batch_anchor_gt_indes = assign_result.batch_anchor_gt_indes | |
batch_anchor_gt_labels = assign_result.batch_anchor_gt_labels | |
batch_anchor_gt_bboxes = torch.zeros( | |
batch_anchors.size(), | |
dtype=batch_anchors.dtype, | |
device=batch_anchors.device) | |
for index_imgs in range(num_imgs): | |
batch_anchor_gt_bboxes[index_imgs] = torch.index_select( | |
batch_gt_bboxes[index_imgs], 0, | |
batch_anchor_gt_indes[index_imgs]) | |
batch_bbox_targets = torch.zeros_like(batch_anchors) | |
batch_bbox_weights = torch.zeros_like(batch_anchors) | |
batch_labels = batch_anchors.new_full((num_imgs, num_anchors), | |
self.num_classes, | |
dtype=torch.int) | |
batch_label_weights = batch_anchors.new_zeros((num_imgs, num_anchors), | |
dtype=torch.float) | |
if not self.reg_decoded_bbox: | |
batch_pos_bbox_targets = self.bbox_coder.encode( | |
batch_anchors, batch_anchor_gt_bboxes) | |
else: | |
batch_pos_bbox_targets = batch_anchor_gt_bboxes | |
batch_bbox_targets = masked_fill(batch_bbox_targets, | |
batch_pos_mask.unsqueeze(2), | |
batch_pos_bbox_targets) | |
batch_bbox_weights = masked_fill(batch_bbox_weights, | |
batch_pos_mask.unsqueeze(2), 1.0) | |
if batch_gt_labels is None: | |
batch_labels = masked_fill(batch_labels, batch_pos_mask, 0.0) | |
else: | |
batch_labels = masked_fill(batch_labels, batch_pos_mask, | |
batch_anchor_gt_labels) | |
if self.train_cfg.pos_weight <= 0: | |
batch_label_weights = masked_fill(batch_label_weights, | |
batch_pos_mask, 1.0) | |
else: | |
batch_label_weights = masked_fill(batch_label_weights, | |
batch_pos_mask, | |
self.train_cfg.pos_weight) | |
batch_label_weights = masked_fill(batch_label_weights, batch_neg_mask, | |
1.0) | |
return (batch_labels, batch_label_weights, batch_bbox_targets, | |
batch_bbox_weights, batch_pos_mask, batch_neg_mask, | |
sampling_result) | |
def get_targets(self, | |
anchor_list, | |
valid_flag_list, | |
gt_bboxes_list, | |
img_metas, | |
gt_bboxes_ignore_list=None, | |
gt_labels_list=None, | |
label_channels=1, | |
unmap_outputs=True, | |
return_sampling_results=False, | |
return_level=True): | |
"""Compute regression and classification targets for anchors in | |
multiple images. | |
Args: | |
anchor_list (list[list[Tensor]]): Multi level anchors of each | |
image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, 4). | |
valid_flag_list (list[list[Tensor]]): Multi level valid flags of | |
each image. The outer list indicates images, and the inner list | |
corresponds to feature levels of the image. Each element of | |
the inner list is a tensor of shape (num_anchors, ) | |
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. | |
img_metas (list[dict]): Meta info of each image. | |
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be | |
ignored. | |
gt_labels_list (list[Tensor]): Ground truth labels of each box. | |
label_channels (int): Channel of label. | |
unmap_outputs (bool): Whether to map outputs back to the original | |
set of anchors. | |
return_sampling_results (bool): Whether to return the result of | |
sample. | |
return_level (bool): Whether to map outputs back to the levels | |
of feature map sizes. | |
Returns: | |
tuple: Usually returns a tuple containing learning targets. | |
- labels_list (list[Tensor]): Labels of each level. | |
- label_weights_list (list[Tensor]): Label weights of each | |
level. | |
- bbox_targets_list (list[Tensor]): BBox targets of each level. | |
- bbox_weights_list (list[Tensor]): BBox weights of each level. | |
- num_total_pos (int): Number of positive samples in all | |
images. | |
- num_total_neg (int): Number of negative samples in all | |
images. | |
additional_returns: This function enables user-defined returns from | |
`self._get_targets_single`. These returns are currently refined | |
to properties at each feature map (i.e. having HxW dimension). | |
The results will be concatenated after the end | |
""" | |
assert gt_bboxes_ignore_list is None | |
assert unmap_outputs is True | |
assert return_sampling_results is False | |
assert self.train_cfg.allowed_border < 0 | |
assert isinstance(self.assigner, AscendMaxIoUAssigner) | |
assert isinstance(self.sampler, PseudoSampler) | |
num_imgs = len(img_metas) | |
assert len(anchor_list) == len(valid_flag_list) == num_imgs | |
device = anchor_list[0][0].device | |
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] | |
batch_anchor_list = [] | |
batch_valid_flag_list = [] | |
for i in range(num_imgs): | |
assert len(anchor_list[i]) == len(valid_flag_list[i]) | |
batch_anchor_list.append(torch.cat(anchor_list[i])) | |
batch_valid_flag_list.append(torch.cat(valid_flag_list[i])) | |
batch_anchors = torch.cat( | |
[torch.unsqueeze(anchor, 0) for anchor in batch_anchor_list], 0) | |
batch_valid_flags = torch.cat([ | |
torch.unsqueeze(batch_valid_flag, 0) | |
for batch_valid_flag in batch_valid_flag_list | |
], 0) | |
gt_nums = [len(gt_bbox) for gt_bbox in gt_bboxes_list] | |
max_gt_nums = get_max_num_gt_division_factor(gt_nums) | |
batch_gt_bboxes = self.get_batch_gt_bboxes(gt_bboxes_list, num_imgs, | |
gt_nums, device, | |
max_gt_nums) | |
batch_gt_bboxes_ignore = self.get_batch_gt_bboxes_ignore( | |
gt_bboxes_ignore_list, num_imgs, gt_nums, device) | |
batch_gt_labels = self.get_batch_gt_labels(gt_labels_list, num_imgs, | |
gt_nums, device, | |
max_gt_nums) | |
results = self._get_targets_concat( | |
batch_anchors, | |
batch_valid_flags, | |
batch_gt_bboxes, | |
batch_gt_bboxes_ignore, | |
batch_gt_labels, | |
img_metas, | |
label_channels=label_channels, | |
unmap_outputs=unmap_outputs) | |
(batch_labels, batch_label_weights, batch_bbox_targets, | |
batch_bbox_weights, batch_pos_mask, batch_neg_mask, | |
sampling_result) = results[:7] | |
rest_results = list(results[7:]) # user-added return values | |
# sampled anchors of all images | |
min_num = torch.ones((num_imgs, ), | |
dtype=torch.long, | |
device=batch_pos_mask.device) | |
num_total_pos = torch.sum( | |
torch.max(torch.sum(batch_pos_mask, dim=1), min_num)) | |
num_total_neg = torch.sum( | |
torch.max(torch.sum(batch_neg_mask, dim=1), min_num)) | |
if return_level is True: | |
labels_list = batch_images_to_levels(batch_labels, | |
num_level_anchors) | |
label_weights_list = batch_images_to_levels( | |
batch_label_weights, num_level_anchors) | |
bbox_targets_list = batch_images_to_levels(batch_bbox_targets, | |
num_level_anchors) | |
bbox_weights_list = batch_images_to_levels(batch_bbox_weights, | |
num_level_anchors) | |
res = (labels_list, label_weights_list, bbox_targets_list, | |
bbox_weights_list, num_total_pos, num_total_neg) | |
if return_sampling_results: | |
res = res + (sampling_result, ) | |
for i, r in enumerate(rest_results): # user-added return values | |
rest_results[i] = batch_images_to_levels(r, num_level_anchors) | |
return res + tuple(rest_results) | |
else: | |
res = (batch_labels, batch_label_weights, batch_bbox_targets, | |
batch_bbox_weights, batch_pos_mask, batch_neg_mask, | |
sampling_result, num_total_pos, num_total_neg, | |
batch_anchors) | |
return res | |