--- license: mit metrics: - mse library_name: diffusers tags: - diffusion --- Generate some cute (cursed) kitty cats! import torch from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler import safetensors model_path = "ddpm-cats/model_epoch_9.safetensors" state_dict = safetensors.torch.load_file(model_path) image_size = 128 model = UNet2DModel( sample_size=image_size, in_channels=3, out_channels=3, layers_per_block=2, block_out_channels=(128, 128, 256, 256, 512, 512), down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"), up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"), ) model.load_state_dict(state_dict) model.eval() scheduler = DDPMScheduler( num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule="linear", ) pipeline = DDPMPipeline( unet=model, scheduler=scheduler, ) image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png")