from lib.net import NormalNet from lib.common.train_util import * import logging import torch import numpy as np from torch import nn from skimage.transform import resize import pytorch_lightning as pl torch.backends.cudnn.benchmark = True logging.getLogger("lightning").setLevel(logging.ERROR) class Normal(pl.LightningModule): def __init__(self, cfg): super(Normal, self).__init__() self.cfg = cfg self.batch_size = self.cfg.batch_size self.lr_N = self.cfg.lr_N self.schedulers = [] self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss()) self.in_nml = [item[0] for item in cfg.net.in_nml] def get_progress_bar_dict(self): tqdm_dict = super().get_progress_bar_dict() if "v_num" in tqdm_dict: del tqdm_dict["v_num"] return tqdm_dict # Training related def configure_optimizers(self): # set optimizer weight_decay = self.cfg.weight_decay momentum = self.cfg.momentum optim_params_N_F = [ {"params": self.netG.netF.parameters(), "lr": self.lr_N}] optim_params_N_B = [ {"params": self.netG.netB.parameters(), "lr": self.lr_N}] optimizer_N_F = torch.optim.Adam( optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay ) optimizer_N_B = torch.optim.Adam( optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay ) scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma ) scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma ) self.schedulers = [scheduler_N_F, scheduler_N_B] optims = [optimizer_N_F, optimizer_N_B] return optims, self.schedulers def render_func(self, render_tensor): height = render_tensor["image"].shape[2] result_list = [] for name in render_tensor.keys(): result_list.append( resize( ((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose( 1, 2, 0 ), (height, height), anti_aliasing=True, ) ) result_array = np.concatenate(result_list, axis=1) return result_array def training_step(self, batch, batch_idx, optimizer_idx): export_cfg(self.logger, self.cfg) # retrieve the data in_tensor = {} for name in self.in_nml: in_tensor[name] = batch[name] FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]} self.netG.train() preds_F, preds_B = self.netG(in_tensor) error_NF, error_NB = self.netG.get_norm_error( preds_F, preds_B, FB_tensor) (opt_nf, opt_nb) = self.optimizers() opt_nf.zero_grad() opt_nb.zero_grad() self.manual_backward(error_NF, opt_nf) self.manual_backward(error_NB, opt_nb) opt_nf.step() opt_nb.step() if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0: self.netG.eval() with torch.no_grad(): nmlF, nmlB = self.netG(in_tensor) in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) result_array = self.render_func(in_tensor) self.logger.experiment.add_image( tag=f"Normal-train/{self.global_step}", img_tensor=result_array.transpose(2, 0, 1), global_step=self.global_step, ) # metrics processing metrics_log = { "train_loss-NF": error_NF.item(), "train_loss-NB": error_NB.item(), } tf_log = tf_log_convert(metrics_log) bar_log = bar_log_convert(metrics_log) return { "loss": error_NF + error_NB, "loss-NF": error_NF, "loss-NB": error_NB, "log": tf_log, "progress_bar": bar_log, } def training_epoch_end(self, outputs): if [] in outputs: outputs = outputs[0] # metrics processing metrics_log = { "train_avgloss": batch_mean(outputs, "loss"), "train_avgloss-NF": batch_mean(outputs, "loss-NF"), "train_avgloss-NB": batch_mean(outputs, "loss-NB"), } tf_log = tf_log_convert(metrics_log) tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0] tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0] return {"log": tf_log} def validation_step(self, batch, batch_idx): # retrieve the data in_tensor = {} for name in self.in_nml: in_tensor[name] = batch[name] FB_tensor = {"normal_F": batch["normal_F"], "normal_B": batch["normal_B"]} self.netG.train() preds_F, preds_B = self.netG(in_tensor) error_NF, error_NB = self.netG.get_norm_error( preds_F, preds_B, FB_tensor) if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or ( batch_idx == 0 ): with torch.no_grad(): nmlF, nmlB = self.netG(in_tensor) in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) result_array = self.render_func(in_tensor) self.logger.experiment.add_image( tag=f"Normal-val/{self.global_step}", img_tensor=result_array.transpose(2, 0, 1), global_step=self.global_step, ) return { "val_loss": error_NF + error_NB, "val_loss-NF": error_NF, "val_loss-NB": error_NB, } def validation_epoch_end(self, outputs): # metrics processing metrics_log = { "val_avgloss": batch_mean(outputs, "val_loss"), "val_avgloss-NF": batch_mean(outputs, "val_loss-NF"), "val_avgloss-NB": batch_mean(outputs, "val_loss-NB"), } tf_log = tf_log_convert(metrics_log) return {"log": tf_log}