Vincentqyw
update: features and matchers
a80d6bb
raw
history blame
No virus
4.74 kB
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable
from network_v0.model import PointModel
from loss_function import KeypointLoss
class Trainer(object):
def __init__(self, config, train_loader=None):
self.config = config
# data parameters
self.train_loader = train_loader
self.num_train = len(self.train_loader)
# training parameters
self.max_epoch = config.max_epoch
self.start_epoch = config.start_epoch
self.momentum = config.momentum
self.lr = config.init_lr
self.lr_factor = config.lr_factor
self.display = config.display
# misc params
self.use_gpu = config.use_gpu
self.random_seed = config.seed
self.gpu = config.gpu
self.ckpt_dir = config.ckpt_dir
self.ckpt_name = '{}-{}'.format(config.ckpt_name, config.seed)
# build model
self.model = PointModel(is_test=False)
# training on GPU
if self.use_gpu:
torch.cuda.set_device(self.gpu)
self.model.cuda()
print('Number of model parameters: {:,}'.format(sum([p.data.nelement() for p in self.model.parameters()])))
# build loss functional
self.loss_func = KeypointLoss(config)
# build optimizer and scheduler
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[4, 8], gamma=self.lr_factor)
# resume
if int(self.config.start_epoch) > 0:
self.config.start_epoch, self.model, self.optimizer, self.lr_scheduler = self.load_checkpoint(int(self.config.start_epoch), self.model, self.optimizer, self.lr_scheduler)
def train(self):
print("\nTrain on {} samples".format(self.num_train))
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
for epoch in range(self.start_epoch, self.max_epoch):
print("\nEpoch: {}/{} --lr: {:.6f}".format(epoch+1, self.max_epoch, self.lr))
# train for one epoch
self.train_one_epoch(epoch)
if self.lr_scheduler:
self.lr_scheduler.step()
self.save_checkpoint(epoch+1, self.model, self.optimizer, self.lr_scheduler)
def train_one_epoch(self, epoch):
self.model.train()
for (i, data) in enumerate(tqdm(self.train_loader)):
if self.use_gpu:
source_img = data['image_aug'].cuda()
target_img = data['image'].cuda()
homography = data['homography'].cuda()
source_img = Variable(source_img)
target_img = Variable(target_img)
homography = Variable(homography)
# forward propogation
output = self.model(source_img, target_img, homography)
# compute loss
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)
# compute gradients and update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# print training info
msg_batch = "Epoch:{} Iter:{} lr:{:.4f} "\
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "\
"loss={:.4f} "\
.format((epoch + 1), i, self.lr, loc_loss.data, desc_loss.data, score_loss.data, corres_loss.data, loss.data)
if((i % self.display) == 0):
print(msg_batch)
return
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + '_' + str(epoch) + '.pth'
torch.save(
{'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict()},
os.path.join(self.ckpt_dir, filename))
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + '_' + str(epoch) + '.pth'
ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
epoch = ckpt['epoch']
model.load_state_dict(ckpt['model_state'])
optimizer.load_state_dict(ckpt['optimizer_state'])
lr_scheduler.load_state_dict(ckpt['lr_scheduler'])
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt['epoch']))
return epoch, model, optimizer, lr_scheduler