File size: 1,077 Bytes
c8cc76d 1678f21 c8cc76d 402b78c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
---
license: mit
metrics:
- mse
library_name: diffusers
tags:
- diffusion
---
Generate some cute (cursed) kitty cats!
import torch
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler
import safetensors
model_path = "ddpm-cats/model_epoch_9.safetensors"
state_dict = safetensors.torch.load_file(model_path)
image_size = 128
model = UNet2DModel(
sample_size=image_size,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)
model.load_state_dict(state_dict)
model.eval()
scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
)
pipeline = DDPMPipeline(
unet=model,
scheduler=scheduler,
)
image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png") |