File size: 4,106 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from pathlib import Path
from typing import Callable, List, Optional, Tuple

from monai.transforms import Compose

from transforms.base import get_image_loading_transform, get_apply_crop_transform, get_stacking_transform
from transforms.mask import get_mask_transform
from transforms.coordinates import get_normalized_coordinates_transform
from transforms.augmentation import *
from transforms.backbone import *


def _build_transforms_composition(hparams, transform_getters: List[Callable], *initial_args) -> Tuple[Compose, List[str]]:
    """
    Builds a transforms composition from the given functions, which take the hparams and loaded keys as arguments, and
    produce a Compose containing the desired transforms. The initialization function receives the provided initial arguments.
    """
    transforms = []
    keys = []

    for i in range(0, len(transform_getters)):
        if len(keys) == 0:
            assert i == 0, f"Function {transform_getters[i]} did not yield any loaded keys."
            # initialize
            transform, keys = transform_getters[0](hparams, *initial_args)
        else:
            transform, keys = transform_getters[i](hparams, keys)
        transforms.append(transform)

    return Compose(transforms), keys

def _get_config_transform_by_name(transform_name: str) -> Callable:
    if transform_name == "intensity":
        return intensity_transform
    elif transform_name.startswith("spatial3d"):
        if "simple" in transform_name:
            return lambda hparams, loaded_keys: spatial_transform(hparams, loaded_keys, mode='simple')
        else:
            return lambda hparams, loaded_keys: spatial_transform(hparams, loaded_keys, mode='default')
    elif transform_name == "modelsgenesis":
        return models_genesis_transform
    elif transform_name == "pretrained_resnet":
        return pretrained_resnet_transform
    elif transform_name == "robustness":
        return robustness_transform
    else:
        raise ValueError(f"Unknown transform: {transform_name}")

def get_training_transforms(hparams, image_dir: Path, mask_dir: Optional[Path] = None) -> Compose:
    transforms_base = [get_image_loading_transform, get_mask_transform]

    # robustness has to run early as we may need to operate on the whole volume for affine transformation and padding, 
    # which must occur prior to any cropping or normalization
    if "robustness" in hparams.transforms: transforms_base.append(_get_config_transform_by_name("robustness"))

    transforms_base.extend([get_apply_crop_transform, get_normalized_coordinates_transform])

    # preprocessing transforms must be run first
    preprocessing_transforms = ["modelsgenesis", "pretrained_resnet"]
    config_transforms = [_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if transform_name in preprocessing_transforms]
    
    # then append the rest minus the robustness transform that is run earlier
    exclusion_criterion = lambda transform_name: transform_name in preprocessing_transforms or transform_name == "robustness"
    config_transforms.extend([_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if not exclusion_criterion])

    # the stacking transform must not occur before config transforms are run to avoid any interference
    return _build_transforms_composition(hparams, transforms_base + config_transforms + [get_stacking_transform], image_dir, mask_dir)[0]

def get_base_transforms(hparams, image_dir: Path, mask_dir: Optional[Path] = None) -> Compose:
    transforms_base = [get_image_loading_transform, get_mask_transform, get_apply_crop_transform, get_normalized_coordinates_transform]

    # apply preprocessing transforms
    preprocessing_transforms = ["modelsgenesis", "pretrained_resnet"]
    config_transforms = [_get_config_transform_by_name(transform_name) for transform_name in hparams.transforms if transform_name in preprocessing_transforms]

    return _build_transforms_composition(hparams, transforms_base + config_transforms + [get_stacking_transform], image_dir, mask_dir)[0]