File size: 1,157 Bytes
7cdf421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch.utils.data import ConcatDataset, Dataset
from .catalog import DatasetCatalog
from .utils import instantiate_from_config


class MyConcatDataset(Dataset):
    def __init__(self, dataset_name_list):
        super(MyConcatDataset, self).__init__()

        _datasets = []

        catalog = DatasetCatalog()
        for dataset_idx, dataset_name in enumerate(dataset_name_list):
            dataset_dict = getattr(catalog, dataset_name)

            target = dataset_dict['target']
            params = dataset_dict['params']
            print(target)
            print(params)
            dataset = instantiate_from_config(dict(target=target, params=params))

            _datasets.append(dataset)
        self.datasets = ConcatDataset(_datasets)

    def __len__(self):
        return self.datasets.__len__()

    def __getitem__(self, item):
        return self.datasets.__getitem__(item)

    def collate(self, instances):
        data = {key: [] for key in instances[0].keys()} if instances else {}

        for instance in instances:
            for key, value in instance.items():
                data[key].append(value)

        return data