Spaces:
Paused
Paused
import os | |
import numpy as np | |
from abc import abstractmethod | |
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset | |
class Txt2ImgIterableBaseDataset(IterableDataset): | |
''' | |
Define an interface to make the IterableDatasets for text2img data chainable | |
''' | |
def __init__(self, num_records=0, valid_ids=None, size=256): | |
super().__init__() | |
self.num_records = num_records | |
self.valid_ids = valid_ids | |
self.sample_ids = valid_ids | |
self.size = size | |
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') | |
def __len__(self): | |
return self.num_records | |
def __iter__(self): | |
pass | |
class PRNGMixin(object): | |
""" | |
Adds a prng property which is a numpy RandomState which gets | |
reinitialized whenever the pid changes to avoid synchronized sampling | |
behavior when used in conjunction with multiprocessing. | |
""" | |
def prng(self): | |
currentpid = os.getpid() | |
if getattr(self, "_initpid", None) != currentpid: | |
self._initpid = currentpid | |
self._prng = np.random.RandomState() | |
return self._prng | |