File size: 1,077 Bytes
c8cc76d
 
1678f21
 
 
 
 
c8cc76d
402b78c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
---
license: mit
metrics:
- mse
library_name: diffusers
tags:
- diffusion
---
Generate some cute (cursed) kitty cats!

import torch
from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler
import safetensors

model_path = "ddpm-cats/model_epoch_9.safetensors"
state_dict = safetensors.torch.load_file(model_path)
image_size = 128

model = UNet2DModel(
    sample_size=image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
)
model.load_state_dict(state_dict)
model.eval()

scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear",
)

pipeline = DDPMPipeline(
    unet=model,
    scheduler=scheduler,
)

image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png")