Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| # flake8: noqa: E712 | |
| """Utils for handling datasets.""" | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| # https://github.com/albumentations-team/albumentations#comments | |
| import cv2 | |
| # from imaginaire.utils.distributed import master_only_print as print | |
| import albumentations as alb # noqa nopep8 | |
| cv2.setNumThreads(0) | |
| cv2.ocl.setUseOpenCL(False) | |
| IMG_EXTENSIONS = ('jpg', 'jpeg', 'png', 'ppm', 'bmp', | |
| 'pgm', 'tif', 'tiff', 'webp', | |
| 'JPG', 'JPEG', 'PNG', 'PPM', 'BMP', | |
| 'PGM', 'TIF', 'TIFF', 'WEBP') | |
| HDR_IMG_EXTENSIONS = ('hdr',) | |
| VIDEO_EXTENSIONS = 'mp4' | |
| class Augmentor(object): | |
| r"""Handles data augmentation using albumentations library.""" | |
| def __init__(self, aug_list, individual_video_frame_aug_list, image_data_types, is_mask, | |
| keypoint_data_types, interpolator): | |
| r"""Initializes augmentation pipeline. | |
| Args: | |
| aug_list (list): List of augmentation operations in sequence. | |
| individual_video_frame_aug_list (list): List of augmentation operations in sequence that will be applied | |
| to individual frames of videos independently. | |
| image_data_types (list): List of keys in expected inputs. | |
| is_mask (dict): Whether this data type is discrete masks? | |
| keypoint_data_types (list): List of keys which are keypoints. | |
| """ | |
| self.aug_list = aug_list | |
| self.individual_video_frame_aug_list = individual_video_frame_aug_list | |
| self.image_data_types = image_data_types | |
| self.is_mask = is_mask | |
| self.crop_h, self.crop_w = None, None | |
| self.resize_h, self.resize_w = None, None | |
| self.resize_smallest_side = None | |
| self.max_time_step = 1 | |
| self.keypoint_data_types = keypoint_data_types | |
| self.interpolator = interpolator | |
| self.augment_ops = self._build_augmentation_ops() | |
| self.individual_video_frame_augmentation_ops = self._build_individual_video_frame_augmentation_ops() | |
| # Both crop and resize can't be none at the same time. | |
| if self.crop_h is None and self.resize_smallest_side is None and \ | |
| self.resize_h is None: | |
| raise ValueError('resize_smallest_side, resize_h_w, ' | |
| 'and crop_h_w cannot all be missing.') | |
| # If resize_smallest_side is given, resize_h_w should not be give. | |
| if self.resize_smallest_side is not None: | |
| assert self.resize_h is None, \ | |
| 'Cannot have both `resize_smallest_side` and `resize_h_w` set.' | |
| if self.resize_smallest_side is None and self.resize_h is None: | |
| self.resize_h, self.resize_w = self.crop_h, self.crop_w | |
| def _build_individual_video_frame_augmentation_ops(self): | |
| r"""Builds sequence of augmentation ops that will be applied to each frame in the video independently. | |
| Returns: | |
| (list of alb.ops): List of augmentation ops. | |
| """ | |
| augs = [] | |
| for key, value in self.individual_video_frame_aug_list.items(): | |
| if key == 'random_scale_limit': | |
| if type(value) == float: | |
| scale_limit_lb = scale_limit_ub = value | |
| p = 1 | |
| else: | |
| scale_limit_lb = value['scale_limit_lb'] | |
| scale_limit_ub = value['scale_limit_ub'] | |
| p = value['p'] | |
| augs.append(alb.RandomScale(scale_limit=(-scale_limit_lb, scale_limit_ub), p=p)) | |
| elif key == 'random_crop_h_w': | |
| h, w = value.split(',') | |
| h, w = int(h), int(w) | |
| self.crop_h, self.crop_w = h, w | |
| augs.append(alb.PadIfNeeded(min_height=h, min_width=w)) | |
| augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) | |
| return augs | |
| def _build_augmentation_ops(self): | |
| r"""Builds sequence of augmentation ops. | |
| Returns: | |
| (list of alb.ops): List of augmentation ops. | |
| """ | |
| augs = [] | |
| for key, value in self.aug_list.items(): | |
| if key == 'resize_smallest_side': | |
| if isinstance(value, int): | |
| self.resize_smallest_side = value | |
| else: | |
| h, w = value.split(',') | |
| h, w = int(h), int(w) | |
| self.resize_smallest_side = (h, w) | |
| elif key == 'resize_h_w': | |
| h, w = value.split(',') | |
| h, w = int(h), int(w) | |
| self.resize_h, self.resize_w = h, w | |
| elif key == 'random_resize_h_w_aspect': | |
| aspect_start, aspect_end = value.find('('), value.find(')') | |
| aspect = value[aspect_start+1:aspect_end] | |
| aspect_min, aspect_max = aspect.split(',') | |
| h, w = value[:aspect_start].split(',')[:2] | |
| h, w = int(h), int(w) | |
| aspect_min, aspect_max = float(aspect_min), float(aspect_max) | |
| augs.append(alb.RandomResizedCrop( | |
| h, w, scale=(1, 1), | |
| ratio=(aspect_min, aspect_max), always_apply=True, p=1)) | |
| self.resize_h, self.resize_w = h, w | |
| elif key == 'rotate': | |
| augs.append(alb.Rotate( | |
| limit=value, always_apply=True, p=1)) | |
| elif key == 'random_rotate_90': | |
| augs.append(alb.RandomRotate90(always_apply=False, p=0.5)) | |
| elif key == 'random_scale_limit': | |
| augs.append(alb.RandomScale(scale_limit=(0, value), p=1)) | |
| elif key == 'random_crop_h_w': | |
| h, w = value.split(',') | |
| h, w = int(h), int(w) | |
| self.crop_h, self.crop_w = h, w | |
| augs.append(alb.RandomCrop(h, w, always_apply=True, p=1)) | |
| elif key == 'center_crop_h_w': | |
| h, w = value.split(',') | |
| h, w = int(h), int(w) | |
| self.crop_h, self.crop_w = h, w | |
| augs.append(alb.CenterCrop(h, w, always_apply=True, p=1)) | |
| elif key == 'horizontal_flip': | |
| # This is handled separately as we need to keep track if this | |
| # was applied in order to correctly modify keypoint data. | |
| if value: | |
| augs.append(alb.HorizontalFlip(always_apply=False, p=0.5)) | |
| # The options below including contrast, blur, motion_blur, compression, gamma | |
| # were used during developing face-vid2vid. | |
| elif key == 'contrast': | |
| brightness_limit = value['brightness_limit'] | |
| contrast_limit = value['contrast_limit'] | |
| p = value['p'] | |
| augs.append(alb.RandomBrightnessContrast( | |
| brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=p)) | |
| elif key == 'blur': | |
| blur_limit = value['blur_limit'] | |
| p = value['p'] | |
| augs.append(alb.Blur(blur_limit=blur_limit, p=p)) | |
| elif key == 'motion_blur': | |
| blur_limit = value['blur_limit'] | |
| p = value['p'] | |
| augs.append(alb.MotionBlur(blur_limit=blur_limit, p=p)) | |
| elif key == 'compression': | |
| quality_lower = value['quality_lower'] | |
| p = value['p'] | |
| augs.append(alb.ImageCompression(quality_lower=quality_lower, p=p)) | |
| elif key == 'gamma': | |
| gamma_limit_lb = value['gamma_limit_lb'] | |
| gamma_limit_ub = value['gamma_limit_ub'] | |
| p = value['p'] | |
| augs.append(alb.RandomGamma(gamma_limit=(gamma_limit_lb, gamma_limit_ub), p=p)) | |
| elif key == 'max_time_step': | |
| self.max_time_step = value | |
| assert self.max_time_step >= 1, \ | |
| 'max_time_step has to be at least 1' | |
| else: | |
| raise ValueError('Unknown augmentation %s' % (key)) | |
| return augs | |
| def _choose_image_key(self, inputs): | |
| r"""Choose key to replace with 'image' for input to albumentations. | |
| Returns: | |
| key (str): Chosen key to be replace with 'image' | |
| """ | |
| if 'image' in inputs: | |
| return 'image' | |
| for data_type in inputs: | |
| if data_type in self.image_data_types: | |
| return data_type | |
| def _choose_keypoint_key(self, inputs): | |
| r"""Choose key to replace with 'keypoints' for input to albumentations. | |
| Returns: | |
| key (str): Chosen key to be replace with 'keypoints' | |
| """ | |
| if not self.keypoint_data_types: | |
| return None | |
| if 'keypoints' in inputs: | |
| return 'keypoints' | |
| for data_type in inputs: | |
| if data_type in self.keypoint_data_types: | |
| return data_type | |
| def _create_augmentation_targets(self, inputs): | |
| r"""Create additional targets as required by the albumentation library. | |
| Args: | |
| inputs (dict): Keys are from self.augmentable_data_types. Values can | |
| be numpy.ndarray or list of numpy.ndarray | |
| (image or list of images). | |
| Returns: | |
| (dict): | |
| - targets (dict): Dict containing mapping of keys to image/mask types. | |
| - new_inputs (dict): Dict containing mapping of keys to data. | |
| """ | |
| # Get additional target list. | |
| targets, new_inputs = {}, {} | |
| for data_type in inputs: | |
| if data_type in self.keypoint_data_types: | |
| # Keypoint-type. | |
| target_type = 'keypoints' | |
| elif data_type in self.image_data_types: | |
| # Image-type. | |
| # Find the target type (image/mask) based on interpolation | |
| # method. | |
| if self.is_mask[data_type]: | |
| target_type = 'mask' | |
| else: | |
| target_type = 'image' | |
| else: | |
| raise ValueError( | |
| 'Data type: %s is not image or keypoint' % (data_type)) | |
| current_data_type_inputs = inputs[data_type] | |
| if not isinstance(current_data_type_inputs, list): | |
| current_data_type_inputs = [current_data_type_inputs] | |
| # Create additional_targets and inputs when there are multiples. | |
| for idx, new_input in enumerate(current_data_type_inputs): | |
| key = data_type | |
| if idx > 0: | |
| key = '%s::%05d' % (key, idx) | |
| targets[key] = target_type | |
| new_inputs[key] = new_input | |
| return targets, new_inputs | |
| def _collate_augmented(self, augmented): | |
| r"""Collate separated images back into sequence, grouped by keys. | |
| Args: | |
| augmented (dict): Dict containing frames with keys of the form | |
| 'key', 'key::00001', 'key::00002', ..., 'key::N'. | |
| Returns: | |
| (dict): | |
| - outputs (dict): Dict with list of collated inputs, i.e. frames of | |
| - same key are arranged in order ['key', 'key::00001', ..., 'key::N']. | |
| """ | |
| full_keys = sorted(augmented.keys()) | |
| outputs = {} | |
| for full_key in full_keys: | |
| if '::' not in full_key: | |
| # First occurrence of this key. | |
| key = full_key | |
| outputs[key] = [] | |
| else: | |
| key = full_key.split('::')[0] | |
| outputs[key].append(augmented[full_key]) | |
| return outputs | |
| def _get_resize_h_w(self, height, width): | |
| r"""Get height and width to resize to, given smallest side. | |
| Args: | |
| height (int): Input image height. | |
| width (int): Input image width. | |
| Returns: | |
| (dict): | |
| - height (int): Height to resize image to. | |
| - width (int): Width to resize image to. | |
| """ | |
| if self.resize_smallest_side is None: | |
| return self.resize_h, self.resize_w | |
| if isinstance(self.resize_smallest_side, int): | |
| resize_smallest_height, resize_smallest_width = self.resize_smallest_side, self.resize_smallest_side | |
| else: | |
| resize_smallest_height, resize_smallest_width = self.resize_smallest_side | |
| if height * resize_smallest_width <= width * resize_smallest_height: | |
| new_height = resize_smallest_height | |
| new_width = int(np.round(new_height * width / float(height))) | |
| else: | |
| new_width = resize_smallest_width | |
| new_height = int(np.round(new_width * height / float(width))) | |
| return new_height, new_width | |
| def _perform_unpaired_augmentation(self, inputs, augment_ops): | |
| r"""Perform different data augmentation on different image inputs. Note that this operation only works | |
| Args: | |
| inputs (dict): Keys are from self.image_data_types. Values are list | |
| of numpy.ndarray (list of images). | |
| augment_ops (list): The augmentation operations. | |
| Returns: | |
| (dict): | |
| - augmented (dict): Augmented inputs, with same keys as inputs. | |
| - is_flipped (dict): Flag which tells if images have been LR flipped. | |
| """ | |
| # Process each data type separately as this is unpaired augmentation. | |
| is_flipped = {} | |
| for data_type in inputs: | |
| assert data_type in self.image_data_types | |
| augmented, flipped_flag = self._perform_paired_augmentation( | |
| {data_type: inputs[data_type]}, augment_ops) | |
| inputs[data_type] = augmented[data_type] | |
| is_flipped[data_type] = flipped_flag | |
| return inputs, is_flipped | |
| def _perform_paired_augmentation(self, inputs, augment_ops): | |
| r"""Perform same data augmentation on all inputs. | |
| Args: | |
| inputs (dict): Keys are from self.augmentable_data_types. Values are | |
| list of numpy.ndarray (list of images). | |
| augment_ops (list): The augmentation operations. | |
| Returns: | |
| (dict): | |
| - augmented (dict): Augmented inputs, with same keys as inputs. | |
| - is_flipped (bool): Flag which tells if images have been LR flipped. | |
| """ | |
| # Different data types may have different sizes and we use the largest one as the original size. | |
| # Convert PIL images to numpy array. | |
| self.original_h, self.original_w = 0, 0 | |
| for data_type in inputs: | |
| if data_type in self.keypoint_data_types or \ | |
| data_type not in self.image_data_types: | |
| continue | |
| for idx in range(len(inputs[data_type])): | |
| value = inputs[data_type][idx] | |
| # Get resize h, w. | |
| w, h = get_image_size(value) | |
| self.original_h, self.original_w = max(self.original_h, h), max(self.original_w, w) | |
| # self.original_h, self.original_w = h, w | |
| # self.resize_h, self.resize_w = self._get_resize_h_w(h, w) | |
| # Convert to numpy array with 3 dims (H, W, C). | |
| value = np.array(value) | |
| if value.ndim == 2: | |
| value = value[..., np.newaxis] | |
| inputs[data_type][idx] = value | |
| self.resize_h, self.resize_w = self._get_resize_h_w(self.original_h, self.original_w) | |
| # Add resize op to augmentation ops. | |
| aug_ops_with_resize = [alb.Resize( | |
| self.resize_h, self.resize_w, interpolation=getattr(cv2, self.interpolator), always_apply=1, p=1 | |
| )] + augment_ops | |
| # Create targets. | |
| targets, new_inputs = self._create_augmentation_targets(inputs) | |
| extra_params = {} | |
| # Albumentation requires a key called 'image' and | |
| # a key called 'keypoints', if any keypoints are being passed in. | |
| # Arbitrarily choose one key of image type to be 'image'. | |
| chosen_image_key = self._choose_image_key(inputs) | |
| new_inputs['image'] = new_inputs.pop(chosen_image_key) | |
| targets['image'] = targets.pop(chosen_image_key) | |
| # Arbitrarily choose one key of keypoint type to be 'keypoints'. | |
| chosen_keypoint_key = self._choose_keypoint_key(inputs) | |
| if chosen_keypoint_key is not None: | |
| new_inputs['keypoints'] = new_inputs.pop(chosen_keypoint_key) | |
| targets['keypoints'] = targets.pop(chosen_keypoint_key) | |
| extra_params['keypoint_params'] = alb.KeypointParams( | |
| format='xy', remove_invisible=False) | |
| # Do augmentation. | |
| augmented = alb.ReplayCompose( | |
| aug_ops_with_resize, additional_targets=targets, | |
| **extra_params)(**new_inputs) | |
| augmentation_params = augmented.pop('replay') | |
| # Check if flipping has occurred. | |
| is_flipped = False | |
| for augmentation_param in augmentation_params['transforms']: | |
| if 'HorizontalFlip' in augmentation_param['__class_fullname__']: | |
| is_flipped = augmentation_param['applied'] | |
| self.is_flipped = is_flipped | |
| # Replace the key 'image' with chosen_image_key, same for 'keypoints'. | |
| augmented[chosen_image_key] = augmented.pop('image') | |
| if chosen_keypoint_key is not None: | |
| augmented[chosen_keypoint_key] = augmented.pop('keypoints') | |
| # Pack images back into a sequence. | |
| augmented = self._collate_augmented(augmented) | |
| # Convert keypoint types to np.array from list. | |
| for data_type in self.keypoint_data_types: | |
| augmented[data_type] = np.array(augmented[data_type]) | |
| return augmented, is_flipped | |
| def perform_augmentation(self, inputs, paired, augment_ops): | |
| r"""Entry point for augmentation. | |
| Args: | |
| inputs (dict): Keys are from self.augmentable_data_types. Values are | |
| list of numpy.ndarray (list of images). | |
| paired (bool): Apply same augmentation to all input keys? | |
| augment_ops (list): The augmentation operations. | |
| """ | |
| # Make sure that all inputs are of same size, else trouble will | |
| # ensue. This is because different images might have different | |
| # aspect ratios. | |
| # Check within data type. | |
| for data_type in inputs: | |
| if data_type in self.keypoint_data_types or \ | |
| data_type not in self.image_data_types: | |
| continue | |
| for idx in range(len(inputs[data_type])): | |
| if idx == 0: | |
| w, h = get_image_size(inputs[data_type][idx]) | |
| else: | |
| this_w, this_h = get_image_size(inputs[data_type][idx]) | |
| # assert this_w == w and this_h == h | |
| # assert this_w / (1.0 * this_h) == w / (1.0 * h) | |
| # Check across data types. | |
| if paired and self.resize_smallest_side is not None: | |
| for idx, data_type in enumerate(inputs): | |
| if data_type in self.keypoint_data_types or \ | |
| data_type not in self.image_data_types: | |
| continue | |
| if paired: | |
| return self._perform_paired_augmentation(inputs, augment_ops) | |
| else: | |
| return self._perform_unpaired_augmentation(inputs, augment_ops) | |
| def load_from_lmdb(keys, lmdbs): | |
| r"""Load keys from lmdb handles. | |
| Args: | |
| keys (dict): This has data_type as key, and a list of paths into LMDB as | |
| values. | |
| lmdbs (dict): This has data_type as key, and LMDB handle as value. | |
| Returns: | |
| data (dict): This has data_type as key, and a list of decoded items from | |
| LMDBs as value. | |
| """ | |
| data = {} | |
| for data_type in keys: | |
| if data_type not in data: | |
| data[data_type] = [] | |
| data_type_keys = keys[data_type] | |
| if not isinstance(data_type_keys, list): | |
| data_type_keys = [data_type_keys] | |
| for key in data_type_keys: | |
| data[data_type].append(lmdbs[data_type].getitem_by_path( | |
| key.encode(), data_type)) | |
| return data | |
| def load_from_folder(keys, handles): | |
| r"""Load keys from lmdb handles. | |
| Args: | |
| keys (dict): This has data_type as key, and a list of paths as | |
| values. | |
| handles (dict): This has data_type as key, and Folder handle as value. | |
| Returns: | |
| data (dict): This has data_type as key, and a list of decoded items from | |
| folders as value. | |
| """ | |
| data = {} | |
| for data_type in keys: | |
| if data_type not in data: | |
| data[data_type] = [] | |
| data_type_keys = keys[data_type] | |
| if not isinstance(data_type_keys, list): | |
| data_type_keys = [data_type_keys] | |
| for key in data_type_keys: | |
| data[data_type].append(handles[data_type].getitem_by_path( | |
| key.encode(), data_type)) | |
| return data | |
| def load_from_object_store(keys, handles): | |
| r"""Load keys from AWS S3 handles. | |
| Args: | |
| keys (dict): This has data_type as key, and a list of paths as | |
| values. | |
| handles (dict): This has data_type as key, and Folder handle as value. | |
| Returns: | |
| data (dict): This has data_type as key, and a list of decoded items from | |
| folders as value. | |
| """ | |
| data = {} | |
| for data_type in keys: | |
| if data_type not in data: | |
| data[data_type] = [] | |
| data_type_keys = keys[data_type] | |
| if not isinstance(data_type_keys, list): | |
| data_type_keys = [data_type_keys] | |
| for key in data_type_keys: | |
| while True: | |
| try: | |
| data[data_type].append(handles[data_type].getitem_by_path(key, data_type)) | |
| except Exception as e: | |
| print(e) | |
| print(key, data_type) | |
| print('Retrying in 30 seconds') | |
| time.sleep(30) | |
| continue | |
| break | |
| return data | |
| def get_paired_input_image_channel_number(data_cfg): | |
| r"""Get number of channels for the input image. | |
| Args: | |
| data_cfg (obj): Data configuration structure. | |
| Returns: | |
| num_channels (int): Number of input image channels. | |
| """ | |
| num_channels = 0 | |
| for ix, data_type in enumerate(data_cfg.input_types): | |
| for k in data_type: | |
| if k in data_cfg.input_image: | |
| num_channels += data_type[k].num_channels | |
| print('Concatenate %s for input.' % data_type) | |
| print('\tNum. of channels in the input image: %d' % num_channels) | |
| return num_channels | |
| def get_paired_input_label_channel_number(data_cfg, video=False): | |
| r"""Get number of channels for the input label map. | |
| Args: | |
| data_cfg (obj): Data configuration structure. | |
| video (bool): Whether we are dealing with video data. | |
| Returns: | |
| num_channels (int): Number of input label map channels. | |
| """ | |
| num_labels = 0 | |
| if not hasattr(data_cfg, 'input_labels'): | |
| return num_labels | |
| for ix, data_type in enumerate(data_cfg.input_types): | |
| for k in data_type: | |
| if k in data_cfg.input_labels: | |
| if hasattr(data_cfg, 'one_hot_num_classes') and k in data_cfg.one_hot_num_classes: | |
| num_labels += data_cfg.one_hot_num_classes[k] | |
| if getattr(data_cfg, 'use_dont_care', False): | |
| num_labels += 1 | |
| else: | |
| num_labels += data_type[k].num_channels | |
| print('Concatenate %s for input.' % data_type) | |
| if video: | |
| num_time_steps = getattr(data_cfg.train, 'initial_sequence_length', | |
| None) | |
| num_labels *= num_time_steps | |
| num_labels += get_paired_input_image_channel_number(data_cfg) * ( | |
| num_time_steps - 1) | |
| print('\tNum. of channels in the input label: %d' % num_labels) | |
| return num_labels | |
| def get_class_number(data_cfg): | |
| r"""Get number of classes for class-conditional GAN model | |
| Args: | |
| data_cfg (obj): Data configuration structure. | |
| Returns: | |
| (int): Number of classes. | |
| """ | |
| return data_cfg.num_classes | |
| def get_crop_h_w(augmentation): | |
| r"""Get height and width of crop. | |
| Args: | |
| augmentation (dict): Dict of applied augmentations. | |
| Returns: | |
| (dict): | |
| - crop_h (int): Height of the image crop. | |
| - crop_w (int): Width of the image crop. | |
| """ | |
| print(augmentation.__dict__.keys()) | |
| for k in augmentation.__dict__.keys(): | |
| if 'crop_h_w' in k: | |
| filed = augmentation[k] | |
| crop_h, crop_w = filed.split(',') | |
| crop_h = int(crop_h) | |
| crop_w = int(crop_w) | |
| # assert crop_w == crop_h, 'This implementation only ' \ | |
| # 'supports square-shaped images.' | |
| print('\tCrop size: (%d, %d)' % (crop_h, crop_w)) | |
| return crop_h, crop_w | |
| raise AttributeError | |
| def get_image_size(x): | |
| try: | |
| w, h = x.size | |
| except Exception: | |
| h, w, _ = x.shape | |
| return w, h | |