PandA / networks /genforce /configs /stylegan_ffhq256_encoder_y.py
james-oldfield's picture
Upload 194 files
2a76164
raw
history blame contribute delete
No virus
2.58 kB
# python3.7
"""Configuration for training StyleGAN Encoder on FF-HQ (256) dataset.
All settings are particularly used for one replica (GPU), such as `batch_size`
and `num_workers`.
"""
gan_model_path = 'checkpoints/stylegan_ffhq256.pth'
perceptual_model_path = 'checkpoints/vgg16.pth'
runner_type = 'EncoderRunner'
gan_type = 'stylegan'
resolution = 256
batch_size = 12
val_batch_size = 25
total_img = 14000_000
space_of_latent = 'y'
# Training dataset is repeated at the beginning to avoid loading dataset
# repeatedly at the end of each epoch. This can save some I/O time.
data = dict(
num_workers=4,
repeat=500,
# train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
# val=dict(root_dir='data/ffhq', resolution=resolution),
train=dict(root_dir='data/', data_format='list',
image_list_path='data/ffhq/ffhq_train_list.txt',
resolution=resolution, mirror=0.5),
val=dict(root_dir='data/', data_format='list',
image_list_path='./data/ffhq/ffhq_val_list.txt',
resolution=resolution),
)
controllers = dict(
RunningLogger=dict(every_n_iters=50),
Snapshoter=dict(every_n_iters=10000, first_iter=True, num=200),
Checkpointer=dict(every_n_iters=10000, first_iter=False),
)
modules = dict(
discriminator=dict(
model=dict(gan_type=gan_type, resolution=resolution),
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
kwargs_train=dict(),
kwargs_val=dict(),
),
generator=dict(
model=dict(gan_type=gan_type, resolution=resolution, repeat_w=True),
kwargs_val=dict(randomize_noise=False),
),
encoder=dict(
model=dict(gan_type=gan_type, resolution=resolution, network_depth=18,
latent_dim = [1024] * 8 + [512, 512, 256, 256, 128, 128],
num_latents_per_head=[4, 4, 6],
use_fpn=True,
fpn_channels=512,
use_sam=True,
sam_channels=512),
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
kwargs_train=dict(),
kwargs_val=dict(),
),
)
loss = dict(
type='EncoderLoss',
d_loss_kwargs=dict(r1_gamma=10.0),
e_loss_kwargs=dict(adv_lw=0.08, perceptual_lw=5e-5),
perceptual_kwargs=dict(output_layer_idx=23,
pretrained_weight_path=perceptual_model_path),
)