|
""" Compose multiple datasets in a single loader. """ |
|
|
|
import numpy as np |
|
from copy import deepcopy |
|
from torch.utils.data import Dataset |
|
|
|
from .wireframe_dataset import WireframeDataset |
|
from .holicity_dataset import HolicityDataset |
|
|
|
|
|
class MergeDataset(Dataset): |
|
def __init__(self, mode, config=None): |
|
super(MergeDataset, self).__init__() |
|
|
|
self._datasets = [] |
|
spec_config = deepcopy(config) |
|
for i, d in enumerate(config["datasets"]): |
|
spec_config["dataset_name"] = d |
|
spec_config["gt_source_train"] = config["gt_source_train"][i] |
|
spec_config["gt_source_test"] = config["gt_source_test"][i] |
|
if d == "wireframe": |
|
self._datasets.append(WireframeDataset(mode, spec_config)) |
|
elif d == "holicity": |
|
spec_config["train_split"] = config["train_splits"][i] |
|
self._datasets.append(HolicityDataset(mode, spec_config)) |
|
else: |
|
raise ValueError("Unknown dataset: " + d) |
|
|
|
self._weights = config["weights"] |
|
|
|
def __getitem__(self, item): |
|
dataset = self._datasets[ |
|
np.random.choice(range(len(self._datasets)), p=self._weights) |
|
] |
|
return dataset[np.random.randint(len(dataset))] |
|
|
|
def __len__(self): |
|
return np.sum([len(d) for d in self._datasets]) |
|
|