|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
from torch_utils import training_stats |
|
from torch_utils import misc |
|
from torch_utils.ops import conv2d_gradfix |
|
|
|
|
|
|
|
|
|
class Loss: |
|
def accumulate_gradients( |
|
self, phase, real_img, real_c, real_h, gen_z, gen_c, gen_h, sync, gain |
|
): |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
class StyleGAN2Loss(Loss): |
|
def __init__( |
|
self, |
|
device, |
|
G_mapping, |
|
G_synthesis, |
|
D, |
|
augment_pipe=None, |
|
style_mixing_prob=0.9, |
|
r1_gamma=10, |
|
pl_batch_shrink=2, |
|
pl_decay=0.01, |
|
pl_weight=2, |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.G_mapping = G_mapping |
|
self.G_synthesis = G_synthesis |
|
self.D = D |
|
self.augment_pipe = augment_pipe |
|
self.style_mixing_prob = style_mixing_prob |
|
self.r1_gamma = r1_gamma |
|
self.pl_batch_shrink = pl_batch_shrink |
|
self.pl_decay = pl_decay |
|
self.pl_weight = pl_weight |
|
self.pl_mean = torch.zeros([], device=device) |
|
|
|
def run_G(self, z, c, h, sync): |
|
with misc.ddp_sync(self.G_mapping, sync): |
|
ws = self.G_mapping(z, c, h) |
|
if self.style_mixing_prob > 0: |
|
with torch.autograd.profiler.record_function("style_mixing"): |
|
cutoff = torch.empty( |
|
[], dtype=torch.int64, device=ws.device |
|
).random_(1, ws.shape[1]) |
|
cutoff = torch.where( |
|
torch.rand([], device=ws.device) < self.style_mixing_prob, |
|
cutoff, |
|
torch.full_like(cutoff, ws.shape[1]), |
|
) |
|
ws[:, cutoff:] = self.G_mapping( |
|
torch.randn_like(z), c, h, skip_w_avg_update=True |
|
)[:, cutoff:] |
|
with misc.ddp_sync(self.G_synthesis, sync): |
|
img = self.G_synthesis(ws) |
|
return img, ws |
|
|
|
def run_D(self, img, c, h, sync): |
|
if self.augment_pipe is not None: |
|
img = self.augment_pipe(img) |
|
with misc.ddp_sync(self.D, sync): |
|
logits = self.D(img, c, h) |
|
return logits |
|
|
|
def accumulate_gradients( |
|
self, phase, real_img, real_c, real_h, gen_z, gen_c, gen_h, sync, gain |
|
): |
|
assert phase in ["Gmain", "Greg", "Gboth", "Dmain", "Dreg", "Dboth"] |
|
do_Gmain = phase in ["Gmain", "Gboth"] |
|
do_Dmain = phase in ["Dmain", "Dboth"] |
|
do_Gpl = (phase in ["Greg", "Gboth"]) and (self.pl_weight != 0) |
|
do_Dr1 = (phase in ["Dreg", "Dboth"]) and (self.r1_gamma != 0) |
|
|
|
|
|
if do_Gmain: |
|
with torch.autograd.profiler.record_function("Gmain_forward"): |
|
gen_img, _gen_ws = self.run_G( |
|
gen_z, gen_c, gen_h, sync=(sync and not do_Gpl) |
|
) |
|
gen_logits = self.run_D(gen_img, gen_c, gen_h, sync=False) |
|
training_stats.report("Loss/scores/fake", gen_logits) |
|
training_stats.report("Loss/signs/fake", gen_logits.sign()) |
|
loss_Gmain = torch.nn.functional.softplus( |
|
-gen_logits |
|
) |
|
training_stats.report("Loss/G/loss", loss_Gmain) |
|
with torch.autograd.profiler.record_function("Gmain_backward"): |
|
loss_Gmain.mean().mul(gain).backward() |
|
|
|
|
|
if do_Gpl: |
|
with torch.autograd.profiler.record_function("Gpl_forward"): |
|
batch_size = gen_z.shape[0] // self.pl_batch_shrink |
|
gen_img, gen_ws = self.run_G( |
|
gen_z[:batch_size], |
|
gen_c[:batch_size], |
|
gen_h[:batch_size], |
|
sync=sync, |
|
) |
|
pl_noise = torch.randn_like(gen_img) / np.sqrt( |
|
gen_img.shape[2] * gen_img.shape[3] |
|
) |
|
with torch.autograd.profiler.record_function( |
|
"pl_grads" |
|
), conv2d_gradfix.no_weight_gradients(): |
|
pl_grads = torch.autograd.grad( |
|
outputs=[(gen_img * pl_noise).sum()], |
|
inputs=[gen_ws], |
|
create_graph=True, |
|
only_inputs=True, |
|
)[0] |
|
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() |
|
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) |
|
self.pl_mean.copy_(pl_mean.detach()) |
|
pl_penalty = (pl_lengths - pl_mean).square() |
|
training_stats.report("Loss/pl_penalty", pl_penalty) |
|
loss_Gpl = pl_penalty * self.pl_weight |
|
training_stats.report("Loss/G/reg", loss_Gpl) |
|
with torch.autograd.profiler.record_function("Gpl_backward"): |
|
(gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() |
|
|
|
|
|
loss_Dgen = 0 |
|
if do_Dmain: |
|
with torch.autograd.profiler.record_function("Dgen_forward"): |
|
gen_img, _gen_ws = self.run_G(gen_z, gen_c, gen_h, sync=False) |
|
gen_logits = self.run_D( |
|
gen_img, gen_c, gen_h, sync=False |
|
) |
|
training_stats.report("Loss/scores/fake", gen_logits) |
|
training_stats.report("Loss/signs/fake", gen_logits.sign()) |
|
loss_Dgen = torch.nn.functional.softplus( |
|
gen_logits |
|
) |
|
with torch.autograd.profiler.record_function("Dgen_backward"): |
|
loss_Dgen.mean().mul(gain).backward() |
|
|
|
|
|
|
|
if do_Dmain or do_Dr1: |
|
name = ( |
|
"Dreal_Dr1" if do_Dmain and do_Dr1 else "Dreal" if do_Dmain else "Dr1" |
|
) |
|
with torch.autograd.profiler.record_function(name + "_forward"): |
|
real_img_tmp = real_img.detach().requires_grad_(do_Dr1) |
|
real_logits = self.run_D(real_img_tmp, real_c, real_h, sync=sync) |
|
training_stats.report("Loss/scores/real", real_logits) |
|
training_stats.report("Loss/signs/real", real_logits.sign()) |
|
|
|
loss_Dreal = 0 |
|
if do_Dmain: |
|
loss_Dreal = torch.nn.functional.softplus( |
|
-real_logits |
|
) |
|
training_stats.report("Loss/D/loss", loss_Dgen + loss_Dreal) |
|
|
|
loss_Dr1 = 0 |
|
if do_Dr1: |
|
with torch.autograd.profiler.record_function( |
|
"r1_grads" |
|
), conv2d_gradfix.no_weight_gradients(): |
|
r1_grads = torch.autograd.grad( |
|
outputs=[real_logits.sum()], |
|
inputs=[real_img_tmp], |
|
create_graph=True, |
|
only_inputs=True, |
|
)[0] |
|
r1_penalty = r1_grads.square().sum([1, 2, 3]) |
|
loss_Dr1 = r1_penalty * (self.r1_gamma / 2) |
|
training_stats.report("Loss/r1_penalty", r1_penalty) |
|
training_stats.report("Loss/D/reg", loss_Dr1) |
|
|
|
with torch.autograd.profiler.record_function(name + "_backward"): |
|
(real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() |
|
|
|
|
|
|
|
|