File size: 3,390 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Dict, Any, Iterable
import numbers
import yaml
import os

with open('globals.yaml', 'r') as stream:
        globals = yaml.load(stream, Loader=yaml.FullLoader)

def _prepare_config(config) -> Dict[str, Any]:
    """
    Returns the final configuration used by the model. May be used to clean user input,
    sort items, etc.
    """

    classes_per_task = {'detection': 1, 'grading': 6, 'simple_grading': 4}
    NUM_CLASSES = classes_per_task[config["TASK"]]

    return {
        "learning_rate": config["LEARNING_RATE"],
        "model_name": config["MODEL_NAME"],
        "dataset": "VerSe",
        "batch_size": config["BATCH_SIZE"], 
        "input_size": config["INPUT_SIZE"],
        "input_dim": config["INPUT_DIM"],
        "mask": config["MASK"],
        "coordinates": config["COORDINATES"],
        "oversampling": config["OVERSAMPLING"],
        "fold": config["FOLD"],
        "dropout": config["DROPOUT"],
        "frozen_layers": config["FROZEN_LAYERS"],
        "num_classes": NUM_CLASSES,
        "early_stopping_patience": config["EARLY_STOPPING_PATIENCE"],
        "min_vertebrae_level": config["MIN_VERTEBRAE_LEVEL"],
        "dataset_path": config["DATASET_PATH"],
        "loss": config["LOSS"], 
        "weighted_loss": config["WEIGHTED_LOSS"],
        "transforms": sorted(config["TRANSFORMS"]),
        "task": config["TASK"],

        # passed through, will not be part of final hyperparameters
        "USE_WANDB": config["USE_WANDB"],
        "WANDB_API_KEY": config["WANDB_API_KEY"],
        "SAVE_MODEL": config["SAVE_MODEL"]
    }

def _sanity_check_config(config):
    """
    Runs simple assertions to test that the config file is actually valid.
    """
    # option validity assertions
    assert any([config['mask'] == o for o in ['none', 'channel', 'apply', 'apply_all', 'crop']])
    assert os.path.exists(config["dataset_path"])

    # datatype assertions
    for numeric_key in ["batch_size", "input_size", "input_dim", "dropout", "early_stopping_patience", "min_vertebrae_level", "fold"]:
        assert isinstance(config[numeric_key], numbers.Number)

    for list_key in ["transforms", "frozen_layers"]:
        assert isinstance(config[list_key], Iterable)

    # logic assertions
    assert not (config['task'] == 'detection' and config['loss'] != 'binary_cross_entropy' and config['loss'] != 'focal')

    # ensure models fit the data
    nets_3d = ["UNet3D", "ModelsGenesis"]
    for net_3d in nets_3d:
        assert not (net_3d in config['model_name']) or config['input_dim'] == 3

    if config['oversampling'] and config['weighted_loss']:
        print("Oversampling as well as weighted loss are enabled, you may want to disable one")

    if config['loss'] == 'focal' and config['weighted_loss']:
        print("Focal loss does not support manual class weighting")

    if config['loss'] == 'focal' and config['oversampling']:
        print("Focal loss and oversampling are enabled, you may want to disable oversampling")

def get_config(config_file_path: str):
    """
    Retrieves the configuration from the given file path after running a sanity check and
    pre-processing steps.
    """

    with open(config_file_path, 'r') as stream:
        config_stream = yaml.load(stream, Loader=yaml.FullLoader)
        config = _prepare_config(config_stream)
        _sanity_check_config(config)
        return config