File size: 3,256 Bytes
3b554c2
 
 
 
 
 
 
 
 
a27d55f
3b554c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286a978
 
 
 
 
3b554c2
 
 
 
 
 
 
 
 
 
 
 
 
286a978
3b554c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing

import argparse
from configparser import ConfigParser
from datetime import datetime

import ddmr.utils.constants as C

TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/IXI_dataset/T1/training'

err = list()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--ini', help='Configuration file')
    args = parser.parse_args()

    configFile = ConfigParser()
    configFile.read(args.ini)
    print('Loaded configuration file: ' + args.ini)
    print({section: dict(configFile[section]) for section in configFile.sections()})
    print('\n\n')

    trainConfig = configFile['TRAIN']
    lossesConfig = configFile['LOSSES']
    datasetConfig = configFile['DATASETS']
    othersConfig = configFile['OTHERS']
    augmentationConfig = configFile['AUGMENTATION']

    simil = lossesConfig['similarity'].split(',')
    segm = lossesConfig['segmentation'].split(',')
    if trainConfig['name'].lower() == 'uw':
        from Brain_study.Train_UncertaintyWeighted import launch_train
        loss_config = {'simil': simil, 'segm': segm}
    elif trainConfig['name'].lower() == 'segguided':
        from Brain_study.Train_SegmentationGuided import launch_train
        loss_config = {'simil': simil[0], 'segm': segm[0]}
    else:
        from Brain_study.Train_Baseline import launch_train
        loss_config = {'simil': simil[0]}

    output_folder = os.path.join(othersConfig['outputFolder'],
                                 '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], '_'.join(simil), '_'.join(segm)))
    output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")

    print('TRAIN ' + datasetConfig['train'])

    if augmentationConfig:
        C.GAMMA_AUGMENTATION = augmentationConfig['gamma'].lower() == 'true'
        C.BRIGHTNESS_AUGMENTATION = augmentationConfig['brightness'].lower() == 'true'

    try:
        unet = [int(x) for x in trainConfig['unet'].split(',')]
    except KeyError as e:
        unet = [16, 32, 64, 128, 256]

    try:
        head = [int(x) for x in trainConfig['head'].split(',')]
    except KeyError as e:
        head = [16, 16]

    try:
        resume_checkpoint = trainConfig['resumeCheckpoint']
    except KeyError as e:
        resume_checkpoint = None

    launch_train(dataset_folder=datasetConfig['train'],
                 validation_folder=datasetConfig['validation'],
                 output_folder=output_folder,
                 gpu_num=eval(trainConfig['gpu']),
                 lr=eval(trainConfig['learningRate']),
                 rw=eval(trainConfig['regularizationWeight']),
                 acc_gradients=eval(trainConfig['accumulativeGradients']),
                 batch_size=eval(trainConfig['batchSize']),
                 max_epochs=eval(trainConfig['epochs']),
                 image_size=eval(trainConfig['imageSize']),
                 early_stop_patience=eval(trainConfig['earlyStopPatience']),
                 unet=unet,
                 head=head,
                 resume=resume_checkpoint,
                 **loss_config)