import math from typing import TypeVar, Optional, Iterator import torch from torch.utils.data import Sampler, Dataset import torch.distributed as dist import random import numpy as np def create_duplicate_dataset(DatasetBaseClass): class DupDataset(DatasetBaseClass): def __init__(self, copy, **kwargs): super(DupDataset, self).__init__(**kwargs) self.copy = copy self.length = super(DupDataset, self).__len__() def __len__(self): return self.copy * self.length def __getitem__(self, index): true_index = index % self.length return super(DupDataset, self).__getitem__(true_index) def get_img_info(self, index): true_index = index % self.length return super(DupDataset, self).get_img_info(true_index) return DupDataset