Spaces:
Runtime error
Runtime error
| import cv2 | |
| import random | |
| import numpy as np | |
| import spiga.data.loaders.dl_config as dl_cfg | |
| import spiga.data.loaders.dataloader as dl | |
| import spiga.data.visualize.plotting as plot | |
| def inspect_parser(): | |
| import argparse | |
| pars = argparse.ArgumentParser(description='Data augmentation and dataset visualization. ' | |
| 'Press Q to quit,' | |
| 'N to visualize the next image' | |
| ' and any other key to visualize the next default data.') | |
| pars.add_argument('database', type=str, | |
| choices=['wflw', '300wpublic', '300wprivate', 'cofw68', 'merlrav'], help='Database name') | |
| pars.add_argument('-a', '--anns', type=str, default='train', help='Annotation type: test, train or valid') | |
| pars.add_argument('-np', '--nopose', action='store_false', default=True, help='Avoid pose generation') | |
| pars.add_argument('-c', '--clean', action='store_true', help='Process without data augmentation for train') | |
| pars.add_argument('--shape', nargs='+', type=int, default=[256, 256], help='Image cropped shape (W,H)') | |
| pars.add_argument('--img', nargs='+', type=int, default=None, help='Select specific image ids') | |
| return pars.parse_args() | |
| class DatasetInspector: | |
| def __init__(self, database, anns_type, data_aug=True, pose=True, image_shape=(256,256)): | |
| data_config = dl_cfg.AlignConfig(database, anns_type) | |
| data_config.image_size = image_shape | |
| data_config.ftmap_size = image_shape | |
| data_config.generate_pose = pose | |
| if not data_aug: | |
| data_config.aug_names = [] | |
| self.data_config = data_config | |
| dataloader, dataset = dl.get_dataloader(1, data_config, debug=True) | |
| self.dataset = dataset | |
| self.dataloader = dataloader | |
| self.colors_dft = {'lnd': (plot.GREEN, plot.RED), 'pose': (plot.BLUE, plot.GREEN, plot.RED)} | |
| def show_dataset(self, ids_list=None): | |
| if ids_list is None: | |
| ids = self.get_idx(shuffle=self.data_config.shuffle) | |
| else: | |
| ids = ids_list | |
| for img_id in ids: | |
| data_dict = self.dataset[img_id] | |
| crop_imgs, full_img = self.plot_features(data_dict) | |
| # Plot crop | |
| if 'merge' in crop_imgs.keys(): | |
| crop = crop_imgs['merge'] | |
| else: | |
| crop = crop_imgs['lnd'] | |
| cv2.imshow('crop', crop) | |
| # Plot full | |
| cv2.imshow('image', full_img['lnd']) | |
| key = cv2.waitKey() | |
| if key == ord('q'): | |
| break | |
| def plot_features(self, data_dict, colors=None): | |
| # Init variables | |
| crop_imgs = {} | |
| full_imgs = {} | |
| if colors is None: | |
| colors = self.colors_dft | |
| # Cropped image | |
| image = data_dict['image'] | |
| landmarks = data_dict['landmarks'] | |
| visible = data_dict['visible'] | |
| if np.any(np.isnan(visible)): | |
| visible = None | |
| mask = data_dict['mask_ldm'] | |
| # Full image | |
| if 'image_ori' in data_dict.keys(): | |
| image_ori = data_dict['image_ori'] | |
| else: | |
| image_ori = cv2.imread(data_dict['imgpath']) | |
| landmarks_ori = data_dict['landmarks_ori'] | |
| visible_ori = data_dict['visible_ori'] | |
| if np.any(np.isnan(visible_ori)): | |
| visible_ori = None | |
| mask_ori = data_dict['mask_ldm_ori'] | |
| # Plot landmarks | |
| crop_imgs['lnd'] = self._plot_lnd(image, landmarks, visible, mask, colors=colors['lnd']) | |
| full_imgs['lnd'] = self._plot_lnd(image_ori, landmarks_ori, visible_ori, mask_ori, colors=colors['lnd']) | |
| if self.data_config.generate_pose: | |
| rot, trl, cam_matrix = self._extract_pose(data_dict) | |
| # Plot pose | |
| crop_imgs['pose'] = plot.draw_pose(image, rot, trl, cam_matrix, euler=True, colors=colors['pose']) | |
| # Plot merge features | |
| crop_imgs['merge'] = plot.draw_pose(crop_imgs['lnd'], rot, trl, cam_matrix, euler=True, colors=colors['pose']) | |
| return crop_imgs, full_imgs | |
| def get_idx(self, shuffle=False): | |
| ids = list(range(len(self.dataset))) | |
| if shuffle: | |
| random.shuffle(ids) | |
| return ids | |
| def reload_dataset(self, data_config=None): | |
| if data_config is None: | |
| data_config = self.data_config | |
| dataloader, dataset = dl.get_dataloader(1, data_config, debug=True) | |
| self.dataset = dataset | |
| self.dataloader = dataloader | |
| def _extract_pose(self, data_dict): | |
| # Rotation and translation matrix | |
| pose = data_dict['pose'] | |
| rot = pose[:3] | |
| trl = pose[3:] | |
| # Camera matrix | |
| cam_matrix = data_dict['cam_matrix'] | |
| # Check for ground truth anns | |
| if 'headpose_ori' in data_dict.keys(): | |
| if len(self.data_config.aug_names) == 0: | |
| print('Image headpose generated by ground truth data') | |
| pose_ori = data_dict['headpose_ori'] | |
| rot = pose_ori | |
| return rot, trl, cam_matrix | |
| def _plot_lnd(self, image, landmarks, visible, mask, max_shape_thr=720, colors=None): | |
| if colors is None: | |
| colors = self.colors_dft['lnd'] | |
| # Full image plots | |
| W, H, C = image.shape | |
| # Original image resize if need it | |
| if W > max_shape_thr or H > max_shape_thr: | |
| max_shape = max(W, H) | |
| scale_factor = max_shape_thr / max_shape | |
| resize_shape = (int(H * scale_factor), int(W * scale_factor)) | |
| image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, | |
| thick_scale=1 / scale_factor, colors=colors) | |
| image_out = cv2.resize(image_out, resize_shape) | |
| else: | |
| image_out = plot.draw_landmarks(image, landmarks, visible=visible, mask=mask, colors=colors) | |
| return image_out | |
| if __name__ == '__main__': | |
| args = inspect_parser() | |
| data_aug = True | |
| database = args.database | |
| anns_type = args.anns | |
| pose = args.nopose | |
| select_img = args.img | |
| if args.clean: | |
| data_aug = False | |
| if len(args.shape) != 2: | |
| raise ValueError('--shape requires two values: width and height. Ej: --shape 256 256') | |
| else: | |
| img_shape = tuple(args.shape) | |
| visualizer = DatasetInspector(database, anns_type, data_aug=data_aug, pose=pose, image_shape=img_shape) | |
| visualizer.show_dataset(ids_list=select_img) | |