Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import pytorch_lightning as pl | |
import importlib | |
import PIL.Image as Image | |
import models | |
import datasets | |
from evaluator.ssim import SSIM, MSSSIM | |
import lpips | |
from models.loss import GANHingeLoss | |
from utils import set_logger, magic_image_handler | |
NUM_TEST_SAVE_IMAGE = 10 | |
class FontLightningModule(pl.LightningModule): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.losses = {} | |
self.metrics = {} | |
self.networks = nn.ModuleDict(self.build_models()) | |
self.module_keys = list(self.networks.keys()) | |
self.losses = self.build_losses() | |
self.metrics = self.build_metrics() | |
self.opt_tag = {key: None for key in self.networks.keys()} | |
self.sched_tag = {key: None for key in self.networks.keys()} | |
self.sched_use = False | |
# self.automatic_optimization = False | |
self.train_d_content = True | |
self.train_d_style = True | |
def build_models(self): | |
networks = {} | |
for key, hp_model in self.args.models.items(): | |
key_ = key.lower() | |
if 'g' == key_[0]: | |
model_ = models.Generator(hp_model) | |
elif 'd' == key_[0]: | |
model_ = models.PatchGANDiscriminator(hp_model) # TODO: add option for selecting discriminator | |
else: | |
raise ValueError(f"No key such as {key}") | |
networks[key.lower()] = model_ | |
return networks | |
def build_losses(self): | |
losses_dict = {} | |
losses_dict['L1'] = torch.nn.L1Loss() | |
if 'd_content' in self.module_keys: | |
losses_dict['GANLoss_content'] = GANHingeLoss() | |
if 'd_style' in self.module_keys: | |
losses_dict['GANLoss_style'] = GANHingeLoss() | |
return losses_dict | |
def build_metrics(self): | |
metrics_dict = nn.ModuleDict() | |
metrics_dict['ssim'] = SSIM(val_range=1) # img value is in [0, 1] | |
metrics_dict['msssim'] = MSSSIM(weights=[0.45, 0.3, 0.25], val_range=1) # since imsize=64, len(weight)<=3 | |
metrics_dict['lpips'] = lpips.LPIPS(net='vgg') | |
return metrics_dict | |
def configure_optimizers(self): | |
optims = {} | |
for key, args_model in self.args.models.items(): | |
key = key.lower() | |
if args_model['optim'] is not None: | |
args_optim = args_model['optim'] | |
module, cls = args_optim['class'].rsplit(".", 1) | |
O = getattr(importlib.import_module(module, package=None), cls) | |
o = O([p for p in self.networks[key].parameters() if p.requires_grad], | |
lr=args_optim.lr, betas=args_optim.betas) | |
optims[key] = o | |
optim_module_keys = optims.keys() | |
count = 0 | |
optim_list = [] | |
for _key in self.module_keys: | |
if _key in optim_module_keys: | |
optim_list.append(optims[_key]) | |
self.opt_tag[_key] = count | |
count += 1 | |
return optim_list | |
def forward(self, content_images, style_images): | |
return self.networks['g']((content_images, style_images)) | |
def common_forward(self, batch, batch_idx): | |
loss = {} | |
logs = {} | |
content_images = batch['content_images'] | |
style_images = batch['style_images'] | |
gt_images = batch['gt_images'] | |
image_paths = batch['image_paths'] | |
char_idx = batch['char_idx'] | |
generated_images = self(content_images, style_images) | |
# l1 loss | |
loss['g_L1'] = self.losses['L1'](generated_images, gt_images) | |
loss['g_backward'] = loss['g_L1'] * self.args.logging.lambda_L1 | |
# loss for training generator | |
if 'd_content' in self.module_keys: | |
loss = self.d_content_loss_for_G(content_images, generated_images, loss) | |
if 'd_style' in self.networks.keys(): | |
loss = self.d_style_loss_for_G(style_images, generated_images, loss) | |
# loss for training discriminator | |
generated_images = generated_images.detach() | |
if 'd_content' in self.module_keys: | |
if self.train_d_content: | |
loss = self.d_content_loss_for_D(content_images, generated_images, gt_images, loss) | |
if 'd_style' in self.module_keys: | |
if self.train_d_style: | |
loss = self.d_style_loss_for_D(style_images, generated_images, gt_images, loss) | |
logs['content_images'] = content_images | |
logs['style_images'] = style_images | |
logs['gt_images'] = gt_images | |
logs['generated_images'] = generated_images | |
return loss, logs | |
def automatic_optimization(self): | |
return False | |
def training_step(self, batch, batch_idx): | |
metrics = {} | |
# forward | |
loss, logs = self.common_forward(batch, batch_idx) | |
if self.global_step % self.args.logging.freq['train'] == 0: | |
with torch.no_grad(): | |
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images'])) | |
# backward | |
opts = self.optimizers() | |
opts[self.opt_tag['g']].zero_grad() | |
self.manual_backward(loss['g_backward']) | |
if 'd_content' in self.module_keys: | |
if self.train_d_content: | |
opts[self.opt_tag['d_content']].zero_grad() | |
self.manual_backward(loss['dcontent_backward']) | |
if 'd_style' in self.module_keys: | |
if self.train_d_style: | |
opts[self.opt_tag['d_style']].zero_grad() | |
self.manual_backward(loss['dstyle_backward']) | |
opts[self.opt_tag['g']].step() | |
if 'd_content' in self.module_keys: | |
if self.train_d_content: | |
opts[self.opt_tag['d_content']].step() | |
if 'd_style' in self.module_keys: | |
if self.train_d_style: | |
opts[self.opt_tag['d_style']].step() | |
if self.global_step % self.args.logging.freq['train'] == 0: | |
self.custom_log(loss, metrics, logs, mode='train') | |
def validation_step(self, batch, batch_idx): | |
metrics = {} | |
loss, logs = self.common_forward(batch, batch_idx) | |
self.custom_log(loss, metrics, logs, mode='eval') | |
def test_step(self, batch, batch_idx): | |
metrics = {} | |
loss, logs = self.common_forward(batch, batch_idx) | |
metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images'])) | |
if batch_idx < NUM_TEST_SAVE_IMAGE: | |
for key, value in logs.items(): | |
if 'image' in key: | |
sample_images = (magic_image_handler(value) * 255)[..., 0].astype(np.uint8) | |
Image.fromarray(sample_images).save(f"{batch_idx:02d}_{key}.png") | |
return loss, logs, metrics | |
def test_epoch_end(self, test_step_outputs): | |
# do something with the outputs of all test batches | |
# all_test_preds = test_step_outputs.metrics | |
ssim_list = [] | |
msssim_list = [] | |
for _, test_output in enumerate(test_step_outputs): | |
ssim_list.append(test_output[2]['SSIM'].cpu().numpy()) | |
msssim_list.append(test_output[2]['MSSSIM'].cpu().numpy()) | |
print(f"SSIM: {np.mean(ssim_list)}") | |
print(f"MSSSIM: {np.mean(msssim_list)}") | |
def common_dataloader(self, mode='train', batch_size=None): | |
dataset_cls = getattr(datasets, self.args.datasets.type) | |
dataset_config = getattr(self.args.datasets, mode) | |
dataset = dataset_cls(dataset_config, mode=mode) | |
_batch_size = batch_size if batch_size is not None else dataset_config.batch_size | |
dataloader = DataLoader(dataset, | |
shuffle=dataset_config.shuffle, | |
batch_size=_batch_size, | |
num_workers=dataset_config.num_workers, | |
drop_last=True) | |
return dataloader | |
def train_dataloader(self): | |
return self.common_dataloader(mode='train') | |
def val_dataloader(self): | |
return self.common_dataloader(mode='eval') | |
def test_dataloader(self): | |
return self.common_dataloader(mode='eval') | |
def calc_metrics(self, gt_images, generated_images): | |
""" | |
:param gt_images: | |
:param generated_images: | |
:return: | |
""" | |
metrics = {} | |
_gt = torch.clamp(gt_images.clone(), 0, 1) | |
_gen = torch.clamp(generated_images.clone(), 0, 1) | |
metrics['SSIM'] = self.metrics['ssim'](_gt, _gen) | |
msssim_value = self.metrics['msssim'](_gt, _gen) | |
metrics['MSSSIM'] = msssim_value if not torch.isnan(msssim_value) else torch.tensor(0.).type_as(_gt) | |
metrics['LPIPS'] = self.metrics['lpips'](_gt * 2 - 1, _gen * 2 - 1).squeeze().mean() | |
return metrics | |
# region step | |
def d_content_loss_for_G(self, content_images, generated_images, loss): | |
pred_generated = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1)) | |
loss['g_gan_content'] = self.losses['GANLoss_content'](pred_generated, True, for_discriminator=False) | |
loss['g_backward'] += loss['g_gan_content'] | |
return loss | |
def d_content_loss_for_D(self, content_images, generated_images, gt_images, loss): | |
# D | |
if 'd_content' in self.module_keys: | |
if self.train_d_content: | |
pred_gt_images = self.networks['d_content'](torch.cat([content_images, gt_images], dim=1)) | |
pred_generated_images = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1)) | |
loss['dcontent_gt'] = self.losses['GANLoss_content'](pred_gt_images, True, for_discriminator=True) | |
loss['dcontent_gen'] = self.losses['GANLoss_content'](pred_generated_images, False, for_discriminator=True) | |
loss['dcontent_backward'] = (loss['dcontent_gt'] + loss['dcontent_gen']) | |
return loss | |
def d_style_loss_for_G(self, style_images, generated_images, loss): | |
pred_generated = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1)) | |
loss['g_gan_style'] = self.losses['GANLoss_style'](pred_generated, True, for_discriminator=False) | |
assert self.train_d_style | |
loss['g_backward'] += loss['g_gan_style'] | |
return loss | |
def d_style_loss_for_D(self, style_images, generated_images, gt_images, loss): | |
pred_gt_images = self.networks['d_style'](torch.cat([style_images, gt_images], dim=1)) | |
pred_generated_images = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1)) | |
loss['dstyle_gt'] = self.losses['GANLoss_style'](pred_gt_images, True, for_discriminator=True) | |
loss['dstyle_gen'] = self.losses['GANLoss_style'](pred_generated_images, False, for_discriminator=True) | |
loss['dstyle_backward'] = (loss['dstyle_gt'] + loss['dstyle_gen']) | |
return loss | |
def custom_log(self, loss, metrics, logs, mode): | |
# logging values with tensorboard | |
for loss_full_key, value in loss.items(): | |
model_type, loss_type = loss_full_key.split('_')[0], "_".join(loss_full_key.split('_')[1:]) | |
self.log(f'{model_type}/{mode}_{loss_type}', value) | |
for metric_full_key, value in metrics.items(): | |
model_type, metric_type = metric_full_key.split('_')[0], "_".join(metric_full_key.split('_')[1:]) | |
self.log(f'{model_type}/{mode}_{metric_type}', value) | |
# logging images, params, etc. | |
tensorboard = self.logger.experiment | |
for key, value in logs.items(): | |
if 'image' in key: | |
sample_images = magic_image_handler(value) | |
tensorboard.add_image(f"{mode}/" + key, sample_images, self.global_step, dataformats='HWC') | |
elif 'param' in key: | |
tensorboard.add_histogram(f"{mode}" + key, value, self.global_step) | |
else: | |
raise RuntimeError(f"Only logging with one of keywords: image, param | current input: {key}") | |