""" Compose multiple datasets in a single loader. """ import numpy as np from copy import deepcopy from torch.utils.data import Dataset from .wireframe_dataset import WireframeDataset from .holicity_dataset import HolicityDataset class MergeDataset(Dataset): def __init__(self, mode, config=None): super(MergeDataset, self).__init__() # Initialize the datasets self._datasets = [] spec_config = deepcopy(config) for i, d in enumerate(config["datasets"]): spec_config["dataset_name"] = d spec_config["gt_source_train"] = config["gt_source_train"][i] spec_config["gt_source_test"] = config["gt_source_test"][i] if d == "wireframe": self._datasets.append(WireframeDataset(mode, spec_config)) elif d == "holicity": spec_config["train_split"] = config["train_splits"][i] self._datasets.append(HolicityDataset(mode, spec_config)) else: raise ValueError("Unknown dataset: " + d) self._weights = config["weights"] def __getitem__(self, item): dataset = self._datasets[ np.random.choice(range(len(self._datasets)), p=self._weights) ] return dataset[np.random.randint(len(dataset))] def __len__(self): return np.sum([len(d) for d in self._datasets])