Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
from imaginaire.evaluation import compute_fid | |
from imaginaire.losses import (GANLoss, GaussianKLLoss, | |
PerceptualLoss) | |
from imaginaire.trainers.base import BaseTrainer | |
from imaginaire.utils.misc import random_shift | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.diff_aug import apply_diff_aug | |
class Trainer(BaseTrainer): | |
r"""Reimplementation of the MUNIT (https://arxiv.org/abs/1804.04732) | |
algorithm. | |
Args: | |
cfg (obj): Global configuration. | |
net_G (obj): Generator network. | |
net_D (obj): Discriminator network. | |
opt_G (obj): Optimizer for the generator network. | |
opt_D (obj): Optimizer for the discriminator network. | |
sch_G (obj): Scheduler for the generator optimizer. | |
sch_D (obj): Scheduler for the discriminator optimizer. | |
train_data_loader (obj): Train data loader. | |
val_data_loader (obj): Validation data loader. | |
""" | |
def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader): | |
super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader) | |
self.gan_recon = getattr(cfg.trainer, 'gan_recon', False) | |
self.best_fid_a = None | |
self.best_fid_b = None | |
def _init_loss(self, cfg): | |
r"""Initialize loss terms. In MUNIT, we have several loss terms | |
including the GAN loss, the image reconstruction loss, the content | |
reconstruction loss, the style reconstruction loss, the cycle | |
reconstruction loss. We also have an optional perceptual loss. A user | |
can choose to have gradient penalty or consistency regularization too. | |
Args: | |
cfg (obj): Global configuration. | |
""" | |
self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) | |
self.criteria['kl'] = GaussianKLLoss() | |
self.criteria['image_recon'] = torch.nn.L1Loss() | |
if getattr(cfg.trainer.loss_weight, 'perceptual', 0) > 0: | |
self.criteria['perceptual'] = \ | |
PerceptualLoss(network=cfg.trainer.perceptual_mode, | |
layers=cfg.trainer.perceptual_layers) | |
for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): | |
if loss_weight > 0: | |
self.weights[loss_name] = loss_weight | |
def gen_forward(self, data): | |
r"""Compute the loss for MUNIT generator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
cycle_recon = 'cycle_recon' in self.weights | |
image_recon = 'image_recon' in self.weights | |
perceptual = 'perceptual' in self.weights | |
within_latent_recon = 'style_recon_within' in self.weights or \ | |
'content_recon_within' in self.weights | |
net_G_output = self.net_G(data, | |
image_recon=image_recon, | |
cycle_recon=cycle_recon, | |
within_latent_recon=within_latent_recon) | |
# Differentiable augmentation. | |
keys = ['images_ab', 'images_ba'] | |
if self.gan_recon: | |
keys += ['images_aa', 'images_bb'] | |
net_D_output = self.net_D(data, | |
apply_diff_aug( | |
net_G_output, keys, self.aug_policy), | |
real=False, | |
gan_recon=self.gan_recon) | |
self._time_before_loss() | |
# GAN loss | |
if self.gan_recon: | |
self.gen_losses['gan_a'] = \ | |
0.5 * (self.criteria['gan'](net_D_output['out_ba'], | |
True, dis_update=False) + | |
self.criteria['gan'](net_D_output['out_aa'], | |
True, dis_update=False)) | |
self.gen_losses['gan_b'] = \ | |
0.5 * (self.criteria['gan'](net_D_output['out_ab'], | |
True, dis_update=False) + | |
self.criteria['gan'](net_D_output['out_bb'], | |
True, dis_update=False)) | |
else: | |
self.gen_losses['gan_a'] = self.criteria['gan']( | |
net_D_output['out_ba'], True, dis_update=False) | |
self.gen_losses['gan_b'] = self.criteria['gan']( | |
net_D_output['out_ab'], True, dis_update=False) | |
self.gen_losses['gan'] = \ | |
self.gen_losses['gan_a'] + self.gen_losses['gan_b'] | |
# Perceptual loss | |
if perceptual: | |
self.gen_losses['perceptual_a'] = \ | |
self.criteria['perceptual'](net_G_output['images_ab'], | |
data['images_a']) | |
self.gen_losses['perceptual_b'] = \ | |
self.criteria['perceptual'](net_G_output['images_ba'], | |
data['images_b']) | |
self.gen_losses['perceptual'] = \ | |
self.gen_losses['perceptual_a'] + \ | |
self.gen_losses['perceptual_b'] | |
# Image reconstruction loss | |
if image_recon: | |
self.gen_losses['image_recon'] = \ | |
self.criteria['image_recon'](net_G_output['images_aa'], | |
data['images_a']) + \ | |
self.criteria['image_recon'](net_G_output['images_bb'], | |
data['images_b']) | |
# Style reconstruction loss | |
self.gen_losses['style_recon_a'] = torch.abs( | |
net_G_output['style_ba'] - | |
net_G_output['style_a_rand']).mean() | |
self.gen_losses['style_recon_b'] = torch.abs( | |
net_G_output['style_ab'] - | |
net_G_output['style_b_rand']).mean() | |
self.gen_losses['style_recon'] = \ | |
self.gen_losses['style_recon_a'] + self.gen_losses['style_recon_b'] | |
if within_latent_recon: | |
self.gen_losses['style_recon_aa'] = torch.abs( | |
net_G_output['style_aa'] - | |
net_G_output['style_a'].detach()).mean() | |
self.gen_losses['style_recon_bb'] = torch.abs( | |
net_G_output['style_bb'] - | |
net_G_output['style_b'].detach()).mean() | |
self.gen_losses['style_recon_within'] = \ | |
self.gen_losses['style_recon_aa'] + \ | |
self.gen_losses['style_recon_bb'] | |
# Content reconstruction loss | |
self.gen_losses['content_recon_a'] = torch.abs( | |
net_G_output['content_ab'] - | |
net_G_output['content_a'].detach()).mean() | |
self.gen_losses['content_recon_b'] = torch.abs( | |
net_G_output['content_ba'] - | |
net_G_output['content_b'].detach()).mean() | |
self.gen_losses['content_recon'] = \ | |
self.gen_losses['content_recon_a'] + \ | |
self.gen_losses['content_recon_b'] | |
if within_latent_recon: | |
self.gen_losses['content_recon_aa'] = torch.abs( | |
net_G_output['content_aa'] - | |
net_G_output['content_a'].detach()).mean() | |
self.gen_losses['content_recon_bb'] = torch.abs( | |
net_G_output['content_bb'] - | |
net_G_output['content_b'].detach()).mean() | |
self.gen_losses['content_recon_within'] = \ | |
self.gen_losses['content_recon_aa'] + \ | |
self.gen_losses['content_recon_bb'] | |
# KL loss | |
self.gen_losses['kl'] = \ | |
self.criteria['kl'](net_G_output['style_a']) + \ | |
self.criteria['kl'](net_G_output['style_b']) | |
# Cycle reconstruction loss | |
if cycle_recon: | |
self.gen_losses['cycle_recon'] = \ | |
torch.abs(net_G_output['images_aba'] - | |
data['images_a']).mean() + \ | |
torch.abs(net_G_output['images_bab'] - | |
data['images_b']).mean() | |
# Compute total loss | |
total_loss = self._get_total_loss(gen_forward=True) | |
return total_loss | |
def dis_forward(self, data): | |
r"""Compute the loss for MUNIT discriminator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
with torch.no_grad(): | |
net_G_output = self.net_G(data, | |
image_recon=self.gan_recon, | |
latent_recon=False, | |
cycle_recon=False, | |
within_latent_recon=False) | |
net_G_output['images_ba'].requires_grad = True | |
net_G_output['images_ab'].requires_grad = True | |
# Differentiable augmentation. | |
keys_fake = ['images_ab', 'images_ba'] | |
if self.gan_recon: | |
keys_fake += ['images_aa', 'images_bb'] | |
keys_real = ['images_a', 'images_b'] | |
net_D_output = self.net_D( | |
apply_diff_aug(data, keys_real, self.aug_policy), | |
apply_diff_aug(net_G_output, keys_fake, self.aug_policy), | |
gan_recon=self.gan_recon) | |
self._time_before_loss() | |
# GAN loss. | |
self.dis_losses['gan_a'] = \ | |
self.criteria['gan'](net_D_output['out_a'], True) + \ | |
self.criteria['gan'](net_D_output['out_ba'], False) | |
self.dis_losses['gan_b'] = \ | |
self.criteria['gan'](net_D_output['out_b'], True) + \ | |
self.criteria['gan'](net_D_output['out_ab'], False) | |
self.dis_losses['gan'] = \ | |
self.dis_losses['gan_a'] + self.dis_losses['gan_b'] | |
# Consistency regularization. | |
self.dis_losses['consistency_reg'] = \ | |
torch.tensor(0., device=torch.device('cuda')) | |
if 'consistency_reg' in self.weights: | |
data_aug, net_G_output_aug = {}, {} | |
data_aug['images_a'] = random_shift(data['images_a'].flip(-1)) | |
data_aug['images_b'] = random_shift(data['images_b'].flip(-1)) | |
net_G_output_aug['images_ab'] = \ | |
random_shift(net_G_output['images_ab'].flip(-1)) | |
net_G_output_aug['images_ba'] = \ | |
random_shift(net_G_output['images_ba'].flip(-1)) | |
net_D_output_aug = self.net_D(data_aug, net_G_output_aug) | |
feature_names = ['fea_ba', 'fea_ab', | |
'fea_a', 'fea_b'] | |
for feature_name in feature_names: | |
self.dis_losses['consistency_reg'] += \ | |
torch.pow(net_D_output_aug[feature_name] - | |
net_D_output[feature_name], 2).mean() | |
# Compute total loss | |
total_loss = self._get_total_loss(gen_forward=False) | |
return total_loss | |
def _get_visualizations(self, data): | |
r"""Compute visualization image. | |
Args: | |
data (dict): The current batch. | |
""" | |
if self.cfg.trainer.model_average_config.enabled: | |
net_G_for_evaluation = self.net_G.module.averaged_model | |
else: | |
net_G_for_evaluation = self.net_G | |
with torch.no_grad(): | |
net_G_output = net_G_for_evaluation(data, random_style=False) | |
net_G_output_random = net_G_for_evaluation(data) | |
vis_images = [data['images_a'], | |
data['images_b'], | |
net_G_output['images_aa'], | |
net_G_output['images_bb'], | |
net_G_output['images_ab'], | |
net_G_output_random['images_ab'], | |
net_G_output['images_ba'], | |
net_G_output_random['images_ba'], | |
net_G_output['images_aba'], | |
net_G_output['images_bab']] | |
return vis_images | |
def write_metrics(self): | |
r"""Compute metrics and save them to tensorboard""" | |
cur_fid_a, cur_fid_b = self._compute_fid() | |
if self.best_fid_a is not None: | |
self.best_fid_a = min(self.best_fid_a, cur_fid_a) | |
else: | |
self.best_fid_a = cur_fid_a | |
if self.best_fid_b is not None: | |
self.best_fid_b = min(self.best_fid_b, cur_fid_b) | |
else: | |
self.best_fid_b = cur_fid_b | |
self._write_to_meters({'FID_a': cur_fid_a, | |
'best_FID_a': self.best_fid_a, | |
'FID_b': cur_fid_b, | |
'best_FID_b': self.best_fid_b}, | |
self.metric_meters) | |
self._flush_meters(self.metric_meters) | |
def _compute_fid(self): | |
r"""Compute FID for both domains. | |
""" | |
self.net_G.eval() | |
if self.cfg.trainer.model_average_config.enabled: | |
net_G_for_evaluation = self.net_G.module.averaged_model | |
else: | |
net_G_for_evaluation = self.net_G | |
fid_a_path = self._get_save_path('fid_a', 'npy') | |
fid_b_path = self._get_save_path('fid_b', 'npy') | |
fid_value_a = compute_fid(fid_a_path, self.val_data_loader, | |
net_G_for_evaluation, 'images_a', 'images_ba') | |
fid_value_b = compute_fid(fid_b_path, self.val_data_loader, | |
net_G_for_evaluation, 'images_b', 'images_ab') | |
print('Epoch {:05}, Iteration {:09}, FID a {}, FID b {}'.format( | |
self.current_epoch, self.current_iteration, | |
fid_value_a, fid_value_b)) | |
return fid_value_a, fid_value_b | |