File size: 10,599 Bytes
a01ef8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader as loader
import numpy as np
import random
import inspect

from tlt.datasets.dataset import BaseDataset


class PyTorchDataset(BaseDataset):
    """
    Base class to represent a PyTorch Dataset
    """

    def __init__(self, dataset_dir, dataset_name="", dataset_catalog=""):
        """
        Class constructor
        """
        BaseDataset.__init__(self, dataset_dir, dataset_name, dataset_catalog)

    @property
    def train_subset(self):
        """
        A subset of the dataset used for training
        """
        return torch.utils.data.Subset(self._dataset, self._train_indices) if self._train_indices else None

    @property
    def validation_subset(self):
        """
        A subset of the dataset used for validation/evaluation
        """
        return torch.utils.data.Subset(self._dataset, self._validation_indices) if self._validation_indices else None

    @property
    def test_subset(self):
        """
        A subset of the dataset held out for final testing/evaluation
        """
        return torch.utils.data.Subset(self._dataset, self._test_indices) if self._test_indices else None

    @property
    def data_loader(self):
        """
        A data loader object corresponding to the dataset
        """
        return self._data_loader

    @property
    def train_loader(self):
        """
        A data loader object corresponding to the training subset
        """
        return self._train_loader

    @property
    def validation_loader(self):
        """
        A data loader object corresponding to the validation subset
        """
        return self._validation_loader

    @property
    def test_loader(self):
        """
        A data loader object corresponding to the test subset
        """
        return self._test_loader

    def get_batch(self, subset='all'):
        """
        Get a single batch of images and labels from the dataset.

            Args:
                subset (str): default "all", can also be "train", "validation", or "test"

            Returns:
                (examples, labels)

            Raises:
                ValueError: if the dataset is not defined yet or the given subset is not valid
        """
        if subset == 'all' and self._dataset is not None:
            return next(iter(self._data_loader))
        elif subset == 'train' and self._train_loader is not None:
            return next(iter(self._train_loader))
        elif subset == 'validation' and self._validation_loader is not None:
            return next(iter(self._validation_loader))
        elif subset == 'test' and self._test_loader is not None:
            return next(iter(self._test_loader))
        else:
            raise ValueError("Unable to return a batch, because the dataset or subset hasn't been defined.")

    def shuffle_split(self, train_pct=.75, val_pct=.25, test_pct=0., shuffle_files=True, seed=None):
        """
        Randomly split the dataset into train, validation, and test subsets with a pseudo-random seed option.

            Args:
                train_pct (float): default .75, percentage of dataset to use for training
                val_pct (float):  default .25, percentage of dataset to use for validation
                test_pct (float): default 0.0, percentage of dataset to use for testing
                shuffle_files (bool): default True, optionally control whether shuffling occurs
                seed (None or int): default None, can be set for pseudo-randomization

            Raises:
                ValueError: if percentage input args are not floats or sum to greater than 1
        """
        if not (isinstance(train_pct, float) and isinstance(val_pct, float) and isinstance(test_pct, float)):
            raise ValueError("Percentage arguments must be floats.")
        if train_pct + val_pct + test_pct > 1.0:
            raise ValueError("Sum of percentage arguments must be less than or equal to 1.")

        length = len(self._dataset)
        train_size = int(train_pct * length)
        val_size = int(val_pct * length)
        test_size = int(test_pct * length)
        generator = torch.Generator().manual_seed(seed) if seed else None
        if shuffle_files:
            dataset_indices = torch.randperm(length, generator=generator).tolist()
        else:
            dataset_indices = range(length)
        self._train_indices = dataset_indices[:train_size]
        self._validation_indices = dataset_indices[train_size:train_size + val_size]
        if test_pct:
            self._test_indices = dataset_indices[train_size + val_size:train_size + val_size + test_size]
        else:
            self._test_indices = None
        self._validation_type = 'shuffle_split'
        if self._preprocessed and 'batch_size' in self._preprocessed:
            self._make_data_loaders(batch_size=self._preprocessed['batch_size'], generator=generator)

    def _make_data_loaders(self, batch_size, generator=None):
        """Make data loaders for the whole dataset and the subsets that have indices defined"""
        def seed_worker(worker_id):
            worker_seed = torch.initial_seed() % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        if self._dataset:
            self._data_loader = loader(self.dataset, batch_size=batch_size, shuffle=False,
                                       num_workers=self._num_workers, worker_init_fn=seed_worker, generator=generator)
        else:
            self._data_loader = None
        if self._train_indices:
            self._train_loader = loader(self.train_subset, batch_size=batch_size, shuffle=False,
                                        num_workers=self._num_workers, worker_init_fn=seed_worker, generator=generator)
        else:
            self._train_loader = None
        if self._validation_indices:
            self._validation_loader = loader(self.validation_subset, batch_size=batch_size, shuffle=False,
                                             num_workers=self._num_workers, worker_init_fn=seed_worker,
                                             generator=generator)
        else:
            self._validation_loader = None
        if self._test_indices:
            self._test_loader = loader(self.test_subset, batch_size=batch_size, shuffle=False,
                                       num_workers=self._num_workers, worker_init_fn=seed_worker,
                                       generator=generator)
        else:
            self._test_loader = None

    def preprocess(self, image_size='variable', batch_size=32, add_aug=None, **kwargs):
        """
        Preprocess the dataset to resize, normalize, and batch the images. Apply augmentation
        if specified.

            Args:
                image_size (int or 'variable'): desired square image size (if 'variable', does not alter image size)
                batch_size (int): desired batch size (default 32)
                add_aug (None or list[str]): Choice of augmentations (RandomHorizontalFlip, RandomRotation) to be
                                             applied during training
                kwargs: optional; additional keyword arguments for Resize and Normalize transforms
            Raises:
                ValueError if the dataset is not defined or has already been processed
        """
        # NOTE: Should this be part of init? If we get image_size and batch size during init,
        # then we don't need a separate call to preprocess.
        if not (self._dataset):
            raise ValueError("Unable to preprocess, because the dataset hasn't been defined.")

        if self._preprocessed:
            raise ValueError("Data has already been preprocessed: {}".format(self._preprocessed))

        if not isinstance(batch_size, int) or batch_size < 1:
            raise ValueError("batch_size should be an positive integer")

        if not image_size == 'variable' and not (isinstance(image_size, int) and image_size >= 1):
            raise ValueError("Input image_size must be either a positive int or 'variable'")

        # Get the user-specified keyword arguments
        resize_args = {k: v for k, v in kwargs.items() if k in inspect.getfullargspec(T.Resize).args}
        normalize_args = {k: v for k, v in kwargs.items() if k in inspect.getfullargspec(T.Normalize).args}

        def get_transform(image_size, add_aug):
            transforms = []
            if isinstance(image_size, int):
                transforms.append(T.Resize([image_size, image_size], **resize_args))
            if add_aug is not None:
                aug_dict = {'hflip': T.RandomHorizontalFlip(),
                            'rotate': T.RandomRotation(0.5)}
                aug_list = ['hflip', 'rotate']
                for option in add_aug:
                    if option not in aug_list:
                        raise ValueError("Unsupported augmentation for PyTorch:{}. \
                        Supported augmentations are {}".format(option, aug_list))
                    transforms.append(aug_dict[option])
            transforms.append(T.ToTensor())
            transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], **normalize_args))

            return T.Compose(transforms)

        self._dataset.transform = get_transform(image_size, add_aug)
        self._preprocessed = {'image_size': image_size, 'batch_size': batch_size}
        self._make_data_loaders(batch_size=batch_size)

    def get_inc_dataloaders(self):
        calib_dataloader = self.train_loader
        if self.validation_loader is not None:
            eval_dataloader = self.validation_loader
        elif self.test_loader is not None:
            eval_dataloader = self.test_loader
        else:
            eval_dataloader = self.train_loader

        return calib_dataloader, eval_dataloader