PeterL1n commited on
Commit
8a94aee
1 Parent(s): 1b39ebb

Update readme

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -27,14 +27,15 @@ Please always use the correct checkpoint for the corresponding inference steps.
27
  import torch
28
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
29
  from huggingface_hub import hf_hub_download
 
30
 
31
  base = "stabilityai/stable-diffusion-xl-base-1.0"
32
  repo = "ByteDance/SDXL-Lightning"
33
- ckpt = "sdxl_lightning_4step_unet.pth" # Use the correct ckpt for your step setting!
34
 
35
  # Load model.
36
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
37
- unet.load_state_dict(torch.load(hf_hub_download(repo, ckpt), map_location="cuda"))
38
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
39
 
40
  # Ensure sampler uses "trailing" timesteps.
@@ -53,7 +54,7 @@ from huggingface_hub import hf_hub_download
53
 
54
  base = "stabilityai/stable-diffusion-xl-base-1.0"
55
  repo = "ByteDance/SDXL-Lightning"
56
- ckpt = "sdxl_lightning_4step_lora.pth" # Use the correct ckpt for your step setting!
57
 
58
  # Load model.
59
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
@@ -75,14 +76,15 @@ The 1-step model uses "sample" prediction instead of "epsilon" prediction! The s
75
  import torch
76
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
77
  from huggingface_hub import hf_hub_download
 
78
 
79
  base = "stabilityai/stable-diffusion-xl-base-1.0"
80
  repo = "ByteDance/SDXL-Lightning"
81
- ckpt = "sdxl_lightning_1step_unet_x0.pth" # Use the correct ckpt for your step setting!
82
 
83
  # Load model.
84
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
85
- unet.load_state_dict(torch.load(hf_hub_download(repo, ckpt), map_location="cuda"))
86
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
87
 
88
  # Ensure sampler uses "trailing" timesteps and "sample" prediction type.
 
27
  import torch
28
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
29
  from huggingface_hub import hf_hub_download
30
+ from safetensors.torch import load_file
31
 
32
  base = "stabilityai/stable-diffusion-xl-base-1.0"
33
  repo = "ByteDance/SDXL-Lightning"
34
+ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
35
 
36
  # Load model.
37
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
38
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
39
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
 
41
  # Ensure sampler uses "trailing" timesteps.
 
54
 
55
  base = "stabilityai/stable-diffusion-xl-base-1.0"
56
  repo = "ByteDance/SDXL-Lightning"
57
+ ckpt = "sdxl_lightning_4step_lora.safetensors" # Use the correct ckpt for your step setting!
58
 
59
  # Load model.
60
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
76
  import torch
77
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
78
  from huggingface_hub import hf_hub_download
79
+ from safetensors.torch import load_file
80
 
81
  base = "stabilityai/stable-diffusion-xl-base-1.0"
82
  repo = "ByteDance/SDXL-Lightning"
83
+ ckpt = "sdxl_lightning_1step_unet_x0.safetensors" # Use the correct ckpt for your step setting!
84
 
85
  # Load model.
86
  unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
87
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
88
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
89
 
90
  # Ensure sampler uses "trailing" timesteps and "sample" prediction type.