Vincentqyw
fix: roma
8b973ee
raw
history blame
1.39 kB
""" Compose multiple datasets in a single loader. """
import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset
from .wireframe_dataset import WireframeDataset
from .holicity_dataset import HolicityDataset
class MergeDataset(Dataset):
def __init__(self, mode, config=None):
super(MergeDataset, self).__init__()
# Initialize the datasets
self._datasets = []
spec_config = deepcopy(config)
for i, d in enumerate(config["datasets"]):
spec_config["dataset_name"] = d
spec_config["gt_source_train"] = config["gt_source_train"][i]
spec_config["gt_source_test"] = config["gt_source_test"][i]
if d == "wireframe":
self._datasets.append(WireframeDataset(mode, spec_config))
elif d == "holicity":
spec_config["train_split"] = config["train_splits"][i]
self._datasets.append(HolicityDataset(mode, spec_config))
else:
raise ValueError("Unknown dataset: " + d)
self._weights = config["weights"]
def __getitem__(self, item):
dataset = self._datasets[
np.random.choice(range(len(self._datasets)), p=self._weights)
]
return dataset[np.random.randint(len(dataset))]
def __len__(self):
return np.sum([len(d) for d in self._datasets])