import math import random from typing import Callable, List, Union from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler class SubsetSampler(Sampler): """ Samples elements sequentially from a given list of indices. Args: indices (list): a sequence of indices """ def __init__(self, indices): super().__init__(indices) self.indices = indices def __iter__(self): return (self.indices[i] for i in range(len(self.indices))) def __len__(self): return len(self.indices) class PerfectBatchSampler(Sampler): """ Samples a mini-batch of indices for a balanced class batching Args: dataset_items(list): dataset items to sample from. classes (list): list of classes of dataset_items to sample from. batch_size (int): total number of samples to be sampled in a mini-batch. num_gpus (int): number of GPU in the data parallel mode. shuffle (bool): if True, samples randomly, otherwise samples sequentially. drop_last (bool): if True, drops last incomplete batch. """ def __init__( self, dataset_items, classes, batch_size, num_classes_in_batch, num_gpus=1, shuffle=True, drop_last=False, label_key="class_name", ): super().__init__(dataset_items) assert ( batch_size % (num_classes_in_batch * num_gpus) == 0 ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." label_indices = {} for idx, item in enumerate(dataset_items): label = item[label_key] if label not in label_indices.keys(): label_indices[label] = [idx] else: label_indices[label].append(idx) if shuffle: self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] else: self._samplers = [SubsetSampler(label_indices[key]) for key in classes] self._batch_size = batch_size self._drop_last = drop_last self._dp_devices = num_gpus self._num_classes_in_batch = num_classes_in_batch def __iter__(self): batch = [] if self._num_classes_in_batch != len(self._samplers): valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) else: valid_samplers_idx = None iters = [iter(s) for s in self._samplers] done = False while True: b = [] for i, it in enumerate(iters): if valid_samplers_idx is not None and i not in valid_samplers_idx: continue idx = next(it, None) if idx is None: done = True break b.append(idx) if done: break batch += b if len(batch) == self._batch_size: yield batch batch = [] if valid_samplers_idx is not None: valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) if not self._drop_last: if len(batch) > 0: groups = len(batch) // self._num_classes_in_batch if groups % self._dp_devices == 0: yield batch else: batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] if len(batch) > 0: yield batch def __len__(self): class_batch_size = self._batch_size // self._num_classes_in_batch return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) def identity(x): return x class SortedSampler(Sampler): """Samples elements sequentially, always in the same order. Taken from https://github.com/PetrochukM/PyTorch-NLP Args: data (iterable): Iterable data. sort_key (callable): Specifies a function of one argument that is used to extract a numerical comparison key from each list element. Example: >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] """ def __init__(self, data, sort_key: Callable = identity): super().__init__(data) self.data = data self.sort_key = sort_key zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] zip_ = sorted(zip_, key=lambda r: r[1]) self.sorted_indexes = [item[0] for item in zip_] def __iter__(self): return iter(self.sorted_indexes) def __len__(self): return len(self.data) class BucketBatchSampler(BatchSampler): """Bucket batch sampler Adapted from https://github.com/PetrochukM/PyTorch-NLP Args: sampler (torch.data.utils.sampler.Sampler): batch_size (int): Size of mini-batch. drop_last (bool): If `True` the sampler will drop the last batch if its size would be less than `batch_size`. data (list): List of data samples. sort_key (callable, optional): Callable to specify a comparison key for sorting. bucket_size_multiplier (int, optional): Buckets are of size `batch_size * bucket_size_multiplier`. Example: >>> sampler = WeightedRandomSampler(weights, len(weights)) >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) """ def __init__( self, sampler, data, batch_size, drop_last, sort_key: Union[Callable, List] = identity, bucket_size_multiplier=100, ): super().__init__(sampler, batch_size, drop_last) self.data = data self.sort_key = sort_key _bucket_size = batch_size * bucket_size_multiplier if hasattr(sampler, "__len__"): _bucket_size = min(_bucket_size, len(sampler)) self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) def __iter__(self): for idxs in self.bucket_sampler: bucket_data = [self.data[idx] for idx in idxs] sorted_sampler = SortedSampler(bucket_data, self.sort_key) for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): sorted_idxs = [idxs[i] for i in batch_idx] yield sorted_idxs def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size return math.ceil(len(self.sampler) / self.batch_size)