Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch.utils.data.distributed import DistributedSampler | |
# stolen from https://github.com/facebookresearch/vissl/blob/94def58538d3c7037f5e093196494331eea1a2a2/vissl/data/data_helper.py#L93 | |
class StatefulDistributedSampler(DistributedSampler): | |
""" | |
More fine-grained state DataSampler that uses training iteration and epoch | |
both for shuffling data. PyTorch DistributedSampler only uses epoch | |
for the shuffling and starts sampling data from the start. In case of training | |
on very large data, we train for one epoch only and when we resume training, | |
we want to resume the data sampler from the training iteration. | |
""" | |
def __init__(self, dataset, batch_size=None, seed: int = 0): | |
""" | |
Initializes the instance of StatefulDistributedSampler. Random seed is set | |
for the epoch set and data is shuffled. For starting the sampling, use | |
the start_iter (set to 0 or set by checkpointing resuming) to | |
sample data from the remaining images. | |
Args: | |
dataset (Dataset): Pytorch dataset that sampler will shuffle | |
batch_size (int): batch size we want the sampler to sample | |
seed (int): Seed for the torch generator. | |
""" | |
super().__init__(dataset, shuffle=False, seed=seed) | |
self.start_iter = 0 | |
self.batch_size = batch_size | |
self.total_size = len(dataset) - (len(dataset) % self.num_replicas) | |
self.num_samples = self.total_size // self.num_replicas | |
print(f"rank: {self.rank}: Sampler created...") | |
def __iter__(self): | |
# partition data into num_replicas and optionally shuffle within a rank | |
g = torch.Generator() | |
g.manual_seed(self.epoch + self.seed) | |
shuffling = torch.randperm(self.num_samples, generator=g).tolist() | |
indices = np.array( | |
list( | |
range( | |
(self.rank * self.num_samples), (self.rank + 1) * self.num_samples | |
) | |
) | |
)[shuffling].tolist() | |
# make sure we have correct number of samples per replica | |
assert len(indices) == self.num_samples | |
assert self.batch_size > 0, "batch_size not set for the sampler" | |
# resume the sampler | |
start_index = self.start_iter * self.batch_size | |
indices = indices[start_index:] | |
return iter(indices) | |
def set_start_iter(self, start_iter): | |
""" | |
Set the iteration number from which the sampling should start. This is | |
used to find the marker in the data permutation order from where the | |
sampler should start sampling. | |
""" | |
self.start_iter = start_iter | |