svjack's picture
Upload folder using huggingface_hub
d015578 verified
raw
history blame contribute delete
No virus
5.54 kB
import os
import json
import cv2
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from spiga.data.loaders.transforms import get_transformers
class AlignmentsDataset(Dataset):
'''Loads datasets of images with landmarks and bounding boxes.
'''
def __init__(self,
database,
json_file,
images_dir,
image_size=(128, 128),
transform=None,
indices=None,
debug=False):
"""
:param database: class DatabaseStruct containing all the specifics of the database
:param json_file: path to the json file which contains the names of the images, landmarks, bounding boxes, etc
:param images_dir: path of the directory containing the images.
:param image_size: tuple like e.g. (128, 128)
:param transform: composition of transformations that will be applied to the samples.
:param debug_mode: bool if True, loads a very reduced_version of the dataset for debugging purposes.
:param indices: If it is a list of indices, allows to work with the subset of
items specified by the list. If it is None, the whole set is used.
"""
self.database = database
self.images_dir = images_dir
self.transform = transform
self.image_size = image_size
self.indices = indices
self._imgs_dict = None
self.debug = debug
with open(json_file) as jsonfile:
self.data = json.load(jsonfile)
def __len__(self):
'''Returns the length of the dataset
'''
if self.indices is None:
return len(self.data)
else:
return len(self.indices)
def __getitem__(self, sample_idx):
'''Returns sample of the dataset of index idx'''
# To allow work with a subset
if self.indices is not None:
sample_idx = self.indices[sample_idx]
# Load sample image
img_name = os.path.join(self.images_dir, self.data[sample_idx]['imgpath'])
if not self._imgs_dict:
image_cv = cv2.imread(img_name)
else:
image_cv = self._imgs_dict[sample_idx]
# Some images are B&W. We make sure that any image has three channels.
if len(image_cv.shape) == 2:
image_cv = np.repeat(image_cv[:, :, np.newaxis], 3, axis=-1)
# Some images have alpha channel
image_cv = image_cv[:, :, :3]
image_cv = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image_cv)
# Load sample anns
ids = np.array(self.data[sample_idx]['ids'])
landmarks = np.array(self.data[sample_idx]['landmarks'])
bbox = np.array(self.data[sample_idx]['bbox'])
vis = np.array(self.data[sample_idx]['visible'])
headpose = self.data[sample_idx]['headpose']
# Generate bbox if need it
if bbox is None:
# Compute bbox using landmarks
aux = landmarks[vis == 1.0]
bbox = np.zeros(4)
bbox[0] = min(aux[:, 0])
bbox[1] = min(aux[:, 1])
bbox[2] = max(aux[:, 0]) - bbox[0]
bbox[3] = max(aux[:, 1]) - bbox[1]
# Clean and mask landmarks
mask_ldm = np.ones(self.database.num_landmarks)
if not self.database.ldm_ids == ids.tolist():
new_ldm = np.zeros((self.database.num_landmarks, 2))
new_vis = np.zeros(self.database.num_landmarks)
xyv = np.hstack((landmarks, vis[np.newaxis,:].T))
ids_dict = dict(zip(ids.astype(int).astype(str), xyv))
for pos, identifier in enumerate(self.database.ldm_ids):
if str(identifier) in ids_dict:
x, y, v = ids_dict[str(identifier)]
new_ldm[pos] = [x,y]
new_vis[pos] = v
else:
mask_ldm[pos] = 0
landmarks = new_ldm
vis = new_vis
sample = {'image': image,
'sample_idx': sample_idx,
'imgpath': img_name,
'ids_ldm': np.array(self.database.ldm_ids),
'bbox': bbox,
'bbox_raw': bbox,
'landmarks': landmarks,
'visible': vis.astype(np.float64),
'mask_ldm': mask_ldm,
'imgpath_local': self.data[sample_idx]['imgpath'],
}
if self.debug:
sample['landmarks_ori'] = landmarks
sample['visible_ori'] = vis.astype(np.float64)
sample['mask_ldm_ori'] = mask_ldm
if headpose is not None:
sample['headpose_ori'] = np.array(headpose)
if self.transform:
sample = self.transform(sample)
return sample
def get_dataset(data_config, pretreat=None, debug=False):
augmentors = get_transformers(data_config)
if pretreat is not None:
augmentors.append(pretreat)
dataset = AlignmentsDataset(data_config.database,
data_config.anns_file,
data_config.image_dir,
image_size=data_config.image_size,
transform=transforms.Compose(augmentors),
indices=data_config.ids,
debug=debug)
return dataset