nroggendorff commited on
Commit
402b78c
1 Parent(s): 8dc4fcc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -1
README.md CHANGED
@@ -6,4 +6,38 @@ library_name: diffusers
6
  tags:
7
  - diffusion
8
  ---
9
- Generate some cute kitty cats!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  tags:
7
  - diffusion
8
  ---
9
+ Generate some cute (cursed) kitty cats!
10
+
11
+ import torch
12
+ from diffusers import UNet2DModel, DDPMPipeline, DDPMScheduler
13
+ import safetensors
14
+
15
+ model_path = "ddpm-cats/model_epoch_9.safetensors"
16
+ state_dict = safetensors.torch.load_file(model_path)
17
+ image_size = 128
18
+
19
+ model = UNet2DModel(
20
+ sample_size=image_size,
21
+ in_channels=3,
22
+ out_channels=3,
23
+ layers_per_block=2,
24
+ block_out_channels=(128, 128, 256, 256, 512, 512),
25
+ down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
26
+ up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"),
27
+ )
28
+ model.load_state_dict(state_dict)
29
+ model.eval()
30
+
31
+ scheduler = DDPMScheduler(
32
+ num_train_timesteps=1000,
33
+ beta_start=0.0001,
34
+ beta_end=0.02,
35
+ beta_schedule="linear",
36
+ )
37
+
38
+ pipeline = DDPMPipeline(
39
+ unet=model,
40
+ scheduler=scheduler,
41
+ )
42
+
43
+ image = pipeline(batch_size=1, generator=torch.manual_seed(0)).images[0].save("generated_image_128.png")