File size: 4,550 Bytes
1b2a9b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from swapae.data.base_dataset import BaseDataset
import swapae.util as util


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "swapae.data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt):
    return ConfigurableDataLoader(opt)


class DataPrefetcher():
    def __init__(self, dataset):
        self.dataset = dataset
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_input = next(self.dataset)
        except StopIteration:
            self.next_input = None
            return

        with torch.cuda.stream(self.stream):
            self.next_input = self.next_input.cuda(non_blocking=True)

    def __next__(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        input = self.next_input
        self.preload()
        return input

    def __iter__(self):
        return self

    def __len__(self):
        return len(self.dataset)


class ConfigurableDataLoader():
    def __init__(self, opt):
        self.opt = opt
        self.initialize(opt.phase)

    def initialize(self, phase):
        opt = self.opt
        self.phase = phase
        if hasattr(self, "dataloader"):
            del self.dataloader
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        dataset = dataset_class(util.copyconf(opt, phase=phase, isTrain=phase == "train"))
        shuffle = phase == "train" if opt.shuffle_dataset is None else opt.shuffle_dataset == "true"
        print("dataset [%s] of size %d was created. shuffled=%s" % (type(dataset).__name__, len(dataset), shuffle))
        #dataset = DataPrefetcher(dataset)
        self.opt = opt
        self.dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=opt.batch_size,
            shuffle=shuffle,
            num_workers=int(opt.num_gpus),
            drop_last=phase == "train",
        )
        #self.dataloader = dataset
        self.dataloader_iterator = iter(self.dataloader)
        self.repeat = phase == "train"
        self.length = len(dataset)
        self.underlying_dataset = dataset

    def set_phase(self, target_phase):
        if self.phase != target_phase:
            self.initialize(target_phase)

    def __iter__(self):
        self.dataloader_iterator = iter(self.dataloader)
        return self

    def __len__(self):
        return self.length

    def __next__(self):
        try:
            return next(self.dataloader_iterator)
        except StopIteration:
            if self.repeat:
                self.dataloader_iterator = iter(self.dataloader)
                return next(self.dataloader_iterator)
            else:
                raise StopIteration