radames commited on
Commit
b14c223
1 Parent(s): 248bc06

all together

Browse files
Files changed (1) hide show
  1. app.py +109 -47
app.py CHANGED
@@ -1,33 +1,39 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionXLPipeline
4
  from diffusers.schedulers import TCDScheduler
 
5
  import spaces
6
  from PIL import Image
7
 
8
  SAFETY_CHECKER = True
9
 
10
- # Constants
11
- base = "stabilityai/stable-diffusion-xl-base-1.0"
12
- repo = "ByteDance/SDXL-Lightning"
13
  checkpoints = {
14
- "2-Step": ["pcm_sdxl_smallcfg_2step_converted.safetensors", 2, 0.0],
15
- "4-Step": ["pcm_sdxl_smallcfg_4step_converted.safetensors", 4, 0.0],
16
- "8-Step": ["pcm_sdxl_smallcfg_8step_converted.safetensors", 8, 0.0],
17
- "16-Step": ["pcm_sdxl_smallcfg_16step_converted.safetensors", 16, 0.0],
18
- "Normal CFG 4-Step": ["pcm_sdxl_normalcfg_4step_converted.safetensors", 4, 7.5],
19
- "Normal CFG 8-Step": ["pcm_sdxl_normalcfg_8step_converted.safetensors", 8, 7.5],
20
- "Normal CFG 16-Step": ["pcm_sdxl_normalcfg_16step_converted.safetensors", 16, 7.5],
21
- "LCM-Like LoRA": ["pcm_sdxl_lcmlike_lora_converted.safetensors", 16, 0.0],
 
 
 
 
22
  }
23
 
24
 
25
  loaded = None
26
 
27
- # Ensure model and scheduler are initialized in GPU-enabled function
28
  if torch.cuda.is_available():
29
- pipe = StableDiffusionXLPipeline.from_pretrained(
30
- base, torch_dtype=torch.float16, variant="fp16"
 
 
 
 
 
31
  ).to("cuda")
32
 
33
  if SAFETY_CHECKER:
@@ -52,29 +58,35 @@ if SAFETY_CHECKER:
52
  return images, has_nsfw_concepts
53
 
54
 
55
- # Function
56
  @spaces.GPU(enable_queue=True)
57
- def generate_image(prompt, ckpt):
 
 
 
 
 
 
58
  global loaded
59
- print(prompt, ckpt)
60
-
61
- checkpoint = checkpoints[ckpt][0]
62
- num_inference_steps = checkpoints[ckpt][1]
63
  guidance_scale = checkpoints[ckpt][2]
 
64
 
65
- if loaded != num_inference_steps:
66
- pipe.scheduler = TCDScheduler(
67
- num_train_timesteps=1000,
68
- beta_start=0.00085,
69
- beta_end=0.012,
70
- beta_schedule="scaled_linear",
71
- timestep_spacing="trailing",
72
- )
73
  pipe.load_lora_weights(
74
- "wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder="sdxl"
75
  )
76
-
77
- loaded = num_inference_steps
 
 
 
 
 
 
 
 
 
 
78
 
79
  results = pipe(
80
  prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
@@ -89,7 +101,12 @@ def generate_image(prompt, ckpt):
89
  return results.images[0]
90
 
91
 
92
- # Gradio Interface
 
 
 
 
 
93
 
94
  css = """
95
  .gradio-container {
@@ -97,31 +114,76 @@ css = """
97
  }
