# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import itertools from typing import Any, Optional import warnings import numpy as np import torch from torch.utils.data.sampler import Sampler import dinov2.distributed as distributed class EpochSampler(Sampler): def __init__( self, *, size: int, sample_count: int, shuffle: bool = False, seed: int = 0, start: Optional[int] = None, step: Optional[int] = None, ): self._size = size self._sample_count = sample_count self._shuffle = shuffle self._seed = seed self._start = distributed.get_global_rank() if start is None else start self._step = distributed.get_global_size() if step is None else step self._epoch = 0 def __iter__(self): count = (self._size + self._sample_count - 1) // self._sample_count tiled_indices = np.tile(np.arange(self._sample_count), count) if self._shuffle: seed = self._seed * self._epoch if self._seed != 0 else self._epoch rng = np.random.default_rng(seed) iterable = rng.choice(tiled_indices, self._size, replace=False) else: iterable = tiled_indices[: self._size] yield from itertools.islice(iterable, self._start, None, self._step) def __len__(self): return (self._size - self._start + self._step - 1) // self._step def set_epoch(self, epoch): self._epoch = epoch def _get_numpy_dtype(size: int) -> Any: return np.int32 if size <= 2**31 else np.int64 def _get_torch_dtype(size: int) -> Any: return torch.int32 if size <= 2**31 else torch.int64 def _generate_randperm_indices(*, size: int, generator: torch.Generator): """Generate the indices of a random permutation.""" dtype = _get_torch_dtype(size) # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 perm = torch.arange(size, dtype=dtype) for i in range(size): j = torch.randint(i, size, size=(1,), generator=generator).item() # Always swap even if no-op value = perm[j].item() perm[j] = perm[i].item() perm[i] = value yield value class InfiniteSampler(Sampler): def __init__( self, *, sample_count: int, shuffle: bool = False, seed: int = 0, start: Optional[int] = None, step: Optional[int] = None, advance: int = 0, ): self._sample_count = sample_count self._seed = seed self._shuffle = shuffle self._start = distributed.get_global_rank() if start is None else start self._step = distributed.get_global_size() if step is None else step self._advance = advance def __iter__(self): if self._shuffle: iterator = self._shuffled_iterator() else: iterator = self._iterator() yield from itertools.islice(iterator, self._advance, None) def _iterator(self): assert not self._shuffle while True: iterable = range(self._sample_count) yield from itertools.islice(iterable, self._start, None, self._step) def _shuffled_iterator(self): assert self._shuffle # Instantiate a generator here (rather than in the ctor) to keep the class # picklable (requirement of mp.spawn) generator = torch.Generator().manual_seed(self._seed) while True: iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) yield from itertools.islice(iterable, self._start, None, self._step) # The following function is somewhat equivalent to _new_shuffle_tensor_slice below, # but avoids a full in-place random permutation generation. def _shuffle_tensor_slice( *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator ) -> np.ndarray: stop = len(tensor) count = stop // step drop_count = stop - step * count if drop_count: warnings.warn(f"# of dropped samples: {drop_count}") dtype = _get_numpy_dtype(stop) result = np.empty(count, dtype=dtype) for i in range(count): j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 result[i] = result[j] result[j] = tensor[start + i * step].item() return result def _new_shuffle_tensor_slice( *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator ) -> np.ndarray: stop = len(tensor) count = stop // step dtype = torch.int64 # Needed for using randperm result as indices count = stop // step drop_count = stop - step * count if drop_count: warnings.warn(f"# of dropped samples: {drop_count}") indices = torch.randperm(count, dtype=dtype, generator=generator) return tensor[start::step][indices].numpy() def _make_seed(seed: int, start: int, iter_count: int) -> int: # NOTE: Tried a few variants (including iter_count << 32), this one worked best. return seed + start + (iter_count << 24) class ShardedInfiniteSampler(Sampler): def __init__( self, *, sample_count: int, shuffle: bool = False, seed: int = 0, start: Optional[int] = None, step: Optional[int] = None, advance: int = 0, use_new_shuffle_tensor_slice: bool = False, ): self._sample_count = sample_count self._seed = seed self._shuffle = shuffle self._start = distributed.get_global_rank() if start is None else start self._step = distributed.get_global_size() if step is None else step self._advance = advance self._iter_count = 0 self._shuffle_tensor_slice_fn = ( _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice ) def __iter__(self): iter_count = self._advance // self._sample_count if iter_count > 0: self._advance -= iter_count * self._sample_count self._iter_count += iter_count if self._shuffle: iterator = self._shuffled_iterator() else: iterator = self._iterator() yield from itertools.islice(iterator, self._advance, None) def _iterator(self): assert not self._shuffle while True: iterable = range(self._sample_count) yield from itertools.islice(iterable, self._start, None, self._step) def _shuffled_iterator(self): assert self._shuffle # Instantiate a generator here (rather than in the ctor) to be keep the class # picklable (requirement of mp.spawn) generator = torch.Generator() # Always shuffle everything first generator.manual_seed(self._seed) dtype = _get_torch_dtype(self._sample_count) perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) while True: # Re-seed on each iteration to allow skipping whole permutations seed = _make_seed(self._seed, self._start, self._iter_count) generator.manual_seed(seed) iterable = self._shuffle_tensor_slice_fn( tensor=perm, start=self._start, step=self._step, generator=generator ) yield from iterable self._iter_count += 1