|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
from typing import Any, Optional |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data.sampler import Sampler |
|
|
|
import dinov2.distributed as distributed |
|
|
|
|
|
class EpochSampler(Sampler): |
|
def __init__( |
|
self, |
|
*, |
|
size: int, |
|
sample_count: int, |
|
shuffle: bool = False, |
|
seed: int = 0, |
|
start: Optional[int] = None, |
|
step: Optional[int] = None, |
|
): |
|
self._size = size |
|
self._sample_count = sample_count |
|
self._shuffle = shuffle |
|
self._seed = seed |
|
self._start = distributed.get_global_rank() if start is None else start |
|
self._step = distributed.get_global_size() if step is None else step |
|
self._epoch = 0 |
|
|
|
def __iter__(self): |
|
count = (self._size + self._sample_count - 1) // self._sample_count |
|
tiled_indices = np.tile(np.arange(self._sample_count), count) |
|
if self._shuffle: |
|
seed = self._seed * self._epoch if self._seed != 0 else self._epoch |
|
rng = np.random.default_rng(seed) |
|
iterable = rng.choice(tiled_indices, self._size, replace=False) |
|
else: |
|
iterable = tiled_indices[: self._size] |
|
|
|
yield from itertools.islice(iterable, self._start, None, self._step) |
|
|
|
def __len__(self): |
|
return (self._size - self._start + self._step - 1) // self._step |
|
|
|
def set_epoch(self, epoch): |
|
self._epoch = epoch |
|
|
|
|
|
def _get_numpy_dtype(size: int) -> Any: |
|
return np.int32 if size <= 2**31 else np.int64 |
|
|
|
|
|
def _get_torch_dtype(size: int) -> Any: |
|
return torch.int32 if size <= 2**31 else torch.int64 |
|
|
|
|
|
def _generate_randperm_indices(*, size: int, generator: torch.Generator): |
|
"""Generate the indices of a random permutation.""" |
|
dtype = _get_torch_dtype(size) |
|
|
|
perm = torch.arange(size, dtype=dtype) |
|
for i in range(size): |
|
j = torch.randint(i, size, size=(1,), generator=generator).item() |
|
|
|
|
|
value = perm[j].item() |
|
perm[j] = perm[i].item() |
|
perm[i] = value |
|
yield value |
|
|
|
|
|
class InfiniteSampler(Sampler): |
|
def __init__( |
|
self, |
|
*, |
|
sample_count: int, |
|
shuffle: bool = False, |
|
seed: int = 0, |
|
start: Optional[int] = None, |
|
step: Optional[int] = None, |
|
advance: int = 0, |
|
): |
|
self._sample_count = sample_count |
|
self._seed = seed |
|
self._shuffle = shuffle |
|
self._start = distributed.get_global_rank() if start is None else start |
|
self._step = distributed.get_global_size() if step is None else step |
|
self._advance = advance |
|
|
|
def __iter__(self): |
|
if self._shuffle: |
|
iterator = self._shuffled_iterator() |
|
else: |
|
iterator = self._iterator() |
|
|
|
yield from itertools.islice(iterator, self._advance, None) |
|
|
|
def _iterator(self): |
|
assert not self._shuffle |
|
|
|
while True: |
|
iterable = range(self._sample_count) |
|
yield from itertools.islice(iterable, self._start, None, self._step) |
|
|
|
def _shuffled_iterator(self): |
|
assert self._shuffle |
|
|
|
|
|
|
|
generator = torch.Generator().manual_seed(self._seed) |
|
|
|
while True: |
|
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) |
|
yield from itertools.islice(iterable, self._start, None, self._step) |
|
|
|
|
|
|
|
|
|
def _shuffle_tensor_slice( |
|
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator |
|
) -> np.ndarray: |
|
stop = len(tensor) |
|
count = stop // step |
|
drop_count = stop - step * count |
|
if drop_count: |
|
warnings.warn(f"# of dropped samples: {drop_count}") |
|
|
|
dtype = _get_numpy_dtype(stop) |
|
result = np.empty(count, dtype=dtype) |
|
|
|
for i in range(count): |
|
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 |
|
|
|
result[i] = result[j] |
|
result[j] = tensor[start + i * step].item() |
|
|
|
return result |
|
|
|
|
|
def _new_shuffle_tensor_slice( |
|
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator |
|
) -> np.ndarray: |
|
stop = len(tensor) |
|
count = stop // step |
|
dtype = torch.int64 |
|
count = stop // step |
|
drop_count = stop - step * count |
|
if drop_count: |
|
warnings.warn(f"# of dropped samples: {drop_count}") |
|
indices = torch.randperm(count, dtype=dtype, generator=generator) |
|
return tensor[start::step][indices].numpy() |
|
|
|
|
|
def _make_seed(seed: int, start: int, iter_count: int) -> int: |
|
|
|
return seed + start + (iter_count << 24) |
|
|
|
|
|
class ShardedInfiniteSampler(Sampler): |
|
def __init__( |
|
self, |
|
*, |
|
sample_count: int, |
|
shuffle: bool = False, |
|
seed: int = 0, |
|
start: Optional[int] = None, |
|
step: Optional[int] = None, |
|
advance: int = 0, |
|
use_new_shuffle_tensor_slice: bool = False, |
|
): |
|
self._sample_count = sample_count |
|
self._seed = seed |
|
self._shuffle = shuffle |
|
self._start = distributed.get_global_rank() if start is None else start |
|
self._step = distributed.get_global_size() if step is None else step |
|
self._advance = advance |
|
self._iter_count = 0 |
|
self._shuffle_tensor_slice_fn = ( |
|
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice |
|
) |
|
|
|
def __iter__(self): |
|
iter_count = self._advance // self._sample_count |
|
if iter_count > 0: |
|
self._advance -= iter_count * self._sample_count |
|
self._iter_count += iter_count |
|
|
|
if self._shuffle: |
|
iterator = self._shuffled_iterator() |
|
else: |
|
iterator = self._iterator() |
|
|
|
yield from itertools.islice(iterator, self._advance, None) |
|
|
|
def _iterator(self): |
|
assert not self._shuffle |
|
|
|
while True: |
|
iterable = range(self._sample_count) |
|
yield from itertools.islice(iterable, self._start, None, self._step) |
|
|
|
def _shuffled_iterator(self): |
|
assert self._shuffle |
|
|
|
|
|
|
|
generator = torch.Generator() |
|
|
|
|
|
generator.manual_seed(self._seed) |
|
dtype = _get_torch_dtype(self._sample_count) |
|
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) |
|
|
|
while True: |
|
|
|
seed = _make_seed(self._seed, self._start, self._iter_count) |
|
generator.manual_seed(seed) |
|
|
|
iterable = self._shuffle_tensor_slice_fn( |
|
tensor=perm, start=self._start, step=self._step, generator=generator |
|
) |
|
yield from iterable |
|
self._iter_count += 1 |
|
|