Spaces:
Runtime error
Runtime error
from tops.config import LazyCall as L | |
from dp2.discriminator import SG2Discriminator | |
import torch | |
from dp2.loss import StyleGAN2Loss | |
discriminator = L(SG2Discriminator)( | |
imsize="${data.imsize}", | |
im_channels="${data.im_channels}", | |
min_fmap_resolution=4, | |
max_cnum_mul=8, | |
cnum=80, | |
input_condition=True, | |
conv_clamp=256, | |
input_cse=False, | |
cse_nc="${data.cse_nc}", | |
fix_residual=False, | |
) | |
loss_fnc = L(StyleGAN2Loss)( | |
lazy_regularization=True, | |
lazy_reg_interval=16, | |
r1_opts=dict(lambd=5, mask_out=False, mask_out_scale=False), | |
EP_lambd=0.001, | |
pl_reg_opts=dict(weight=0, batch_shrink=2,start_nimg=int(1e6), pl_decay=0.01) | |
) | |
def build_D_optim(type, lr, betas, lazy_regularization, lazy_reg_interval, **kwargs): | |
if lazy_regularization: | |
# From Analyzing and improving the image quality of stylegan, CVPR 2020 | |
c = lazy_reg_interval / (lazy_reg_interval + 1) | |
betas = [beta ** c for beta in betas] | |
lr *= c | |
print(f"Lazy regularization on. Setting lr to: {lr}, betas to: {betas}") | |
return type(lr=lr, betas=betas, **kwargs) | |
D_optim = L(build_D_optim)( | |
type=torch.optim.Adam, lr=0.001, betas=(0.0, 0.99), | |
lazy_regularization="${loss_fnc.lazy_regularization}", | |
lazy_reg_interval="${loss_fnc.lazy_reg_interval}") | |
G_optim = L(torch.optim.Adam)(lr=0.001, betas=(0.0, 0.99)) | |