|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
from src import (ContentEncoder, |
|
StyleEncoder, |
|
UNet) |
|
|
|
|
|
def build_unet(args): |
|
unet = UNet( |
|
sample_size=args.resolution, |
|
in_channels=3, |
|
out_channels=3, |
|
flip_sin_to_cos=True, |
|
freq_shift=0, |
|
down_block_types=('DownBlock2D', |
|
'MCADownBlock2D', |
|
'MCADownBlock2D', |
|
'DownBlock2D'), |
|
up_block_types=('UpBlock2D', |
|
'StyleRSIUpBlock2D', |
|
'StyleRSIUpBlock2D', |
|
'UpBlock2D'), |
|
block_out_channels=args.unet_channels, |
|
layers_per_block=2, |
|
downsample_padding=1, |
|
mid_block_scale_factor=1, |
|
act_fn='silu', |
|
norm_num_groups=32, |
|
norm_eps=1e-05, |
|
cross_attention_dim=args.style_start_channel * 16, |
|
attention_head_dim=1, |
|
channel_attn=args.channel_attn, |
|
content_encoder_downsample_size=args.content_encoder_downsample_size, |
|
content_start_channel=args.content_start_channel, |
|
reduction=32) |
|
|
|
return unet |
|
|
|
|
|
def build_style_encoder(args): |
|
style_image_encoder = StyleEncoder( |
|
G_ch=args.style_start_channel, |
|
resolution=args.style_image_size[0]) |
|
print("Get CG-GAN Style Encoder!") |
|
return style_image_encoder |
|
|
|
|
|
def build_content_encoder(args): |
|
content_image_encoder = ContentEncoder( |
|
G_ch=args.content_start_channel, |
|
resolution=args.content_image_size[0]) |
|
print("Get CG-GAN Content Encoder!") |
|
return content_image_encoder |
|
|
|
|
|
def build_ddpm_scheduler(args): |
|
ddpm_scheduler = DDPMScheduler( |
|
num_train_timesteps=1000, |
|
beta_start=0.0001, |
|
beta_end=0.02, |
|
beta_schedule=args.beta_scheduler, |
|
trained_betas=None, |
|
variance_type="fixed_small", |
|
clip_sample=True) |
|
return ddpm_scheduler |