| """Model definitions for the VAE and UNet components.""" | |
| from diffusers import UNet2DConditionModel, AutoencoderKL | |
| from config import cfg, device | |
| vae = AutoencoderKL( | |
| in_channels=1, | |
| out_channels=1, | |
| latent_channels=cfg.latent_channels, | |
| sample_size=32, | |
| block_out_channels=(16, 32, 64), | |
| norm_num_groups=4, | |
| down_block_types=( | |
| "DownEncoderBlock2D", | |
| "DownEncoderBlock2D", | |
| "DownEncoderBlock2D", | |
| ), | |
| up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), | |
| ).to(device) | |
| unet = UNet2DConditionModel( | |
| in_channels=cfg.latent_channels, | |
| out_channels=cfg.latent_channels, | |
| sample_size=8, | |
| layers_per_block=2, | |
| down_block_types=("AttnDownBlock2D", "AttnDownBlock2D"), | |
| up_block_types=("AttnUpBlock2D", "UpBlock2D"), | |
| block_out_channels=(128, 256), | |
| norm_num_groups=1, | |
| num_class_embeds=10, | |
| time_embedding_act_fn="silu", | |
| cross_attention_dim=256, | |
| class_embeddings_concat=True, | |
| ).to(device) | |