DexinedApp / legacy /train.py
Dinars34's picture
Upload 60 files
89c5d90 verified
import tensorflow as tf
import matplotlib.pyplot as plt
from models.dexined import dexined
# from models.dexinedBs import dexined
from utls.utls import *
from utls.dataset_manager import (data_parser,
get_training_batch,get_validation_batch, visualize_result)
class m_trainer():
def __init__(self,args ):
self.init = True
self.args = args
def setup(self):
try:
if self.args.model_name=='DXN':
self.model = dexined(self.args)
else:
print_error("Error setting model, {}".format(self.args.model_name))
print_info("DL model Set")
except Exception as err:
print_error("Error setting up DL model, {}".format(err))
self.init=False
def run(self, sess):
if not self.init:
return
train_data = data_parser(self.args)
self.model.setup_training(sess)
if self.args.lr_scheduler is not None:
global_step = tf.Variable(0, trainable=False, dtype=tf.int64)
if self.args.lr_scheduler is None:
learning_rate = tf.constant(self.args.learning_rate, dtype=tf.float16)
else:
raise NotImplementedError('Learning rate scheduler type [%s] is not implemented',
self.args.lr_scheduler)
opt = tf.compat.v1.train.AdamOptimizer(learning_rate)
trainG = opt.minimize(self.model.loss)# like hed
saver = tf.compat.v1.train.Saver(max_to_keep=7)
sess.run(tf.compat.v1.global_variables_initializer())
# here to recovery previous training
if self.args.use_previous_trained:
if self.args.dataset_name.lower()!='biped': # using biped pretrained to use in other dataset
model_path = os.path.join(self.args.checkpoint_dir,self.args.model_name+
'_'+self.args.train_dataset,'train')
else:
model_path = os.path.join(self.args.checkpoint_dir, self.args.model_name + '_' + self.args.train_dataset)
model_path = os.path.join(model_path, 'train')
if not os.path.exists(model_path) or len(os.listdir(model_path))==0: # :
ini = 0
maxi = self.args.max_iterations+1
print_warning('There is not previous trained data for the current model... and')
print_warning('*** The training process is starting from scratch ***')
else:
# restoring using the last checkpoint
assert (len(os.listdir(model_path)) != 0),'There is not previous trained data for the current model...'
last_ckpt = tf.train.latest_checkpoint(model_path)
saver.restore(sess,last_ckpt)
ini=self.args.max_iterations
maxi=ini+self.args.max_iterations+1 # check
print_info('--> Previous model restored successfully: {}'.format(last_ckpt))
else:
print_warning('*** The training process is starting from scratch ***')
ini = 0
maxi = ini + self.args.max_iterations
prev_loss=1000.
prev_val = None
# directories for checkpoints
checkpoint_dir = os.path.join(
self.args.checkpoint_dir, self.args.model_name + '_' + self.args.train_dataset,
self.args.model_state)
os.makedirs(checkpoint_dir,exist_ok=True)
fig = plt.figure()
for idx in range(ini, maxi):
x_batch, y_batch,_ = get_training_batch(self.args, train_data)
run_metadata = tf.compat.v1.RunMetadata()
_, summary, loss,pred_maps= sess.run(
[trainG, self.model.merged_summary, self.model.loss, self.model.predictions],
feed_dict={self.model.images: x_batch, self.model.edgemaps: y_batch})
if idx%5==0:
self.model.train_writer.add_run_metadata(run_metadata,
'step{:06}'.format(idx))
self.model.train_writer.add_summary(summary, idx)
print(time.ctime(), '[{}/{}]'.format(idx, maxi), ' TRAINING loss: %.5f' % loss,
'prev_loss: %.5f' % prev_loss)
# saving trained parameters
save_inter = ini+self.args.save_interval
if prev_loss>loss:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Weights saved in the lowest loss",idx, " Current Loss",prev_loss)
if idx % self.args.save_interval == 0:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Weights saved in the interval", idx, " Current Loss",prev_loss)
# ********* for validation **********
if (idx+1) % self.args.val_interval== 0:
pause_show=0.01
imgs_list = []
img = x_batch[2][:,:,0:3]
gt_mp= y_batch[2]
imgs_list.append(img)
imgs_list.append(gt_mp)
for i in range(len(pred_maps)):
tmp=pred_maps[i][2,...]
imgs_list.append(tmp)
vis_imgs = visualize_result(imgs_list, self.args)
fig.suptitle("Iterac:" + str(idx + 1) + " Loss:" + '%.5f' % loss + " training")
fig.add_subplot(1,1,1)
plt.imshow(np.uint8(vis_imgs))
print("Evaluation in progress...")
plt.draw()
plt.pause(pause_show)
im, em, _ = get_validation_batch(self.args, train_data)
summary, error, pred_val = sess.run(
[self.model.merged_summary, self.model.error, self.model.fuse_output],
feed_dict={self.model.images: im, self.model.edgemaps: em})
if error<=0.08:
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
prev_loss = loss
print("Parameters saved in the validation stage when its error is <=0.08::", error)
self.model.val_writer.add_summary(summary, idx)
print_info(('[{}/{}]'.format(idx, self.args.max_iterations),'VALIDATION error: %0.5f'%error,
'pError: %.5f'%prev_loss))
if (idx+1) % (self.args.val_interval*150)== 0:
print('updating visualisation')
plt.close()
fig = plt.figure()
saver.save(sess, os.path.join(checkpoint_dir, self.args.model_name), global_step=idx)
print("Final Weights saved", idx, " Current Loss", loss)
self.model.train_writer.close()
sess.close()