Spaces:
Running
Running
| import os | |
| import torch | |
| import torch.optim as optim | |
| from tqdm import tqdm | |
| from torch.autograd import Variable | |
| from network_v0.model import PointModel | |
| from loss_function import KeypointLoss | |
| class Trainer(object): | |
| def __init__(self, config, train_loader=None): | |
| self.config = config | |
| # data parameters | |
| self.train_loader = train_loader | |
| self.num_train = len(self.train_loader) | |
| # training parameters | |
| self.max_epoch = config.max_epoch | |
| self.start_epoch = config.start_epoch | |
| self.momentum = config.momentum | |
| self.lr = config.init_lr | |
| self.lr_factor = config.lr_factor | |
| self.display = config.display | |
| # misc params | |
| self.use_gpu = config.use_gpu | |
| self.random_seed = config.seed | |
| self.gpu = config.gpu | |
| self.ckpt_dir = config.ckpt_dir | |
| self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) | |
| # build model | |
| self.model = PointModel(is_test=False) | |
| # training on GPU | |
| if self.use_gpu: | |
| torch.cuda.set_device(self.gpu) | |
| self.model.cuda() | |
| print( | |
| "Number of model parameters: {:,}".format( | |
| sum([p.data.nelement() for p in self.model.parameters()]) | |
| ) | |
| ) | |
| # build loss functional | |
| self.loss_func = KeypointLoss(config) | |
| # build optimizer and scheduler | |
| self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) | |
| self.lr_scheduler = optim.lr_scheduler.MultiStepLR( | |
| self.optimizer, milestones=[4, 8], gamma=self.lr_factor | |
| ) | |
| # resume | |
| if int(self.config.start_epoch) > 0: | |
| ( | |
| self.config.start_epoch, | |
| self.model, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| ) = self.load_checkpoint( | |
| int(self.config.start_epoch), | |
| self.model, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| ) | |
| def train(self): | |
| print("\nTrain on {} samples".format(self.num_train)) | |
| self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) | |
| for epoch in range(self.start_epoch, self.max_epoch): | |
| print( | |
| "\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) | |
| ) | |
| # train for one epoch | |
| self.train_one_epoch(epoch) | |
| if self.lr_scheduler: | |
| self.lr_scheduler.step() | |
| self.save_checkpoint( | |
| epoch + 1, self.model, self.optimizer, self.lr_scheduler | |
| ) | |
| def train_one_epoch(self, epoch): | |
| self.model.train() | |
| for (i, data) in enumerate(tqdm(self.train_loader)): | |
| if self.use_gpu: | |
| source_img = data["image_aug"].cuda() | |
| target_img = data["image"].cuda() | |
| homography = data["homography"].cuda() | |
| source_img = Variable(source_img) | |
| target_img = Variable(target_img) | |
| homography = Variable(homography) | |
| # forward propogation | |
| output = self.model(source_img, target_img, homography) | |
| # compute loss | |
| loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) | |
| # compute gradients and update | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| # print training info | |
| msg_batch = ( | |
| "Epoch:{} Iter:{} lr:{:.4f} " | |
| "loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " | |
| "loss={:.4f} ".format( | |
| (epoch + 1), | |
| i, | |
| self.lr, | |
| loc_loss.data, | |
| desc_loss.data, | |
| score_loss.data, | |
| corres_loss.data, | |
| loss.data, | |
| ) | |
| ) | |
| if (i % self.display) == 0: | |
| print(msg_batch) | |
| return | |
| def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
| filename = self.ckpt_name + "_" + str(epoch) + ".pth" | |
| torch.save( | |
| { | |
| "epoch": epoch, | |
| "model_state": model.state_dict(), | |
| "optimizer_state": optimizer.state_dict(), | |
| "lr_scheduler": lr_scheduler.state_dict(), | |
| }, | |
| os.path.join(self.ckpt_dir, filename), | |
| ) | |
| def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
| filename = self.ckpt_name + "_" + str(epoch) + ".pth" | |
| ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) | |
| epoch = ckpt["epoch"] | |
| model.load_state_dict(ckpt["model_state"]) | |
| optimizer.load_state_dict(ckpt["optimizer_state"]) | |
| lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) | |
| print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) | |
| return epoch, model, optimizer, lr_scheduler | |