Vincentqyw
fix: roma
c74a070
raw history blame
No virus
5.06 kB
import os
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.autograd import Variable
from network_v0.model import PointModel
from loss_function import KeypointLoss
class Trainer(object):
def __init__(self, config, train_loader=None):
self.config = config
# data parameters
self.train_loader = train_loader
self.num_train = len(self.train_loader)
# training parameters
self.max_epoch = config.max_epoch
self.start_epoch = config.start_epoch
self.momentum = config.momentum
self.lr = config.init_lr
self.lr_factor = config.lr_factor
self.display = config.display
# misc params
self.use_gpu = config.use_gpu
self.random_seed = config.seed
self.gpu = config.gpu
self.ckpt_dir = config.ckpt_dir
self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed)
# build model
self.model = PointModel(is_test=False)
# training on GPU
if self.use_gpu:
torch.cuda.set_device(self.gpu)
self.model.cuda()
print(
"Number of model parameters: {:,}".format(
sum([p.data.nelement() for p in self.model.parameters()])
)
)
# build loss functional
self.loss_func = KeypointLoss(config)
# build optimizer and scheduler
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=[4, 8], gamma=self.lr_factor
)
# resume
if int(self.config.start_epoch) > 0:
(
self.config.start_epoch,
self.model,
self.optimizer,
self.lr_scheduler,
) = self.load_checkpoint(
int(self.config.start_epoch),
self.model,
self.optimizer,
self.lr_scheduler,
)
def train(self):
print("\nTrain on {} samples".format(self.num_train))
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler)
for epoch in range(self.start_epoch, self.max_epoch):
print(
"\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr)
)
# train for one epoch
self.train_one_epoch(epoch)
if self.lr_scheduler:
self.lr_scheduler.step()
self.save_checkpoint(
epoch + 1, self.model, self.optimizer, self.lr_scheduler
)
def train_one_epoch(self, epoch):
self.model.train()
for (i, data) in enumerate(tqdm(self.train_loader)):
if self.use_gpu:
source_img = data["image_aug"].cuda()
target_img = data["image"].cuda()
homography = data["homography"].cuda()
source_img = Variable(source_img)
target_img = Variable(target_img)
homography = Variable(homography)
# forward propogation
output = self.model(source_img, target_img, homography)
# compute loss
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output)
# compute gradients and update
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# print training info
msg_batch = (
"Epoch:{} Iter:{} lr:{:.4f} "
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} "
"loss={:.4f} ".format(
(epoch + 1),
i,
self.lr,
loc_loss.data,
desc_loss.data,
score_loss.data,
corres_loss.data,
loss.data,
)
)
if (i % self.display) == 0:
print(msg_batch)
return
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + "_" + str(epoch) + ".pth"
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
},
os.path.join(self.ckpt_dir, filename),
)
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler):
filename = self.ckpt_name + "_" + str(epoch) + ".pth"
ckpt = torch.load(os.path.join(self.ckpt_dir, filename))
epoch = ckpt["epoch"]
model.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])
lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"]))
return epoch, model, optimizer, lr_scheduler