# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import torch from torch import nn import torch.nn.functional as F import numpy as np import os from imaginaire.evaluation import compute_fid, compute_kid from imaginaire.utils.diff_aug import apply_diff_aug from imaginaire.losses import GANLoss from imaginaire.trainers.base import BaseTrainer from imaginaire.utils.distributed import is_master class Trainer(BaseTrainer): r"""Reimplementation of the FUNIT (https://arxiv.org/abs/1905.01723) algorithm. Args: cfg (obj): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): self.best_kid = None self.use_fid = getattr(cfg.trainer, 'use_fid', False) self.use_kid = getattr(cfg.trainer, 'use_kid', True) self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 1) self.kid_sample_size = getattr(cfg.trainer, 'kid_sample_size', 256) self.kid_subset_size = getattr(cfg.trainer, 'kid_subset_size', 256) super().__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) def _init_loss(self, cfg): r"""Initialize loss terms. In FUNIT, we have several loss terms including the GAN loss, the image reconstruction loss, the feature matching loss, and the gradient penalty loss. Args: cfg (obj): Global configuration. """ self.criteria['gan'] = GANLoss(cfg.trainer.gan_mode) self.criteria['image_recon'] = nn.L1Loss() self.criteria['feature_matching'] = nn.L1Loss() for loss_name, loss_weight in cfg.trainer.loss_weight.__dict__.items(): if loss_weight > 0: self.weights[loss_name] = loss_weight def gen_forward(self, data): r"""Compute the loss for FUNIT generator. Args: data (dict): Training data at the current iteration. """ net_G_output = self.net_G(data) # Differentiable augmentation. keys = ['images_recon', 'images_trans'] net_D_output = self.net_D(data, apply_diff_aug( net_G_output, keys, self.aug_policy)) self._time_before_loss() # GAN loss # We use both the translation and reconstruction streams. if 'gan' in self.weights: self.gen_losses['gan'] = 0.5 * ( self.criteria['gan']( net_D_output['fake_out_trans'], True, dis_update=False) + self.criteria['gan']( net_D_output['fake_out_recon'], True, dis_update=False)) # Image reconstruction loss if 'image_recon' in self.weights: self.gen_losses['image_recon'] = \ self.criteria['image_recon'](net_G_output['images_recon'], data['images_content']) # Feature matching loss if 'feature_matching' in self.weights: self.gen_losses['feature_matching'] = \ self.criteria['feature_matching']( net_D_output['fake_features_trans'], net_D_output['real_features_style']) # Compute total loss total_loss = self._get_total_loss(gen_forward=True) return total_loss def dis_forward(self, data): r"""Compute the loss for FUNIT discriminator. Args: data (dict): Training data at the current iteration. """ with torch.no_grad(): net_G_output = self.net_G(data) net_G_output['images_trans'].requires_grad = True net_D_output = self.net_D( apply_diff_aug(data, ['images_style'], self.aug_policy), apply_diff_aug(net_G_output, ['images_trans'], self.aug_policy), recon=False) self._time_before_loss() self.dis_losses['gan'] = \ self.criteria['gan'](net_D_output['real_out_style'], True) + \ self.criteria['gan'](net_D_output['fake_out_trans'], False) # Compute total loss total_loss = self._get_total_loss(gen_forward=False) return total_loss def _get_visualizations(self, data): r"""Compute visualization image. Args: data (dict): The current batch. """ net_G_for_evaluation = self.net_G with torch.no_grad(): net_G_output = net_G_for_evaluation(data) vis_images = [data['images_content'], data['images_style'], net_G_output['images_recon'], net_G_output['images_trans']] _, _, h, w = net_G_output['images_recon'].size() if 'attn_a' in net_G_output: for i in range(net_G_output['attn_a'].size(1)): vis_images += [ F.interpolate( net_G_output['attn_a'][:, i:i + 1, :, :], ( h, w)).expand(-1, 3, -1, -1)] for i in range(net_G_output['attn_a'].size(1)): vis_images += [ F.interpolate( net_G_output['attn_b'][:, i:i + 1, :, :], ( h, w)).expand(-1, 3, -1, -1)] if self.cfg.trainer.model_average_config.enabled: net_G_for_evaluation = self.net_G.module.averaged_model net_G_output = net_G_for_evaluation(data) vis_images += [net_G_output['images_recon'], net_G_output['images_trans']] return vis_images def _compute_fid(self): r"""Compute FID. We will compute a FID value per test class. That is if you have 30 test classes, we will compute 30 different FID values. We will then report the mean of the FID values as the final performance number as described in the FUNIT paper. """ self.net_G.eval() if self.cfg.trainer.model_average_config.enabled: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G all_fid_values = [] num_test_classes = self.val_data_loader.dataset.num_style_classes for class_idx in range(num_test_classes): fid_path = self._get_save_path(os.path.join('fid', str(class_idx)), 'npy') self.val_data_loader.dataset.set_sample_class_idx(class_idx) fid_value = compute_fid(fid_path, self.val_data_loader, net_G_for_evaluation, 'images_style', 'images_trans') all_fid_values.append(fid_value) if is_master(): mean_fid = np.mean(all_fid_values) print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format( self.current_epoch, self.current_iteration, mean_fid)) return mean_fid else: return None def _compute_kid(self): self.net_G.eval() if self.cfg.trainer.model_average_config.enabled: net_G_for_evaluation = self.net_G.module.averaged_model else: net_G_for_evaluation = self.net_G all_kid_values = [] num_test_classes = self.val_data_loader.dataset.num_style_classes for class_idx in range(num_test_classes): kid_path = self._get_save_path(os.path.join('kid', str(class_idx)), 'npy') self.val_data_loader.dataset.set_sample_class_idx(class_idx) kid_value = compute_kid( kid_path, self.val_data_loader, net_G_for_evaluation, 'images_style', 'images_trans', num_subsets=self.kid_num_subsets, sample_size=self.kid_sample_size, subset_size=self.kid_subset_size) all_kid_values.append(kid_value) if is_master(): mean_kid = np.mean(all_kid_values) print('Epoch {:05}, Iteration {:09}, Mean FID {}'.format( self.current_epoch, self.current_iteration, mean_kid)) return mean_kid else: return None def write_metrics(self): r"""Write metrics to the tensorboard.""" metric_dict = {} if self.use_kid: cur_kid = self._compute_kid() if cur_kid is not None: if self.best_kid is not None: self.best_kid = min(self.best_kid, cur_kid) else: self.best_kid = cur_kid metric_dict.update({'KID': cur_kid, 'best_KID': self.best_kid}) if self.use_fid: cur_fid = self._compute_fid() if cur_fid is not None: if self.best_fid is not None: self.best_fid = min(self.best_fid, cur_fid) else: self.best_fid = cur_fid metric_dict.update({'FID': cur_fid, 'best_FID': self.best_fid}) if is_master(): self._write_to_meters(metric_dict, self.metric_meters) self._flush_meters(self.metric_meters)