import os import json import cv2 import numpy as np from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from spiga.data.loaders.transforms import get_transformers class AlignmentsDataset(Dataset): '''Loads datasets of images with landmarks and bounding boxes. ''' def __init__(self, database, json_file, images_dir, image_size=(128, 128), transform=None, indices=None, debug=False): """ :param database: class DatabaseStruct containing all the specifics of the database :param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc :param images_dir: path of the directory containing the images. :param image_size: tuple like e.g. (128, 128) :param transform: composition of transformations that will be applied to the samples. :param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes. :param indices: If it is a list of indices, allows to work with the subset of items specified by the list. If it is None, the whole set is used. """ self.database = database self.images_dir = images_dir self.transform = transform self.image_size = image_size self.indices = indices self._imgs_dict = None self.debug = debug with open(json_file) as jsonfile: self.data = json.load(jsonfile) def __len__(self): '''Returns the length of the dataset ''' if self.indices is None: return len(self.data) else: return len(self.indices) def __getitem__(self, sample_idx): '''Returns sample of the dataset of index idx''' # To allow work with a subset if self.indices is not None: sample_idx = self.indices[sample_idx] # Load sample image img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath']) if not self._imgs_dict: image_cv = cv2.imread(img_name) else: image_cv = self._imgs_dict[sample_idx] # Some images are B&W. We make sure that any image has three channels. if len(image_cv.shape) == 2: image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1) # Some images have alpha channel image_cv = image_cv[:, :, :3] image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) image = Image.fromarray(image_cv) # Load sample anns ids = np.array(self.data[sample_idx]['ids']) landmarks = np.array(self.data[sample_idx]['landmarks']) bbox = np.array(self.data[sample_idx]['bbox']) vis = np.array(self.data[sample_idx]['visible']) headpose = self.data[sample_idx]['headpose'] # Generate bbox if need it if bbox is None: # Compute bbox using landmarks aux = landmarks[vis == 1.0] bbox = np.zeros(4) bbox[0] = min(aux[:, 0]) bbox[1] = min(aux[:, 1]) bbox[2] = max(aux[:, 0]) - bbox[0] bbox[3] = max(aux[:, 1]) - bbox[1] # Clean and mask landmarks mask_ldm = np.ones(self.database.num_landmarks) if not self.database.ldm_ids == ids.tolist(): new_ldm = np.zeros((self.database.num_landmarks, 2)) new_vis = np.zeros(self.database.num_landmarks) xyv = np.hstack((landmarks, vis[np.newaxis,:].T)) ids_dict = dict(zip(ids.astype(int).astype(str), xyv)) for pos, identifier in enumerate(self.database.ldm_ids): if str(identifier) in ids_dict: x, y, v = ids_dict[str(identifier)] new_ldm[pos] = [x,y] new_vis[pos] = v else: mask_ldm[pos] = 0 landmarks = new_ldm vis = new_vis sample = {'image': image, 'sample_idx': sample_idx, 'imgpath': img_name, 'ids_ldm': np.array(self.database.ldm_ids), 'bbox': bbox, 'bbox_raw': bbox, 'landmarks': landmarks, 'visible': vis.astype(np.float64), 'mask_ldm': mask_ldm, 'imgpath_local': self.data[sample_idx]['imgpath'], } if self.debug: sample['landmarks_ori'] = landmarks sample['visible_ori'] = vis.astype(np.float64) sample['mask_ldm_ori'] = mask_ldm if headpose is not None: sample['headpose_ori'] = np.array(headpose) if self.transform: sample = self.transform(sample) return sample def get_dataset(data_config, pretreat=None, debug=False): augmentors = get_transformers(data_config) if pretreat is not None: augmentors.append(pretreat) dataset = AlignmentsDataset(data_config.database, data_config.anns_file, data_config.image_dir, image_size=data_config.image_size, transform=transforms.Compose(augmentors), indices=data_config.ids, debug=debug) return dataset