deepkyu's picture
initial commit
1ba3df3
raw
history blame
No virus
12.2 kB
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import importlib
import PIL.Image as Image
import models
import datasets
from evaluator.ssim import SSIM, MSSSIM
import lpips
from models.loss import GANHingeLoss
from utils import set_logger, magic_image_handler
NUM_TEST_SAVE_IMAGE = 10
class FontLightningModule(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.args = args
self.losses = {}
self.metrics = {}
self.networks = nn.ModuleDict(self.build_models())
self.module_keys = list(self.networks.keys())
self.losses = self.build_losses()
self.metrics = self.build_metrics()
self.opt_tag = {key: None for key in self.networks.keys()}
self.sched_tag = {key: None for key in self.networks.keys()}
self.sched_use = False
# self.automatic_optimization = False
self.train_d_content = True
self.train_d_style = True
def build_models(self):
networks = {}
for key, hp_model in self.args.models.items():
key_ = key.lower()
if 'g' == key_[0]:
model_ = models.Generator(hp_model)
elif 'd' == key_[0]:
model_ = models.PatchGANDiscriminator(hp_model) # TODO: add option for selecting discriminator
else:
raise ValueError(f"No key such as {key}")
networks[key.lower()] = model_
return networks
def build_losses(self):
losses_dict = {}
losses_dict['L1'] = torch.nn.L1Loss()
if 'd_content' in self.module_keys:
losses_dict['GANLoss_content'] = GANHingeLoss()
if 'd_style' in self.module_keys:
losses_dict['GANLoss_style'] = GANHingeLoss()
return losses_dict
def build_metrics(self):
metrics_dict = nn.ModuleDict()
metrics_dict['ssim'] = SSIM(val_range=1) # img value is in [0, 1]
metrics_dict['msssim'] = MSSSIM(weights=[0.45, 0.3, 0.25], val_range=1) # since imsize=64, len(weight)<=3
metrics_dict['lpips'] = lpips.LPIPS(net='vgg')
return metrics_dict
def configure_optimizers(self):
optims = {}
for key, args_model in self.args.models.items():
key = key.lower()
if args_model['optim'] is not None:
args_optim = args_model['optim']
module, cls = args_optim['class'].rsplit(".", 1)
O = getattr(importlib.import_module(module, package=None), cls)
o = O([p for p in self.networks[key].parameters() if p.requires_grad],
lr=args_optim.lr, betas=args_optim.betas)
optims[key] = o
optim_module_keys = optims.keys()
count = 0
optim_list = []
for _key in self.module_keys:
if _key in optim_module_keys:
optim_list.append(optims[_key])
self.opt_tag[_key] = count
count += 1
return optim_list
def forward(self, content_images, style_images):
return self.networks['g']((content_images, style_images))
def common_forward(self, batch, batch_idx):
loss = {}
logs = {}
content_images = batch['content_images']
style_images = batch['style_images']
gt_images = batch['gt_images']
image_paths = batch['image_paths']
char_idx = batch['char_idx']
generated_images = self(content_images, style_images)
# l1 loss
loss['g_L1'] = self.losses['L1'](generated_images, gt_images)
loss['g_backward'] = loss['g_L1'] * self.args.logging.lambda_L1
# loss for training generator
if 'd_content' in self.module_keys:
loss = self.d_content_loss_for_G(content_images, generated_images, loss)
if 'd_style' in self.networks.keys():
loss = self.d_style_loss_for_G(style_images, generated_images, loss)
# loss for training discriminator
generated_images = generated_images.detach()
if 'd_content' in self.module_keys:
if self.train_d_content:
loss = self.d_content_loss_for_D(content_images, generated_images, gt_images, loss)
if 'd_style' in self.module_keys:
if self.train_d_style:
loss = self.d_style_loss_for_D(style_images, generated_images, gt_images, loss)
logs['content_images'] = content_images
logs['style_images'] = style_images
logs['gt_images'] = gt_images
logs['generated_images'] = generated_images
return loss, logs
@property
def automatic_optimization(self):
return False
def training_step(self, batch, batch_idx):
metrics = {}
# forward
loss, logs = self.common_forward(batch, batch_idx)
if self.global_step % self.args.logging.freq['train'] == 0:
with torch.no_grad():
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
# backward
opts = self.optimizers()
opts[self.opt_tag['g']].zero_grad()
self.manual_backward(loss['g_backward'])
if 'd_content' in self.module_keys:
if self.train_d_content:
opts[self.opt_tag['d_content']].zero_grad()
self.manual_backward(loss['dcontent_backward'])
if 'd_style' in self.module_keys:
if self.train_d_style:
opts[self.opt_tag['d_style']].zero_grad()
self.manual_backward(loss['dstyle_backward'])
opts[self.opt_tag['g']].step()
if 'd_content' in self.module_keys:
if self.train_d_content:
opts[self.opt_tag['d_content']].step()
if 'd_style' in self.module_keys:
if self.train_d_style:
opts[self.opt_tag['d_style']].step()
if self.global_step % self.args.logging.freq['train'] == 0:
self.custom_log(loss, metrics, logs, mode='train')
def validation_step(self, batch, batch_idx):
metrics = {}
loss, logs = self.common_forward(batch, batch_idx)
self.custom_log(loss, metrics, logs, mode='eval')
def test_step(self, batch, batch_idx):
metrics = {}
loss, logs = self.common_forward(batch, batch_idx)
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))
if batch_idx < NUM_TEST_SAVE_IMAGE:
for key, value in logs.items():
if 'image' in key:
sample_images = (magic_image_handler(value) * 255)[..., 0].astype(np.uint8)
Image.fromarray(sample_images).save(f"{batch_idx:02d}_{key}.png")
return loss, logs, metrics
def test_epoch_end(self, test_step_outputs):
# do something with the outputs of all test batches
# all_test_preds = test_step_outputs.metrics
ssim_list = []
msssim_list = []
for _, test_output in enumerate(test_step_outputs):
ssim_list.append(test_output[2]['SSIM'].cpu().numpy())
msssim_list.append(test_output[2]['MSSSIM'].cpu().numpy())
print(f"SSIM: {np.mean(ssim_list)}")
print(f"MSSSIM: {np.mean(msssim_list)}")
def common_dataloader(self, mode='train', batch_size=None):
dataset_cls = getattr(datasets, self.args.datasets.type)
dataset_config = getattr(self.args.datasets, mode)
dataset = dataset_cls(dataset_config, mode=mode)
_batch_size = batch_size if batch_size is not None else dataset_config.batch_size
dataloader = DataLoader(dataset,
shuffle=dataset_config.shuffle,
batch_size=_batch_size,
num_workers=dataset_config.num_workers,
drop_last=True)
return dataloader
def train_dataloader(self):
return self.common_dataloader(mode='train')
def val_dataloader(self):
return self.common_dataloader(mode='eval')
def test_dataloader(self):
return self.common_dataloader(mode='eval')
def calc_metrics(self, gt_images, generated_images):
"""
:param gt_images:
:param generated_images:
:return:
"""
metrics = {}
_gt = torch.clamp(gt_images.clone(), 0, 1)
_gen = torch.clamp(generated_images.clone(), 0, 1)
metrics['SSIM'] = self.metrics['ssim'](_gt, _gen)
msssim_value = self.metrics['msssim'](_gt, _gen)
metrics['MSSSIM'] = msssim_value if not torch.isnan(msssim_value) else torch.tensor(0.).type_as(_gt)
metrics['LPIPS'] = self.metrics['lpips'](_gt * 2 - 1, _gen * 2 - 1).squeeze().mean()
return metrics
# region step
def d_content_loss_for_G(self, content_images, generated_images, loss):
pred_generated = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
loss['g_gan_content'] = self.losses['GANLoss_content'](pred_generated, True, for_discriminator=False)
loss['g_backward'] += loss['g_gan_content']
return loss
def d_content_loss_for_D(self, content_images, generated_images, gt_images, loss):
# D
if 'd_content' in self.module_keys:
if self.train_d_content:
pred_gt_images = self.networks['d_content'](torch.cat([content_images, gt_images], dim=1))
pred_generated_images = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
loss['dcontent_gt'] = self.losses['GANLoss_content'](pred_gt_images, True, for_discriminator=True)
loss['dcontent_gen'] = self.losses['GANLoss_content'](pred_generated_images, False, for_discriminator=True)
loss['dcontent_backward'] = (loss['dcontent_gt'] + loss['dcontent_gen'])
return loss
def d_style_loss_for_G(self, style_images, generated_images, loss):
pred_generated = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
loss['g_gan_style'] = self.losses['GANLoss_style'](pred_generated, True, for_discriminator=False)
assert self.train_d_style
loss['g_backward'] += loss['g_gan_style']
return loss
def d_style_loss_for_D(self, style_images, generated_images, gt_images, loss):
pred_gt_images = self.networks['d_style'](torch.cat([style_images, gt_images], dim=1))
pred_generated_images = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
loss['dstyle_gt'] = self.losses['GANLoss_style'](pred_gt_images, True, for_discriminator=True)
loss['dstyle_gen'] = self.losses['GANLoss_style'](pred_generated_images, False, for_discriminator=True)
loss['dstyle_backward'] = (loss['dstyle_gt'] + loss['dstyle_gen'])
return loss
def custom_log(self, loss, metrics, logs, mode):
# logging values with tensorboard
for loss_full_key, value in loss.items():
model_type, loss_type = loss_full_key.split('_')[0], "_".join(loss_full_key.split('_')[1:])
self.log(f'{model_type}/{mode}_{loss_type}', value)
for metric_full_key, value in metrics.items():
model_type, metric_type = metric_full_key.split('_')[0], "_".join(metric_full_key.split('_')[1:])
self.log(f'{model_type}/{mode}_{metric_type}', value)
# logging images, params, etc.
tensorboard = self.logger.experiment
for key, value in logs.items():
if 'image' in key:
sample_images = magic_image_handler(value)
tensorboard.add_image(f"{mode}/" + key, sample_images, self.global_step, dataformats='HWC')
elif 'param' in key:
tensorboard.add_histogram(f"{mode}" + key, value, self.global_step)
else:
raise RuntimeError(f"Only logging with one of keywords: image, param | current input: {key}")