Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from functools import partial
from vidar.datasets.augmentations.image import \
colorjitter_sample, normalize_sample
from vidar.datasets.augmentations.crop import \
crop_sample_input, crop_sample
from vidar.datasets.augmentations.misc import \
duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range
from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input
from vidar.datasets.augmentations.tensor import to_tensor_sample
from vidar.datasets.utils.misc import parse_crop
from vidar.utils.types import is_list
def train_transforms(sample, cfg):
"""
Training data augmentation transformations
Parameters
----------
sample : Dict
Sample to be augmented
cfg : Config
Configuration for transformations
Returns
-------
sample : Dict
Augmented sample
"""
# Resize
if cfg.has('resize'):
resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input
shape_supervision = None if not cfg.has('resize_supervision') else \
cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision
sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision,
depth_downsample=cfg.has('depth_downsample', 1.0),
preserve_depth=cfg.has('preserve_depth', False))
# Crop
if cfg.has('crop_borders') or cfg.has('crop_random'):
crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input
sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample]
# Clip depth to a maximum value
if cfg.has('clip_depth'):
sample = clip_depth(sample, cfg.clip_depth)
if cfg.has('mask_depth_range'):
sample = mask_depth_range(sample, cfg.mask_depth_range)
# Change input depth density
if 'input_depth' in sample:
if cfg.has('input_depth_number'):
sample['input_depth'] = mask_depth_number(
sample['input_depth'], cfg.input_depth_number)
if cfg.has('input_depth_percentage'):
sample['input_depth'] = mask_depth_percentage(
sample['input_depth'], cfg.input_depth_percentage)
# Apply jittering
if cfg.has('jittering'):
sample = duplicate_sample(sample, ['rgb'])
sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0)
# Convert to tensor
sample = to_tensor_sample(sample)
if cfg.has('normalization'):
sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1])
# Return augmented sample
return sample
def no_transform(sample):
"""No transformation, only convert sample to tensors"""
sample = to_tensor_sample(sample)
return sample
def get_transforms(mode, cfg=None):
"""
Get data augmentation transformations for each split
Parameters
----------
mode : String {'train', 'validation', 'test'}
Mode from which we want the data augmentation transformations
cfg : Config
Configuration file
Returns
-------
XXX_transform: Partial function
Data augmentation transformation for that mode
"""
if mode == 'train':
return partial(train_transforms, cfg=cfg)
elif mode == 'none':
return partial(no_transform)
else:
raise ValueError('Unknown mode {}'.format(mode))