Spaces:
Runtime error
Runtime error
| # TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
| from functools import partial | |
| from vidar.datasets.augmentations.image import \ | |
| colorjitter_sample, normalize_sample | |
| from vidar.datasets.augmentations.crop import \ | |
| crop_sample_input, crop_sample | |
| from vidar.datasets.augmentations.misc import \ | |
| duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range | |
| from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input | |
| from vidar.datasets.augmentations.tensor import to_tensor_sample | |
| from vidar.datasets.utils.misc import parse_crop | |
| from vidar.utils.types import is_list | |
| def train_transforms(sample, cfg): | |
| """ | |
| Training data augmentation transformations | |
| Parameters | |
| ---------- | |
| sample : Dict | |
| Sample to be augmented | |
| cfg : Config | |
| Configuration for transformations | |
| Returns | |
| ------- | |
| sample : Dict | |
| Augmented sample | |
| """ | |
| # Resize | |
| if cfg.has('resize'): | |
| resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input | |
| shape_supervision = None if not cfg.has('resize_supervision') else \ | |
| cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision | |
| sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision, | |
| depth_downsample=cfg.has('depth_downsample', 1.0), | |
| preserve_depth=cfg.has('preserve_depth', False)) | |
| # Crop | |
| if cfg.has('crop_borders') or cfg.has('crop_random'): | |
| crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input | |
| sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample] | |
| # Clip depth to a maximum value | |
| if cfg.has('clip_depth'): | |
| sample = clip_depth(sample, cfg.clip_depth) | |
| if cfg.has('mask_depth_range'): | |
| sample = mask_depth_range(sample, cfg.mask_depth_range) | |
| # Change input depth density | |
| if 'input_depth' in sample: | |
| if cfg.has('input_depth_number'): | |
| sample['input_depth'] = mask_depth_number( | |
| sample['input_depth'], cfg.input_depth_number) | |
| if cfg.has('input_depth_percentage'): | |
| sample['input_depth'] = mask_depth_percentage( | |
| sample['input_depth'], cfg.input_depth_percentage) | |
| # Apply jittering | |
| if cfg.has('jittering'): | |
| sample = duplicate_sample(sample, ['rgb']) | |
| sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0) | |
| # Convert to tensor | |
| sample = to_tensor_sample(sample) | |
| if cfg.has('normalization'): | |
| sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1]) | |
| # Return augmented sample | |
| return sample | |
| def no_transform(sample): | |
| """No transformation, only convert sample to tensors""" | |
| sample = to_tensor_sample(sample) | |
| return sample | |
| def get_transforms(mode, cfg=None): | |
| """ | |
| Get data augmentation transformations for each split | |
| Parameters | |
| ---------- | |
| mode : String {'train', 'validation', 'test'} | |
| Mode from which we want the data augmentation transformations | |
| cfg : Config | |
| Configuration file | |
| Returns | |
| ------- | |
| XXX_transform: Partial function | |
| Data augmentation transformation for that mode | |
| """ | |
| if mode == 'train': | |
| return partial(train_transforms, cfg=cfg) | |
| elif mode == 'none': | |
| return partial(no_transform) | |
| else: | |
| raise ValueError('Unknown mode {}'.format(mode)) | |