File size: 2,518 Bytes
7e2a2a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler

class CustomDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, config, dataset, DDP_gpu=None, drop_last=False):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.config = config
        self.dataset = dataset

        if DDP_gpu is None:
            self.dataloader = torch.utils.data.DataLoader(
                self.dataset,
                batch_size=config['dataset']['batch_size'],
                shuffle=not config['dataset']['serial_batches'],
                num_workers=int(config['dataset']['n_threads']), drop_last=drop_last)
        else:
            sampler = DistributedSampler(self.dataset, num_replicas=self.config['training']['world_size'],
                                         rank=DDP_gpu)
            self.dataloader = torch.utils.data.DataLoader(
                self.dataset,
                batch_size=config['dataset']['batch_size'],
                shuffle=False,
                num_workers=int(config['dataset']['n_threads']),
                sampler=sampler,
                drop_last=drop_last)

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), 1e9)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.config['dataset']['batch_size'] >= 1e9:
                break
            yield data