|
--- |
|
license: mit |
|
metrics: |
|
- mse |
|
library_name: diffusers |
|
tags: |
|
- diffusion |
|
--- |
|
Generate some cute (cursed) kitty cats! |
|
|
|
import torch |
|
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler |
|
import safetensors |
|
|
|
model_path = "ddpm-cats/model_epoch_9.safetensors" |
|
state_dict = safetensors.torch.load_file(model_path) |
|
image_size = 128 |
|
|
|
model = UNet2DModel( |
|
sample_size=image_size, |
|
in_channels=3, |
|
out_channels=3, |
|
layers_per_block=2, |
|
block_out_channels=(128, 128, 256, 256, 512, 512), |
|
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"), |
|
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), |
|
) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
scheduler = DDPMScheduler( |
|
num_train_timesteps=1000, |
|
beta_start=0.0001, |
|
beta_end=0.02, |
|
beta_schedule="linear", |
|
) |
|
|
|
pipeline = DDPMPipeline( |
|
unet=model, |
|
scheduler=scheduler, |
|
) |
|
|
|
image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png") |