venite's picture
history blame
25.3 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out
# flake8: noqa: E712
"""Utils for handling datasets."""
import time
import numpy as np
from PIL import Image
import cv2
# from imaginaire.utils.distributed import master_only_print as print
import albumentations as alb # noqa nopep8
IMG_EXTENSIONS = ('jpg', 'jpeg', 'png', 'ppm', 'bmp',
'pgm', 'tif', 'tiff', 'webp',
'JPG', 'JPEG', 'PNG', 'PPM', 'BMP',
'PGM', 'TIF', 'TIFF', 'WEBP')
class Augmentor(object):
r"""Handles data augmentation using albumentations library."""
def __init__(self, aug_list, individual_video_frame_aug_list, image_data_types, is_mask,
keypoint_data_types, interpolator):
r"""Initializes augmentation pipeline.
aug_list (list): List of augmentation operations in sequence.
individual_video_frame_aug_list (list): List of augmentation operations in sequence that will be applied
to individual frames of videos independently.
image_data_types (list): List of keys in expected inputs.
is_mask (dict): Whether this data type is discrete masks?
keypoint_data_types (list): List of keys which are keypoints.
self.aug_list = aug_list
self.individual_video_frame_aug_list = individual_video_frame_aug_list
self.image_data_types = image_data_types
self.is_mask = is_mask
self.crop_h, self.crop_w = None, None
self.resize_h, self.resize_w = None, None
self.resize_smallest_side = None
self.max_time_step = 1
self.keypoint_data_types = keypoint_data_types
self.interpolator = interpolator
self.augment_ops = self._build_augmentation_ops()
self.individual_video_frame_augmentation_ops = self._build_individual_video_frame_augmentation_ops()
# Both crop and resize can't be none at the same time.
if self.crop_h is None and self.resize_smallest_side is None and \
self.resize_h is None:
raise ValueError('resize_smallest_side, resize_h_w, '
'and crop_h_w cannot all be missing.')
# If resize_smallest_side is given, resize_h_w should not be give.
if self.resize_smallest_side is not None:
assert self.resize_h is None, \
'Cannot have both `resize_smallest_side` and `resize_h_w` set.'
if self.resize_smallest_side is None and self.resize_h is None:
self.resize_h, self.resize_w = self.crop_h, self.crop_w
def _build_individual_video_frame_augmentation_ops(self):
r"""Builds sequence of augmentation ops that will be applied to each frame in the video independently.
(list of alb.ops): List of augmentation ops.
augs = []
for key, value in self.individual_video_frame_aug_list.items():
if key == 'random_scale_limit':
if type(value) == float:
scale_limit_lb = scale_limit_ub = value
p = 1
scale_limit_lb = value['scale_limit_lb']
scale_limit_ub = value['scale_limit_ub']
p = value['p']
augs.append(alb.RandomScale(scale_limit=(-scale_limit_lb, scale_limit_ub), p=p))
elif key == 'random_crop_h_w':
h, w = value.split(',')
h, w = int(h), int(w)
self.crop_h, self.crop_w = h, w
augs.append(alb.PadIfNeeded(min_height=h, min_width=w))
augs.append(alb.RandomCrop(h, w, always_apply=True, p=1))
return augs
def _build_augmentation_ops(self):
r"""Builds sequence of augmentation ops.
(list of alb.ops): List of augmentation ops.
augs = []
for key, value in self.aug_list.items():
if key == 'resize_smallest_side':
if isinstance(value, int):
self.resize_smallest_side = value
h, w = value.split(',')
h, w = int(h), int(w)
self.resize_smallest_side = (h, w)
elif key == 'resize_h_w':
h, w = value.split(',')
h, w = int(h), int(w)
self.resize_h, self.resize_w = h, w
elif key == 'random_resize_h_w_aspect':
aspect_start, aspect_end = value.find('('), value.find(')')
aspect = value[aspect_start+1:aspect_end]
aspect_min, aspect_max = aspect.split(',')
h, w = value[:aspect_start].split(',')[:2]
h, w = int(h), int(w)
aspect_min, aspect_max = float(aspect_min), float(aspect_max)
h, w, scale=(1, 1),
ratio=(aspect_min, aspect_max), always_apply=True, p=1))
self.resize_h, self.resize_w = h, w
elif key == 'rotate':
limit=value, always_apply=True, p=1))
elif key == 'random_rotate_90':
augs.append(alb.RandomRotate90(always_apply=False, p=0.5))
elif key == 'random_scale_limit':
augs.append(alb.RandomScale(scale_limit=(0, value), p=1))
elif key == 'random_crop_h_w':
h, w = value.split(',')
h, w = int(h), int(w)
self.crop_h, self.crop_w = h, w
augs.append(alb.RandomCrop(h, w, always_apply=True, p=1))
elif key == 'center_crop_h_w':
h, w = value.split(',')
h, w = int(h), int(w)
self.crop_h, self.crop_w = h, w
augs.append(alb.CenterCrop(h, w, always_apply=True, p=1))
elif key == 'horizontal_flip':
# This is handled separately as we need to keep track if this
# was applied in order to correctly modify keypoint data.
if value:
augs.append(alb.HorizontalFlip(always_apply=False, p=0.5))
# The options below including contrast, blur, motion_blur, compression, gamma
# were used during developing face-vid2vid.
elif key == 'contrast':
brightness_limit = value['brightness_limit']
contrast_limit = value['contrast_limit']
p = value['p']
brightness_limit=brightness_limit, contrast_limit=contrast_limit, p=p))
elif key == 'blur':
blur_limit = value['blur_limit']
p = value['p']
augs.append(alb.Blur(blur_limit=blur_limit, p=p))
elif key == 'motion_blur':
blur_limit = value['blur_limit']
p = value['p']
augs.append(alb.MotionBlur(blur_limit=blur_limit, p=p))
elif key == 'compression':
quality_lower = value['quality_lower']
p = value['p']
augs.append(alb.ImageCompression(quality_lower=quality_lower, p=p))
elif key == 'gamma':
gamma_limit_lb = value['gamma_limit_lb']
gamma_limit_ub = value['gamma_limit_ub']
p = value['p']
augs.append(alb.RandomGamma(gamma_limit=(gamma_limit_lb, gamma_limit_ub), p=p))
elif key == 'max_time_step':
self.max_time_step = value
assert self.max_time_step >= 1, \
'max_time_step has to be at least 1'
raise ValueError('Unknown augmentation %s' % (key))
return augs
def _choose_image_key(self, inputs):
r"""Choose key to replace with 'image' for input to albumentations.
key (str): Chosen key to be replace with 'image'
if 'image' in inputs:
return 'image'
for data_type in inputs:
if data_type in self.image_data_types:
return data_type
def _choose_keypoint_key(self, inputs):
r"""Choose key to replace with 'keypoints' for input to albumentations.
key (str): Chosen key to be replace with 'keypoints'
if not self.keypoint_data_types:
return None
if 'keypoints' in inputs:
return 'keypoints'
for data_type in inputs:
if data_type in self.keypoint_data_types:
return data_type
def _create_augmentation_targets(self, inputs):
r"""Create additional targets as required by the albumentation library.
inputs (dict): Keys are from self.augmentable_data_types. Values can
be numpy.ndarray or list of numpy.ndarray
(image or list of images).
- targets (dict): Dict containing mapping of keys to image/mask types.
- new_inputs (dict): Dict containing mapping of keys to data.
# Get additional target list.
targets, new_inputs = {}, {}
for data_type in inputs:
if data_type in self.keypoint_data_types:
# Keypoint-type.
target_type = 'keypoints'
elif data_type in self.image_data_types:
# Image-type.
# Find the target type (image/mask) based on interpolation
# method.
if self.is_mask[data_type]:
target_type = 'mask'
target_type = 'image'
raise ValueError(
'Data type: %s is not image or keypoint' % (data_type))
current_data_type_inputs = inputs[data_type]
if not isinstance(current_data_type_inputs, list):
current_data_type_inputs = [current_data_type_inputs]
# Create additional_targets and inputs when there are multiples.
for idx, new_input in enumerate(current_data_type_inputs):
key = data_type
if idx > 0:
key = '%s::%05d' % (key, idx)
targets[key] = target_type
new_inputs[key] = new_input
return targets, new_inputs
def _collate_augmented(self, augmented):
r"""Collate separated images back into sequence, grouped by keys.
augmented (dict): Dict containing frames with keys of the form
'key', 'key::00001', 'key::00002', ..., 'key::N'.
- outputs (dict): Dict with list of collated inputs, i.e. frames of
- same key are arranged in order ['key', 'key::00001', ..., 'key::N'].
full_keys = sorted(augmented.keys())
outputs = {}
for full_key in full_keys:
if '::' not in full_key:
# First occurrence of this key.
key = full_key
outputs[key] = []
key = full_key.split('::')[0]
return outputs
def _get_resize_h_w(self, height, width):
r"""Get height and width to resize to, given smallest side.
height (int): Input image height.
width (int): Input image width.
- height (int): Height to resize image to.
- width (int): Width to resize image to.
if self.resize_smallest_side is None:
return self.resize_h, self.resize_w
if isinstance(self.resize_smallest_side, int):
resize_smallest_height, resize_smallest_width = self.resize_smallest_side, self.resize_smallest_side
resize_smallest_height, resize_smallest_width = self.resize_smallest_side
if height * resize_smallest_width <= width * resize_smallest_height:
new_height = resize_smallest_height
new_width = int(np.round(new_height * width / float(height)))
new_width = resize_smallest_width
new_height = int(np.round(new_width * height / float(width)))
return new_height, new_width
def _perform_unpaired_augmentation(self, inputs, augment_ops):
r"""Perform different data augmentation on different image inputs. Note that this operation only works
inputs (dict): Keys are from self.image_data_types. Values are list
of numpy.ndarray (list of images).
augment_ops (list): The augmentation operations.
- augmented (dict): Augmented inputs, with same keys as inputs.
- is_flipped (dict): Flag which tells if images have been LR flipped.
# Process each data type separately as this is unpaired augmentation.
is_flipped = {}
for data_type in inputs:
assert data_type in self.image_data_types
augmented, flipped_flag = self._perform_paired_augmentation(
{data_type: inputs[data_type]}, augment_ops)
inputs[data_type] = augmented[data_type]
is_flipped[data_type] = flipped_flag
return inputs, is_flipped
def _perform_paired_augmentation(self, inputs, augment_ops):
r"""Perform same data augmentation on all inputs.
inputs (dict): Keys are from self.augmentable_data_types. Values are
list of numpy.ndarray (list of images).
augment_ops (list): The augmentation operations.
- augmented (dict): Augmented inputs, with same keys as inputs.
- is_flipped (bool): Flag which tells if images have been LR flipped.
# Different data types may have different sizes and we use the largest one as the original size.
# Convert PIL images to numpy array.
self.original_h, self.original_w = 0, 0
for data_type in inputs:
if data_type in self.keypoint_data_types or \
data_type not in self.image_data_types:
for idx in range(len(inputs[data_type])):
value = inputs[data_type][idx]
# Get resize h, w.
w, h = get_image_size(value)
self.original_h, self.original_w = max(self.original_h, h), max(self.original_w, w)
# self.original_h, self.original_w = h, w
# self.resize_h, self.resize_w = self._get_resize_h_w(h, w)
# Convert to numpy array with 3 dims (H, W, C).
value = np.array(value)
if value.ndim == 2:
value = value[..., np.newaxis]
inputs[data_type][idx] = value
self.resize_h, self.resize_w = self._get_resize_h_w(self.original_h, self.original_w)
# Add resize op to augmentation ops.
aug_ops_with_resize = [alb.Resize(
self.resize_h, self.resize_w, interpolation=getattr(cv2, self.interpolator), always_apply=1, p=1
)] + augment_ops
# Create targets.
targets, new_inputs = self._create_augmentation_targets(inputs)
extra_params = {}
# Albumentation requires a key called 'image' and
# a key called 'keypoints', if any keypoints are being passed in.
# Arbitrarily choose one key of image type to be 'image'.
chosen_image_key = self._choose_image_key(inputs)
new_inputs['image'] = new_inputs.pop(chosen_image_key)
targets['image'] = targets.pop(chosen_image_key)
# Arbitrarily choose one key of keypoint type to be 'keypoints'.
chosen_keypoint_key = self._choose_keypoint_key(inputs)
if chosen_keypoint_key is not None:
new_inputs['keypoints'] = new_inputs.pop(chosen_keypoint_key)
targets['keypoints'] = targets.pop(chosen_keypoint_key)
extra_params['keypoint_params'] = alb.KeypointParams(
format='xy', remove_invisible=False)
# Do augmentation.
augmented = alb.ReplayCompose(
aug_ops_with_resize, additional_targets=targets,
augmentation_params = augmented.pop('replay')
# Check if flipping has occurred.
is_flipped = False
for augmentation_param in augmentation_params['transforms']:
if 'HorizontalFlip' in augmentation_param['__class_fullname__']:
is_flipped = augmentation_param['applied']
self.is_flipped = is_flipped
# Replace the key 'image' with chosen_image_key, same for 'keypoints'.
augmented[chosen_image_key] = augmented.pop('image')
if chosen_keypoint_key is not None:
augmented[chosen_keypoint_key] = augmented.pop('keypoints')
# Pack images back into a sequence.
augmented = self._collate_augmented(augmented)
# Convert keypoint types to np.array from list.
for data_type in self.keypoint_data_types:
augmented[data_type] = np.array(augmented[data_type])
return augmented, is_flipped
def perform_augmentation(self, inputs, paired, augment_ops):
r"""Entry point for augmentation.
inputs (dict): Keys are from self.augmentable_data_types. Values are
list of numpy.ndarray (list of images).
paired (bool): Apply same augmentation to all input keys?
augment_ops (list): The augmentation operations.
# Make sure that all inputs are of same size, else trouble will
# ensue. This is because different images might have different
# aspect ratios.
# Check within data type.
for data_type in inputs:
if data_type in self.keypoint_data_types or \
data_type not in self.image_data_types:
for idx in range(len(inputs[data_type])):
if idx == 0:
w, h = get_image_size(inputs[data_type][idx])
this_w, this_h = get_image_size(inputs[data_type][idx])
# assert this_w == w and this_h == h
# assert this_w / (1.0 * this_h) == w / (1.0 * h)
# Check across data types.
if paired and self.resize_smallest_side is not None:
for idx, data_type in enumerate(inputs):
if data_type in self.keypoint_data_types or \
data_type not in self.image_data_types:
if paired:
return self._perform_paired_augmentation(inputs, augment_ops)
return self._perform_unpaired_augmentation(inputs, augment_ops)
def load_from_lmdb(keys, lmdbs):
r"""Load keys from lmdb handles.
keys (dict): This has data_type as key, and a list of paths into LMDB as
lmdbs (dict): This has data_type as key, and LMDB handle as value.
data (dict): This has data_type as key, and a list of decoded items from
LMDBs as value.
data = {}
for data_type in keys:
if data_type not in data:
data[data_type] = []
data_type_keys = keys[data_type]
if not isinstance(data_type_keys, list):
data_type_keys = [data_type_keys]
for key in data_type_keys:
key.encode(), data_type))
return data
def load_from_folder(keys, handles):
r"""Load keys from lmdb handles.
keys (dict): This has data_type as key, and a list of paths as
handles (dict): This has data_type as key, and Folder handle as value.
data (dict): This has data_type as key, and a list of decoded items from
folders as value.
data = {}
for data_type in keys:
if data_type not in data:
data[data_type] = []
data_type_keys = keys[data_type]
if not isinstance(data_type_keys, list):
data_type_keys = [data_type_keys]
for key in data_type_keys:
key.encode(), data_type))
return data
def load_from_object_store(keys, handles):
r"""Load keys from AWS S3 handles.
keys (dict): This has data_type as key, and a list of paths as
handles (dict): This has data_type as key, and Folder handle as value.
data (dict): This has data_type as key, and a list of decoded items from
folders as value.
data = {}
for data_type in keys:
if data_type not in data:
data[data_type] = []
data_type_keys = keys[data_type]
if not isinstance(data_type_keys, list):
data_type_keys = [data_type_keys]
for key in data_type_keys:
while True:
data[data_type].append(handles[data_type].getitem_by_path(key, data_type))
except Exception as e:
print(key, data_type)
print('Retrying in 30 seconds')
return data
def get_paired_input_image_channel_number(data_cfg):
r"""Get number of channels for the input image.
data_cfg (obj): Data configuration structure.
num_channels (int): Number of input image channels.
num_channels = 0
for ix, data_type in enumerate(data_cfg.input_types):
for k in data_type:
if k in data_cfg.input_image:
num_channels += data_type[k].num_channels
print('Concatenate %s for input.' % data_type)
print('\tNum. of channels in the input image: %d' % num_channels)
return num_channels
def get_paired_input_label_channel_number(data_cfg, video=False):
r"""Get number of channels for the input label map.
data_cfg (obj): Data configuration structure.
video (bool): Whether we are dealing with video data.
num_channels (int): Number of input label map channels.
num_labels = 0
if not hasattr(data_cfg, 'input_labels'):
return num_labels
for ix, data_type in enumerate(data_cfg.input_types):
for k in data_type:
if k in data_cfg.input_labels:
if hasattr(data_cfg, 'one_hot_num_classes') and k in data_cfg.one_hot_num_classes:
num_labels += data_cfg.one_hot_num_classes[k]
if getattr(data_cfg, 'use_dont_care', False):
num_labels += 1
num_labels += data_type[k].num_channels
print('Concatenate %s for input.' % data_type)
if video:
num_time_steps = getattr(data_cfg.train, 'initial_sequence_length',
num_labels *= num_time_steps
num_labels += get_paired_input_image_channel_number(data_cfg) * (
num_time_steps - 1)
print('\tNum. of channels in the input label: %d' % num_labels)
return num_labels
def get_class_number(data_cfg):
r"""Get number of classes for class-conditional GAN model
data_cfg (obj): Data configuration structure.
(int): Number of classes.
return data_cfg.num_classes
def get_crop_h_w(augmentation):
r"""Get height and width of crop.
augmentation (dict): Dict of applied augmentations.
- crop_h (int): Height of the image crop.
- crop_w (int): Width of the image crop.
for k in augmentation.__dict__.keys():
if 'crop_h_w' in k:
filed = augmentation[k]
crop_h, crop_w = filed.split(',')
crop_h = int(crop_h)
crop_w = int(crop_w)
# assert crop_w == crop_h, 'This implementation only ' \
# 'supports square-shaped images.'
print('\tCrop size: (%d, %d)' % (crop_h, crop_w))
return crop_h, crop_w
raise AttributeError
def get_image_size(x):
w, h = x.size
except Exception:
h, w, _ = x.shape
return w, h