cats / README.md
nroggendorff's picture
Update README.md
402b78c verified
|
raw
history blame
No virus
1.08 kB
---
license: mit
metrics:
- mse
library_name: diffusers
tags:
- diffusion
---
Generate some cute (cursed) kitty cats!
import torch
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler
import safetensors
model_path = "ddpm-cats/model_epoch_9.safetensors"
state_dict = safetensors.torch.load_file(model_path)
image_size = 128
model = UNet2DModel(
sample_size=image_size,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)
model.load_state_dict(state_dict)
model.eval()
scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
)
pipeline = DDPMPipeline(
unet=model,
scheduler=scheduler,
)
image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png")