Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import itertools | |
from typing import Any, Optional | |
import warnings | |
import numpy as np | |
import torch | |
from torch.utils.data.sampler import Sampler | |
import dinov2.distributed as distributed | |
class EpochSampler(Sampler): | |
def __init__( | |
self, | |
*, | |
size: int, | |
sample_count: int, | |
shuffle: bool = False, | |
seed: int = 0, | |
start: Optional[int] = None, | |
step: Optional[int] = None, | |
): | |
self._size = size | |
self._sample_count = sample_count | |
self._shuffle = shuffle | |
self._seed = seed | |
self._start = distributed.get_global_rank() if start is None else start | |
self._step = distributed.get_global_size() if step is None else step | |
self._epoch = 0 | |
def __iter__(self): | |
count = (self._size + self._sample_count - 1) // self._sample_count | |
tiled_indices = np.tile(np.arange(self._sample_count), count) | |
if self._shuffle: | |
seed = self._seed * self._epoch if self._seed != 0 else self._epoch | |
rng = np.random.default_rng(seed) | |
iterable = rng.choice(tiled_indices, self._size, replace=False) | |
else: | |
iterable = tiled_indices[: self._size] | |
yield from itertools.islice(iterable, self._start, None, self._step) | |
def __len__(self): | |
return (self._size - self._start + self._step - 1) // self._step | |
def set_epoch(self, epoch): | |
self._epoch = epoch | |
def _get_numpy_dtype(size: int) -> Any: | |
return np.int32 if size <= 2**31 else np.int64 | |
def _get_torch_dtype(size: int) -> Any: | |
return torch.int32 if size <= 2**31 else torch.int64 | |
def _generate_randperm_indices(*, size: int, generator: torch.Generator): | |
"""Generate the indices of a random permutation.""" | |
dtype = _get_torch_dtype(size) | |
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 | |
perm = torch.arange(size, dtype=dtype) | |
for i in range(size): | |
j = torch.randint(i, size, size=(1,), generator=generator).item() | |
# Always swap even if no-op | |
value = perm[j].item() | |
perm[j] = perm[i].item() | |
perm[i] = value | |
yield value | |
class InfiniteSampler(Sampler): | |
def __init__( | |
self, | |
*, | |
sample_count: int, | |
shuffle: bool = False, | |
seed: int = 0, | |
start: Optional[int] = None, | |
step: Optional[int] = None, | |
advance: int = 0, | |
): | |
self._sample_count = sample_count | |
self._seed = seed | |
self._shuffle = shuffle | |
self._start = distributed.get_global_rank() if start is None else start | |
self._step = distributed.get_global_size() if step is None else step | |
self._advance = advance | |
def __iter__(self): | |
if self._shuffle: | |
iterator = self._shuffled_iterator() | |
else: | |
iterator = self._iterator() | |
yield from itertools.islice(iterator, self._advance, None) | |
def _iterator(self): | |
assert not self._shuffle | |
while True: | |
iterable = range(self._sample_count) | |
yield from itertools.islice(iterable, self._start, None, self._step) | |
def _shuffled_iterator(self): | |
assert self._shuffle | |
# Instantiate a generator here (rather than in the ctor) to keep the class | |
# picklable (requirement of mp.spawn) | |
generator = torch.Generator().manual_seed(self._seed) | |
while True: | |
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) | |
yield from itertools.islice(iterable, self._start, None, self._step) | |
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, | |
# but avoids a full in-place random permutation generation. | |
def _shuffle_tensor_slice( | |
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator | |
) -> np.ndarray: | |
stop = len(tensor) | |
count = stop // step | |
drop_count = stop - step * count | |
if drop_count: | |
warnings.warn(f"# of dropped samples: {drop_count}") | |
dtype = _get_numpy_dtype(stop) | |
result = np.empty(count, dtype=dtype) | |
for i in range(count): | |
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 | |
result[i] = result[j] | |
result[j] = tensor[start + i * step].item() | |
return result | |
def _new_shuffle_tensor_slice( | |
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator | |
) -> np.ndarray: | |
stop = len(tensor) | |
count = stop // step | |
dtype = torch.int64 # Needed for using randperm result as indices | |
count = stop // step | |
drop_count = stop - step * count | |
if drop_count: | |
warnings.warn(f"# of dropped samples: {drop_count}") | |
indices = torch.randperm(count, dtype=dtype, generator=generator) | |
return tensor[start::step][indices].numpy() | |
def _make_seed(seed: int, start: int, iter_count: int) -> int: | |
# NOTE: Tried a few variants (including iter_count << 32), this one worked best. | |
return seed + start + (iter_count << 24) | |
class ShardedInfiniteSampler(Sampler): | |
def __init__( | |
self, | |
*, | |
sample_count: int, | |
shuffle: bool = False, | |
seed: int = 0, | |
start: Optional[int] = None, | |
step: Optional[int] = None, | |
advance: int = 0, | |
use_new_shuffle_tensor_slice: bool = False, | |
): | |
self._sample_count = sample_count | |
self._seed = seed | |
self._shuffle = shuffle | |
self._start = distributed.get_global_rank() if start is None else start | |
self._step = distributed.get_global_size() if step is None else step | |
self._advance = advance | |
self._iter_count = 0 | |
self._shuffle_tensor_slice_fn = ( | |
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice | |
) | |
def __iter__(self): | |
iter_count = self._advance // self._sample_count | |
if iter_count > 0: | |
self._advance -= iter_count * self._sample_count | |
self._iter_count += iter_count | |
if self._shuffle: | |
iterator = self._shuffled_iterator() | |
else: | |
iterator = self._iterator() | |
yield from itertools.islice(iterator, self._advance, None) | |
def _iterator(self): | |
assert not self._shuffle | |
while True: | |
iterable = range(self._sample_count) | |
yield from itertools.islice(iterable, self._start, None, self._step) | |
def _shuffled_iterator(self): | |
assert self._shuffle | |
# Instantiate a generator here (rather than in the ctor) to be keep the class | |
# picklable (requirement of mp.spawn) | |
generator = torch.Generator() | |
# Always shuffle everything first | |
generator.manual_seed(self._seed) | |
dtype = _get_torch_dtype(self._sample_count) | |
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) | |
while True: | |
# Re-seed on each iteration to allow skipping whole permutations | |
seed = _make_seed(self._seed, self._start, self._iter_count) | |
generator.manual_seed(seed) | |
iterable = self._shuffle_tensor_slice_fn( | |
tensor=perm, start=self._start, step=self._step, generator=generator | |
) | |
yield from iterable | |
self._iter_count += 1 | |