SandSnapModelDemo / app_files /src /sedinet_infer.py
dbuscombe's picture
v1
d86998c
raw
history blame
25.7 kB
# Written by Dr Daniel Buscombe, Marda Science LLC
# for the SandSnap Program
#
# MIT License
#
# Copyright (c) 2020-2021, Marda Science LLC
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
##> Release v1.4 (Aug 2021)
from sedinet_models import *
###===================================================
def run_training_siso_simo(vars, train_csvfile, test_csvfile, val_csvfile, name, res_folder,
mode, greyscale, dropout, numclass): #scale
"""
This function generates, trains and evaluates a sedinet model for
continuous prediction
"""
if numclass>0:
ID_MAP = dict(zip(np.arange(numclass), [str(k) for k in range(numclass)]))
# ##======================================
# ## this randomly selects imagery for training and testing imagery sets
# ## while also making sure that both training and tetsing sets have
# ## at least 3 examples of each category
# train_idx, train_df, _ = get_df(train_csvfile,fortrain=True)
# test_idx, test_df, _ = get_df(test_csvfile,fortrain=True)
##==============================================
## create a sedinet model to estimate category
if numclass>0:
SM = make_cat_sedinet(ID_MAP, dropout)
else:
SM = make_sedinet_siso_simo(vars, greyscale, dropout)
# if scale==True:
# CS = []
# for var in vars:
# cs = RobustScaler() ##alternative = MinMaxScaler()
# cs.fit_transform(
# np.r_[train_df[var].values, test_df[var].values].reshape(-1,1)
# )
# CS.append(cs)
# del cs
# else:
# CS = []
##==============================================
## train model
if numclass==0:
if type(BATCH_SIZE)==list:
SMs = []; weights_path = []
for batch_size, valid_batch_size in zip(BATCH_SIZE, VALID_BATCH_SIZE):
sm, wp,train_df, test_df, val_df, train_idx, test_idx, val_idx = train_sedinet_siso_simo(SM, name,
train_csvfile, test_csvfile, val_csvfile, vars, mode, greyscale, #CS,
dropout, batch_size, valid_batch_size,
res_folder)#, scale)
SMs.append(sm)
weights_path.append(wp)
gc.collect()
else:
SM, weights_path,train_df, test_df, val_df, train_idx, test_idx, val_idx = train_sedinet_siso_simo(SM, name,
train_csvfile, test_csvfile, val_csvfile, vars, mode, greyscale, #CS,
dropout, BATCH_SIZE, VALID_BATCH_SIZE,
res_folder)#, scale)
else:
if type(BATCH_SIZE)==list:
SMs = []; weights_path = []
for batch_size, valid_batch_size in zip(BATCH_SIZE, VALID_BATCH_SIZE):
sm, wp = train_sedinet_cat(SM, train_df, test_df, train_idx,
test_idx, ID_MAP, vars, greyscale, name, mode,
batch_size, valid_batch_size, res_folder)
SMs.append(sm)
weights_path.append(wp)
gc.collect()
else:
SM, weights_path = train_sedinet_cat(SM, train_df, test_df, train_idx,
test_idx, ID_MAP, vars, greyscale, name, mode,
BATCH_SIZE, VALID_BATCH_SIZE, res_folder)
classes = np.arange(len(ID_MAP))
K.clear_session()
##==============================================
# test model
do_aug = False
for_training = False
if type(test_df)==list:
print('Reading in all train files and memory mapping in batches ... takes a while')
test_gen = []
for df,id in zip(test_df,test_idx):
test_gen.append(get_data_generator_Nvars_siso_simo(df, id, for_training,
vars, len(id), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)) #CS,
x_test = []; test_vals = []; files = []
for gen in test_gen:
a, b = next(gen)
outfile = TemporaryFile()
files.append(outfile)
dt = a.dtype; sh = a.shape
fp = np.memmap(outfile, dtype=dt, mode='w+', shape=sh)
fp[:] = a[:]
fp.flush()
del a
del fp
a = np.memmap(outfile, dtype=dt, mode='r', shape=sh)
x_test.append(a)
test_vals.append(b)
else:
# train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
# vars, len(train_idx), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)#CS,
# x_train, train_vals = next(train_gen)
test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
vars, len(test_idx), greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT)
x_test, test_vals = next(test_gen)
# if numclass==0:
# # suffix = 'train'
# if type(BATCH_SIZE)==list:
# count_in = 0
# predict_train_siso_simo(x_train, train_vals, vars, #train_df, test_df, train_idx, test_idx, vars, x_test, test_vals,
# SMs, weights_path, name, mode, greyscale,# CS,
# dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
# else:
# if type(x_train)==list:
# for count_in, (a, b) in enumerate(zip(x_train, train_vals)): #x_test, test_vals
# predict_train_siso_simo(a, b, vars, #train_df, test_df, train_idx, test_idx, vars, c, d,
# SM, weights_path, name, mode, greyscale,# CS,
# dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
# plot_all_save_all(weights_path, vars)
# else:
# count_in = 0; consolidate = False
# predict_train_siso_simo(x_train, train_vals, vars, #train_df, test_df, train_idx, test_idx, vars, x_test, test_vals,
# SM, weights_path, name, mode, greyscale,# CS,
# dropout, DO_AUG,DO_STANDARDIZE, count_in)#scale,
if numclass==0:
if type(BATCH_SIZE)==list:
count_in = 0
predict_train_siso_simo(x_test, test_vals, vars,
SMs, weights_path, name, mode, greyscale,
dropout, DO_AUG,DO_STANDARDIZE, count_in)
else:
if type(x_test)==list:
for count_in, (a, b) in enumerate(zip(x_test, test_vals)):
predict_train_siso_simo(a, b, vars,
SM, weights_path, name, mode, greyscale,
dropout, DO_AUG,DO_STANDARDIZE, count_in)
plot_all_save_all(weights_path, vars)
else:
count_in = 0; #consolidate = False
predict_train_siso_simo(x_test, test_vals, vars,
SM, weights_path, name, mode, greyscale,
dropout, DO_AUG,DO_STANDARDIZE, count_in)
else:
if type(BATCH_SIZE)==list:
predict_test_train_cat(train_df, test_df, train_idx, test_idx, vars[0],
SMs, [i for i in ID_MAP.keys()], weights_path, greyscale,
name, DO_AUG,DO_STANDARDIZE)
else:
predict_test_train_cat(train_df, test_df, train_idx, test_idx, vars[0],
SM, [i for i in ID_MAP.keys()], weights_path, greyscale,
name, DO_AUG,DO_STANDARDIZE)
K.clear_session()
#
##===================================
## move model files and plots to the results folder
tidy(name, res_folder)
###==================================
def train_sedinet_cat(SM, train_csvfile, test_csvfile, #train_df, test_df, train_idx, test_idx,
ID_MAP, vars, greyscale, name, mode, batch_size, valid_batch_size,
res_folder):
"""
This function trains an implementation of SediNet
"""
##================================
## create training and testing file generators, set the weights path,
## plot the model, and create a callback list for model training
for_training=True
train_gen = get_data_generator_1image(train_df, train_idx, for_training, ID_MAP,
vars[0], batch_size, greyscale, DO_AUG, DO_STANDARDIZE, IM_HEIGHT) ##BATCH_SIZE
do_aug = False
valid_gen = get_data_generator_1image(test_df, test_idx, for_training, ID_MAP,
vars[0], valid_batch_size, greyscale, do_aug, DO_STANDARDIZE, IM_HEIGHT) ##VALID_BATCH_SIZE
if SHALLOW is True:
if DO_AUG is True:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_shallow_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
else:
if DO_AUG is True:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_aug.hdf5"
else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_"+vars[0]+"_"+CAT_LOSS+"_noaug.hdf5"
if os.path.exists(weights_path):
SM.load_weights(weights_path)
print("==========================================")
print("Loading weights that already exist: %s" % (weights_path) )
print("Skipping model training")
elif os.path.exists(res_folder+os.sep+weights_path):
weights_path = res_folder+os.sep+weights_path
SM.load_weights(weights_path)
print("==========================================")
print("Loading weights that already exist: %s" % (weights_path) )
print("Skipping model training")
else:
try:
plot_model(SM, weights_path.replace('.hdf5', '_model.png'),
show_shapes=True, show_layer_names=True)
except:
pass
callbacks_list = [
ModelCheckpoint(weights_path, monitor='val_loss', verbose=1,
save_best_only=True, mode='min',
save_weights_only = True)
]
print("=========================================")
print("[INFORMATION] schematic of the model has been written out to: "+\
weights_path.replace('.hdf5', '_model.png'))
print("[INFORMATION] weights will be written out to: "+weights_path)
##==============================================
## set checkpoint file and parameters that control early stopping,
## and reduction of learning rate if and when validation
## scores plateau upon successive epochs
# reduceloss_plat = ReduceLROnPlateau(monitor='val_loss', factor=FACTOR,
# patience=STOP_PATIENCE, verbose=1, mode='auto', min_delta=MIN_DELTA,
# cooldown=STOP_PATIENCE, min_lr=MIN_LR)
#
earlystop = EarlyStopping(monitor="val_loss", mode="min", patience=10)
model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss',
verbose=1, save_best_only=True, mode='min',
save_weights_only = True)
##==============================================
## train the model
## with non-adaptive exponentially decreasing learning rate
#exponential_decay_fn = exponential_decay(MAX_LR, NUM_EPOCHS)
#lr_scheduler = LearningRateScheduler(exponential_decay_fn)
callbacks_list = [model_checkpoint, earlystop] #lr_scheduler
## train the model
history = SM.fit(train_gen,
steps_per_epoch=len(train_idx)//batch_size, ##BATCH_SIZE
epochs=NUM_EPOCHS,
callbacks=callbacks_list,
validation_data=valid_gen, #use_multiprocessing=True,
validation_steps=len(test_idx)//valid_batch_size) #max_queue_size=10 ##VALID_BATCH_SIZE
###===================================================
## Plot the loss and accuracy as a function of epoch
plot_train_history_1var(history)
# plt.savefig(vars+'_'+str(IM_HEIGHT)+'_batch'+str(batch_size)+'_history.png', ##BATCH_SIZE
# dpi=300, bbox_inches='tight')
plt.savefig(weights_path.replace('.hdf5','_history.png'),dpi=300, bbox_inches='tight')
plt.close('all')
# serialize model to JSON to use later to predict
model_json = SM.to_json()
with open(weights_path.replace('.hdf5','.json'), "w") as json_file:
json_file.write(model_json)
return SM, weights_path
###===================================================
def train_sedinet_siso_simo(SM, name, train_csvfile, test_csvfile, val_csvfile, #train_df, test_df, train_idx, test_idx,
vars, mode, greyscale, dropout, batch_size, valid_batch_size,#CS,
res_folder):#, scale):
"""
This function trains an implementation of sedinet
"""
##==============================================
## create training and testing file generators, set the weights path,
## plot the model, and create a callback list for model training
# get a string saying how many variables, fr the output files
varstring = str(len(vars))+'vars' #''.join([str(k)+'_' for k in vars])
# mae the appropriate weights file
if SHALLOW is True:
if DO_AUG is True:
# if len(CS)>0:#scale is True:
# weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
# "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_aug_scale.hdf5"
# else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_aug.hdf5"
else:
# if len(CS)>0:#scale is True:
# weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
# "_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_noaug_scale.hdf5"
# else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_shallow_"+varstring+"_"+CONT_LOSS+"_noaug.hdf5"
else:
if DO_AUG is True:
# if len(CS)>0:#scale is True:
# weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
# "_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_aug_scale.hdf5"
# else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_aug.hdf5"
else:
# if len(CS)>0:#scale is True:
# weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
# "_"+str(IM_WIDTH)+"_"+varstring+"_"+CONT_LOSS+"_noaug_scale.hdf5"
# else:
weights_path = name+"_"+mode+"_batch"+str(batch_size)+"_im"+str(IM_HEIGHT)+\
"_"+varstring+"_"+CONT_LOSS+"_noaug.hdf5"
# if it already exists, skip training
if os.path.exists(weights_path):
SM.load_weights(weights_path)
print("==========================================")
print("Loading weights that already exist: %s" % (weights_path) )
print("Skipping model training")
##======================================
## this randomly selects imagery for training and testing imagery sets
## while also making sure that both training and tetsing sets have
## at least 3 examples of each category
train_idx, train_df, _ = get_df(train_csvfile,fortrain=False)
test_idx, test_df, _ = get_df(test_csvfile,fortrain=False)
val_idx, test_df, _ = get_df(val_csvfile,fortrain=False)
for_training = False
train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
vars, batch_size, greyscale,
DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
do_aug = False
valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
vars, valid_batch_size, greyscale,
do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
# do_aug = False
# test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
# vars, valid_batch_size, greyscale,
# do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
# if it already exists in res_folder, skip training
elif os.path.exists(res_folder+os.sep+weights_path):
weights_path = res_folder+os.sep+weights_path
SM.load_weights(weights_path)
print("==========================================")
print("Loading weights that already exist: %s" % (weights_path) )
print("Skipping model training")
##======================================
## this randomly selects imagery for training and testing imagery sets
## while also making sure that both training and tetsing sets have
## at least 3 examples of each category
train_idx, train_df, _ = get_df(train_csvfile,fortrain=False)
test_idx, test_df, _ = get_df(test_csvfile,fortrain=False)
val_idx, val_df, _ = get_df(val_csvfile,fortrain=False)
for_training = False
train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
vars, batch_size, greyscale,
DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
do_aug = False
valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
vars, valid_batch_size, greyscale,
do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
# do_aug = False
# test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
# vars, valid_batch_size, greyscale,
# do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
else: #train
##======================================
## this randomly selects imagery for training and testing imagery sets
## while also making sure that both training and tetsing sets have
## at least 3 examples of each category
train_idx, train_df, _ = get_df(train_csvfile,fortrain=True)
test_idx, test_df, _ = get_df(test_csvfile,fortrain=True)
val_idx, val_df, _ = get_df(val_csvfile,fortrain=True)
for_training = True
train_gen = get_data_generator_Nvars_siso_simo(train_df, train_idx, for_training,
vars, batch_size, greyscale,
DO_AUG, DO_STANDARDIZE, IM_HEIGHT) # CS,
# do_aug = False
# test_gen = get_data_generator_Nvars_siso_simo(test_df, test_idx, for_training,
# vars, valid_batch_size, greyscale,
# do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
do_aug = False
valid_gen = get_data_generator_Nvars_siso_simo(val_df, val_idx, for_training,
vars, valid_batch_size, greyscale,
do_aug, DO_STANDARDIZE, IM_HEIGHT) ##only augment training # CS,
# if scaler=true (CS=[]), dump out scalers to pickle file
# if len(CS)==0:
# pass
# else:
# joblib.dump(CS, weights_path.replace('.hdf5','_scaler.pkl'))
# print('Wrote scaler to pkl file')
try: # plot the model if pydot/graphviz installed
plot_model(SM, weights_path.replace('.hdf5', '_model.png'),
show_shapes=True, show_layer_names=True)
print("model schematic written to: "+\
weights_path.replace('.hdf5', '_model.png'))
except:
pass
print("==========================================")
print("weights will be written out to: "+weights_path)
##==============================================
## set checkpoint file and parameters that control early stopping,
## and reduction of learning rate if and when validation scores plateau upon successive epochs
# reduceloss_plat = ReduceLROnPlateau(monitor='val_loss', factor=FACTOR,
# patience=STOP_PATIENCE, verbose=1, mode='auto',
# min_delta=MIN_DELTA, cooldown=5,
# min_lr=MIN_LR)
earlystop = EarlyStopping(monitor="val_loss", mode="min",
patience=10)
# set model checkpoint. only save best weights, based on min validation loss
model_checkpoint = ModelCheckpoint(weights_path, monitor='val_loss', verbose=1,
save_best_only=True, mode='min',
save_weights_only = True)
#tqdm_callback = tfa.callbacks.TQDMProgressBar()
# callbacks_list = [model_checkpoint, reduceloss_plat, earlystop] #, tqdm_callback]
try: #write summary of the model to txt file
with open(weights_path.replace('.hdf5','') + '_report.txt','w') as fh:
# Pass the file handle in as a lambda function to make it callable
SM.summary(print_fn=lambda x: fh.write(x + '\n'))
fh.close()
print("model summary written to: "+ \
weights_path.replace('.hdf5','') + '_report.txt')
with open(weights_path.replace('.hdf5','') + '_report.txt','r') as fh:
tmp = fh.readlines()
print("===============================================")
print("Total parameters: %s" %\
(''.join(tmp).split('Total params:')[-1].split('\n')[0]))
fh.close()
print("===============================================")
except:
pass
##==============================================
## train the model
## non-adaptive exponentially decreasing learning rate
# exponential_decay_fn = exponential_decay(MAX_LR, NUM_EPOCHS)
#lr_scheduler = LearningRateScheduler(exponential_decay_fn)
callbacks_list = [model_checkpoint, earlystop] #lr_scheduler
## train the model
history = SM.fit(train_gen,
steps_per_epoch=len(train_idx)//batch_size, ##BATCH_SIZE
epochs=NUM_EPOCHS,
callbacks=callbacks_list,
validation_data=valid_gen, #use_multiprocessing=True,
validation_steps=len(val_idx)//valid_batch_size) #max_queue_size=10 ##VALID_BATCH_SIZE
###===================================================
## Plot the loss and accuracy as a function of epoch
if len(vars)==1:
plot_train_history_1var_mae(history)
else:
plot_train_history_Nvar(history, vars, len(vars))
varstring = ''.join([str(k)+'_' for k in vars])
plt.savefig(weights_path.replace('.hdf5', '_history.png'), dpi=300,
bbox_inches='tight')
plt.close('all')
# serialize model to JSON to use later to predict
model_json = SM.to_json()
with open(weights_path.replace('.hdf5','.json'), "w") as json_file:
json_file.write(model_json)
return SM, weights_path,train_df, test_df, val_df, train_idx, test_idx, val_idx
#