from torch.utils.data import ConcatDataset, Dataset from .catalog import DatasetCatalog from .utils import instantiate_from_config class MyConcatDataset(Dataset): def __init__(self, dataset_name_list): super(MyConcatDataset, self).__init__() _datasets = [] catalog = DatasetCatalog() for dataset_idx, dataset_name in enumerate(dataset_name_list): dataset_dict = getattr(catalog, dataset_name) target = dataset_dict['target'] params = dataset_dict['params'] print(target) print(params) dataset = instantiate_from_config(dict(target=target, params=params)) _datasets.append(dataset) self.datasets = ConcatDataset(_datasets) def __len__(self): return self.datasets.__len__() def __getitem__(self, item): return self.datasets.__getitem__(item) def collate(self, instances): data = {key: [] for key in instances[0].keys()} if instances else {} for instance in instances: for key, value in instance.items(): data[key].append(value) return data