Spaces:
Running
Running
import cv2 | |
import random | |
import numpy as np | |
import spiga.data.loaders.dl_config as dl_cfg | |
import spiga.data.loaders.dataloader as dl | |
import spiga.data.visualize.plotting as plot | |
def inspect_parser(): | |
import argparse | |
pars = argparse.ArgumentParser(description='Data augmentation and dataset visualization. ' | |
'Press Q to quit,' | |
'N to visualize the next image' | |
' and any other key to visualize the next default data.') | |
pars.add_argument('database', type=str, | |
choices=['wflw', '300wpublic', '300wprivate', 'cofw68', 'merlrav'], help='Database name') | |
pars.add_argument('-a', '--anns', type=str, default='train', help='Annotation type: test, train or valid') | |
pars.add_argument('-np', '--nopose', action='store_false', default=True, help='Avoid pose generation') | |
pars.add_argument('-c', '--clean', action='store_true', help='Process without data augmentation for train') | |
pars.add_argument('--shape', nargs='+', type=int, default=[256, 256], help='Image cropped shape (W,H)') | |
pars.add_argument('--img', nargs='+', type=int, default=None, help='Select specific image ids') | |
return pars.parse_args() | |
class DatasetInspector: | |
def __init__(self, database, anns_type, data_aug=True, pose=True, image_shape=(256,256)): | |
data_config = dl_cfg.AlignConfig(database, anns_type) | |
data_config.image_size = image_shape | |
data_config.ftmap_size = image_shape | |
data_config.generate_pose = pose | |
if not data_aug: | |
data_config.aug_names = [] | |
self.data_config = data_config | |
dataloader, dataset = dl.get_dataloader(1, data_config, debug=True) | |
self.dataset = dataset | |
self.dataloader = dataloader | |
self.colors_dft = {'lnd': (plot.GREEN, plot.RED), 'pose': (plot.BLUE, plot.GREEN, plot.RED)} | |
def show_dataset(self, ids_list=None): | |
if ids_list is None: | |
ids = self.get_idx(shuffle=self.data_config.shuffle) | |
else: | |
ids = ids_list | |
for img_id in ids: | |
data_dict = self.dataset[img_id] | |
crop_imgs, full_img = self.plot_features(data_dict) | |
# Plot crop | |
if 'merge' in crop_imgs.keys(): | |
crop = crop_imgs['merge'] | |
else: | |
crop = crop_imgs['lnd'] | |
cv2.imshow('crop', crop) | |
# Plot full | |
cv2.imshow('image', full_img['lnd']) | |
key = cv2.waitKey() | |
if key == ord('q'): | |
break | |
def plot_features(self, data_dict, colors=None): | |
# Init variables | |
crop_imgs = {} | |
full_imgs = {} | |
if colors is None: | |
colors = self.colors_dft | |
# Cropped image | |
image = data_dict['image'] | |
landmarks = data_dict['landmarks'] | |
visible = data_dict['visible'] | |
if np.any(np.isnan(visible)): | |
visible = None | |
mask = data_dict['mask_ldm'] | |
# Full image | |
if 'image_ori' in data_dict.keys(): | |
image_ori = data_dict['image_ori'] | |
else: | |
image_ori = cv2.imread(data_dict['imgpath']) | |
landmarks_ori = data_dict['landmarks_ori'] | |
visible_ori = data_dict['visible_ori'] | |
if np.any(np.isnan(visible_ori)): | |
visible_ori = None | |
mask_ori = data_dict['mask_ldm_ori'] | |
# Plot landmarks | |
crop_imgs['lnd'] = self._plot_lnd(image, landmarks, visible, mask, colors=colors['lnd']) | |
full_imgs['lnd'] = self._plot_lnd(image_ori, landmarks_ori, visible_ori, mask_ori, colors=colors['lnd']) | |
if self.data_config.generate_pose: | |
rot, trl, cam_matrix = self._extract_pose(data_dict) | |
# Plot pose | |
crop_imgs['pose'] = plot.draw_pose(image, rot, trl, cam_matrix, euler=True, colors=colors['pose']) | |
# Plot merge features | |
crop_imgs['merge'] = plot.draw_pose(crop_imgs['lnd'], rot, trl, cam_matrix, euler=True, colors=colors['pose']) | |
return crop_imgs, full_imgs | |
def get_idx(self, shuffle=False): | |
ids = list(range(len(self.dataset))) | |
if shuffle: | |
random.shuffle(ids) | |
return ids | |
def reload_dataset(self, data_config=None): | |
if data_config is None: | |
data_config = self.data_config | |
dataloader, dataset = dl.get_dataloader(1, data_config, debug=True) | |
self.dataset = dataset | |
self.dataloader = dataloader | |
def _extract_pose(self, data_dict): | |
# Rotation and translation matrix | |
pose = data_dict['pose'] | |
rot = pose[:3] | |
trl = pose[3:] | |
# Camera matrix | |
cam_matrix = data_dict['cam_matrix'] | |
# Check for ground truth anns | |
if 'headpose_ori' in data_dict.keys(): | |
if len(self.data_config.aug_names) == 0: | |
print('Image headpose generated by ground truth data') | |
pose_ori = data_dict['headpose_ori'] | |
rot = pose_ori | |
return rot, trl, cam_matrix | |
def _plot_lnd(self, image, landmarks, visible, mask, max_shape_thr=720, colors=None): | |
if colors is None: | |
colors = self.colors_dft['lnd'] | |
# Full image plots | |
W, H, C = image.shape | |
# Original image resize if need it | |
if W > max_shape_thr or H > max_shape_thr: | |
max_shape = max(W, H) | |
scale_factor = max_shape_thr / max_shape | |
resize_shape = (int(H * scale_factor), int(W * scale_factor)) | |
image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, | |
thick_scale=1 / scale_factor, colors=colors) | |
image_out = cv2.resize(image_out, resize_shape) | |
else: | |
image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, colors=colors) | |
return image_out | |
if __name__ == '__main__': | |
args = inspect_parser() | |
data_aug = True | |
database = args.database | |
anns_type = args.anns | |
pose = args.nopose | |
select_img = args.img | |
if args.clean: | |
data_aug = False | |
if len(args.shape) != 2: | |
raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256') | |
else: | |
img_shape = tuple(args.shape) | |
visualizer = DatasetInspector(database, anns_type, data_aug=data_aug, pose=pose, image_shape=img_shape) | |
visualizer.show_dataset(ids_list=select_img) | |