Spaces:
Runtime error
Runtime error
from lib.net import NormalNet | |
from lib.common.train_util import * | |
import logging | |
import torch | |
import numpy as np | |
from torch import nn | |
from skimage.transform import resize | |
import pytorch_lightning as pl | |
torch.backends.cudnn.benchmark = True | |
logging.getLogger("lightning").setLevel(logging.ERROR) | |
class Normal(pl.LightningModule): | |
def __init__(self, cfg): | |
super(Normal, self).__init__() | |
self.cfg = cfg | |
self.batch_size = self.cfg.batch_size | |
self.lr_N = self.cfg.lr_N | |
self.schedulers = [] | |
self.netG = NormalNet(self.cfg, error_term=nn.SmoothL1Loss()) | |
self.in_nml = [item[0] for item in cfg.net.in_nml] | |
def get_progress_bar_dict(self): | |
tqdm_dict = super().get_progress_bar_dict() | |
if "v_num" in tqdm_dict: | |
del tqdm_dict["v_num"] | |
return tqdm_dict | |
# Training related | |
def configure_optimizers(self): | |
# set optimizer | |
weight_decay = self.cfg.weight_decay | |
momentum = self.cfg.momentum | |
optim_params_N_F = [ | |
{"params": self.netG.netF.parameters(), "lr": self.lr_N}] | |
optim_params_N_B = [ | |
{"params": self.netG.netB.parameters(), "lr": self.lr_N}] | |
optimizer_N_F = torch.optim.Adam( | |
optim_params_N_F, lr=self.lr_N, weight_decay=weight_decay | |
) | |
optimizer_N_B = torch.optim.Adam( | |
optim_params_N_B, lr=self.lr_N, weight_decay=weight_decay | |
) | |
scheduler_N_F = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer_N_F, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
) | |
scheduler_N_B = torch.optim.lr_scheduler.MultiStepLR( | |
optimizer_N_B, milestones=self.cfg.schedule, gamma=self.cfg.gamma | |
) | |
self.schedulers = [scheduler_N_F, scheduler_N_B] | |
optims = [optimizer_N_F, optimizer_N_B] | |
return optims, self.schedulers | |
def render_func(self, render_tensor): | |
height = render_tensor["image"].shape[2] | |
result_list = [] | |
for name in render_tensor.keys(): | |
result_list.append( | |
resize( | |
((render_tensor[name].cpu().numpy()[0] + 1.0) / 2.0).transpose( | |
1, 2, 0 | |
), | |
(height, height), | |
anti_aliasing=True, | |
) | |
) | |
result_array = np.concatenate(result_list, axis=1) | |
return result_array | |
def training_step(self, batch, batch_idx, optimizer_idx): | |
export_cfg(self.logger, self.cfg) | |
# retrieve the data | |
in_tensor = {} | |
for name in self.in_nml: | |
in_tensor[name] = batch[name] | |
FB_tensor = {"normal_F": batch["normal_F"], | |
"normal_B": batch["normal_B"]} | |
self.netG.train() | |
preds_F, preds_B = self.netG(in_tensor) | |
error_NF, error_NB = self.netG.get_norm_error( | |
preds_F, preds_B, FB_tensor) | |
(opt_nf, opt_nb) = self.optimizers() | |
opt_nf.zero_grad() | |
opt_nb.zero_grad() | |
self.manual_backward(error_NF, opt_nf) | |
self.manual_backward(error_NB, opt_nb) | |
opt_nf.step() | |
opt_nb.step() | |
if batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0: | |
self.netG.eval() | |
with torch.no_grad(): | |
nmlF, nmlB = self.netG(in_tensor) | |
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
result_array = self.render_func(in_tensor) | |
self.logger.experiment.add_image( | |
tag=f"Normal-train/{self.global_step}", | |
img_tensor=result_array.transpose(2, 0, 1), | |
global_step=self.global_step, | |
) | |
# metrics processing | |
metrics_log = { | |
"train_loss-NF": error_NF.item(), | |
"train_loss-NB": error_NB.item(), | |
} | |
tf_log = tf_log_convert(metrics_log) | |
bar_log = bar_log_convert(metrics_log) | |
return { | |
"loss": error_NF + error_NB, | |
"loss-NF": error_NF, | |
"loss-NB": error_NB, | |
"log": tf_log, | |
"progress_bar": bar_log, | |
} | |
def training_epoch_end(self, outputs): | |
if [] in outputs: | |
outputs = outputs[0] | |
# metrics processing | |
metrics_log = { | |
"train_avgloss": batch_mean(outputs, "loss"), | |
"train_avgloss-NF": batch_mean(outputs, "loss-NF"), | |
"train_avgloss-NB": batch_mean(outputs, "loss-NB"), | |
} | |
tf_log = tf_log_convert(metrics_log) | |
tf_log["lr-NF"] = self.schedulers[0].get_last_lr()[0] | |
tf_log["lr-NB"] = self.schedulers[1].get_last_lr()[0] | |
return {"log": tf_log} | |
def validation_step(self, batch, batch_idx): | |
# retrieve the data | |
in_tensor = {} | |
for name in self.in_nml: | |
in_tensor[name] = batch[name] | |
FB_tensor = {"normal_F": batch["normal_F"], | |
"normal_B": batch["normal_B"]} | |
self.netG.train() | |
preds_F, preds_B = self.netG(in_tensor) | |
error_NF, error_NB = self.netG.get_norm_error( | |
preds_F, preds_B, FB_tensor) | |
if (batch_idx > 0 and batch_idx % int(self.cfg.freq_show_train) == 0) or ( | |
batch_idx == 0 | |
): | |
with torch.no_grad(): | |
nmlF, nmlB = self.netG(in_tensor) | |
in_tensor.update({"nmlF": nmlF, "nmlB": nmlB}) | |
result_array = self.render_func(in_tensor) | |
self.logger.experiment.add_image( | |
tag=f"Normal-val/{self.global_step}", | |
img_tensor=result_array.transpose(2, 0, 1), | |
global_step=self.global_step, | |
) | |
return { | |
"val_loss": error_NF + error_NB, | |
"val_loss-NF": error_NF, | |
"val_loss-NB": error_NB, | |
} | |
def validation_epoch_end(self, outputs): | |
# metrics processing | |
metrics_log = { | |
"val_avgloss": batch_mean(outputs, "val_loss"), | |
"val_avgloss-NF": batch_mean(outputs, "val_loss-NF"), | |
"val_avgloss-NB": batch_mean(outputs, "val_loss-NB"), | |
} | |
tf_log = tf_log_convert(metrics_log) | |
return {"log": tf_log} | |