|
|
|
import torch |
|
|
|
from ..builder import BBOX_ASSIGNERS |
|
from ..iou_calculators import build_iou_calculator |
|
from .assign_result import AssignResult |
|
from .base_assigner import BaseAssigner |
|
|
|
|
|
@BBOX_ASSIGNERS.register_module() |
|
class GridAssigner(BaseAssigner): |
|
"""Assign a corresponding gt bbox or background to each bbox. |
|
|
|
Each proposals will be assigned with `-1`, `0`, or a positive integer |
|
indicating the ground truth index. |
|
|
|
- -1: don't care |
|
- 0: negative sample, no assigned gt |
|
- positive integer: positive sample, index (1-based) of assigned gt |
|
|
|
Args: |
|
pos_iou_thr (float): IoU threshold for positive bboxes. |
|
neg_iou_thr (float or tuple): IoU threshold for negative bboxes. |
|
min_pos_iou (float): Minimum iou for a bbox to be considered as a |
|
positive bbox. Positive samples can have smaller IoU than |
|
pos_iou_thr due to the 4th step (assign max IoU sample to each gt). |
|
gt_max_assign_all (bool): Whether to assign all bboxes with the same |
|
highest overlap with some gt to that gt. |
|
""" |
|
|
|
def __init__(self, |
|
pos_iou_thr, |
|
neg_iou_thr, |
|
min_pos_iou=.0, |
|
gt_max_assign_all=True, |
|
iou_calculator=dict(type='BboxOverlaps2D')): |
|
self.pos_iou_thr = pos_iou_thr |
|
self.neg_iou_thr = neg_iou_thr |
|
self.min_pos_iou = min_pos_iou |
|
self.gt_max_assign_all = gt_max_assign_all |
|
self.iou_calculator = build_iou_calculator(iou_calculator) |
|
|
|
def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None): |
|
"""Assign gt to bboxes. The process is very much like the max iou |
|
assigner, except that positive samples are constrained within the cell |
|
that the gt boxes fell in. |
|
|
|
This method assign a gt bbox to every bbox (proposal/anchor), each bbox |
|
will be assigned with -1, 0, or a positive number. -1 means don't care, |
|
0 means negative sample, positive number is the index (1-based) of |
|
assigned gt. |
|
The assignment is done in following steps, the order matters. |
|
|
|
1. assign every bbox to -1 |
|
2. assign proposals whose iou with all gts <= neg_iou_thr to 0 |
|
3. for each bbox within a cell, if the iou with its nearest gt > |
|
pos_iou_thr and the center of that gt falls inside the cell, |
|
assign it to that bbox |
|
4. for each gt bbox, assign its nearest proposals within the cell the |
|
gt bbox falls in to itself. |
|
|
|
Args: |
|
bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4). |
|
box_responsible_flags (Tensor): flag to indicate whether box is |
|
responsible for prediction, shape(n, ) |
|
gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). |
|
gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). |
|
|
|
Returns: |
|
:obj:`AssignResult`: The assign result. |
|
""" |
|
num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0) |
|
|
|
|
|
overlaps = self.iou_calculator(gt_bboxes, bboxes) |
|
|
|
|
|
assigned_gt_inds = overlaps.new_full((num_bboxes, ), |
|
-1, |
|
dtype=torch.long) |
|
|
|
if num_gts == 0 or num_bboxes == 0: |
|
|
|
max_overlaps = overlaps.new_zeros((num_bboxes, )) |
|
if num_gts == 0: |
|
|
|
assigned_gt_inds[:] = 0 |
|
if gt_labels is None: |
|
assigned_labels = None |
|
else: |
|
assigned_labels = overlaps.new_full((num_bboxes, ), |
|
-1, |
|
dtype=torch.long) |
|
return AssignResult( |
|
num_gts, |
|
assigned_gt_inds, |
|
max_overlaps, |
|
labels=assigned_labels) |
|
|
|
|
|
|
|
|
|
|
|
max_overlaps, argmax_overlaps = overlaps.max(dim=0) |
|
|
|
if isinstance(self.neg_iou_thr, float): |
|
assigned_gt_inds[(max_overlaps >= 0) |
|
& (max_overlaps <= self.neg_iou_thr)] = 0 |
|
elif isinstance(self.neg_iou_thr, (tuple, list)): |
|
assert len(self.neg_iou_thr) == 2 |
|
assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0]) |
|
& (max_overlaps <= self.neg_iou_thr[1])] = 0 |
|
|
|
|
|
|
|
|
|
|
|
overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1. |
|
|
|
|
|
|
|
max_overlaps, argmax_overlaps = overlaps.max(dim=0) |
|
|
|
|
|
|
|
|
|
gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1) |
|
|
|
pos_inds = (max_overlaps > |
|
self.pos_iou_thr) & box_responsible_flags.type(torch.bool) |
|
assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1 |
|
|
|
|
|
for i in range(num_gts): |
|
if gt_max_overlaps[i] > self.min_pos_iou: |
|
if self.gt_max_assign_all: |
|
max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \ |
|
box_responsible_flags.type(torch.bool) |
|
assigned_gt_inds[max_iou_inds] = i + 1 |
|
elif box_responsible_flags[gt_argmax_overlaps[i]]: |
|
assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1 |
|
|
|
|
|
if gt_labels is not None: |
|
assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) |
|
pos_inds = torch.nonzero( |
|
assigned_gt_inds > 0, as_tuple=False).squeeze() |
|
if pos_inds.numel() > 0: |
|
assigned_labels[pos_inds] = gt_labels[ |
|
assigned_gt_inds[pos_inds] - 1] |
|
|
|
else: |
|
assigned_labels = None |
|
|
|
return AssignResult( |
|
num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels) |
|
|