Your Name
add
c310e19
raw
history blame contribute delete
No virus
4.84 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import itertools
import torch
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import Sampler
class GroupedBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices.
It enforces that elements from the same group should appear in groups of batch_size.
It also tries to provide mini-batches which follows an ordering which is
as close as possible to the ordering from the original sampler.
Arguments:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_uneven (bool): If ``True``, the sampler will drop the batches whose
size is less than ``batch_size``
"""
def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
if not isinstance(sampler, Sampler):
raise ValueError(
"sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
)
self.sampler = sampler
self.group_ids = torch.as_tensor(group_ids)
assert self.group_ids.dim() == 1
self.batch_size = batch_size
self.drop_uneven = drop_uneven
self.groups = torch.unique(self.group_ids).sort(0)[0]
self._can_reuse_batches = False
def _prepare_batches(self):
dataset_size = len(self.group_ids)
# get the sampled indices from the sampler
sampled_ids = torch.as_tensor(list(self.sampler))
# potentially not all elements of the dataset were sampled
# by the sampler (e.g., DistributedSampler).
# construct a tensor which contains -1 if the element was
# not sampled, and a non-negative number indicating the
# order where the element was sampled.
# for example. if sampled_ids = [3, 1] and dataset_size = 5,
# the order is [-1, 1, -1, 0, -1]
order = torch.full((dataset_size,), -1, dtype=torch.int64)
order[sampled_ids] = torch.arange(len(sampled_ids))
# get a mask with the elements that were sampled
mask = order >= 0
# find the elements that belong to each individual cluster
clusters = [(self.group_ids == i) & mask for i in self.groups]
# get relative order of the elements inside each cluster
# that follows the order from the sampler
relative_order = [order[cluster] for cluster in clusters]
# with the relative order, find the absolute order in the
# sampled space
permutation_ids = [s[s.sort()[1]] for s in relative_order]
# permute each cluster so that they follow the order from
# the sampler
permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
# splits each cluster in batch_size, and merge as a list of tensors
splits = [c.split(self.batch_size) for c in permuted_clusters]
merged = tuple(itertools.chain.from_iterable(splits))
# now each batch internally has the right order, but
# they are grouped by clusters. Find the permutation between
# different batches that brings them as close as possible to
# the order that we have in the sampler. For that, we will consider the
# ordering as coming from the first element of each batch, and sort
# correspondingly
first_element_of_batch = [t[0].item() for t in merged]
# get and inverse mapping from sampled indices and the position where
# they occur (as returned by the sampler)
inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
# from the first element in each batch, get a relative ordering
first_index_of_batch = torch.as_tensor(
[inv_sampled_ids_map[s] for s in first_element_of_batch]
)
# permute the batches so that they approximately follow the order
# from the sampler
permutation_order = first_index_of_batch.sort(0)[1].tolist()
# finally, permute the batches
batches = [merged[i].tolist() for i in permutation_order]
if self.drop_uneven:
kept = []
for batch in batches:
if len(batch) == self.batch_size:
kept.append(batch)
batches = kept
return batches
def __iter__(self):
if self._can_reuse_batches:
batches = self._batches
self._can_reuse_batches = False
else:
batches = self._prepare_batches()
self._batches = batches
return iter(batches)
def __len__(self):
if not hasattr(self, "_batches"):
self._batches = self._prepare_batches()
self._can_reuse_batches = True
return len(self._batches)