# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch class BalancedPositiveNegativeSampler(object): """ This class samples batches, ensuring that they contain a fixed proportion of positives """ def __init__(self, batch_size_per_image, positive_fraction): """ Arguments: batch_size_per_image (int): number of elements to be selected per image positive_fraction (float): percentace of positive elements per batch """ self.batch_size_per_image = batch_size_per_image self.positive_fraction = positive_fraction def __call__(self, matched_idxs): """ Arguments: matched idxs: list of tensors containing -1, 0 or positive values. Each tensor corresponds to a specific image. -1 values are ignored, 0 are considered as negatives and > 0 as positives. Returns: pos_idx (list[tensor]) neg_idx (list[tensor]) Returns two lists of binary masks for each image. The first list contains the positive elements that were selected, and the second list the negative example. """ pos_idx = [] neg_idx = [] for matched_idxs_per_image in matched_idxs: positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) num_pos = int(self.batch_size_per_image * self.positive_fraction) # protect against not enough positive examples num_pos = min(positive.numel(), num_pos) num_neg = self.batch_size_per_image - num_pos # protect against not enough negative examples num_neg = min(negative.numel(), num_neg) # randomly select positive and negative examples perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] pos_idx_per_image = positive[perm1] neg_idx_per_image = negative[perm2] # create binary mask from indices pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.bool) neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.bool) pos_idx_per_image_mask[pos_idx_per_image] = 1 neg_idx_per_image_mask[neg_idx_per_image] = 1 pos_idx.append(pos_idx_per_image_mask) neg_idx.append(neg_idx_per_image_mask) return pos_idx, neg_idx