# Copyright (c) OpenMMLab. All rights reserved. import math from typing import Iterator, Optional, Sequence, Sized import torch from mmengine.dist import get_dist_info, sync_random_seed from mmengine.registry import DATA_SAMPLERS from torch.utils.data import Sampler @DATA_SAMPLERS.register_module() class MultiDataSampler(Sampler): """The default data sampler for both distributed and non-distributed environment. It has several differences from the PyTorch ``DistributedSampler`` as below: 1. This sampler supports non-distributed environment. 2. The round up behaviors are a little different. - If ``round_up=True``, this sampler will add extra samples to make the number of samples is evenly divisible by the world size. And this behavior is the same as the ``DistributedSampler`` with ``drop_last=False``. - If ``round_up=False``, this sampler won't remove or add any samples while the ``DistributedSampler`` with ``drop_last=True`` will remove tail samples. Args: dataset (Sized): The dataset. dataset_ratio (Sequence(int)) The ratios of different datasets. seed (int, optional): Random seed used to shuffle the sampler if :attr:`shuffle=True`. This number should be identical across all processes in the distributed group. Defaults to None. round_up (bool): Whether to add extra samples to make the number of samples evenly divisible by the world size. Defaults to True. """ def __init__(self, dataset: Sized, dataset_ratio: Sequence[int], seed: Optional[int] = None, round_up: bool = True) -> None: rank, world_size = get_dist_info() self.rank = rank self.world_size = world_size self.dataset = dataset self.dataset_ratio = dataset_ratio if seed is None: seed = sync_random_seed() self.seed = seed self.epoch = 0 self.round_up = round_up if self.round_up: self.num_samples = math.ceil(len(self.dataset) / world_size) self.total_size = self.num_samples * self.world_size else: self.num_samples = math.ceil( (len(self.dataset) - rank) / world_size) self.total_size = len(self.dataset) self.sizes = [len(dataset) for dataset in self.dataset.datasets] dataset_weight = [ torch.ones(s) * max(self.sizes) / s * r / sum(self.dataset_ratio) for i, (r, s) in enumerate(zip(self.dataset_ratio, self.sizes)) ] self.weights = torch.cat(dataset_weight) def __iter__(self) -> Iterator[int]: """Iterate the indices.""" # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.multinomial( self.weights, len(self.weights), generator=g, replacement=True).tolist() # add extra samples to make it evenly divisible if self.round_up: indices = ( indices * int(self.total_size / len(indices) + 1))[:self.total_size] # subsample indices = indices[self.rank:self.total_size:self.world_size] return iter(indices) def __len__(self) -> int: """The number of samples in this rank.""" return self.num_samples def set_epoch(self, epoch: int) -> None: """Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch