Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import itertools | |
import torch | |
from torch.utils.data.sampler import BatchSampler | |
from torch.utils.data.sampler import Sampler | |
class GroupedBatchSampler(BatchSampler): | |
""" | |
Wraps another sampler to yield a mini-batch of indices. | |
It enforces that elements from the same group should appear in groups of batch_size. | |
It also tries to provide mini-batches which follows an ordering which is | |
as close as possible to the ordering from the original sampler. | |
Arguments: | |
sampler (Sampler): Base sampler. | |
batch_size (int): Size of mini-batch. | |
drop_uneven (bool): If ``True``, the sampler will drop the batches whose | |
size is less than ``batch_size`` | |
""" | |
def __init__(self, sampler, group_ids, batch_size, drop_uneven=False): | |
if not isinstance(sampler, Sampler): | |
raise ValueError( | |
"sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler) | |
) | |
self.sampler = sampler | |
self.group_ids = torch.as_tensor(group_ids) | |
assert self.group_ids.dim() == 1 | |
self.batch_size = batch_size | |
self.drop_uneven = drop_uneven | |
self.groups = torch.unique(self.group_ids).sort(0)[0] | |
self._can_reuse_batches = False | |
def _prepare_batches(self): | |
dataset_size = len(self.group_ids) | |
# get the sampled indices from the sampler | |
sampled_ids = torch.as_tensor(list(self.sampler)) | |
# potentially not all elements of the dataset were sampled | |
# by the sampler (e.g., DistributedSampler). | |
# construct a tensor which contains -1 if the element was | |
# not sampled, and a non-negative number indicating the | |
# order where the element was sampled. | |
# for example. if sampled_ids = [3, 1] and dataset_size = 5, | |
# the order is [-1, 1, -1, 0, -1] | |
order = torch.full((dataset_size,), -1, dtype=torch.int64) | |
order[sampled_ids] = torch.arange(len(sampled_ids)) | |
# get a mask with the elements that were sampled | |
mask = order >= 0 | |
# find the elements that belong to each individual cluster | |
clusters = [(self.group_ids == i) & mask for i in self.groups] | |
# get relative order of the elements inside each cluster | |
# that follows the order from the sampler | |
relative_order = [order[cluster] for cluster in clusters] | |
# with the relative order, find the absolute order in the | |
# sampled space | |
permutation_ids = [s[s.sort()[1]] for s in relative_order] | |
# permute each cluster so that they follow the order from | |
# the sampler | |
permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] | |
# splits each cluster in batch_size, and merge as a list of tensors | |
splits = [c.split(self.batch_size) for c in permuted_clusters] | |
merged = tuple(itertools.chain.from_iterable(splits)) | |
# now each batch internally has the right order, but | |
# they are grouped by clusters. Find the permutation between | |
# different batches that brings them as close as possible to | |
# the order that we have in the sampler. For that, we will consider the | |
# ordering as coming from the first element of each batch, and sort | |
# correspondingly | |
first_element_of_batch = [t[0].item() for t in merged] | |
# get and inverse mapping from sampled indices and the position where | |
# they occur (as returned by the sampler) | |
inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} | |
# from the first element in each batch, get a relative ordering | |
first_index_of_batch = torch.as_tensor([inv_sampled_ids_map[s] for s in first_element_of_batch]) | |
# permute the batches so that they approximately follow the order | |
# from the sampler | |
permutation_order = first_index_of_batch.sort(0)[1].tolist() | |
# finally, permute the batches | |
batches = [merged[i].tolist() for i in permutation_order] | |
if self.drop_uneven: | |
kept = [] | |
for batch in batches: | |
if len(batch) == self.batch_size: | |
kept.append(batch) | |
batches = kept | |
return batches | |
def __iter__(self): | |
if self._can_reuse_batches: | |
batches = self._batches | |
self._can_reuse_batches = False | |
else: | |
batches = self._prepare_batches() | |
self._batches = batches | |
return iter(batches) | |
def __len__(self): | |
if not hasattr(self, "_batches"): | |
self._batches = self._prepare_batches() | |
self._can_reuse_batches = True | |
return len(self._batches) | |