import os import torch import torch.optim as optim from tqdm import tqdm from torch.autograd import Variable from network_v0.model import PointModel from loss_function import KeypointLoss class Trainer(object): def __init__(self, config, train_loader=None): self.config = config # data parameters self.train_loader = train_loader self.num_train = len(self.train_loader) # training parameters self.max_epoch = config.max_epoch self.start_epoch = config.start_epoch self.momentum = config.momentum self.lr = config.init_lr self.lr_factor = config.lr_factor self.display = config.display # misc params self.use_gpu = config.use_gpu self.random_seed = config.seed self.gpu = config.gpu self.ckpt_dir = config.ckpt_dir self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed) # build model self.model = PointModel(is_test=False) # training on GPU if self.use_gpu: torch.cuda.set_device(self.gpu) self.model.cuda() print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) # build loss functional self.loss_func = KeypointLoss(config) # build optimizer and scheduler self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor) # resume if int(self.config.start_epoch) > 0: self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler) def train(self): print("\nTrain on {} samples".format(self.num_train)) self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) for epoch in range(self.start_epoch, self.max_epoch): print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr)) # train for one epoch self.train_one_epoch(epoch) if self.lr_scheduler: self.lr_scheduler.step() self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler) def train_one_epoch(self, epoch): self.model.train() for (i, data) in enumerate(tqdm(self.train_loader)): if self.use_gpu: source_img = data['image_aug'].cuda() target_img = data['image'].cuda() homography = data['homography'].cuda() source_img = Variable(source_img) target_img = Variable(target_img) homography = Variable(homography) # forward propogation output = self.model(source_img, target_img, homography) # compute loss loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) # compute gradients and update self.optimizer.zero_grad() loss.backward() self.optimizer.step() # print training info msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\ "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\ "loss={:.4f} "\ .format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data) if((i % self.display) == 0): print(msg_batch) return def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): filename = self.ckpt_name + '_' + str(epoch) + '.pth' torch.save( {'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}, os.path.join(self.ckpt_dir, filename)) def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): filename = self.ckpt_name + '_' + str(epoch) + '.pth' ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) epoch = ckpt['epoch'] model.load_state_dict(ckpt['model_state']) optimizer.load_state_dict(ckpt['optimizer_state']) lr_scheduler.load_state_dict(ckpt['lr_scheduler']) print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch'])) return epoch, model, optimizer, lr_scheduler