Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import functools | |
import math | |
import torch | |
import torch.nn.functional as F | |
from imaginaire.evaluation import compute_fid | |
from imaginaire.losses import (FeatureMatchingLoss, GANLoss, GaussianKLLoss, | |
PerceptualLoss) | |
from imaginaire.trainers.base import BaseTrainer | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.model_average import reset_batch_norm, \ | |
calibrate_batch_norm_momentum | |
from imaginaire.utils.misc import split_labels, to_device | |
from imaginaire.utils.visualization import tensor2label | |
class Trainer(BaseTrainer): | |
r"""Initialize SPADE trainer. | |
Args: | |
cfg (Config): Global configuration. | |
net_G (obj): Generator network. | |
net_D (obj): Discriminator network. | |
opt_G (obj): Optimizer for the generator network. | |
opt_D (obj): Optimizer for the discriminator network. | |
sch_G (obj): Scheduler for the generator optimizer. | |
sch_D (obj): Scheduler for the discriminator optimizer. | |
train_data_loader (obj): Train data loader. | |
val_data_loader (obj): Validation data loader. | |
""" | |
def __init__(self, | |
cfg, | |
net_G, | |
net_D, | |
opt_G, | |
opt_D, | |
sch_G, | |
sch_D, | |
train_data_loader, | |
val_data_loader): | |
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader) | |
if cfg.data.type == 'imaginaire.datasets.paired_videos': | |
self.video_mode = True | |
else: | |
self.video_mode = False | |
def _init_loss(self, cfg): | |
r"""Initialize loss terms. | |
Args: | |
cfg (obj): Global configuration. | |
""" | |
self.criteria['GAN'] = GANLoss(cfg.trainer.gan_mode) | |
self.weights['GAN'] = cfg.trainer.loss_weight.gan | |
# Setup the perceptual loss. Note that perceptual loss can run in | |
# fp16 mode for additional speed. We find that running on fp16 mode | |
# leads to improve training speed while maintaining the same accuracy. | |
if hasattr(cfg.trainer, 'perceptual_loss'): | |
self.criteria['Perceptual'] = \ | |
PerceptualLoss( | |
network=cfg.trainer.perceptual_loss.mode, | |
layers=cfg.trainer.perceptual_loss.layers, | |
weights=cfg.trainer.perceptual_loss.weights) | |
self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual | |
# Setup the feature matching loss. | |
self.criteria['FeatureMatching'] = FeatureMatchingLoss() | |
self.weights['FeatureMatching'] = \ | |
cfg.trainer.loss_weight.feature_matching | |
# Setup the Gaussian KL divergence loss. | |
self.criteria['GaussianKL'] = GaussianKLLoss() | |
self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl | |
def _start_of_iteration(self, data, current_iteration): | |
r"""Model specific custom start of iteration process. We will do two | |
things. First, put all the data to GPU. Second, we will resize the | |
input so that it becomes multiple of the factor for bug-free | |
convolutional operations. This factor is given by the yaml file. | |
E.g., base = getattr(self.net_G, 'base', 32) | |
Args: | |
data (dict): The current batch. | |
current_iteration (int): The iteration number of the current batch. | |
""" | |
data = to_device(data, 'cuda') | |
data = self._resize_data(data) | |
return data | |
def gen_forward(self, data): | |
r"""Compute the loss for SPADE generator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
net_G_output = self.net_G(data) | |
net_D_output = self.net_D(data, net_G_output) | |
self._time_before_loss() | |
output_fake = self._get_outputs(net_D_output, real=False) | |
self.gen_losses['GAN'] = self.criteria['GAN'](output_fake, True, dis_update=False) | |
self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( | |
net_D_output['fake_features'], net_D_output['real_features']) | |
if self.net_G_module.use_style_encoder: | |
self.gen_losses['GaussianKL'] = \ | |
self.criteria['GaussianKL'](net_G_output['mu'], | |
net_G_output['logvar']) | |
else: | |
self.gen_losses['GaussianKL'] = \ | |
self.gen_losses['GAN'].new_tensor([0]) | |
if hasattr(self.cfg.trainer, 'perceptual_loss'): | |
self.gen_losses['Perceptual'] = self.criteria['Perceptual']( | |
net_G_output['fake_images'], data['images']) | |
total_loss = self.gen_losses['GAN'].new_tensor([0]) | |
for key in self.criteria: | |
total_loss += self.gen_losses[key] * self.weights[key] | |
self.gen_losses['total'] = total_loss | |
return total_loss | |
def dis_forward(self, data): | |
r"""Compute the loss for SPADE discriminator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
with torch.no_grad(): | |
net_G_output = self.net_G(data) | |
net_G_output['fake_images'] = net_G_output['fake_images'].detach() | |
net_D_output = self.net_D(data, net_G_output) | |
self._time_before_loss() | |
output_fake = self._get_outputs(net_D_output, real=False) | |
output_real = self._get_outputs(net_D_output, real=True) | |
fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) | |
true_loss = self.criteria['GAN'](output_real, True, dis_update=True) | |
self.dis_losses['GAN/fake'] = fake_loss | |
self.dis_losses['GAN/true'] = true_loss | |
self.dis_losses['GAN'] = fake_loss + true_loss | |
total_loss = self.dis_losses['GAN'] * self.weights['GAN'] | |
self.dis_losses['total'] = total_loss | |
return total_loss | |
def _get_visualizations(self, data): | |
r"""Compute visualization image. We will first recalculate the batch | |
statistics for the moving average model. | |
Args: | |
data (dict): The current batch. | |
""" | |
self.recalculate_batch_norm_statistics( | |
self.train_data_loader) | |
with torch.no_grad(): | |
label_lengths = self.train_data_loader.dataset.get_label_lengths() | |
labels = split_labels(data['label'], label_lengths) | |
# Get visualization of the segmentation mask. | |
vis_images = list() | |
vis_images.append(data['images']) | |
net_G_output = self.net_G(data, random_style=True) | |
# print(labels.keys()) | |
for key in labels.keys(): | |
if 'seg' in key: | |
segmaps = tensor2label(labels[key], label_lengths[key], output_normalized_tensor=True) | |
segmaps = torch.cat([x.unsqueeze(0) for x in segmaps], 0) | |
vis_images.append(segmaps) | |
if 'edge' in key: | |
edgemaps = torch.cat((labels[key], labels[key], labels[key]), 1) | |
vis_images.append(edgemaps) | |
vis_images.append(net_G_output['fake_images']) | |
if self.cfg.trainer.model_average_config.enabled: | |
net_G_model_average_output = \ | |
self.net_G.module.averaged_model(data, random_style=True) | |
vis_images.append(net_G_model_average_output['fake_images']) | |
return vis_images | |
def recalculate_batch_norm_statistics(self, data_loader): | |
r"""Update the statistics in the moving average model. | |
Args: | |
data_loader (pytorch data loader): Data loader for estimating the | |
statistics. | |
""" | |
if not self.cfg.trainer.model_average_config.enabled: | |
return | |
model_average_iteration = \ | |
self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations | |
if model_average_iteration == 0: | |
return | |
with torch.no_grad(): | |
# Accumulate bn stats.. | |
self.net_G.module.averaged_model.train() | |
# Reset running stats. | |
self.net_G.module.averaged_model.apply(reset_batch_norm) | |
for cal_it, cal_data in enumerate(data_loader): | |
if cal_it >= model_average_iteration: | |
print('Done with {} iterations of updating batch norm ' | |
'statistics'.format(model_average_iteration)) | |
break | |
# cal_data = to_device(cal_data, 'cuda') | |
cal_data = self._start_of_iteration(cal_data, 0) | |
# Averaging over all batches | |
self.net_G.module.averaged_model.apply( | |
calibrate_batch_norm_momentum) | |
self.net_G.module.averaged_model(cal_data) | |
def write_metrics(self): | |
r"""If moving average model presents, we have two meters one for | |
regular FID and one for average FID. If no moving average model, | |
we just report average FID. | |
""" | |
if self.cfg.trainer.model_average_config.enabled: | |
regular_fid, average_fid = self._compute_fid() | |
metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid} | |
self._write_to_meters(metric_dict, self.metric_meters, reduce=False) | |
else: | |
regular_fid = self._compute_fid() | |
metric_dict = {'FID/regular': regular_fid} | |
self._write_to_meters(metric_dict, self.metric_meters, reduce=False) | |
self._flush_meters(self.metric_meters) | |
def _compute_fid(self): | |
r"""We will compute FID for the regular model using the eval mode. | |
For the moving average model, we will use the eval mode. | |
""" | |
self.net_G.eval() | |
net_G_for_evaluation = \ | |
functools.partial(self.net_G, random_style=True) | |
regular_fid_path = self._get_save_path('regular_fid', 'npy') | |
preprocess = \ | |
functools.partial(self._start_of_iteration, current_iteration=0) | |
regular_fid_value = compute_fid(regular_fid_path, | |
self.val_data_loader, | |
net_G_for_evaluation, | |
preprocess=preprocess) | |
print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( | |
self.current_epoch, self.current_iteration, regular_fid_value)) | |
if self.cfg.trainer.model_average_config.enabled: | |
avg_net_G_for_evaluation = \ | |
functools.partial(self.net_G.module.averaged_model, | |
random_style=True) | |
fid_path = self._get_save_path('average_fid', 'npy') | |
fid_value = compute_fid(fid_path, self.val_data_loader, | |
avg_net_G_for_evaluation, | |
preprocess=preprocess) | |
print('Epoch {:05}, Iteration {:09}, FID {}'.format( | |
self.current_epoch, self.current_iteration, fid_value)) | |
self.net_G.float() | |
return regular_fid_value, fid_value | |
else: | |
self.net_G.float() | |
return regular_fid_value | |
def _resize_data(self, data): | |
r"""Resize input label maps and images so that it can be properly | |
generated by the generator. | |
Args: | |
data (dict): Input dictionary contains 'label' and 'image fields. | |
""" | |
base = getattr(self.net_G, 'base', 32) | |
sy = math.floor(data['label'].size()[2] * 1.0 // base) * base | |
sx = math.floor(data['label'].size()[3] * 1.0 // base) * base | |
data['label'] = F.interpolate( | |
data['label'], size=[sy, sx], mode='nearest') | |
if 'images' in data.keys(): | |
data['images'] = F.interpolate( | |
data['images'], size=[sy, sx], mode='bicubic') | |
return data | |