import os import numpy as np from abc import abstractmethod from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset class Txt2ImgIterableBaseDataset(IterableDataset): ''' Define an interface to make the IterableDatasets for text2img data chainable ''' def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records self.valid_ids = valid_ids self.sample_ids = valid_ids self.size = size print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') def __len__(self): return self.num_records @abstractmethod def __iter__(self): pass class PRNGMixin(object): """ Adds a prng property which is a numpy RandomState which gets reinitialized whenever the pid changes to avoid synchronized sampling behavior when used in conjunction with multiprocessing. """ @property def prng(self): currentpid = os.getpid() if getattr(self, "_initpid", None) != currentpid: self._initpid = currentpid self._prng = np.random.RandomState() return self._prng