Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import collections | |
import os | |
import torch | |
import torch.nn as nn | |
from imaginaire.config import Config | |
from imaginaire.generators.spade import Generator as SPADEGenerator | |
from imaginaire.losses import (FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss) | |
from imaginaire.model_utils.gancraft.loss import GANLoss | |
from imaginaire.trainers.base import BaseTrainer | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.io import get_checkpoint | |
from imaginaire.utils.misc import split_labels, to_device | |
from imaginaire.utils.trainer import ModelAverage, WrappedModel | |
from imaginaire.utils.visualization import tensor2label | |
class GauGANLoader(object): | |
r"""Manages the SPADE/GauGAN model used to generate pseudo-GTs for training GANcraft. | |
Args: | |
gaugan_cfg (Config): SPADE configuration. | |
""" | |
def __init__(self, gaugan_cfg): | |
print('[GauGANLoader] Loading GauGAN model.') | |
cfg = Config(gaugan_cfg.config) | |
default_checkpoint_path = os.path.basename(gaugan_cfg.config).split('.yaml')[0] + '-' + \ | |
cfg.pretrained_weight + '.pt' | |
checkpoint = get_checkpoint(default_checkpoint_path, cfg.pretrained_weight) | |
ckpt = torch.load(checkpoint) | |
net_G = WrappedModel(ModelAverage(SPADEGenerator(cfg.gen, cfg.data).to('cuda'))) | |
net_G.load_state_dict(ckpt['net_G']) | |
self.net_GG = net_G.module.averaged_model | |
self.net_GG.eval() | |
self.net_GG.half() | |
print('[GauGANLoader] GauGAN loading complete.') | |
def eval(self, label, z=None, style_img=None): | |
r"""Produce output given segmentation and other conditioning inputs. | |
random style will be used if neither z nor style_img is provided. | |
Args: | |
label (N x C x H x W tensor): One-hot segmentation mask of shape. | |
z: Style vector. | |
style_img: Style image. | |
""" | |
inputs = {'label': label[:, :-1].detach().half()} | |
random_style = True | |
if z is not None: | |
random_style = False | |
inputs['z'] = z.detach().half() | |
elif style_img is not None: | |
random_style = False | |
inputs['images'] = style_img.detach().half() | |
net_GG_output = self.net_GG(inputs, random_style=random_style) | |
return net_GG_output['fake_images'] | |
class Trainer(BaseTrainer): | |
r"""Initialize GANcraft trainer. | |
Args: | |
cfg (Config): Global configuration. | |
net_G (obj): Generator network. | |
net_D (obj): Discriminator network. | |
opt_G (obj): Optimizer for the generator network. | |
opt_D (obj): Optimizer for the discriminator network. | |
sch_G (obj): Scheduler for the generator optimizer. | |
sch_D (obj): Scheduler for the discriminator optimizer. | |
train_data_loader (obj): Train data loader. | |
val_data_loader (obj): Validation data loader. | |
""" | |
def __init__(self, | |
cfg, | |
net_G, | |
net_D, | |
opt_G, | |
opt_D, | |
sch_G, | |
sch_D, | |
train_data_loader, | |
val_data_loader): | |
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader) | |
# Load the pseudo-GT network only if in training mode, else not needed. | |
if not self.is_inference: | |
self.gaugan_model = GauGANLoader(cfg.trainer.gaugan_loader) | |
def _init_loss(self, cfg): | |
r"""Initialize loss terms. | |
Args: | |
cfg (obj): Global configuration. | |
""" | |
if hasattr(cfg.trainer.loss_weight, 'gan'): | |
self.criteria['GAN'] = GANLoss() | |
self.weights['GAN'] = cfg.trainer.loss_weight.gan | |
if hasattr(cfg.trainer.loss_weight, 'pseudo_gan'): | |
self.criteria['PGAN'] = GANLoss() | |
self.weights['PGAN'] = cfg.trainer.loss_weight.pseudo_gan | |
if hasattr(cfg.trainer.loss_weight, 'l2'): | |
self.criteria['L2'] = nn.MSELoss() | |
self.weights['L2'] = cfg.trainer.loss_weight.l2 | |
if hasattr(cfg.trainer.loss_weight, 'l1'): | |
self.criteria['L1'] = nn.L1Loss() | |
self.weights['L1'] = cfg.trainer.loss_weight.l1 | |
if hasattr(cfg.trainer.loss_weight, 'TV') | |
if hasattr(cfg.trainer, 'perceptual_loss'): | |
self.criteria['Perceptual'] = \ | |
PerceptualLoss( | |
network=cfg.trainer.perceptual_loss.mode, | |
layers=cfg.trainer.perceptual_loss.layers, | |
weights=cfg.trainer.perceptual_loss.weights) | |
self.weights['Perceptual'] = cfg.trainer.loss_weight.perceptual | |
# Setup the feature matching loss. | |
if hasattr(cfg.trainer.loss_weight, 'feature_matching'): | |
self.criteria['FeatureMatching'] = FeatureMatchingLoss() | |
self.weights['FeatureMatching'] = \ | |
cfg.trainer.loss_weight.feature_matching | |
# Setup the Gaussian KL divergence loss. | |
if hasattr(cfg.trainer.loss_weight, 'kl'): | |
self.criteria['GaussianKL'] = GaussianKLLoss() | |
self.weights['GaussianKL'] = cfg.trainer.loss_weight.kl | |
def _start_of_epoch(self, current_epoch): | |
torch.cuda.empty_cache() # Prevent the first iteration from running OOM. | |
def _start_of_iteration(self, data, current_iteration): | |
r"""Model specific custom start of iteration process. We will do two | |
things. First, put all the data to GPU. Second, we will resize the | |
input so that it becomes multiple of the factor for bug-free | |
convolutional operations. This factor is given by the yaml file. | |
E.g., base = getattr(self.net_G, 'base', 32) | |
Args: | |
data (dict): The current batch. | |
current_iteration (int): The iteration number of the current batch. | |
""" | |
data = to_device(data, 'cuda') | |
# Sample camera poses and pseudo-GTs. | |
with torch.no_grad(): | |
samples = self.net_G.module.sample_camera(data, self.gaugan_model.eval) | |
return {**data, **samples} | |
def gen_forward(self, data): | |
r"""Compute the loss for SPADE generator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
net_G_output = self.net_G(data, random_style=False) | |
self._time_before_loss() | |
if 'GAN' in self.criteria or 'PGAN' in self.criteria: | |
incl_pseudo_real = False | |
if 'FeatureMatching' in self.criteria: | |
incl_pseudo_real = True | |
net_D_output = self.net_D(data, net_G_output, incl_real=False, incl_pseudo_real=incl_pseudo_real) | |
output_fake = net_D_output['fake_outputs'] # Choose from real_outputs and fake_outputs. | |
gan_loss = self.criteria['GAN'](output_fake, True, dis_update=False) | |
if 'GAN' in self.criteria: | |
self.gen_losses['GAN'] = gan_loss | |
if 'PGAN' in self.criteria: | |
self.gen_losses['PGAN'] = gan_loss | |
if 'FeatureMatching' in self.criteria: | |
self.gen_losses['FeatureMatching'] = self.criteria['FeatureMatching']( | |
net_D_output['fake_features'], net_D_output['pseudo_real_features']) | |
if 'GaussianKL' in self.criteria: | |
self.gen_losses['GaussianKL'] = self.criteria['GaussianKL'](net_G_output['mu'], net_G_output['logvar']) | |
# Perceptual loss is always between fake image and pseudo real image. | |
if 'Perceptual' in self.criteria: | |
self.gen_losses['Perceptual'] = self.criteria['Perceptual']( | |
net_G_output['fake_images'], data['pseudo_real_img']) | |
# Reconstruction loss between fake and pseudo real. | |
if 'L2' in self.criteria: | |
self.gen_losses['L2'] = self.criteria['L2'](net_G_output['fake_images'], data['pseudo_real_img']) | |
if 'L1' in self.criteria: | |
self.gen_losses['L1'] = self.criteria['L1'](net_G_output['fake_images'], data['pseudo_real_img']) | |
total_loss = 0 | |
for key in self.criteria: | |
total_loss = total_loss + self.gen_losses[key] * self.weights[key] | |
self.gen_losses['total'] = total_loss | |
return total_loss | |
def dis_forward(self, data): | |
r"""Compute the loss for GANcraft discriminator. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
if 'GAN' not in self.criteria and 'PGAN' not in self.criteria: | |
return | |
with torch.no_grad(): | |
net_G_output = self.net_G(data, random_style=False) | |
net_G_output['fake_images'] = net_G_output['fake_images'].detach() | |
incl_real = False | |
incl_pseudo_real = False | |
if 'GAN' in self.criteria: | |
incl_real = True | |
if 'PGAN' in self.criteria: | |
incl_pseudo_real = True | |
net_D_output = self.net_D(data, net_G_output, incl_real=incl_real, incl_pseudo_real=incl_pseudo_real) | |
self._time_before_loss() | |
total_loss = 0 | |
if 'GAN' in self.criteria: | |
output_fake = net_D_output['fake_outputs'] | |
output_real = net_D_output['real_outputs'] | |
fake_loss = self.criteria['GAN'](output_fake, False, dis_update=True) | |
true_loss = self.criteria['GAN'](output_real, True, dis_update=True) | |
self.dis_losses['GAN/fake'] = fake_loss | |
self.dis_losses['GAN/true'] = true_loss | |
self.dis_losses['GAN'] = fake_loss + true_loss | |
total_loss = total_loss + self.dis_losses['GAN'] * self.weights['GAN'] | |
if 'PGAN' in self.criteria: | |
output_fake = net_D_output['fake_outputs'] | |
output_pseudo_real = net_D_output['pseudo_real_outputs'] | |
fake_loss = self.criteria['PGAN'](output_fake, False, dis_update=True) | |
true_loss = self.criteria['PGAN'](output_pseudo_real, True, dis_update=True) | |
self.dis_losses['PGAN/fake'] = fake_loss | |
self.dis_losses['PGAN/true'] = true_loss | |
self.dis_losses['PGAN'] = fake_loss + true_loss | |
total_loss = total_loss + self.dis_losses['PGAN'] * self.weights['PGAN'] | |
self.dis_losses['total'] = total_loss | |
return total_loss | |
def _get_visualizations(self, data): | |
r"""Compute visualization image. | |
Args: | |
data (dict): The current batch. | |
""" | |
with torch.no_grad(): | |
label_lengths = self.train_data_loader.dataset.get_label_lengths() | |
labels = split_labels(data['label'], label_lengths) | |
# Get visualization of the real image and segmentation mask. | |
segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) | |
segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) | |
# Get output from GANcraft model | |
net_G_output_randstyle = self.net_G(data, random_style=True) | |
net_G_output = self.net_G(data, random_style=False) | |
vis_images = [data['images'], segmap, net_G_output_randstyle['fake_images'], net_G_output['fake_images']] | |
if 'fake_masks' in data: | |
# Get pseudo-GT. | |
labels = split_labels(data['fake_masks'], label_lengths) | |
segmap = tensor2label(labels['seg_maps'], label_lengths['seg_maps'], output_normalized_tensor=True) | |
segmap = torch.cat([x.unsqueeze(0) for x in segmap], 0) | |
vis_images.append(segmap) | |
if 'pseudo_real_img' in data: | |
vis_images.append(data['pseudo_real_img']) | |
if self.cfg.trainer.model_average_config.enabled: | |
net_G_model_average_output = self.net_G.module.averaged_model(data, random_style=True) | |
vis_images.append(net_G_model_average_output['fake_images']) | |
return vis_images | |
def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): | |
r"""Load network weights, optimizer parameters, scheduler parameters | |
from a checkpoint. | |
Args: | |
cfg (obj): Global configuration. | |
checkpoint_path (str): Path to the checkpoint. | |
resume (bool or None): If not ``None``, will determine whether or | |
not to load optimizers in addition to network weights. | |
""" | |
ret = super().load_checkpoint(cfg, checkpoint_path, resume, load_sch) | |
if getattr(cfg.trainer, 'reset_opt_g_on_resume', False): | |
self.opt_G.state = collections.defaultdict(dict) | |
print('[GANcraft::load_checkpoint] Resetting opt_G.state') | |
if getattr(cfg.trainer, 'reset_opt_d_on_resume', False): | |
self.opt_D.state = collections.defaultdict(dict) | |
print('[GANcraft::load_checkpoint] Resetting opt_D.state') | |
return ret | |
def test(self, data_loader, output_dir, inference_args): | |
r"""Compute results images for a batch of input data and save the | |
results in the specified folder. | |
Args: | |
data_loader (torch.utils.data.DataLoader): PyTorch dataloader. | |
output_dir (str): Target location for saving the output image. | |
""" | |
if self.cfg.trainer.model_average_config.enabled: | |
net_G = self.net_G.module.averaged_model | |
else: | |
net_G = self.net_G.module | |
net_G.eval() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
net_G.inference(output_dir, **vars(inference_args)) | |