radames commited on
Commit
cac33a7
1 Parent(s): 2eb807c
Files changed (3) hide show
  1. README.md +2 -2
  2. app.py +65 -40
  3. requirements.txt +6 -5
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Unofficial SDXL Turbo Real Time Text To Image
3
  emoji: 🏆
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Real-Time Text-to-Image SDXL Lightning
3
  emoji: 🏆
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,6 +1,7 @@
1
- from diffusers import DiffusionPipeline
2
  import torch
3
  import os
 
4
 
5
  try:
6
  import intel_extension_for_pytorch as ipex
@@ -8,14 +9,26 @@ except:
8
  pass
9
 
10
  from PIL import Image
11
- import numpy as np
12
  import gradio as gr
13
- import psutil
14
  import time
 
15
 
16
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
17
- TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
18
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # check if MPS is available OSX only M1/M2/M3 chips
20
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
21
  xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
@@ -25,7 +38,6 @@ device = torch.device(
25
  torch_device = device
26
  torch_dtype = torch.float16
27
 
28
- print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
29
  print(f"TORCH_COMPILE: {TORCH_COMPILE}")
30
  print(f"device: {device}")
31
 
@@ -34,32 +46,30 @@ if mps_available:
34
  torch_device = "cpu"
35
  torch_dtype = torch.float32
36
 
37
- if SAFETY_CHECKER == "True":
38
- pipe = DiffusionPipeline.from_pretrained(
39
- "stabilityai/sdxl-turbo",
40
- torch_dtype=torch_dtype,
41
- variant="fp16" if torch_dtype == torch.float16 else "fp32")
42
- else:
43
- pipe = DiffusionPipeline.from_pretrained(
44
- "stabilityai/sdxl-turbo",
45
- safety_checker=None,
46
- torch_dtype=torch_dtype,
47
- variant="fp16" if torch_dtype == torch.float16 else "fp32",
48
- )
49
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  pipe.to(device=torch_device, dtype=torch_dtype).to(device)
52
- pipe.unet.to(memory_format=torch.channels_last)
53
  pipe.set_progress_bar_config(disable=True)
54
 
55
 
56
- def predict(prompt, steps, seed=1231231):
57
  generator = torch.manual_seed(seed)
58
  last_time = time.time()
59
  results = pipe(
60
  prompt=prompt,
61
  generator=generator,
62
- num_inference_steps=steps,
63
  guidance_scale=0.0,
64
  width=512,
65
  height=512,
@@ -108,37 +118,52 @@ with gr.Blocks(css=css) as demo:
108
 
109
  image = gr.Image(type="filepath")
110
  with gr.Accordion("Advanced options", open=False):
111
- steps = gr.Slider(label="Steps", value=2, minimum=1, maximum=10, step=1)
112
  seed = gr.Slider(
113
  randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
114
  )
115
  with gr.Accordion("Run with diffusers"):
116
  gr.Markdown(
117
  """## Running SDXL Turbo with `diffusers`
118
- ```bash
119
- pip install diffusers==0.23.1
120
- ```
121
  ```py
122
- from diffusers import DiffusionPipeline
123
-
124
- pipe = DiffusionPipeline.from_pretrained(
125
- "stabilityai/sdxl-turbo", variant="fp16", torch_dtype=torch.float16
126
- ).to("cuda")
127
- results = pipe(
128
- prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
129
- num_inference_steps=1,
130
- guidance_scale=0.0,
131
- )
132
- imga = results.images[0]
133
- imga.save("image.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  ```
135
  """
136
  )
137
 
138
- inputs = [prompt, steps, seed]
139
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
140
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
141
- steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
142
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
143
 
144
  demo.queue()
 
1
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
2
  import torch
3
  import os
4
+ from huggingface_hub import hf_hub_download
5
 
6
  try:
7
  import intel_extension_for_pytorch as ipex
 
9
  pass
10
 
11
  from PIL import Image
 
12
  import gradio as gr
 
13
  import time
14
+ from safetensors.torch import load_file
15
 
16
+
17
+ # Constants
18
+ BASE = "stabilityai/stable-diffusion-xl-base-1.0"
19
+ REPO = "ByteDance/SDXL-Lightning"
20
+ # 1-step
21
+ CHECKPOINT = "sdxl_lightning_1step_unet_x0.safetensors"
22
+
23
+ # {
24
+ # "1-Step": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
25
+ # "2-Step": ["sdxl_lightning_2step_unet.safetensors", 2],
26
+ # "4-Step": ["sdxl_lightning_4step_unet.safetensors", 4],
27
+ # "8-Step": ["sdxl_lightning_8step_unet.safetensors", 8],
28
+ # }
29
+
30
+
31
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", "0") == "1"
32
  # check if MPS is available OSX only M1/M2/M3 chips
33
  mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
34
  xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
 
38
  torch_device = device
39
  torch_dtype = torch.float16
40
 
 
41
  print(f"TORCH_COMPILE: {TORCH_COMPILE}")
42
  print(f"device: {device}")
43
 
 
46
  torch_device = "cpu"
47
  torch_dtype = torch.float32
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ pipe = StableDiffusionXLPipeline.from_pretrained(
51
+ BASE, torch_dtype=torch.float16, variant="fp16"
52
+ )
53
+
54
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
55
+ pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample"
56
+ )
57
+
58
+ pipe.unet.load_state_dict(
59
+ torch.load(load_file(hf_hub_download(REPO, CHECKPOINT)), map_location="cuda")
60
+ )
61
 
62
  pipe.to(device=torch_device, dtype=torch_dtype).to(device)
 
63
  pipe.set_progress_bar_config(disable=True)
64
 
65
 
66
+ def predict(prompt, seed=1231231):
67
  generator = torch.manual_seed(seed)
68
  last_time = time.time()
69
  results = pipe(
70
  prompt=prompt,
71
  generator=generator,
72
+ num_inference_steps=1,
73
  guidance_scale=0.0,
74
  width=512,
75
  height=512,
 
118
 
119
  image = gr.Image(type="filepath")
120
  with gr.Accordion("Advanced options", open=False):
 
121
  seed = gr.Slider(
122
  randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
123
  )
124
  with gr.Accordion("Run with diffusers"):
125
  gr.Markdown(
126
  """## Running SDXL Turbo with `diffusers`
 
 
 
127
  ```py
128
+ import torch
129
+ from diffusers import (
130
+ StableDiffusionXLPipeline,
131
+ UNet2DConditionModel,
132
+ EulerDiscreteScheduler,
133
+ )
134
+ from huggingface_hub import hf_hub_download
135
+ from safetensors.torch import load_file
136
+
137
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
138
+ repo = "ByteDance/SDXL-Lightning"
139
+ ckpt = "sdxl_lightning_1step_unet_x0.safetensors" # Use the correct ckpt for your step setting!
140
+
141
+ # Load model.
142
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(
143
+ "cuda", torch.float16
144
+ )
145
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
146
+ pipe = StableDiffusionXLPipeline.from_pretrained(
147
+ base, unet=unet, torch_dtype=torch.float16, variant="fp16"
148
+ ).to("cuda")
149
+
150
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type.
151
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
152
+ pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample"
153
+ )
154
+
155
+ # Ensure using the same inference steps as the loaded model and CFG set to 0.
156
+ pipe("A girl smiling", num_inference_steps=1, guidance_scale=0).images[0].save(
157
+ "output.png"
158
+ )
159
+
160
  ```
161
  """
162
  )
163
 
164
+ inputs = [prompt, seed]
165
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
166
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
 
167
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
168
 
169
  demo.queue()
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- diffusers==0.23.1
2
  transformers
3
- gradio==4.7.1
4
- --extra-index-url https://download.pytorch.org/whl/cu121
5
- torch==2.1.0
6
  fastapi==0.104.0
7
  uvicorn==0.23.2
8
  Pillow==10.1.0
@@ -11,4 +10,6 @@ compel==2.0.2
11
  controlnet-aux==0.0.7
12
  peft==0.6.0
13
  xformers
14
- hf_transfer
 
 
 
1
+ diffusers==0.26.3
2
  transformers
3
+ gradio==4.19.1
4
+ torch==2.2.0
 
5
  fastapi==0.104.0
6
  uvicorn==0.23.2
7
  Pillow==10.1.0
 
10
  controlnet-aux==0.0.7
11
  peft==0.6.0
12
  xformers
13
+ hf_transfer
14
+ huggingface_hub
15
+ safetensors