Spaces:
Runtime error
Runtime error
File size: 5,995 Bytes
51f6859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner
@BBOX_ASSIGNERS.register_module()
class PointAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each point.
Each proposals will be assigned with `0`, or a positive integer
indicating the ground truth index.
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
"""
def __init__(self, scale=4, pos_num=3):
self.scale = scale
self.pos_num = pos_num
def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
"""Assign gt to points.
This method assign a gt bbox to every points set, each points set
will be assigned with the background_label (-1), or a label number.
-1 is background, and semi-positive number is the index (0-based) of
assigned gt.
The assignment is done in following steps, the order matters.
1. assign every points to the background_label (-1)
2. A point is assigned to some gt bbox if
(i) the point is within the k closest points to the gt bbox
(ii) the distance between this point and the gt is smaller than
other gt bboxes
Args:
points (Tensor): points to be assigned, shape(n, 3) while last
dimension stands for (x, y, stride).
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
labelled as `ignored`, e.g., crowd boxes in COCO.
NOTE: currently unused.
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
Returns:
:obj:`AssignResult`: The assign result.
"""
num_points = points.shape[0]
num_gts = gt_bboxes.shape[0]
if num_gts == 0 or num_points == 0:
# If no truth assign everything to the background
assigned_gt_inds = points.new_full((num_points, ),
0,
dtype=torch.long)
if gt_labels is None:
assigned_labels = None
else:
assigned_labels = points.new_full((num_points, ),
-1,
dtype=torch.long)
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
points_xy = points[:, :2]
points_stride = points[:, 2]
points_lvl = torch.log2(
points_stride).int() # [3...,4...,5...,6...,7...]
lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
# assign gt box
gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
scale = self.scale
gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
# stores the assigned gt index of each point
assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
# stores the assigned gt dist (to this point) of each point
assigned_gt_dist = points.new_full((num_points, ), float('inf'))
points_range = torch.arange(points.shape[0])
for idx in range(num_gts):
gt_lvl = gt_bboxes_lvl[idx]
# get the index of points in this level
lvl_idx = gt_lvl == points_lvl
points_index = points_range[lvl_idx]
# get the points in this level
lvl_points = points_xy[lvl_idx, :]
# get the center point of gt
gt_point = gt_bboxes_xy[[idx], :]
# get width and height of gt
gt_wh = gt_bboxes_wh[[idx], :]
# compute the distance between gt center and
# all points in this level
points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
# find the nearest k points to gt center in this level
min_dist, min_dist_index = torch.topk(
points_gt_dist, self.pos_num, largest=False)
# the index of nearest k points to gt center in this level
min_dist_points_index = points_index[min_dist_index]
# The less_than_recorded_index stores the index
# of min_dist that is less then the assigned_gt_dist. Where
# assigned_gt_dist stores the dist from previous assigned gt
# (if exist) to each point.
less_than_recorded_index = min_dist < assigned_gt_dist[
min_dist_points_index]
# The min_dist_points_index stores the index of points satisfy:
# (1) it is k nearest to current gt center in this level.
# (2) it is closer to current gt center than other gt center.
min_dist_points_index = min_dist_points_index[
less_than_recorded_index]
# assign the result
assigned_gt_inds[min_dist_points_index] = idx + 1
assigned_gt_dist[min_dist_points_index] = min_dist[
less_than_recorded_index]
if gt_labels is not None:
assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
pos_inds = torch.nonzero(
assigned_gt_inds > 0, as_tuple=False).squeeze()
if pos_inds.numel() > 0:
assigned_labels[pos_inds] = gt_labels[
assigned_gt_inds[pos_inds] - 1]
else:
assigned_labels = None
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
|