codinglabsong's picture
Update models.py
e385be8 verified
raw
history blame
998 Bytes
"""Model definitions for the VAE and UNet components."""
from diffusers import UNet2DConditionModel, AutoencoderKL
from config import cfg, device
vae = AutoencoderKL(
in_channels=1,
out_channels=1,
latent_channels=cfg.latent_channels,
sample_size=32,
block_out_channels=(16, 32, 64),
norm_num_groups=4,
down_block_types=(
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
).to(device)
unet = UNet2DConditionModel(
in_channels=cfg.latent_channels,
out_channels=cfg.latent_channels,
sample_size=8,
layers_per_block=2,
down_block_types=("AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(128, 256),
norm_num_groups=1,
num_class_embeds=10,
time_embedding_act_fn="silu",
cross_attention_dim=256,
class_embeddings_concat=True,
).to(device)