RSPrompter / mmdet /models /task_modules /samplers /instance_balanced_pos_sampler.py
KyanChen's picture
Upload 787 files
3e06e1c
raw
history blame
2.32 kB
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmdet.registry import TASK_UTILS
from .random_sampler import RandomSampler
@TASK_UTILS.register_module()
class InstanceBalancedPosSampler(RandomSampler):
"""Instance balanced sampler that samples equal number of positive samples
for each instance."""
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive boxes.
Args:
assign_result (:obj:`AssignResult`): The assigned results of boxes.
num_expected (int): The number of expected positive samples
Returns:
Tensor or ndarray: sampled indices.
"""
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
if pos_inds.numel() <= num_expected:
return pos_inds
else:
unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
num_gts = len(unique_gt_inds)
num_per_gt = int(round(num_expected / float(num_gts)) + 1)
sampled_inds = []
for i in unique_gt_inds:
inds = torch.nonzero(
assign_result.gt_inds == i.item(), as_tuple=False)
if inds.numel() != 0:
inds = inds.squeeze(1)
else:
continue
if len(inds) > num_per_gt:
inds = self.random_choice(inds, num_per_gt)
sampled_inds.append(inds)
sampled_inds = torch.cat(sampled_inds)
if len(sampled_inds) < num_expected:
num_extra = num_expected - len(sampled_inds)
extra_inds = np.array(
list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
if len(extra_inds) > num_extra:
extra_inds = self.random_choice(extra_inds, num_extra)
extra_inds = torch.from_numpy(extra_inds).to(
assign_result.gt_inds.device).long()
sampled_inds = torch.cat([sampled_inds, extra_inds])
elif len(sampled_inds) > num_expected:
sampled_inds = self.random_choice(sampled_inds, num_expected)
return sampled_inds