# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from torch.utils.data.sampler import BatchSampler class IterationBasedBatchSampler(BatchSampler): """ Wraps a BatchSampler, resampling from it until a specified number of iterations have been sampled """ def __init__(self, batch_sampler, num_iterations, start_iter=0): self.batch_sampler = batch_sampler self.num_iterations = num_iterations self.start_iter = start_iter def __iter__(self): iteration = self.start_iter while iteration <= self.num_iterations: # if the underlying sampler has a set_epoch method, like # DistributedSampler, used for making each process see # a different split of the dataset, then set it if hasattr(self.batch_sampler.sampler, "set_epoch"): self.batch_sampler.sampler.set_epoch(iteration) for batch in self.batch_sampler: iteration += 1 if iteration > self.num_iterations: break yield batch def __len__(self): return self.num_iterations