File size: 5,540 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from spiga.data.loaders.transforms import get_transformers


class AlignmentsDataset(Dataset):
    '''Loads datasets of images with landmarks and bounding boxes.
    '''

    def __init__(self,
                 database,
                 json_file,
                 images_dir,
                 image_size=(128, 128),
                 transform=None,
                 indices=None,
                 debug=False):
        """

        :param database: class DatabaseStruct containing all the specifics of the database

        :param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc

        :param images_dir: path of the directory containing the images.

        :param image_size: tuple like e.g. (128, 128)

        :param transform: composition of transformations that will be applied to the samples.

        :param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes.

        :param indices: If it is a list of indices, allows to work with the subset of
                        items specified by the list. If it is None, the whole set is used.
        """

        self.database = database
        self.images_dir = images_dir
        self.transform = transform
        self.image_size = image_size
        self.indices = indices
        self._imgs_dict = None
        self.debug = debug

        with open(json_file) as jsonfile:
            self.data = json.load(jsonfile)

    def __len__(self):
        '''Returns the length of the dataset
        '''
        if self.indices is None:
            return len(self.data)
        else:
            return len(self.indices)

    def __getitem__(self, sample_idx):
        '''Returns sample of the dataset of index idx'''

        # To allow work with a subset
        if self.indices is not None:
            sample_idx = self.indices[sample_idx]

        # Load sample image
        img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath'])
        if not self._imgs_dict:
            image_cv = cv2.imread(img_name)
        else:
            image_cv = self._imgs_dict[sample_idx]

        # Some images are B&W. We make sure that any image has three channels.
        if len(image_cv.shape) == 2:
            image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1)

        # Some images have alpha channel
        image_cv = image_cv[:, :, :3]

        image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(image_cv)

        # Load sample anns
        ids = np.array(self.data[sample_idx]['ids'])
        landmarks = np.array(self.data[sample_idx]['landmarks'])
        bbox = np.array(self.data[sample_idx]['bbox'])
        vis = np.array(self.data[sample_idx]['visible'])
        headpose = self.data[sample_idx]['headpose']

        # Generate bbox if need it
        if bbox is None:
            # Compute bbox using landmarks
            aux = landmarks[vis == 1.0]
            bbox = np.zeros(4)
            bbox[0] = min(aux[:, 0])
            bbox[1] = min(aux[:, 1])
            bbox[2] = max(aux[:, 0]) - bbox[0]
            bbox[3] = max(aux[:, 1]) - bbox[1]

        # Clean and mask landmarks
        mask_ldm = np.ones(self.database.num_landmarks)
        if not self.database.ldm_ids == ids.tolist():
            new_ldm = np.zeros((self.database.num_landmarks, 2))
            new_vis = np.zeros(self.database.num_landmarks)
            xyv = np.hstack((landmarks, vis[np.newaxis,:].T))
            ids_dict = dict(zip(ids.astype(int).astype(str), xyv))

            for pos, identifier in enumerate(self.database.ldm_ids):
                if str(identifier) in ids_dict:
                    x, y, v = ids_dict[str(identifier)]
                    new_ldm[pos] = [x,y]
                    new_vis[pos] = v
                else:
                    mask_ldm[pos] = 0
            landmarks = new_ldm
            vis = new_vis

        sample = {'image': image,
                  'sample_idx': sample_idx,
                  'imgpath': img_name,
                  'ids_ldm': np.array(self.database.ldm_ids),
                  'bbox': bbox,
                  'bbox_raw': bbox,
                  'landmarks': landmarks,
                  'visible': vis.astype(np.float64),
                  'mask_ldm': mask_ldm,
                  'imgpath_local': self.data[sample_idx]['imgpath'],
                  }

        if self.debug:
            sample['landmarks_ori'] = landmarks
            sample['visible_ori'] = vis.astype(np.float64)
            sample['mask_ldm_ori'] = mask_ldm
            if headpose is not None:
                sample['headpose_ori'] = np.array(headpose)

        if self.transform:
            sample = self.transform(sample)

        return sample


def get_dataset(data_config, pretreat=None, debug=False):

    augmentors = get_transformers(data_config)
    if pretreat is not None:
        augmentors.append(pretreat)

    dataset = AlignmentsDataset(data_config.database,
                                data_config.anns_file,
                                data_config.image_dir,
                                image_size=data_config.image_size,
                                transform=transforms.Compose(augmentors),
                                indices=data_config.ids,
                                debug=debug)
    return dataset