import torch import torch.optim as optim from tqdm import trange import os from tensorboardX import SummaryWriter import numpy as np import cv2 from loss import SGMLoss, SGLoss from valid import valid, dump_train_vis import sys ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, ROOT_DIR) from utils import train_utils def train_step(optimizer, model, match_loss, data, step, pre_avg_loss): data["step"] = step result = model(data, test_mode=False) loss_res = match_loss.run(data, result) optimizer.zero_grad() loss_res["total_loss"].backward() # apply reduce on all record tensor for key in loss_res.keys(): loss_res[key] = train_utils.reduce_tensor(loss_res[key], "mean") if loss_res["total_loss"] < 7 * pre_avg_loss or step < 200 or pre_avg_loss == 0: optimizer.step() unusual_loss = False else: optimizer.zero_grad() unusual_loss = True return loss_res, unusual_loss def train(model, train_loader, valid_loader, config, model_config): model.train() optimizer = optim.Adam(model.parameters(), lr=config.train_lr) if config.model_name == "SGM": match_loss = SGMLoss(config, model_config) elif config.model_name == "SG": match_loss = SGLoss(config, model_config) else: raise NotImplementedError checkpoint_path = os.path.join(config.log_base, "checkpoint.pth") config.resume = os.path.isfile(checkpoint_path) if config.resume: if config.local_rank == 0: print("==> Resuming from checkpoint..") checkpoint = torch.load( checkpoint_path, map_location="cuda:{}".format(config.local_rank) ) model.load_state_dict(checkpoint["state_dict"]) best_acc = checkpoint["best_acc"] start_step = checkpoint["step"] optimizer.load_state_dict(checkpoint["optimizer"]) else: best_acc = -1 start_step = 0 train_loader_iter = iter(train_loader) if config.local_rank == 0: writer = SummaryWriter(os.path.join(config.log_base, "log_file")) train_loader.sampler.set_epoch( start_step * config.train_batch_size // len(train_loader.dataset) ) pre_avg_loss = 0 progress_bar = ( trange(start_step, config.train_iter, ncols=config.tqdm_width) if config.local_rank == 0 else range(start_step, config.train_iter) ) for step in progress_bar: try: train_data = next(train_loader_iter) except StopIteration: if config.local_rank == 0: print( "epoch: ", step * config.train_batch_size // len(train_loader.dataset), ) train_loader.sampler.set_epoch( step * config.train_batch_size // len(train_loader.dataset) ) train_loader_iter = iter(train_loader) train_data = next(train_loader_iter) train_data = train_utils.tocuda(train_data) lr = min( config.train_lr * config.decay_rate ** (step - config.decay_iter), config.train_lr, ) for param_group in optimizer.param_groups: param_group["lr"] = lr # run training loss_res, unusual_loss = train_step( optimizer, model, match_loss, train_data, step - start_step, pre_avg_loss ) if (step - start_step) <= 200: pre_avg_loss = loss_res["total_loss"].data if (step - start_step) > 200 and not unusual_loss: pre_avg_loss = pre_avg_loss.data * 0.9 + loss_res["total_loss"].data * 0.1 if unusual_loss and config.local_rank == 0: print( "unusual loss! pre_avg_loss: ", pre_avg_loss, "cur_loss: ", loss_res["total_loss"].data, ) # log if config.local_rank == 0 and step % config.log_intv == 0 and not unusual_loss: writer.add_scalar("TotalLoss", loss_res["total_loss"], step) writer.add_scalar("CorrLoss", loss_res["loss_corr"], step) writer.add_scalar("InCorrLoss", loss_res["loss_incorr"], step) writer.add_scalar("dustbin", model.module.dustbin, step) if config.model_name == "SGM": writer.add_scalar("SeedConfLoss", loss_res["loss_seed_conf"], step) writer.add_scalar("MidCorrLoss", loss_res["loss_corr_mid"].sum(), step) writer.add_scalar( "MidInCorrLoss", loss_res["loss_incorr_mid"].sum(), step ) # valid ans save b_save = ((step + 1) % config.save_intv) == 0 b_validate = ((step + 1) % config.val_intv) == 0 if b_validate: ( total_loss, acc_corr, acc_incorr, seed_precision_tower, seed_recall_tower, acc_mid, ) = valid(valid_loader, model, match_loss, config, model_config) if config.local_rank == 0: writer.add_scalar("ValidAcc", acc_corr, step) writer.add_scalar("ValidLoss", total_loss, step) if config.model_name == "SGM": for i in range(len(seed_recall_tower)): writer.add_scalar( "seed_conf_pre_%d" % i, seed_precision_tower[i], step ) writer.add_scalar( "seed_conf_recall_%d" % i, seed_precision_tower[i], step ) for i in range(len(acc_mid)): writer.add_scalar("acc_mid%d" % i, acc_mid[i], step) print( "acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data, "seed_conf_pre: ", seed_precision_tower.mean().data, "seed_conf_recall: ", seed_recall_tower.mean().data, "acc_mid: ", acc_mid.mean().data, ) else: print("acc_corr: ", acc_corr.data, "acc_incorr: ", acc_incorr.data) # saving best if acc_corr > best_acc: print("Saving best model with va_res = {}".format(acc_corr)) best_acc = acc_corr save_dict = { "step": step + 1, "state_dict": model.state_dict(), "best_acc": best_acc, "optimizer": optimizer.state_dict(), } save_dict.update(save_dict) torch.save( save_dict, os.path.join(config.log_base, "model_best.pth") ) if b_save: if config.local_rank == 0: save_dict = { "step": step + 1, "state_dict": model.state_dict(), "best_acc": best_acc, "optimizer": optimizer.state_dict(), } torch.save(save_dict, checkpoint_path) # draw match results model.eval() with torch.no_grad(): if config.local_rank == 0: if not os.path.exists( os.path.join(config.train_vis_folder, "train_vis") ): os.mkdir(os.path.join(config.train_vis_folder, "train_vis")) if not os.path.exists( os.path.join( config.train_vis_folder, "train_vis", config.log_base ) ): os.mkdir( os.path.join( config.train_vis_folder, "train_vis", config.log_base ) ) os.mkdir( os.path.join( config.train_vis_folder, "train_vis", config.log_base, str(step), ) ) res = model(train_data) dump_train_vis(res, train_data, step, config) model.train() if config.local_rank == 0: writer.close()