File size: 2,950 Bytes
6477265 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
"""
Module for training the 2D clutter filtering model with L2 loss.
"""
import os
import argparse
import json
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from utils import *
from Model_ClutterFilter2D import clutter_filter_2D
from DataGen import DataGen
def data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config):
DtaGenTr_prm = {
'dim': config["network_prm"]["input_dim"],
'in_dir': in_ids_tr,
'out_dir': out_ids_tr,
'id_list': np.arange(len(in_ids_tr)),
'batch_size': config["learning_prm"]["batch_size"],
'tr_phase': True}
DtaGenVal_prm = {
'dim': config["network_prm"]["input_dim"],
'in_dir': in_ids_val,
'out_dir': out_ids_val,
'id_list': np.arange(len(in_ids_val)),
'batch_size': config["learning_prm"]["batch_size"],
'tr_phase': True}
tr_gen = DataGen(**DtaGenTr_prm)
val_gen = DataGen(**DtaGenVal_prm)
return tr_gen, val_gen
def model_chkpnt(val_subject, te_subject, weight_dir, config):
weight_name = (
f'CF2D_ValTeSbj_{val_subject}_{te_subject}_nLvl{config["network_prm"]["n_levels"]}'
f'_InSkp{config["network_prm"]["in_skip"]}_Att{config["network_prm"]["attention"]}'
f'_Act{config["network_prm"]["act"]}_nInitFlt{config["network_prm"]["n_init_filters"]}_lr{config["learning_prm"]["lr"]}')
filepath = (weight_dir + '/'+ weight_name +
'_epc' + "{epoch:03d}" + '_trloss' + "{loss:.5f}" +
'_valloss' + "{val_loss:.5f}" + ".hdf5")
model_checkpoint = ModelCheckpoint(filepath=filepath,
monitor="val_loss",
verbose=0,
save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1,
patience=4, min_lr=1e-7)
return model_checkpoint, reduce_lr
def main(config):
in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject = id_preparation(config)
weight_dir = create_weight_dir(val_subject, te_subject, config)
tr_gen, val_gen = data_generation(in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, config)
model = clutter_filter_2D(**config)
model_checkpoint, reduce_lr = model_chkpnt(val_subject, te_subject, weight_dir, config)
model.fit(tr_gen,
validation_data=val_gen,
epochs=config["learning_prm"]["n_epochs"],
verbose=1,
callbacks=[model_checkpoint, reduce_lr])
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", help="path of the config file", default="config.json")
args = parser.parse_args()
assert os.path.isfile(args.config)
with open(args.config, "r") as read_file:
config = json.load(read_file)
main(config) |