Spaces:
Running
Running
import torch | |
import torch.optim as optim | |
from tqdm import trange | |
import os | |
from tensorboardX import SummaryWriter | |
import numpy as np | |
import cv2 | |
from loss import SGMLoss,SGLoss | |
from valid import valid,dump_train_vis | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from utils import train_utils | |
def train_step(optimizer, model, match_loss, data,step,pre_avg_loss): | |
data['step']=step | |
result=model(data,test_mode=False) | |
loss_res=match_loss.run(data,result) | |
optimizer.zero_grad() | |
loss_res['total_loss'].backward() | |
#apply reduce on all record tensor | |
for key in loss_res.keys(): | |
loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean') | |
if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0: | |
optimizer.step() | |
unusual_loss=False | |
else: | |
optimizer.zero_grad() | |
unusual_loss=True | |
return loss_res,unusual_loss | |
def train(model, train_loader, valid_loader, config,model_config): | |
model.train() | |
optimizer = optim.Adam(model.parameters(), lr=config.train_lr) | |
if config.model_name=='SGM': | |
match_loss = SGMLoss(config,model_config) | |
elif config.model_name=='SG': | |
match_loss= SGLoss(config,model_config) | |
else: | |
raise NotImplementedError | |
checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth') | |
config.resume = os.path.isfile(checkpoint_path) | |
if config.resume: | |
if config.local_rank==0: | |
print('==> Resuming from checkpoint..') | |
checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank)) | |
model.load_state_dict(checkpoint['state_dict']) | |
best_acc = checkpoint['best_acc'] | |
start_step = checkpoint['step'] | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
else: | |
best_acc = -1 | |
start_step = 0 | |
train_loader_iter = iter(train_loader) | |
if config.local_rank==0: | |
writer=SummaryWriter(os.path.join(config.log_base,'log_file')) | |
train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset)) | |
pre_avg_loss=0 | |
progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter) | |
for step in progress_bar: | |
try: | |
train_data = next(train_loader_iter) | |
except StopIteration: | |
if config.local_rank==0: | |
print('epoch: ',step*config.train_batch_size//len(train_loader.dataset)) | |
train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset)) | |
train_loader_iter = iter(train_loader) | |
train_data = next(train_loader_iter) | |
train_data = train_utils.tocuda(train_data) | |
lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
# run training | |
loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss) | |
if (step-start_step)<=200: | |
pre_avg_loss=loss_res['total_loss'].data | |
if (step-start_step)>200 and not unusual_loss: | |
pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1 | |
if unusual_loss and config.local_rank==0: | |
print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data) | |
#log | |
if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss: | |
writer.add_scalar('TotalLoss',loss_res['total_loss'],step) | |
writer.add_scalar('CorrLoss',loss_res['loss_corr'],step) | |
writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step) | |
writer.add_scalar('dustbin', model.module.dustbin, step) | |
if config.model_name=='SGM': | |
writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step) | |
writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step) | |
writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step) | |
# valid ans save | |
b_save = ((step + 1) % config.save_intv) == 0 | |
b_validate = ((step + 1) % config.val_intv) == 0 | |
if b_validate: | |
total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config) | |
if config.local_rank==0: | |
writer.add_scalar('ValidAcc', acc_corr, step) | |
writer.add_scalar('ValidLoss', total_loss, step) | |
if config.model_name=='SGM': | |
for i in range(len(seed_recall_tower)): | |
writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step) | |
writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step) | |
for i in range(len(acc_mid)): | |
writer.add_scalar('acc_mid%d'%i,acc_mid[i],step) | |
print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data, | |
'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data) | |
else: | |
print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data) | |
#saving best | |
if acc_corr > best_acc: | |
print("Saving best model with va_res = {}".format(acc_corr)) | |
best_acc = acc_corr | |
save_dict={'step': step + 1, | |
'state_dict': model.state_dict(), | |
'best_acc': best_acc, | |
'optimizer' : optimizer.state_dict()} | |
save_dict.update(save_dict) | |
torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth')) | |
if b_save: | |
if config.local_rank==0: | |
save_dict={'step': step + 1, | |
'state_dict': model.state_dict(), | |
'best_acc': best_acc, | |
'optimizer' : optimizer.state_dict()} | |
torch.save(save_dict, checkpoint_path) | |
#draw match results | |
model.eval() | |
with torch.no_grad(): | |
if config.local_rank==0: | |
if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')): | |
os.mkdir(os.path.join(config.train_vis_folder,'train_vis')) | |
if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)): | |
os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base)) | |
os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step))) | |
res=model(train_data) | |
dump_train_vis(res,train_data,step,config) | |
model.train() | |
if config.local_rank==0: | |
writer.close() | |