File size: 2,950 Bytes
6477265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
""" 
Module for training the 2D clutter filtering model with L2 loss.
"""
import os
import argparse
import json
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

from utils import *
from Model_ClutterFilter2D import clutter_filter_2D
from DataGen import DataGen

def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
    DtaGenTr_prm = {
        'dim': config["network_prm"]["input_dim"],
        'in_dir': in_ids_tr,
        'out_dir': out_ids_tr,
        'id_list': np.arange(len(in_ids_tr)),
        'batch_size': config["learning_prm"]["batch_size"],
        'tr_phase': True}
    DtaGenVal_prm = {
        'dim': config["network_prm"]["input_dim"],
        'in_dir': in_ids_val,
        'out_dir': out_ids_val,
        'id_list': np.arange(len(in_ids_val)),
        'batch_size': config["learning_prm"]["batch_size"],
        'tr_phase': True}
    tr_gen = DataGen(**DtaGenTr_prm)
    val_gen = DataGen(**DtaGenVal_prm)
    return tr_gen, val_gen

def model_chkpnt(val_subject, te_subject, weight_dir, config):
    weight_name = (
            f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
            f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
            f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}')
    filepath = (weight_dir + '/'+  weight_name +
                '_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
                '_valloss' + "{val_loss:.5f}" + ".hdf5")
    model_checkpoint = ModelCheckpoint(filepath=filepath,
                                       monitor="val_loss",
                                       verbose=0,
                                       save_best_only=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
                                  patience=4, min_lr=1e-7)
    return model_checkpoint, reduce_lr

def main(config):
    in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
    weight_dir = create_weight_dir(val_subject, te_subject, config)
    tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
    model = clutter_filter_2D(**config)
    model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
    model.fit(tr_gen,
              validation_data=val_gen,
              epochs=config["learning_prm"]["n_epochs"],
              verbose=1,
              callbacks=[model_checkpoint, reduce_lr])
    return None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", help="path of the config file", default="config.json")
    args = parser.parse_args()
    assert os.path.isfile(args.config)
    with open(args.config, "r") as read_file:
        config = json.load(read_file)
    main(config)