Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
from functools import partial | |
from vidar.datasets.augmentations.image import \ | |
colorjitter_sample, normalize_sample | |
from vidar.datasets.augmentations.crop import \ | |
crop_sample_input, crop_sample | |
from vidar.datasets.augmentations.misc import \ | |
duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range | |
from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input | |
from vidar.datasets.augmentations.tensor import to_tensor_sample | |
from vidar.datasets.utils.misc import parse_crop | |
from vidar.utils.types import is_list | |
def train_transforms(sample, cfg): | |
""" | |
Training data augmentation transformations | |
Parameters | |
---------- | |
sample : Dict | |
Sample to be augmented | |
cfg : Config | |
Configuration for transformations | |
Returns | |
------- | |
sample : Dict | |
Augmented sample | |
""" | |
# Resize | |
if cfg.has('resize'): | |
resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input | |
shape_supervision = None if not cfg.has('resize_supervision') else \ | |
cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision | |
sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision, | |
depth_downsample=cfg.has('depth_downsample', 1.0), | |
preserve_depth=cfg.has('preserve_depth', False)) | |
# Crop | |
if cfg.has('crop_borders') or cfg.has('crop_random'): | |
crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input | |
sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample] | |
# Clip depth to a maximum value | |
if cfg.has('clip_depth'): | |
sample = clip_depth(sample, cfg.clip_depth) | |
if cfg.has('mask_depth_range'): | |
sample = mask_depth_range(sample, cfg.mask_depth_range) | |
# Change input depth density | |
if 'input_depth' in sample: | |
if cfg.has('input_depth_number'): | |
sample['input_depth'] = mask_depth_number( | |
sample['input_depth'], cfg.input_depth_number) | |
if cfg.has('input_depth_percentage'): | |
sample['input_depth'] = mask_depth_percentage( | |
sample['input_depth'], cfg.input_depth_percentage) | |
# Apply jittering | |
if cfg.has('jittering'): | |
sample = duplicate_sample(sample, ['rgb']) | |
sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0) | |
# Convert to tensor | |
sample = to_tensor_sample(sample) | |
if cfg.has('normalization'): | |
sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1]) | |
# Return augmented sample | |
return sample | |
def no_transform(sample): | |
"""No transformation, only convert sample to tensors""" | |
sample = to_tensor_sample(sample) | |
return sample | |
def get_transforms(mode, cfg=None): | |
""" | |
Get data augmentation transformations for each split | |
Parameters | |
---------- | |
mode : String {'train', 'validation', 'test'} | |
Mode from which we want the data augmentation transformations | |
cfg : Config | |
Configuration file | |
Returns | |
------- | |
XXX_transform: Partial function | |
Data augmentation transformation for that mode | |
""" | |
if mode == 'train': | |
return partial(train_transforms, cfg=cfg) | |
elif mode == 'none': | |
return partial(no_transform) | |
else: | |
raise ValueError('Unknown mode {}'.format(mode)) | |