cats / README.md
nroggendorff's picture
Update README.md
402b78c verified
metadata
license: mit
metrics:
  - mse
library_name: diffusers
tags:
  - diffusion

Generate some cute (cursed) kitty cats!

import torch from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler import safetensors

model_path = "ddpm-cats/model_epoch_9.safetensors" state_dict = safetensors.torch.load_file(model_path) image_size = 128

model = UNet2DModel( sample_size=image_size, in_channels=3, out_channels=3, layers_per_block=2, block_out_channels=(128, 128, 256, 256, 512, 512), down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"), up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), ) model.load_state_dict(state_dict) model.eval()

scheduler = DDPMScheduler( num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", )

pipeline = DDPMPipeline( unet=model, scheduler=scheduler, )

image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png")