Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
from email import generator | |
from cv2 import DescriptorMatcher | |
import training | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torchvision.utils import save_image | |
from torch_utils import training_stats | |
from torch_utils import misc | |
from torch_utils.ops import conv2d_gradfix | |
#---------------------------------------------------------------------------- | |
class Loss: | |
def accumulate_gradients(self, **kwargs): # to be overridden by subclass | |
raise NotImplementedError() | |
#---------------------------------------------------------------------------- | |
class StyleGAN2Loss(Loss): | |
def __init__( | |
self, device, G_mapping, G_synthesis, D, | |
G_encoder=None, augment_pipe=None, D_ema=None, | |
style_mixing_prob=0.9, r1_gamma=10, | |
pl_batch_shrink=2, pl_decay=0.01, pl_weight=2, other_weights=None, | |
curriculum=None, alpha_start=0.0, cycle_consistency=False, label_smooth=0, | |
generator_mode='random_z_random_c'): | |
super().__init__() | |
self.device = device | |
self.G_mapping = G_mapping | |
self.G_synthesis = G_synthesis | |
self.G_encoder = G_encoder | |
self.D = D | |
self.D_ema = D_ema | |
self.augment_pipe = augment_pipe | |
self.style_mixing_prob = style_mixing_prob | |
self.r1_gamma = r1_gamma | |
self.pl_batch_shrink = pl_batch_shrink | |
self.pl_decay = pl_decay | |
self.pl_weight = pl_weight | |
self.other_weights = other_weights | |
self.pl_mean = torch.zeros([], device=device) | |
self.curriculum = curriculum | |
self.alpha_start = alpha_start | |
self.alpha = None | |
self.cycle_consistency = cycle_consistency | |
self.label_smooth = label_smooth | |
self.generator_mode = generator_mode | |
if self.G_encoder is not None: | |
import lpips | |
self.lpips_loss = lpips.LPIPS(net='vgg').to(device=device) | |
def set_alpha(self, steps): | |
alpha = None | |
if self.curriculum is not None: | |
if self.curriculum == 'upsample': | |
alpha = 0.0 | |
else: | |
assert len(self.curriculum) == 2, "currently support one stage for now" | |
start, end = self.curriculum | |
alpha = min(1., max(0., (steps / 1e3 - start) / (end - start))) | |
if self.alpha_start > 0: | |
alpha = self.alpha_start + (1 - self.alpha_start) * alpha | |
self.alpha = alpha | |
self.steps = steps | |
self.curr_status = None | |
def _apply(m): | |
if hasattr(m, "set_alpha") and m != self: | |
m.set_alpha(alpha) | |
if hasattr(m, "set_steps") and m != self: | |
m.set_steps(steps) | |
if hasattr(m, "set_resolution") and m != self: | |
m.set_resolution(self.curr_status) | |
self.G_synthesis.apply(_apply) | |
self.curr_status = self.resolution | |
self.D.apply(_apply) | |
if self.G_encoder is not None: | |
self.G_encoder.apply(_apply) | |
def run_G(self, z, c, sync, img=None, mode=None, get_loss=True): | |
synthesis_kwargs = {'camera_mode': 'random'} | |
generator_mode = self.generator_mode if mode is None else mode | |
if (generator_mode == 'image_z_random_c') or (generator_mode == 'image_z_image_c'): | |
assert (self.G_encoder is not None) and (img is not None) | |
with misc.ddp_sync(self.G_encoder, sync): | |
ws = self.G_encoder(img)['ws'] | |
if generator_mode == 'image_z_image_c': | |
with misc.ddp_sync(self.D, False): | |
synthesis_kwargs['camera_RT'] = misc.get_func(self.D, 'get_estimated_camera')[0](img) | |
with misc.ddp_sync(self.G_synthesis, sync): | |
out = self.G_synthesis(ws, **synthesis_kwargs) | |
if get_loss: # consistency loss given the image predicted camera (train the image encoder jointly) | |
out['consist_l1_loss'] = F.smooth_l1_loss(out['img'], img['img']) * 2.0 # TODO: DEBUG | |
out['consist_lpips_loss'] = self.lpips_loss(out['img'], img['img']) * 10.0 # TODO: DEBUG | |
elif (generator_mode == 'random_z_random_c') or (generator_mode == 'random_z_image_c'): | |
with misc.ddp_sync(self.G_mapping, sync): | |
ws = self.G_mapping(z, c) | |
if self.style_mixing_prob > 0: | |
with torch.autograd.profiler.record_function('style_mixing'): | |
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) | |
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) | |
ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] | |
if generator_mode == 'random_z_image_c': | |
assert img is not None | |
with torch.no_grad(): | |
D = self.D_ema if self.D_ema is not None else self.D | |
with misc.ddp_sync(D, sync): | |
estimated_c = misc.get_func(D, 'get_estimated_camera')(img)[0].detach() | |
if estimated_c.size(-1) == 16: | |
synthesis_kwargs['camera_RT'] = estimated_c | |
if estimated_c.size(-1) == 3: | |
synthesis_kwargs['camera_UV'] = estimated_c | |
with misc.ddp_sync(self.G_synthesis, sync): | |
out = self.G_synthesis(ws, **synthesis_kwargs) | |
else: | |
raise NotImplementedError(f'wrong generator_mode {generator_mode}') | |
return out, ws | |
def run_D(self, img, c, sync): | |
with misc.ddp_sync(self.D, sync): | |
logits = self.D(img, c, aug_pipe=self.augment_pipe) | |
return logits | |
def get_loss(self, outputs, module='D'): | |
reg_loss, logs, del_keys = 0, [], [] | |
if isinstance(outputs, dict): | |
for key in outputs: | |
if key[-5:] == '_loss': | |
logs += [(f'Loss/{module}/{key}', outputs[key])] | |
del_keys += [key] | |
if (self.other_weights is not None) and (key in self.other_weights): | |
reg_loss = reg_loss + outputs[key].mean() * self.other_weights[key] | |
else: | |
reg_loss = reg_loss + outputs[key].mean() | |
for key in del_keys: | |
del outputs[key] | |
for key, loss in logs: | |
training_stats.report(key, loss) | |
return reg_loss | |
def resolution(self): | |
return misc.get_func(self.G_synthesis, 'get_current_resolution')()[-1] | |
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, fake_img, sync, gain, scaler=None): | |
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] | |
do_Gmain = (phase in ['Gmain', 'Gboth']) | |
do_Dmain = (phase in ['Dmain', 'Dboth']) | |
do_Gpl = (phase in ['Greg', 'Gboth']) | |
do_Dr1 = (phase in ['Dreg', 'Dboth']) | |
losses = {} | |
# Gmain: Maximize logits for generated images. | |
loss_Gmain, reg_loss = 0, 0 | |
if isinstance(fake_img, dict): fake_img = fake_img['img'] | |
if do_Gmain: | |
with torch.autograd.profiler.record_function('Gmain_forward'): | |
gen_img, gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl), img=fake_img) # May get synced by Gpl. | |
reg_loss += self.get_loss(gen_img, 'G') | |
gen_logits = self.run_D(gen_img, gen_c, sync=False) | |
reg_loss += self.get_loss(gen_logits, 'G') | |
if isinstance(gen_logits, dict): | |
gen_logits = gen_logits['logits'] | |
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) | |
if self.label_smooth > 0: | |
loss_Gmain = loss_Gmain * (1 - self.label_smooth) + torch.nn.functional.softplus(gen_logits) * self.label_smooth | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
training_stats.report('Loss/G/loss', loss_Gmain) | |
with torch.autograd.profiler.record_function('Gmain_backward'): | |
loss_Gmain = loss_Gmain + reg_loss | |
losses['Gmain'] = loss_Gmain.mean().mul(gain) | |
loss = scaler.scale(losses['Gmain']) if scaler is not None else losses['Gmain'] | |
loss.backward() | |
# Gpl: Apply path length regularization. | |
if do_Gpl and (self.pl_weight != 0): | |
with torch.autograd.profiler.record_function('Gpl_forward'): | |
batch_size = max(1, gen_z.shape[0] // self.pl_batch_shrink) | |
gen_img, gen_ws = self.run_G( | |
gen_z[:batch_size], gen_c[:batch_size], sync=sync, | |
img=fake_img[:batch_size] if fake_img is not None else None) | |
if isinstance(gen_img, dict): gen_img = gen_img['img'] | |
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) | |
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): | |
# with torch.autograd.profiler.record_function('pl_grads'): | |
pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True, allow_unused=True)[0] | |
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() | |
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) | |
self.pl_mean.copy_(pl_mean.detach()) | |
pl_penalty = (pl_lengths - pl_mean).square() | |
training_stats.report('Loss/pl_penalty', pl_penalty) | |
loss_Gpl = pl_penalty * self.pl_weight | |
training_stats.report('Loss/G/reg', loss_Gpl) | |
with torch.autograd.profiler.record_function('Gpl_backward'): | |
losses['Gpl'] = (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain) | |
loss = scaler.scale(losses['Gpl']) if scaler is not None else losses['Gpl'] | |
loss.backward() | |
# Dmain: Minimize logits for generated images. | |
loss_Dgen, reg_loss = 0, 0 | |
if do_Dmain: | |
with torch.autograd.profiler.record_function('Dgen_forward'): | |
gen_img = self.run_G(gen_z, gen_c, sync=False, img=fake_img)[0] | |
reg_loss += self.get_loss(gen_img, 'D') | |
gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. | |
reg_loss += self.get_loss(gen_logits, 'D') | |
if isinstance(gen_logits, dict): | |
gen_logits = gen_logits['logits'] | |
training_stats.report('Loss/scores/fake', gen_logits) | |
training_stats.report('Loss/signs/fake', gen_logits.sign()) | |
loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) | |
with torch.autograd.profiler.record_function('Dgen_backward'): | |
loss_Dgen = loss_Dgen + reg_loss | |
losses['Dgen'] = loss_Dgen.mean().mul(gain) | |
loss = scaler.scale(losses['Dgen']) if scaler is not None else losses['Dgen'] | |
loss.backward() | |
# Dmain: Maximize logits for real images. | |
# Dr1: Apply R1 regularization. | |
if do_Dmain or (do_Dr1 and (self.r1_gamma != 0)): | |
name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' | |
with torch.autograd.profiler.record_function(name + '_forward'): | |
if isinstance(real_img, dict): | |
real_img['img'] = real_img['img'].requires_grad_(do_Dr1) | |
else: | |
real_img = real_img.requires_grad_(do_Dr1) | |
real_logits = self.run_D(real_img, real_c, sync=sync) | |
if isinstance(real_logits, dict): | |
real_logits = real_logits['logits'] | |
training_stats.report('Loss/scores/real', real_logits) | |
training_stats.report('Loss/signs/real', real_logits.sign()) | |
loss_Dreal = 0 | |
if do_Dmain: | |
loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) | |
if self.label_smooth > 0: | |
loss_Dreal = loss_Dreal * (1 - self.label_smooth) + torch.nn.functional.softplus(real_logits) * self.label_smooth | |
training_stats.report('Loss/D/loss', loss_Dgen.mean() + loss_Dreal.mean()) | |
loss_Dr1 = 0 | |
if do_Dr1: | |
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): | |
real_img_tmp = real_img['img'] if isinstance(real_img, dict) else real_img | |
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] | |
r1_penalty = r1_grads.square().sum([1,2,3]) | |
loss_Dr1 = r1_penalty * (self.r1_gamma / 2) | |
training_stats.report('Loss/r1_penalty', r1_penalty) | |
training_stats.report('Loss/D/reg', loss_Dr1) | |
with torch.autograd.profiler.record_function(name + '_backward'): | |
losses['Dr1'] = (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain) | |
loss = scaler.scale(losses['Dr1']) if scaler is not None else losses['Dr1'] | |
loss.backward() | |
return losses | |
#---------------------------------------------------------------------------- | |