# Copyright (c) OpenMMLab. All rights reserved. import itertools import numpy as np import torch from mmcv.runner import get_dist_info from torch.utils.data.sampler import Sampler from mmdet.core.utils import sync_random_seed class InfiniteGroupBatchSampler(Sampler): """Similar to `BatchSampler` warping a `GroupSampler. It is designed for iteration-based runners like `IterBasedRunner` and yields a mini-batch indices each time, all indices in a batch should be in the same group. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py Args: dataset (object): The dataset. batch_size (int): When model is :obj:`DistributedDataParallel`, it is the number of training samples on each GPU. When model is :obj:`DataParallel`, it is `num_gpus * samples_per_gpu`. Default : 1. world_size (int, optional): Number of processes participating in distributed training. Default: None. rank (int, optional): Rank of current process. Default: None. seed (int): Random seed. Default: 0. shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it should be noted that `shuffle` can not guarantee that you can generate sequential indices because it need to ensure that all indices in a batch is in a group. Default: True. """ # noqa: W605 def __init__(self, dataset, batch_size=1, world_size=None, rank=None, seed=0, shuffle=True): _rank, _world_size = get_dist_info() if world_size is None: world_size = _world_size if rank is None: rank = _rank self.rank = rank self.world_size = world_size self.dataset = dataset self.batch_size = batch_size # In distributed sampling, different ranks should sample # non-overlapped data in the dataset. Therefore, this function # is used to make sure that each rank shuffles the data indices # in the same order based on the same seed. Then different ranks # could use different indices to select non-overlapped data from the # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle assert hasattr(self.dataset, 'flag') self.flag = self.dataset.flag self.group_sizes = np.bincount(self.flag) # buffer used to save indices of each group self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} self.size = len(dataset) self.indices = self._indices_of_rank() def _infinite_indices(self): """Infinitely yield a sequence of indices.""" g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.size, generator=g).tolist() else: yield from torch.arange(self.size).tolist() def _indices_of_rank(self): """Slice the infinite indices by rank.""" yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) def __iter__(self): # once batch size is reached, yield the indices for idx in self.indices: flag = self.flag[idx] group_buffer = self.buffer_per_group[flag] group_buffer.append(idx) if len(group_buffer) == self.batch_size: yield group_buffer[:] del group_buffer[:] def __len__(self): """Length of base dataset.""" return self.size def set_epoch(self, epoch): """Not supported in `IterationBased` runner.""" raise NotImplementedError class InfiniteBatchSampler(Sampler): """Similar to `BatchSampler` warping a `DistributedSampler. It is designed iteration-based runners like `IterBasedRunner` and yields a mini-batch indices each time. The implementation logic is referred to https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py Args: dataset (object): The dataset. batch_size (int): When model is :obj:`DistributedDataParallel`, it is the number of training samples on each GPU, When model is :obj:`DataParallel`, it is `num_gpus * samples_per_gpu`. Default : 1. world_size (int, optional): Number of processes participating in distributed training. Default: None. rank (int, optional): Rank of current process. Default: None. seed (int): Random seed. Default: 0. shuffle (bool): Whether shuffle the dataset or not. Default: True. """ # noqa: W605 def __init__(self, dataset, batch_size=1, world_size=None, rank=None, seed=0, shuffle=True): _rank, _world_size = get_dist_info() if world_size is None: world_size = _world_size if rank is None: rank = _rank self.rank = rank self.world_size = world_size self.dataset = dataset self.batch_size = batch_size # In distributed sampling, different ranks should sample # non-overlapped data in the dataset. Therefore, this function # is used to make sure that each rank shuffles the data indices # in the same order based on the same seed. Then different ranks # could use different indices to select non-overlapped data from the # same data list. self.seed = sync_random_seed(seed) self.shuffle = shuffle self.size = len(dataset) self.indices = self._indices_of_rank() def _infinite_indices(self): """Infinitely yield a sequence of indices.""" g = torch.Generator() g.manual_seed(self.seed) while True: if self.shuffle: yield from torch.randperm(self.size, generator=g).tolist() else: yield from torch.arange(self.size).tolist() def _indices_of_rank(self): """Slice the infinite indices by rank.""" yield from itertools.islice(self._infinite_indices(), self.rank, None, self.world_size) def __iter__(self): # once batch size is reached, yield the indices batch_buffer = [] for idx in self.indices: batch_buffer.append(idx) if len(batch_buffer) == self.batch_size: yield batch_buffer batch_buffer = [] def __len__(self): """Length of base dataset.""" return self.size def set_epoch(self, epoch): """Not supported in `IterationBased` runner.""" raise NotImplementedError