File size: 3,570 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

from functools import partial

from vidar.datasets.augmentations.image import \
    colorjitter_sample, normalize_sample
from vidar.datasets.augmentations.crop import \
    crop_sample_input, crop_sample
from vidar.datasets.augmentations.misc import \
    duplicate_sample, mask_depth_percentage, mask_depth_number, clip_depth, mask_depth_range
from vidar.datasets.augmentations.resize import resize_sample, resize_sample_input
from vidar.datasets.augmentations.tensor import to_tensor_sample
from vidar.datasets.utils.misc import parse_crop
from vidar.utils.types import is_list


def train_transforms(sample, cfg):
    """
    Training data augmentation transformations

    Parameters
    ----------
    sample : Dict
        Sample to be augmented
    cfg : Config
        Configuration for transformations

    Returns
    -------
    sample : Dict
        Augmented sample
    """
    # Resize
    if cfg.has('resize'):
        resize_fn = resize_sample if cfg.has('resize_supervision') else resize_sample_input
        shape_supervision = None if not cfg.has('resize_supervision') else \
            cfg.resize if not is_list(cfg.resize_supervision) else cfg.resize_supervision
        sample = resize_fn(sample, shape=cfg.resize, shape_supervision=shape_supervision,
                           depth_downsample=cfg.has('depth_downsample', 1.0),
                           preserve_depth=cfg.has('preserve_depth', False))
    # Crop
    if cfg.has('crop_borders') or cfg.has('crop_random'):
        crop_fn = crop_sample if cfg.has('crop_supervision') else crop_sample_input
        sample = [crop_fn(s, parse_crop(cfg, s['rgb'][0].size[::-1])) for s in sample]
    # Clip depth to a maximum value
    if cfg.has('clip_depth'):
        sample = clip_depth(sample, cfg.clip_depth)
    if cfg.has('mask_depth_range'):
        sample = mask_depth_range(sample, cfg.mask_depth_range)
    # Change input depth density
    if 'input_depth' in sample:
        if cfg.has('input_depth_number'):
            sample['input_depth'] = mask_depth_number(
                sample['input_depth'], cfg.input_depth_number)
        if cfg.has('input_depth_percentage'):
            sample['input_depth'] = mask_depth_percentage(
                sample['input_depth'], cfg.input_depth_percentage)
    # Apply jittering
    if cfg.has('jittering'):
        sample = duplicate_sample(sample, ['rgb'])
        sample = colorjitter_sample(sample, cfg.jittering, cfg.has('background', None), prob=1.0)
    # Convert to tensor
    sample = to_tensor_sample(sample)
    if cfg.has('normalization'):
        sample = normalize_sample(sample, cfg.normalization[0], cfg.normalization[1])
    # Return augmented sample
    return sample


def no_transform(sample):
    """No transformation, only convert sample to tensors"""
    sample = to_tensor_sample(sample)
    return sample


def get_transforms(mode, cfg=None):
    """
    Get data augmentation transformations for each split

    Parameters
    ----------
    mode : String {'train', 'validation', 'test'}
        Mode from which we want the data augmentation transformations
    cfg : Config
        Configuration file

    Returns
    -------
        XXX_transform: Partial function
            Data augmentation transformation for that mode
    """
    if mode == 'train':
        return partial(train_transforms, cfg=cfg)
    elif mode == 'none':
        return partial(no_transform)
    else:
        raise ValueError('Unknown mode {}'.format(mode))