Spaces:
Running
Running
File size: 5,540 Bytes
d015578 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from spiga.data.loaders.transforms import get_transformers
class AlignmentsDataset(Dataset):
'''Loads datasets of images with landmarks and bounding boxes.
'''
def __init__(self,
database,
json_file,
images_dir,
image_size=(128, 128),
transform=None,
indices=None,
debug=False):
"""
:param database: class DatabaseStruct containing all the specifics of the database
:param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc
:param images_dir: path of the directory containing the images.
:param image_size: tuple like e.g. (128, 128)
:param transform: composition of transformations that will be applied to the samples.
:param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes.
:param indices: If it is a list of indices, allows to work with the subset of
items specified by the list. If it is None, the whole set is used.
"""
self.database = database
self.images_dir = images_dir
self.transform = transform
self.image_size = image_size
self.indices = indices
self._imgs_dict = None
self.debug = debug
with open(json_file) as jsonfile:
self.data = json.load(jsonfile)
def __len__(self):
'''Returns the length of the dataset
'''
if self.indices is None:
return len(self.data)
else:
return len(self.indices)
def __getitem__(self, sample_idx):
'''Returns sample of the dataset of index idx'''
# To allow work with a subset
if self.indices is not None:
sample_idx = self.indices[sample_idx]
# Load sample image
img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath'])
if not self._imgs_dict:
image_cv = cv2.imread(img_name)
else:
image_cv = self._imgs_dict[sample_idx]
# Some images are B&W. We make sure that any image has three channels.
if len(image_cv.shape) == 2:
image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1)
# Some images have alpha channel
image_cv = image_cv[:, :, :3]
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image_cv)
# Load sample anns
ids = np.array(self.data[sample_idx]['ids'])
landmarks = np.array(self.data[sample_idx]['landmarks'])
bbox = np.array(self.data[sample_idx]['bbox'])
vis = np.array(self.data[sample_idx]['visible'])
headpose = self.data[sample_idx]['headpose']
# Generate bbox if need it
if bbox is None:
# Compute bbox using landmarks
aux = landmarks[vis == 1.0]
bbox = np.zeros(4)
bbox[0] = min(aux[:, 0])
bbox[1] = min(aux[:, 1])
bbox[2] = max(aux[:, 0]) - bbox[0]
bbox[3] = max(aux[:, 1]) - bbox[1]
# Clean and mask landmarks
mask_ldm = np.ones(self.database.num_landmarks)
if not self.database.ldm_ids == ids.tolist():
new_ldm = np.zeros((self.database.num_landmarks, 2))
new_vis = np.zeros(self.database.num_landmarks)
xyv = np.hstack((landmarks, vis[np.newaxis,:].T))
ids_dict = dict(zip(ids.astype(int).astype(str), xyv))
for pos, identifier in enumerate(self.database.ldm_ids):
if str(identifier) in ids_dict:
x, y, v = ids_dict[str(identifier)]
new_ldm[pos] = [x,y]
new_vis[pos] = v
else:
mask_ldm[pos] = 0
landmarks = new_ldm
vis = new_vis
sample = {'image': image,
'sample_idx': sample_idx,
'imgpath': img_name,
'ids_ldm': np.array(self.database.ldm_ids),
'bbox': bbox,
'bbox_raw': bbox,
'landmarks': landmarks,
'visible': vis.astype(np.float64),
'mask_ldm': mask_ldm,
'imgpath_local': self.data[sample_idx]['imgpath'],
}
if self.debug:
sample['landmarks_ori'] = landmarks
sample['visible_ori'] = vis.astype(np.float64)
sample['mask_ldm_ori'] = mask_ldm
if headpose is not None:
sample['headpose_ori'] = np.array(headpose)
if self.transform:
sample = self.transform(sample)
return sample
def get_dataset(data_config, pretreat=None, debug=False):
augmentors = get_transformers(data_config)
if pretreat is not None:
augmentors.append(pretreat)
dataset = AlignmentsDataset(data_config.database,
data_config.anns_file,
data_config.image_dir,
image_size=data_config.image_size,
transform=transforms.Compose(augmentors),
indices=data_config.ids,
debug=debug)
return dataset
|