Spaces:
Runtime error
Runtime error
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Loss functions.""" | |
import numpy as np | |
import torch | |
from torch_utils import training_stats | |
from torch_utils.ops import conv2d_gradfix | |
from torch_utils.ops import upfirdn2d | |
# ---------------------------------------------------------------------------- | |
class Loss: | |
# to be overridden by subclass | |
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): | |
raise NotImplementedError() | |
# ---------------------------------------------------------------------------- | |
class StyleGAN2Loss(Loss): | |
def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0): | |
super().__init__() | |
self.device = device | |
self.G = G | |
self.D = D | |
self.augment_pipe = augment_pipe | |
self.r1_gamma = r1_gamma | |
self.style_mixing_prob = style_mixing_prob | |
self.pl_weight = pl_weight | |
self.pl_batch_shrink = pl_batch_shrink | |
self.pl_decay = pl_decay | |
self.pl_no_weight_grad = pl_no_weight_grad | |
self.pl_mean = torch.zeros([], device=device) | |
self.blur_init_sigma = blur_init_sigma | |
self.blur_fade_kimg = blur_fade_kimg | |
def run_G(self, z, c, update_emas=False): | |
ws = self.G.mapping(z, c, update_emas=update_emas) | |
if self.style_mixing_prob > 0: | |
with torch.autograd.profiler.record_function('style_mixing'): | |
cutoff = torch.empty([], dtype=torch.int64, | |
device=ws.device).random_(1, ws.shape[1]) | |
cutoff = torch.where(torch.rand( | |
[], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) | |
ws[:, cutoff:] = self.G.mapping( | |
torch.randn_like(z), c, update_emas=False)[:, cutoff:] | |
img = self.G.synthesis(ws, update_emas=update_emas) | |
return img, ws | |
def run_D(self, img, c, blur_sigma=0, update_emas=False): | |
blur_size = np.floor(blur_sigma * 3) | |
if blur_size > 0: | |
with torch.autograd.profiler.record_function('blur'): | |
f = torch.arange(-blur_size, blur_size + 1, | |
device=img.device).div(blur_sigma).square().neg().exp2() | |
img = upfirdn2d.filter2d(img, f / f.sum()) | |
if self.augment_pipe is not None: | |
img = self.augment_pipe(img) | |
logits = self.D(img, c, update_emas=update_emas) | |
return logits | |
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): | |
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] | |
if self.pl_weight == 0: | |
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) | |
if self.r1_gamma == 0: | |
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) | |
blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * \ | |
self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 | |
# Gmain: Maximize logits for generated images. | |
if phase in ['Gmain', 'Gboth']: | |
with torch.autograd.profiler.record_function('Gmain_forward'): | |
gen_img, _gen_ws = self.run_G(gen_z, gen_c) | |
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
# -log(sigmoid(gen_logits)) | |
loss_Gmain = torch.nn.functional.softplus(-gen_logits) | |
training_stats.report('Loss/G/loss', loss_Gmain) | |
with torch.autograd.profiler.record_function('Gmain_backward'): | |
loss_Gmain.mean().mul(gain).backward() | |
# Gpl: Apply path length regularization. | |
if phase in ['Greg', 'Gboth']: | |
with torch.autograd.profiler.record_function('Gpl_forward'): | |
batch_size = gen_z.shape[0] // self.pl_batch_shrink | |
gen_img, gen_ws = self.run_G( | |
gen_z[:batch_size], gen_c[:batch_size]) | |
pl_noise = torch.randn_like( | |
gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) | |
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(self.pl_no_weight_grad): | |
pl_grads = torch.autograd.grad(outputs=[( | |
gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] | |
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() | |
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) | |
self.pl_mean.copy_(pl_mean.detach()) | |
pl_penalty = (pl_lengths - pl_mean).square() | |
training_stats.report('Loss/pl_penalty', pl_penalty) | |
loss_Gpl = pl_penalty * self.pl_weight | |
training_stats.report('Loss/G/reg', loss_Gpl) | |
with torch.autograd.profiler.record_function('Gpl_backward'): | |
loss_Gpl.mean().mul(gain).backward() | |
# Dmain: Minimize logits for generated images. | |
loss_Dgen = 0 | |
if phase in ['Dmain', 'Dboth']: | |
with torch.autograd.profiler.record_function('Dgen_forward'): | |
gen_img, _gen_ws = self.run_G(gen_z, gen_c, update_emas=True) | |
gen_logits = self.run_D( | |
gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
loss_Dgen = torch.nn.functional.softplus( | |
gen_logits) # -log(1 - sigmoid(gen_logits)) | |
with torch.autograd.profiler.record_function('Dgen_backward'): | |
loss_Dgen.mean().mul(gain).backward() | |
# Dmain: Maximize logits for real images. | |
# Dr1: Apply R1 regularization. | |
if phase in ['Dmain', 'Dreg', 'Dboth']: | |
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' | |
with torch.autograd.profiler.record_function(name + '_forward'): | |
real_img_tmp = real_img.detach().requires_grad_( | |
phase in ['Dreg', 'Dboth']) | |
real_logits = self.run_D( | |
real_img_tmp, real_c, blur_sigma=blur_sigma) | |
training_stats.report('Loss/scores/real', real_logits) | |
training_stats.report('Loss/signs/real', real_logits.sign()) | |
loss_Dreal = 0 | |
if phase in ['Dmain', 'Dboth']: | |
# -log(sigmoid(real_logits)) | |
loss_Dreal = torch.nn.functional.softplus(-real_logits) | |
training_stats.report( | |
'Loss/D/loss', loss_Dgen + loss_Dreal) | |
loss_Dr1 = 0 | |
if phase in ['Dreg', 'Dboth']: | |
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): | |
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[ | |
real_img_tmp], create_graph=True, only_inputs=True)[0] | |
r1_penalty = r1_grads.square().sum([1, 2, 3]) | |
loss_Dr1 = r1_penalty * (self.r1_gamma / 2) | |
training_stats.report('Loss/r1_penalty', r1_penalty) | |
training_stats.report('Loss/D/reg', loss_Dr1) | |
with torch.autograd.profiler.record_function(name + '_backward'): | |
(loss_Dreal + loss_Dr1).mean().mul(gain).backward() | |
# ---------------------------------------------------------------------------- | |