|
import os |
|
import numpy as np |
|
from abc import abstractmethod |
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset |
|
|
|
|
|
class Txt2ImgIterableBaseDataset(IterableDataset): |
|
''' |
|
Define an interface to make the IterableDatasets for text2img data chainable |
|
''' |
|
def __init__(self, num_records=0, valid_ids=None, size=256): |
|
super().__init__() |
|
self.num_records = num_records |
|
self.valid_ids = valid_ids |
|
self.sample_ids = valid_ids |
|
self.size = size |
|
|
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') |
|
|
|
def __len__(self): |
|
return self.num_records |
|
|
|
@abstractmethod |
|
def __iter__(self): |
|
pass |
|
|
|
|
|
class PRNGMixin(object): |
|
""" |
|
Adds a prng property which is a numpy RandomState which gets |
|
reinitialized whenever the pid changes to avoid synchronized sampling |
|
behavior when used in conjunction with multiprocessing. |
|
""" |
|
@property |
|
def prng(self): |
|
currentpid = os.getpid() |
|
if getattr(self, "_initpid", None) != currentpid: |
|
self._initpid = currentpid |
|
self._prng = np.random.RandomState() |
|
return self._prng |
|
|