Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| from typing import Callable, List, Union | |
| from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler | |
| class SubsetSampler(Sampler): | |
| """ | |
| Samples elements sequentially from a given list of indices. | |
| Args: | |
| indices (list): a sequence of indices | |
| """ | |
| def __init__(self, indices): | |
| super().__init__(indices) | |
| self.indices = indices | |
| def __iter__(self): | |
| return (self.indices[i] for i in range(len(self.indices))) | |
| def __len__(self): | |
| return len(self.indices) | |
| class PerfectBatchSampler(Sampler): | |
| """ | |
| Samples a mini-batch of indices for a balanced class batching | |
| Args: | |
| dataset_items(list): dataset items to sample from. | |
| classes (list): list of classes of dataset_items to sample from. | |
| batch_size (int): total number of samples to be sampled in a mini-batch. | |
| num_gpus (int): number of GPU in the data parallel mode. | |
| shuffle (bool): if True, samples randomly, otherwise samples sequentially. | |
| drop_last (bool): if True, drops last incomplete batch. | |
| """ | |
| def __init__( | |
| self, | |
| dataset_items, | |
| classes, | |
| batch_size, | |
| num_classes_in_batch, | |
| num_gpus=1, | |
| shuffle=True, | |
| drop_last=False, | |
| label_key="class_name", | |
| ): | |
| super().__init__(dataset_items) | |
| assert ( | |
| batch_size % (num_classes_in_batch * num_gpus) == 0 | |
| ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." | |
| label_indices = {} | |
| for idx, item in enumerate(dataset_items): | |
| label = item[label_key] | |
| if label not in label_indices.keys(): | |
| label_indices[label] = [idx] | |
| else: | |
| label_indices[label].append(idx) | |
| if shuffle: | |
| self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] | |
| else: | |
| self._samplers = [SubsetSampler(label_indices[key]) for key in classes] | |
| self._batch_size = batch_size | |
| self._drop_last = drop_last | |
| self._dp_devices = num_gpus | |
| self._num_classes_in_batch = num_classes_in_batch | |
| def __iter__(self): | |
| batch = [] | |
| if self._num_classes_in_batch != len(self._samplers): | |
| valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) | |
| else: | |
| valid_samplers_idx = None | |
| iters = [iter(s) for s in self._samplers] | |
| done = False | |
| while True: | |
| b = [] | |
| for i, it in enumerate(iters): | |
| if valid_samplers_idx is not None and i not in valid_samplers_idx: | |
| continue | |
| idx = next(it, None) | |
| if idx is None: | |
| done = True | |
| break | |
| b.append(idx) | |
| if done: | |
| break | |
| batch += b | |
| if len(batch) == self._batch_size: | |
| yield batch | |
| batch = [] | |
| if valid_samplers_idx is not None: | |
| valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) | |
| if not self._drop_last: | |
| if len(batch) > 0: | |
| groups = len(batch) // self._num_classes_in_batch | |
| if groups % self._dp_devices == 0: | |
| yield batch | |
| else: | |
| batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] | |
| if len(batch) > 0: | |
| yield batch | |
| def __len__(self): | |
| class_batch_size = self._batch_size // self._num_classes_in_batch | |
| return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) | |
| def identity(x): | |
| return x | |
| class SortedSampler(Sampler): | |
| """Samples elements sequentially, always in the same order. | |
| Taken from https://github.com/PetrochukM/PyTorch-NLP | |
| Args: | |
| data (iterable): Iterable data. | |
| sort_key (callable): Specifies a function of one argument that is used to extract a | |
| numerical comparison key from each list element. | |
| Example: | |
| >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) | |
| [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] | |
| """ | |
| def __init__(self, data, sort_key: Callable = identity): | |
| super().__init__(data) | |
| self.data = data | |
| self.sort_key = sort_key | |
| zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] | |
| zip_ = sorted(zip_, key=lambda r: r[1]) | |
| self.sorted_indexes = [item[0] for item in zip_] | |
| def __iter__(self): | |
| return iter(self.sorted_indexes) | |
| def __len__(self): | |
| return len(self.data) | |
| class BucketBatchSampler(BatchSampler): | |
| """Bucket batch sampler | |
| Adapted from https://github.com/PetrochukM/PyTorch-NLP | |
| Args: | |
| sampler (torch.data.utils.sampler.Sampler): | |
| batch_size (int): Size of mini-batch. | |
| drop_last (bool): If `True` the sampler will drop the last batch if its size would be less | |
| than `batch_size`. | |
| data (list): List of data samples. | |
| sort_key (callable, optional): Callable to specify a comparison key for sorting. | |
| bucket_size_multiplier (int, optional): Buckets are of size | |
| `batch_size * bucket_size_multiplier`. | |
| Example: | |
| >>> sampler = WeightedRandomSampler(weights, len(weights)) | |
| >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) | |
| """ | |
| def __init__( | |
| self, | |
| sampler, | |
| data, | |
| batch_size, | |
| drop_last, | |
| sort_key: Union[Callable, List] = identity, | |
| bucket_size_multiplier=100, | |
| ): | |
| super().__init__(sampler, batch_size, drop_last) | |
| self.data = data | |
| self.sort_key = sort_key | |
| _bucket_size = batch_size * bucket_size_multiplier | |
| if hasattr(sampler, "__len__"): | |
| _bucket_size = min(_bucket_size, len(sampler)) | |
| self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) | |
| def __iter__(self): | |
| for idxs in self.bucket_sampler: | |
| bucket_data = [self.data[idx] for idx in idxs] | |
| sorted_sampler = SortedSampler(bucket_data, self.sort_key) | |
| for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): | |
| sorted_idxs = [idxs[i] for i in batch_idx] | |
| yield sorted_idxs | |
| def __len__(self): | |
| if self.drop_last: | |
| return len(self.sampler) // self.batch_size | |
| return math.ceil(len(self.sampler) / self.batch_size) | |