Spaces:
Runtime error
Runtime error
import math | |
from typing import Iterator, Optional, Sized | |
import torch | |
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed | |
from torch.utils.data import Sampler | |
from mmpretrain.registry import DATA_SAMPLERS | |
class RepeatAugSampler(Sampler): | |
"""Sampler that restricts data loading to a subset of the dataset for | |
distributed, with repeated augmentation. It ensures that different each | |
augmented version of a sample will be visible to a different process (GPU). | |
Heavily based on torch.utils.data.DistributedSampler. | |
This sampler was taken from | |
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py | |
Used in | |
Copyright (c) 2015-present, Facebook, Inc. | |
Args: | |
dataset (Sized): The dataset. | |
shuffle (bool): Whether shuffle the dataset or not. Defaults to True. | |
num_repeats (int): The repeat times of every sample. Defaults to 3. | |
seed (int, optional): Random seed used to shuffle the sampler if | |
:attr:`shuffle=True`. This number should be identical across all | |
processes in the distributed group. Defaults to None. | |
""" | |
def __init__(self, | |
dataset: Sized, | |
shuffle: bool = True, | |
num_repeats: int = 3, | |
seed: Optional[int] = None): | |
rank, world_size = get_dist_info() | |
self.rank = rank | |
self.world_size = world_size | |
self.dataset = dataset | |
self.shuffle = shuffle | |
if not self.shuffle and is_main_process(): | |
from mmengine.logging import MMLogger | |
logger = MMLogger.get_current_instance() | |
logger.warning('The RepeatAugSampler always picks a ' | |
'fixed part of data if `shuffle=False`.') | |
if seed is None: | |
seed = sync_random_seed() | |
self.seed = seed | |
self.epoch = 0 | |
self.num_repeats = num_repeats | |
# The number of repeated samples in the rank | |
self.num_samples = math.ceil( | |
len(self.dataset) * num_repeats / world_size) | |
# The total number of repeated samples in all ranks. | |
self.total_size = self.num_samples * world_size | |
# The number of selected samples in the rank | |
self.num_selected_samples = math.ceil(len(self.dataset) / world_size) | |
def __iter__(self) -> Iterator[int]: | |
"""Iterate the indices.""" | |
# deterministically shuffle based on epoch and seed | |
if self.shuffle: | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
indices = torch.randperm(len(self.dataset), generator=g).tolist() | |
else: | |
indices = list(range(len(self.dataset))) | |
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] | |
indices = [x for x in indices for _ in range(self.num_repeats)] | |
# add extra samples to make it evenly divisible | |
padding_size = self.total_size - len(indices) | |
indices += indices[:padding_size] | |
assert len(indices) == self.total_size | |
# subsample per rank | |
indices = indices[self.rank:self.total_size:self.world_size] | |
assert len(indices) == self.num_samples | |
# return up to num selected samples | |
return iter(indices[:self.num_selected_samples]) | |
def __len__(self) -> int: | |
"""The number of samples in this rank.""" | |
return self.num_selected_samples | |
def set_epoch(self, epoch: int) -> None: | |
"""Sets the epoch for this sampler. | |
When :attr:`shuffle=True`, this ensures all replicas use a different | |
random ordering for each epoch. Otherwise, the next iteration of this | |
sampler will yield the same ordering. | |
Args: | |
epoch (int): Epoch number. | |
""" | |
self.epoch = epoch | |