import os import torch import torch.optim as optim from tqdm import tqdm from torch.autograd import Variable from network_v0.model import PointModel from loss_function import KeypointLoss class Trainer(object): def __init__(self, config, train_loader=None): self.config = config # data parameters self.train_loader = train_loader self.num_train = len(self.train_loader) # training parameters self.max_epoch = config.max_epoch self.start_epoch = config.start_epoch self.momentum = config.momentum self.lr = config.init_lr self.lr_factor = config.lr_factor self.display = config.display # misc params self.use_gpu = config.use_gpu self.random_seed = config.seed self.gpu = config.gpu self.ckpt_dir = config.ckpt_dir self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) # build model self.model = PointModel(is_test=False) # training on GPU if self.use_gpu: torch.cuda.set_device(self.gpu) self.model.cuda() print( "Number of model parameters: {:,}".format( sum([p.data.nelement() for p in self.model.parameters()]) ) ) # build loss functional self.loss_func = KeypointLoss(config) # build optimizer and scheduler self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.lr_scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=[4, 8], gamma=self.lr_factor ) # resume if int(self.config.start_epoch) > 0: ( self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler, ) = self.load_checkpoint( int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler, ) def train(self): print("\nTrain on {} samples".format(self.num_train)) self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) for epoch in range(self.start_epoch, self.max_epoch): print( "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) ) # train for one epoch self.train_one_epoch(epoch) if self.lr_scheduler: self.lr_scheduler.step() self.save_checkpoint( epoch + 1, self.model, self.optimizer, self.lr_scheduler ) def train_one_epoch(self, epoch): self.model.train() for (i, data) in enumerate(tqdm(self.train_loader)): if self.use_gpu: source_img = data["image_aug"].cuda() target_img = data["image"].cuda() homography = data["homography"].cuda() source_img = Variable(source_img) target_img = Variable(target_img) homography = Variable(homography) # forward propogation output = self.model(source_img, target_img, homography) # compute loss loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) # compute gradients and update self.optimizer.zero_grad() loss.backward() self.optimizer.step() # print training info msg_batch = ( "Epoch:{} Iter:{} lr:{:.4f} " "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " "loss={:.4f} ".format( (epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data, ) ) if (i % self.display) == 0: print(msg_batch) return def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): filename = self.ckpt_name + "_" + str(epoch) + ".pth" torch.save( { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), }, os.path.join(self.ckpt_dir, filename), ) def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): filename = self.ckpt_name + "_" + str(epoch) + ".pth" ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) epoch = ckpt["epoch"] model.load_state_dict(ckpt["model_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) return epoch, model, optimizer, lr_scheduler