98
  """
99
  with gr.Blocks(css=css) as demo:
100
- gr.HTML("<h1><center>SDXL-Lightning ⚡</center></h1>")
101
- gr.HTML(
102
- "<p><center>Lightning-fast text-to-image generation</center></p><p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
 
 
103
  )
104
  with gr.Group():
105
  with gr.Row():
106
- prompt = gr.Textbox(label="Enter your prompt (English)", scale=8)
107
  ckpt = gr.Dropdown(
108
  label="Select inference steps",
109
  choices=list(checkpoints.keys()),
110
  value="4-Step",
111
- interactive=True,
112
  )
113
- submit = gr.Button(scale=1, variant="primary")
114
- img = gr.Image(label="SDXL-Lightning Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- prompt.submit(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  fn=generate_image,
118
- inputs=[prompt, ckpt],
119
- outputs=img,
120
  )
121
- submit.click(
 
122
  fn=generate_image,
123
- inputs=[prompt, ckpt],
124
- outputs=img,
 
 
 
 
 
 
 
125
  )
126
 
127
- demo.queue().launch()
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, LCMScheduler
4
  from diffusers.schedulers import TCDScheduler
5
+
6
  import spaces
7
  from PIL import Image
8
 
9
  SAFETY_CHECKER = True
10
 
 
 
 
11
  checkpoints = {
12
+ "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
13
+ "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
14
+ "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
15
+ "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
16
+ "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
17
+ "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
18
+ "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
19
+ "LCM-Like LoRA": [
20
+ "pcm_{}_lcmlike_lora_converted.safetensors",
21
+ 4,
22
+ 0.0,
23
+ ],
24
  }
25
 
26
 
27
  loaded = None
28
 
 
29
  if torch.cuda.is_available():
30
+ pipe_sdxl = StableDiffusionXLPipeline.from_pretrained(
31
+ "stabilityai/stable-diffusion-xl-base-1.0",
32
+ torch_dtype=torch.float16,
33
+ variant="fp16",
34
+ ).to("cuda")
35
+ pipe_sd15 = StableDiffusionPipeline.from_pretrained(
36
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
37
  ).to("cuda")
38
 
39
  if SAFETY_CHECKER:
 
58
  return images, has_nsfw_concepts
59
 
60
 
 
61
  @spaces.GPU(enable_queue=True)
62
+ def generate_image(
63
+ prompt,
64
+ ckpt,
65
+ num_inference_steps,
66
+ progress=gr.Progress(track_tqdm=True),
67
+ mode="sdxl",
68
+ ):
69
  global loaded
70
+ checkpoint = checkpoints[ckpt][0].format(mode)
 
 
 
71
  guidance_scale = checkpoints[ckpt][2]
72
+ pipe = pipe_sdxl if mode == "sdxl" else pipe_sd15
73
 
74
+ if loaded != (ckpt + mode):
 
 
 
 
 
 
 
75
  pipe.load_lora_weights(
76
+ "wangfuyun/PCM_Weights", weight_name=checkpoint, subfolder=mode
77
  )
78
+ loaded = ckpt + mode
79
+
80
+ if ckpt == "LCM-Like LoRA":
81
+ pipe.scheduler = LCMScheduler()
82
+ else:
83
+ pipe.scheduler = TCDScheduler(
84
+ num_train_timesteps=1000,
85
+ beta_start=0.00085,
86
+ beta_end=0.012,
87
+ beta_schedule="scaled_linear",
88
+ timestep_spacing="trailing",
89
+ )
90
 
91
  results = pipe(
92
  prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
 
101
  return results.images[0]
102
 
103
 
104
+ def update_steps(ckpt):
105
+ num_inference_steps = checkpoints[ckpt][1]
106
+ if ckpt == "LCM-Like LoRA":
107
+ return gr.update(interactive=True, value=num_inference_steps)
108
+ return gr.update(interactive=False, value=num_inference_steps)
109
+
110
 
111
  css = """
112
  .gradio-container {
 
114
  }
115
  """
116
  with gr.Blocks(css=css) as demo:
117
+ gr.Markdown(
118
+ """
119
+ # Phased Consistency Model
120
+ [[paper](https://huggingface.co/papers/2405.18407)] [[arXiv](https://arxiv.org/abs/2405.18407)] [[code](https://github.com/G-U-N/Phased-Consistency-Model)] [[project page](https://g-u-n.github.io/projects/pcm)]
121
+ """
122
  )
123
  with gr.Group():
124
  with gr.Row():
125
+ prompt = gr.Textbox(label="Prompt", scale=8)
126
  ckpt = gr.Dropdown(
127
  label="Select inference steps",
128
  choices=list(checkpoints.keys()),
129
  value="4-Step",
 
130
  )
131
+ steps = gr.Slider(
132
+ label="Number of Inference Steps",
133
+ minimum=1,
134
+ maximum=20,
135
+ step=1,
136
+ value=4,
137
+ interactive=False,
138
+ )
139
+ ckpt.change(
140
+ fn=update_steps,
141
+ inputs=[ckpt],
142
+ outputs=[steps],
143
+ queue=False,
144
+ show_progress=False,
145
+ )
146
 
147
+ submit_sdxl = gr.Button("Run on SDXL", scale=1)
148
+ submit_sd15 = gr.Button("Run on SD15", scale=1)
149
+
150
+ img = gr.Image(label="PCM Image")
151
+ gr.Examples(
152
+ examples=[
153
+ [
154
+ "Echoes of a forgotten song drift across the moonlit sea, where a ghost ship sails, its spectral crew bound to an eternal quest for redemption.",
155
+ "4-Step",
156
+ 4,
157
+ ],
158
+ [
159
+ "Roger rabbit as a real person, photorealistic, cinematic.",
160
+ "16-Step",
161
+ 16,
162
+ ],
163
+ [
164
+ "tanding tall amidst the ruins, a stone golem awakens, vines and flowers sprouting from the crevices in its body.",
165
+ "LCM-Like LoRA",
166
+ 4,
167
+ ],
168
+ ],
169
+ inputs=[prompt, ckpt, steps],
170
+ outputs=[img],
171
  fn=generate_image,
172
+ cache_examples="lazy",
 
173
  )
174
+
175
+ gr.on(
176
  fn=generate_image,
177
+ triggers=[prompt.submit, submit_sdxl.click],
178
+ inputs=[prompt, ckpt, steps],
179
+ outputs=[img],
180
+ )
181
+ gr.on(
182
+ fn=lambda *args: generate_image(*args, mode="sd15"),
183
+ triggers=[submit_sd15.click],
184
+ inputs=[prompt, ckpt, steps],
185
+ outputs=[img],
186
  )
187
 
188
+
189
+ demo.queue(api_open=False).launch(show_api=False)