Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.optim as optim | |
| from tqdm import trange | |
| import os | |
| from tensorboardX import SummaryWriter | |
| import numpy as np | |
| import cv2 | |
| from loss import SGMLoss,SGLoss | |
| from valid import valid,dump_train_vis | |
| import sys | |
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| sys.path.insert(0, ROOT_DIR) | |
| from utils import train_utils | |
| def train_step(optimizer, model, match_loss, data,step,pre_avg_loss): | |
| data['step']=step | |
| result=model(data,test_mode=False) | |
| loss_res=match_loss.run(data,result) | |
| optimizer.zero_grad() | |
| loss_res['total_loss'].backward() | |
| #apply reduce on all record tensor | |
| for key in loss_res.keys(): | |
| loss_res[key]=train_utils.reduce_tensor(loss_res[key],'mean') | |
| if loss_res['total_loss']<7*pre_avg_loss or step<200 or pre_avg_loss==0: | |
| optimizer.step() | |
| unusual_loss=False | |
| else: | |
| optimizer.zero_grad() | |
| unusual_loss=True | |
| return loss_res,unusual_loss | |
| def train(model, train_loader, valid_loader, config,model_config): | |
| model.train() | |
| optimizer = optim.Adam(model.parameters(), lr=config.train_lr) | |
| if config.model_name=='SGM': | |
| match_loss = SGMLoss(config,model_config) | |
| elif config.model_name=='SG': | |
| match_loss= SGLoss(config,model_config) | |
| else: | |
| raise NotImplementedError | |
| checkpoint_path = os.path.join(config.log_base, 'checkpoint.pth') | |
| config.resume = os.path.isfile(checkpoint_path) | |
| if config.resume: | |
| if config.local_rank==0: | |
| print('==> Resuming from checkpoint..') | |
| checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(config.local_rank)) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| best_acc = checkpoint['best_acc'] | |
| start_step = checkpoint['step'] | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| else: | |
| best_acc = -1 | |
| start_step = 0 | |
| train_loader_iter = iter(train_loader) | |
| if config.local_rank==0: | |
| writer=SummaryWriter(os.path.join(config.log_base,'log_file')) | |
| train_loader.sampler.set_epoch(start_step*config.train_batch_size//len(train_loader.dataset)) | |
| pre_avg_loss=0 | |
| progress_bar=trange(start_step, config.train_iter,ncols=config.tqdm_width) if config.local_rank==0 else range(start_step, config.train_iter) | |
| for step in progress_bar: | |
| try: | |
| train_data = next(train_loader_iter) | |
| except StopIteration: | |
| if config.local_rank==0: | |
| print('epoch: ',step*config.train_batch_size//len(train_loader.dataset)) | |
| train_loader.sampler.set_epoch(step*config.train_batch_size//len(train_loader.dataset)) | |
| train_loader_iter = iter(train_loader) | |
| train_data = next(train_loader_iter) | |
| train_data = train_utils.tocuda(train_data) | |
| lr=min(config.train_lr*config.decay_rate**(step-config.decay_iter),config.train_lr) | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| # run training | |
| loss_res,unusual_loss = train_step(optimizer, model, match_loss, train_data,step-start_step,pre_avg_loss) | |
| if (step-start_step)<=200: | |
| pre_avg_loss=loss_res['total_loss'].data | |
| if (step-start_step)>200 and not unusual_loss: | |
| pre_avg_loss=pre_avg_loss.data*0.9+loss_res['total_loss'].data*0.1 | |
| if unusual_loss and config.local_rank==0: | |
| print('unusual loss! pre_avg_loss: ',pre_avg_loss,'cur_loss: ',loss_res['total_loss'].data) | |
| #log | |
| if config.local_rank==0 and step%config.log_intv==0 and not unusual_loss: | |
| writer.add_scalar('TotalLoss',loss_res['total_loss'],step) | |
| writer.add_scalar('CorrLoss',loss_res['loss_corr'],step) | |
| writer.add_scalar('InCorrLoss', loss_res['loss_incorr'], step) | |
| writer.add_scalar('dustbin', model.module.dustbin, step) | |
| if config.model_name=='SGM': | |
| writer.add_scalar('SeedConfLoss', loss_res['loss_seed_conf'], step) | |
| writer.add_scalar('MidCorrLoss', loss_res['loss_corr_mid'].sum(), step) | |
| writer.add_scalar('MidInCorrLoss', loss_res['loss_incorr_mid'].sum(), step) | |
| # valid ans save | |
| b_save = ((step + 1) % config.save_intv) == 0 | |
| b_validate = ((step + 1) % config.val_intv) == 0 | |
| if b_validate: | |
| total_loss,acc_corr,acc_incorr,seed_precision_tower,seed_recall_tower,acc_mid=valid(valid_loader, model, match_loss, config,model_config) | |
| if config.local_rank==0: | |
| writer.add_scalar('ValidAcc', acc_corr, step) | |
| writer.add_scalar('ValidLoss', total_loss, step) | |
| if config.model_name=='SGM': | |
| for i in range(len(seed_recall_tower)): | |
| writer.add_scalar('seed_conf_pre_%d'%i,seed_precision_tower[i],step) | |
| writer.add_scalar('seed_conf_recall_%d' % i, seed_precision_tower[i], step) | |
| for i in range(len(acc_mid)): | |
| writer.add_scalar('acc_mid%d'%i,acc_mid[i],step) | |
| print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data,'seed_conf_pre: ',seed_precision_tower.mean().data, | |
| 'seed_conf_recall: ',seed_recall_tower.mean().data,'acc_mid: ',acc_mid.mean().data) | |
| else: | |
| print('acc_corr: ',acc_corr.data,'acc_incorr: ',acc_incorr.data) | |
| #saving best | |
| if acc_corr > best_acc: | |
| print("Saving best model with va_res = {}".format(acc_corr)) | |
| best_acc = acc_corr | |
| save_dict={'step': step + 1, | |
| 'state_dict': model.state_dict(), | |
| 'best_acc': best_acc, | |
| 'optimizer' : optimizer.state_dict()} | |
| save_dict.update(save_dict) | |
| torch.save(save_dict, os.path.join(config.log_base, 'model_best.pth')) | |
| if b_save: | |
| if config.local_rank==0: | |
| save_dict={'step': step + 1, | |
| 'state_dict': model.state_dict(), | |
| 'best_acc': best_acc, | |
| 'optimizer' : optimizer.state_dict()} | |
| torch.save(save_dict, checkpoint_path) | |
| #draw match results | |
| model.eval() | |
| with torch.no_grad(): | |
| if config.local_rank==0: | |
| if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis')): | |
| os.mkdir(os.path.join(config.train_vis_folder,'train_vis')) | |
| if not os.path.exists(os.path.join(config.train_vis_folder,'train_vis',config.log_base)): | |
| os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base)) | |
| os.mkdir(os.path.join(config.train_vis_folder,'train_vis',config.log_base,str(step))) | |
| res=model(train_data) | |
| dump_train_vis(res,train_data,step,config) | |
| model.train() | |
| if config.local_rank==0: | |
| writer.close() | |