File size: 4,683 Bytes
			
			476daa5 a27e593 286a978 476daa5 a27e593 476daa5 a27e593 476daa5 a27e593 476daa5 a27e593 476daa5 a27e593 476daa5 286a978 476daa5 286a978 476daa5 a27e593 476daa5 286a978 476daa5 286a978 a27e593 476daa5 286a978 a27e593 476daa5 a27e593 286a978  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109  | 
								import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing
import argparse
from configparser import ConfigParser
from shutil import copy2
import os
from datetime import datetime
import DeepDeformationMapRegistration.utils.constants as C
import re
from COMET.augmentation_constants import LAYER_SELECTION
TRAIN_DATASET = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/train'
err = list()
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--ini', help='Configuration file')
    args = parser.parse_args()
    configFile = ConfigParser()
    configFile.read(args.ini)
    print('Loaded configuration file: ' + args.ini)
    print({section: dict(configFile[section]) for section in configFile.sections()})
    print('\n\n')
    trainConfig = configFile['TRAIN']
    lossesConfig = configFile['LOSSES']
    datasetConfig = configFile['DATASETS']
    othersConfig = configFile['OTHERS']
    augmentationConfig = configFile['AUGMENTATION']
    try:
        print('TRAIN MODEL IN' + trainConfig['model'])
    except KeyError as er:
            trainConfig['model'] = ''
    simil = lossesConfig['similarity'].split(',')
    segm = lossesConfig['segmentation'].split(',')
    if trainConfig['name'].lower() == 'uw':
        from COMET.COMET_train_UW import launch_train
        output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], '_'.join(simil), '_'.join(segm)))
    elif trainConfig['name'].lower() == 'segguided':
        from COMET.COMET_train_seggguided import launch_train
        simil = simil[0]
        segm = segm[0]
        output_folder = os.path.join(othersConfig['outputFolder'],
                                     '{}_Lsim_{}__Lseg_{}'.format(trainConfig['name'], simil, segm))
    else:
        from COMET.COMET_train import launch_train
        simil = simil[0]
        segm = segm[0]
        output_folder = os.path.join(othersConfig['outputFolder'], '{}_Lsim_{}'.format(trainConfig['name'], simil))
    output_folder = output_folder + '_' + datetime.now().strftime("%H%M%S-%d%m%Y")
    try:
        froozen_layers = eval(trainConfig['freeze'])
    except KeyError as err:
        froozen_layers = None
    except NameError as err:
        froozen_layers = list(filter(lambda x: x != '', re.split(';|\s|,|,\s|;\s', trainConfig['freeze'].upper())))
    if froozen_layers is not None:
        assert all(s in LAYER_SELECTION.keys() for s in froozen_layers), \
            'Invalid option for "freeze". Expected one or several of: ' + ', '.join(LAYER_SELECTION.keys())
        froozen_layers = list(set(froozen_layers))  # Unique elements
    if augmentationConfig:
        C.GAMMA_AUGMENTATION = augmentationConfig['gamma'].lower() == 'true'
        C.BRIGHTNESS_AUGMENTATION = augmentationConfig['brightness'].lower() == 'true'
    # copy the configuration file to the destionation folder
    os.makedirs(output_folder, exist_ok=True) # TODO: move this within the "resume" if case, and bring here the creation of the resume-output folder!
    copy2(args.ini, os.path.join(output_folder, os.path.split(args.ini)[-1]))
    try:
        unet = [int(x) for x in trainConfig['unet'].split(',')] if trainConfig['unet'] else [16, 32, 64, 128, 256]
        head = [int(x) for x in trainConfig['head'].split(',')] if trainConfig['head'] else [16, 16]
    except KeyError as err:
        unet = [16, 32, 64, 128, 256]
        head = [16, 16]
    try:
        resume_checkpoint = trainConfig['resumeCheckpoint']
    except KeyError as e:
        resume_checkpoint = None
    launch_train(dataset_folder=datasetConfig['train'],
                 validation_folder=datasetConfig['validation'],
                 output_folder=output_folder,
                 gpu_num=eval(trainConfig['gpu']),
                 lr=eval(trainConfig['learningRate']),
                 rw=eval(trainConfig['regularizationWeight']),
                 simil=simil,
                 segm=segm,
                 max_epochs=eval(trainConfig['epochs']),
                 image_size=eval(trainConfig['imageSize']),
                 early_stop_patience=eval(trainConfig['earlyStopPatience']),
                 model_file=trainConfig['model'],
                 freeze_layers=froozen_layers,
                 acc_gradients=eval(trainConfig['accumulativeGradients']),
                 batch_size=eval(trainConfig['batchSize']),
                 unet=unet,
                 head=head,
                 resume=resume_checkpoint)
 |