| |
| |
| |
| |
|
|
| import math |
| import random |
|
|
| from torch.utils.data import ConcatDataset, Dataset |
| from torch.utils.data.sampler import ( |
| BatchSampler, |
| RandomSampler, |
| Sampler, |
| SequentialSampler, |
| ) |
|
|
|
|
| class ScheduledSampler(Sampler): |
| """A sampler that samples data from a given concat-dataset. |
| |
| Args: |
| concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets |
| batch_size (int): batch size |
| holistic_shuffle (bool): whether to shuffle the whole dataset or not |
| logger (logging.Logger): logger to print warning message |
| |
| Usage: |
| For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: |
| >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) |
| [3, 4, 5, 0, 1, 2, 6, 7, 8] |
| """ |
|
|
| def __init__( |
| self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train" |
| ): |
| if not isinstance(concat_dataset, ConcatDataset): |
| raise ValueError( |
| "concat_dataset must be an instance of ConcatDataset, but got {}".format( |
| type(concat_dataset) |
| ) |
| ) |
| if not isinstance(batch_size, int): |
| raise ValueError( |
| "batch_size must be an integer, but got {}".format(type(batch_size)) |
| ) |
| if not isinstance(holistic_shuffle, bool): |
| raise ValueError( |
| "holistic_shuffle must be a boolean, but got {}".format( |
| type(holistic_shuffle) |
| ) |
| ) |
|
|
| self.concat_dataset = concat_dataset |
| self.batch_size = batch_size |
| self.holistic_shuffle = holistic_shuffle |
|
|
| affected_dataset_name = [] |
| affected_dataset_len = [] |
| for dataset in concat_dataset.datasets: |
| dataset_len = len(dataset) |
| dataset_name = dataset.get_dataset_name() |
| if dataset_len < batch_size: |
| affected_dataset_name.append(dataset_name) |
| affected_dataset_len.append(dataset_len) |
|
|
| self.type = type |
| for dataset_name, dataset_len in zip( |
| affected_dataset_name, affected_dataset_len |
| ): |
| if not type == "valid": |
| logger.warning( |
| "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( |
| type, dataset_name, dataset_len, batch_size |
| ) |
| ) |
|
|
| def __len__(self): |
| |
| num_of_batches = sum( |
| [ |
| math.floor(len(dataset) / self.batch_size) |
| for dataset in self.concat_dataset.datasets |
| ] |
| ) |
| return num_of_batches * self.batch_size |
|
|
| def __iter__(self): |
| iters = [] |
| for dataset in self.concat_dataset.datasets: |
| iters.append( |
| SequentialSampler(dataset).__iter__() |
| if self.holistic_shuffle |
| else RandomSampler(dataset).__iter__() |
| ) |
| init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] |
| output_batches = [] |
| for dataset_idx in range(len(self.concat_dataset.datasets)): |
| cur_batch = [] |
| for idx in iters[dataset_idx]: |
| cur_batch.append(idx + init_indices[dataset_idx]) |
| if len(cur_batch) == self.batch_size: |
| output_batches.append(cur_batch) |
| cur_batch = [] |
| if self.type == "valid" and len(cur_batch) > 0: |
| output_batches.append(cur_batch) |
| cur_batch = [] |
| |
| random.shuffle(output_batches) |
| output_indices = [item for sublist in output_batches for item in sublist] |
| return iter(output_indices) |
|
|
|
|
| def build_samplers(concat_dataset: Dataset, cfg, logger, type): |
| sampler = ScheduledSampler( |
| concat_dataset, |
| cfg.train.batch_size, |
| cfg.train.sampler.holistic_shuffle, |
| logger, |
| type, |
| ) |
| batch_sampler = BatchSampler( |
| sampler, |
| cfg.train.batch_size, |
| cfg.train.sampler.drop_last if not type == "valid" else False, |
| ) |
| return sampler, batch_sampler |
|
|