File size: 7,266 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import random

from imaginaire.datasets.base import BaseDataset


class Dataset(BaseDataset):
    r"""Image dataset for use in FUNIT.

    Args:
        cfg (Config): Loaded config object.
        is_inference (bool): In train or inference mode?
    """

    def __init__(self, cfg, is_inference=False, is_test=False):
        self.paired = False
        super(Dataset, self).__init__(cfg, is_inference, is_test)
        self.num_content_classes = len(self.class_name_to_idx['images_content'])
        self.num_style_classes = len(self.class_name_to_idx['images_style'])
        self.sample_class_idx = None
        self.content_offset = 8888
        self.content_interval = 100

    def set_sample_class_idx(self, class_idx=None):
        r"""Set sample class idx.

        Args:
            class_idx (int): Which class idx to sample from.
        """
        self.sample_class_idx = class_idx
        if class_idx is None:
            self.epoch_length = \
                max([len(lmdb_keys) for _, lmdb_keys in self.mapping.items()])
        else:
            self.epoch_length = \
                len(self.mapping_class['images_style'][class_idx])

    def _create_mapping(self):
        r"""Creates mapping from idx to key in LMDB.

        Returns:
            (tuple):
              - self.mapping (dict): Dict with data type as key mapping idx to
              LMDB key.
              - self.epoch_length (int): Number of samples in an epoch.
        """
        idx_to_key, class_names = {}, {}
        for lmdb_idx, sequence_list in enumerate(self.sequence_lists):
            for data_type, data_type_sequence_list in sequence_list.items():
                class_names[data_type] = []
                if data_type not in idx_to_key:
                    idx_to_key[data_type] = []
                for sequence_name, filenames in data_type_sequence_list.items():
                    class_name = sequence_name.split('/')[0]
                    for filename in filenames:
                        idx_to_key[data_type].append({
                            'lmdb_root': self.lmdb_roots[lmdb_idx],
                            'lmdb_idx': lmdb_idx,
                            'sequence_name': sequence_name,
                            'filename': filename,
                            'class_name': class_name
                        })
                    class_names[data_type].append(class_name)
        self.mapping = idx_to_key
        self.epoch_length = max([len(lmdb_keys)
                                 for _, lmdb_keys in self.mapping.items()])

        # Create mapping from class name to class idx.
        self.class_name_to_idx = {}
        for data_type, class_names_data_type in class_names.items():
            self.class_name_to_idx[data_type] = {}
            class_names_data_type = sorted(list(set(class_names_data_type)))
            for class_idx, class_name in enumerate(class_names_data_type):
                self.class_name_to_idx[data_type][class_name] = class_idx

        # Add class idx to mapping.
        for data_type in self.mapping:
            for key in self.mapping[data_type]:
                key['class_idx'] = \
                    self.class_name_to_idx[data_type][key['class_name']]

        # Create a mapping from index to lmdb key for each class.
        idx_to_key_class = {}
        for data_type in self.mapping:
            idx_to_key_class[data_type] = {}
            for class_idx, class_name in enumerate(class_names[data_type]):
                idx_to_key_class[data_type][class_idx] = []
            for key in self.mapping[data_type]:
                idx_to_key_class[data_type][key['class_idx']].append(key)
        self.mapping_class = idx_to_key_class

        return self.mapping, self.epoch_length

    def _sample_keys(self, index):
        r"""Gets files to load for this sample.

        Args:
            index (int): Index in [0, len(dataset)].
        Returns:
            (tuple):
              - keys (dict): Each key of this dict is a data type.
              - lmdb_key (dict):
                - lmdb_idx (int): Chosen LMDB dataset root.
                - sequence_name (str): Chosen sequence in chosen dataset.
                - filename (str): Chosen filename in chosen sequence.
        """

        keys = {}
        if self.is_inference:  # evaluation mode
            lmdb_keys_content = self.mapping['images_content']
            keys['images_content'] = \
                lmdb_keys_content[
                    ((index + self.content_offset * self.sample_class_idx) *
                     self.content_interval) % len(lmdb_keys_content)]

            lmdb_keys_style = \
                self.mapping_class['images_style'][self.sample_class_idx]
            keys['images_style'] = lmdb_keys_style[index]
        else:
            lmdb_keys_content = self.mapping['images_content']
            lmdb_keys_style = self.mapping['images_style']
            keys['images_content'] = random.choice(lmdb_keys_content)
            keys['images_style'] = random.choice(lmdb_keys_style)
        return keys

    def __getitem__(self, index):
        r"""Gets selected files.

        Args:
            index (int): Index into dataset.
            concat (bool): Concatenate all items in labels?
        Returns:
            data (dict): Dict with all chosen data_types.
        """
        # Select a sample from the available data.
        keys_per_data_type = self._sample_keys(index)

        # Get class idx into a list.
        class_idxs = []
        for data_type in keys_per_data_type:
            class_idxs.append(keys_per_data_type[data_type]['class_idx'])

        # Get keys and lmdbs.
        keys, lmdbs = {}, {}
        for data_type in self.dataset_data_types:
            # Unpack keys.
            lmdb_idx = keys_per_data_type[data_type]['lmdb_idx']
            sequence_name = keys_per_data_type[data_type]['sequence_name']
            filename = keys_per_data_type[data_type]['filename']
            keys[data_type] = '%s/%s' % (sequence_name, filename)
            lmdbs[data_type] = self.lmdbs[data_type][lmdb_idx]

        # Load all data for this index.
        data = self.load_from_dataset(keys, lmdbs)

        # Apply ops pre augmentation.
        data = self.apply_ops(data, self.pre_aug_ops)

        # Do augmentations for images.
        data, is_flipped = self.perform_augmentation(data, paired=False, augment_ops=self.augmentor.augment_ops)

        # Apply ops post augmentation.
        data = self.apply_ops(data, self.post_aug_ops)
        data = self.apply_ops(data, self.full_data_post_aug_ops, full_data=True)

        # Convert images to tensor.
        data = self.to_tensor(data)

        # Remove any extra dimensions.
        for data_type in self.image_data_types:
            data[data_type] = data[data_type][0]

        # Package output.
        data['is_flipped'] = is_flipped
        data['key'] = keys_per_data_type
        data['labels_content'] = class_idxs[0]
        data['labels_style'] = class_idxs[1]

        return data