Paul Engstler
Initial commit
92f0e98
raw
history blame
No virus
3.39 kB
from typing import Dict, Any, Iterable
import numbers
import yaml
import os
with open('globals.yaml', 'r') as stream:
globals = yaml.load(stream, Loader=yaml.FullLoader)
def _prepare_config(config) -> Dict[str, Any]:
"""
Returns the final configuration used by the model. May be used to clean user input,
sort items, etc.
"""
classes_per_task = {'detection': 1, 'grading': 6, 'simple_grading': 4}
NUM_CLASSES = classes_per_task[config["TASK"]]
return {
"learning_rate": config["LEARNING_RATE"],
"model_name": config["MODEL_NAME"],
"dataset": "VerSe",
"batch_size": config["BATCH_SIZE"],
"input_size": config["INPUT_SIZE"],
"input_dim": config["INPUT_DIM"],
"mask": config["MASK"],
"coordinates": config["COORDINATES"],
"oversampling": config["OVERSAMPLING"],
"fold": config["FOLD"],
"dropout": config["DROPOUT"],
"frozen_layers": config["FROZEN_LAYERS"],
"num_classes": NUM_CLASSES,
"early_stopping_patience": config["EARLY_STOPPING_PATIENCE"],
"min_vertebrae_level": config["MIN_VERTEBRAE_LEVEL"],
"dataset_path": config["DATASET_PATH"],
"loss": config["LOSS"],
"weighted_loss": config["WEIGHTED_LOSS"],
"transforms": sorted(config["TRANSFORMS"]),
"task": config["TASK"],
# passed through, will not be part of final hyperparameters
"USE_WANDB": config["USE_WANDB"],
"WANDB_API_KEY": config["WANDB_API_KEY"],
"SAVE_MODEL": config["SAVE_MODEL"]
}
def _sanity_check_config(config):
"""
Runs simple assertions to test that the config file is actually valid.
"""
# option validity assertions
assert any([config['mask'] == o for o in ['none', 'channel', 'apply', 'apply_all', 'crop']])
assert os.path.exists(config["dataset_path"])
# datatype assertions
for numeric_key in ["batch_size", "input_size", "input_dim", "dropout", "early_stopping_patience", "min_vertebrae_level", "fold"]:
assert isinstance(config[numeric_key], numbers.Number)
for list_key in ["transforms", "frozen_layers"]:
assert isinstance(config[list_key], Iterable)
# logic assertions
assert not (config['task'] == 'detection' and config['loss'] != 'binary_cross_entropy' and config['loss'] != 'focal')
# ensure models fit the data
nets_3d = ["UNet3D", "ModelsGenesis"]
for net_3d in nets_3d:
assert not (net_3d in config['model_name']) or config['input_dim'] == 3
if config['oversampling'] and config['weighted_loss']:
print("Oversampling as well as weighted loss are enabled, you may want to disable one")
if config['loss'] == 'focal' and config['weighted_loss']:
print("Focal loss does not support manual class weighting")
if config['loss'] == 'focal' and config['oversampling']:
print("Focal loss and oversampling are enabled, you may want to disable oversampling")
def get_config(config_file_path: str):
"""
Retrieves the configuration from the given file path after running a sanity check and
pre-processing steps.
"""
with open(config_file_path, 'r') as stream:
config_stream = yaml.load(stream, Loader=yaml.FullLoader)
config = _prepare_config(config_stream)
_sanity_check_config(config)
return config