radames commited on
Commit
36c93fe
1 Parent(s): 23b3095

sdxl lightning

Browse files
server/pipelines/controlnetLoraSDXL-Lightning.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ UNet2DConditionModel,
3
+ StableDiffusionXLControlNetImg2ImgPipeline,
4
+ ControlNetModel,
5
+ AutoencoderKL,
6
+ AutoencoderTiny,
7
+ EulerDiscreteScheduler,
8
+ )
9
+ from compel import Compel, ReturnedEmbeddingsType
10
+ import torch
11
+ from pipelines.utils.canny_gpu import SobelOperator
12
+ from huggingface_hub import hf_hub_download
13
+ from safetensors.torch import load_file
14
+
15
+ try:
16
+ import intel_extension_for_pytorch as ipex # type: ignore
17
+ except:
18
+ pass
19
+
20
+ import psutil
21
+ from config import Args
22
+ from pydantic import BaseModel, Field
23
+ from PIL import Image
24
+ import math
25
+
26
+ controlnet_model = "diffusers/controlnet-canny-sdxl-1.0-small"
27
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
28
+ repo = "ByteDance/SDXL-Lightning"
29
+ ckpt = "sdxl_lightning_2step_unet.safetensors"
30
+ taesd_model = "madebyollin/taesdxl"
31
+ NUM_STEPS = 2
32
+
33
+
34
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
35
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
36
+ page_content = """
37
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model SDXL</h1>
38
+ <h3 class="text-xl font-bold">SDXL-Lightining + LCM + LoRA + Controlnet</h3>
39
+ <p class="text-sm">
40
+ This demo showcases
41
+ <a
42
+ href="https://huggingface.co/blog/lcm_lora"
43
+ target="_blank"
44
+ class="text-blue-500 underline hover:no-underline">LCM LoRA</a>
45
+ + SDXL + Controlnet + Image to Image pipeline using
46
+ <a
47
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/lcm#performing-inference-with-lcm"
48
+ target="_blank"
49
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
50
+ > with a MJPEG stream server.
51
+ </p>
52
+ <p class="text-sm text-gray-500">
53
+ Change the prompt to generate different images, accepts <a
54
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
55
+ target="_blank"
56
+ class="text-blue-500 underline hover:no-underline">Compel</a
57
+ > syntax.
58
+ </p>
59
+ """
60
+
61
+
62
+ class Pipeline:
63
+ class Info(BaseModel):
64
+ name: str = "controlnet+loras+sdxl+lightning"
65
+ title: str = "SDXL + LCM + LoRA + Controlnet"
66
+ description: str = "Generates an image from a text prompt"
67
+ input_mode: str = "image"
68
+ page_content: str = page_content
69
+
70
+ class InputParams(BaseModel):
71
+ prompt: str = Field(
72
+ default_prompt,
73
+ title="Prompt",
74
+ field="textarea",
75
+ id="prompt",
76
+ )
77
+ negative_prompt: str = Field(
78
+ default_negative_prompt,
79
+ title="Negative Prompt",
80
+ field="textarea",
81
+ id="negative_prompt",
82
+ hide=True,
83
+ )
84
+ seed: int = Field(
85
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
86
+ )
87
+ steps: int = Field(
88
+ 1, min=1, max=10, title="Steps", field="range", hide=True, id="steps"
89
+ )
90
+ width: int = Field(
91
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
92
+ )
93
+ height: int = Field(
94
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
95
+ )
96
+ guidance_scale: float = Field(
97
+ 0.0,
98
+ min=0,
99
+ max=2.0,
100
+ step=0.001,
101
+ title="Guidance Scale",
102
+ field="range",
103
+ hide=True,
104
+ id="guidance_scale",
105
+ )
106
+ strength: float = Field(
107
+ 1,
108
+ min=0.25,
109
+ max=1.0,
110
+ step=0.0001,
111
+ title="Strength",
112
+ field="range",
113
+ hide=True,
114
+ id="strength",
115
+ )
116
+ controlnet_scale: float = Field(
117
+ 0.5,
118
+ min=0,
119
+ max=1.0,
120
+ step=0.001,
121
+ title="Controlnet Scale",
122
+ field="range",
123
+ hide=True,
124
+ id="controlnet_scale",
125
+ )
126
+ controlnet_start: float = Field(
127
+ 0.0,
128
+ min=0,
129
+ max=1.0,
130
+ step=0.001,
131
+ title="Controlnet Start",
132
+ field="range",
133
+ hide=True,
134
+ id="controlnet_start",
135
+ )
136
+ controlnet_end: float = Field(
137
+ 1.0,
138
+ min=0,
139
+ max=1.0,
140
+ step=0.001,
141
+ title="Controlnet End",
142
+ field="range",
143
+ hide=True,
144
+ id="controlnet_end",
145
+ )
146
+ canny_low_threshold: float = Field(
147
+ 0.31,
148
+ min=0,
149
+ max=1.0,
150
+ step=0.001,
151
+ title="Canny Low Threshold",
152
+ field="range",
153
+ hide=True,
154
+ id="canny_low_threshold",
155
+ )
156
+ canny_high_threshold: float = Field(
157
+ 0.125,
158
+ min=0,
159
+ max=1.0,
160
+ step=0.001,
161
+ title="Canny High Threshold",
162
+ field="range",
163
+ hide=True,
164
+ id="canny_high_threshold",
165
+ )
166
+ debug_canny: bool = Field(
167
+ False,
168
+ title="Debug Canny",
169
+ field="checkbox",
170
+ hide=True,
171
+ id="debug_canny",
172
+ )
173
+
174
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
175
+
176
+ if args.taesd:
177
+ vae = AutoencoderTiny.from_pretrained(
178
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
179
+ )
180
+ else:
181
+ vae = AutoencoderKL.from_pretrained(
182
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
183
+ )
184
+
185
+ controlnet_canny = ControlNetModel.from_pretrained(
186
+ controlnet_model, torch_dtype=torch_dtype
187
+ )
188
+
189
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet")
190
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device.type))
191
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
192
+ base,
193
+ unet=unet,
194
+ torch_dtype=torch_dtype,
195
+ variant="fp16",
196
+ controlnet=controlnet_canny,
197
+ vae=vae,
198
+ )
199
+
200
+ # Ensure sampler uses "trailing" timesteps.
201
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
202
+ self.pipe.scheduler.config, timestep_spacing="trailing"
203
+ )
204
+
205
+ self.canny_torch = SobelOperator(device=device)
206
+ self.pipe.set_progress_bar_config(disable=True)
207
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
208
+
209
+ if args.sfast:
210
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
211
+ compile,
212
+ CompilationConfig,
213
+ )
214
+
215
+ config = CompilationConfig.Default()
216
+ config.enable_xformers = True
217
+ config.enable_triton = True
218
+ config.enable_cuda_graph = True
219
+ self.pipe = compile(self.pipe, config=config)
220
+
221
+ if device.type != "mps":
222
+ self.pipe.unet.to(memory_format=torch.channels_last)
223
+
224
+ if args.compel:
225
+ self.pipe.compel_proc = Compel(
226
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
227
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
228
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
229
+ requires_pooled=[False, True],
230
+ )
231
+
232
+ if args.torch_compile:
233
+ self.pipe.unet = torch.compile(
234
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
235
+ )
236
+ self.pipe.vae = torch.compile(
237
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
238
+ )
239
+ self.pipe(
240
+ prompt="warmup",
241
+ image=[Image.new("RGB", (768, 768))],
242
+ control_image=[Image.new("RGB", (768, 768))],
243
+ )
244
+
245
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
246
+ generator = torch.manual_seed(params.seed)
247
+
248
+ prompt = params.prompt
249
+ negative_prompt = params.negative_prompt
250
+ prompt_embeds = None
251
+ pooled_prompt_embeds = None
252
+ negative_prompt_embeds = None
253
+ negative_pooled_prompt_embeds = None
254
+ if hasattr(self.pipe, "compel_proc"):
255
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
256
+ [params.prompt, params.negative_prompt]
257
+ )
258
+ prompt = None
259
+ negative_prompt = None
260
+ prompt_embeds = _prompt_embeds[0:1]
261
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
262
+ negative_prompt_embeds = _prompt_embeds[1:2]
263
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
264
+
265
+ control_image = self.canny_torch(
266
+ params.image, params.canny_low_threshold, params.canny_high_threshold
267
+ )
268
+ steps = params.steps
269
+ strength = params.strength
270
+ if int(steps * strength) < 1:
271
+ steps = math.ceil(1 / max(0.10, strength))
272
+
273
+ results = self.pipe(
274
+ image=params.image,
275
+ control_image=control_image,
276
+ prompt=prompt,
277
+ negative_prompt=negative_prompt,
278
+ prompt_embeds=prompt_embeds,
279
+ pooled_prompt_embeds=pooled_prompt_embeds,
280
+ negative_prompt_embeds=negative_prompt_embeds,
281
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
282
+ generator=generator,
283
+ strength=strength,
284
+ num_inference_steps=steps,
285
+ guidance_scale=params.guidance_scale,
286
+ width=params.width,
287
+ height=params.height,
288
+ output_type="pil",
289
+ controlnet_conditioning_scale=params.controlnet_scale,
290
+ control_guidance_start=params.controlnet_start,
291
+ control_guidance_end=params.controlnet_end,
292
+ )
293
+
294
+ nsfw_content_detected = (
295
+ results.nsfw_content_detected[0]
296
+ if "nsfw_content_detected" in results
297
+ else False
298
+ )
299
+ if nsfw_content_detected:
300
+ return None
301
+ result_image = results.images[0]
302
+ if params.debug_canny:
303
+ # paste control_image on top of result_image
304
+ w0, h0 = (200, 200)
305
+ control_image = control_image.resize((w0, h0))
306
+ w1, h1 = result_image.size
307
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
308
+
309
+ return result_image
server/pipelines/img2imgSDXL-Lightning.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ AutoPipelineForImage2Image,
3
+ AutoencoderTiny,
4
+ AutoencoderKL,
5
+ UNet2DConditionModel,
6
+ EulerDiscreteScheduler,
7
+ )
8
+ from compel import Compel, ReturnedEmbeddingsType
9
+ import torch
10
+
11
+ try:
12
+ import intel_extension_for_pytorch as ipex # type: ignore
13
+ except:
14
+ pass
15
+
16
+ from safetensors.torch import load_file
17
+ from huggingface_hub import hf_hub_download
18
+ from config import Args
19
+ from pydantic import BaseModel, Field
20
+ from PIL import Image
21
+ import math
22
+
23
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
24
+ repo = "ByteDance/SDXL-Lightning"
25
+ ckpt = "sdxl_lightning_2step_unet.safetensors"
26
+ taesd_model = "madebyollin/taesdxl"
27
+ NUM_STEPS = 2
28
+
29
+ default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
30
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
31
+ page_content = """
32
+ <h1 class="text-3xl font-bold">Real-Time SDXL Lightning</h1>
33
+ <h3 class="text-xl font-bold">Image-to-Image</h3>
34
+ <p class="text-sm">
35
+ This demo showcases
36
+ <a
37
+ href="https://huggingface.co/stabilityai/sdxl-turbo"
38
+ target="_blank"
39
+ class="text-blue-500 underline hover:no-underline">SDXL Turbo</a>
40
+ Image to Image pipeline using
41
+ <a
42
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl_turbo"
43
+ target="_blank"
44
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
45
+ > with a MJPEG stream server.
46
+ </p>
47
+ <p class="text-sm text-gray-500">
48
+ Change the prompt to generate different images, accepts <a
49
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
50
+ target="_blank"
51
+ class="text-blue-500 underline hover:no-underline">Compel</a
52
+ > syntax.
53
+ </p>
54
+ """
55
+
56
+
57
+ class Pipeline:
58
+ class Info(BaseModel):
59
+ name: str = "img2img"
60
+ title: str = "Image-to-Image SDXL-Lightning"
61
+ description: str = "Generates an image from a text prompt"
62
+ input_mode: str = "image"
63
+ page_content: str = page_content
64
+
65
+ class InputParams(BaseModel):
66
+ prompt: str = Field(
67
+ default_prompt,
68
+ title="Prompt",
69
+ field="textarea",
70
+ id="prompt",
71
+ )
72
+ negative_prompt: str = Field(
73
+ default_negative_prompt,
74
+ title="Negative Prompt",
75
+ field="textarea",
76
+ id="negative_prompt",
77
+ hide=True,
78
+ )
79
+ seed: int = Field(
80
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
81
+ )
82
+ steps: int = Field(
83
+ 1, min=1, max=10, title="Steps", field="range", hide=True, id="steps"
84
+ )
85
+ width: int = Field(
86
+ 1024, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
87
+ )
88
+ height: int = Field(
89
+ 1024, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
90
+ )
91
+ guidance_scale: float = Field(
92
+ 0.0,
93
+ min=0,
94
+ max=1,
95
+ step=0.001,
96
+ title="Guidance Scale",
97
+ field="range",
98
+ hide=True,
99
+ id="guidance_scale",
100
+ )
101
+ strength: float = Field(
102
+ 0.5,
103
+ min=0.25,
104
+ max=1.0,
105
+ step=0.001,
106
+ title="Strength",
107
+ field="range",
108
+ hide=True,
109
+ id="strength",
110
+ )
111
+
112
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
113
+
114
+ if args.taesd:
115
+ vae = AutoencoderTiny.from_pretrained(
116
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
117
+ )
118
+ else:
119
+ vae = AutoencoderKL.from_pretrained(
120
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
121
+ )
122
+
123
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet")
124
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device.type))
125
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
126
+ base,
127
+ unet=unet,
128
+ torch_dtype=torch_dtype,
129
+ variant="fp16",
130
+ safety_checker=False,
131
+ vae=vae,
132
+ )
133
+ # Ensure sampler uses "trailing" timesteps.
134
+ self.pipe.scheduler = EulerDiscreteScheduler.from_config(
135
+ self.pipe.scheduler.config, timestep_spacing="trailing"
136
+ )
137
+
138
+ if args.sfast:
139
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
140
+ compile,
141
+ CompilationConfig,
142
+ )
143
+
144
+ config = CompilationConfig.Default()
145
+ config.enable_xformers = True
146
+ config.enable_triton = True
147
+ config.enable_cuda_graph = True
148
+ self.pipe = compile(self.pipe, config=config)
149
+
150
+ self.pipe.set_progress_bar_config(disable=True)
151
+ self.pipe.to(device=device, dtype=torch_dtype)
152
+ if device.type != "mps":
153
+ self.pipe.unet.to(memory_format=torch.channels_last)
154
+
155
+ if args.torch_compile:
156
+ print("Running torch compile")
157
+ self.pipe.unet = torch.compile(
158
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
159
+ )
160
+ self.pipe.vae = torch.compile(
161
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
162
+ )
163
+ self.pipe(
164
+ prompt="warmup",
165
+ image=[Image.new("RGB", (768, 768))],
166
+ )
167
+
168
+ if args.compel:
169
+ self.pipe.compel_proc = Compel(
170
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
171
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
172
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
173
+ requires_pooled=[False, True],
174
+ )
175
+
176
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
177
+ generator = torch.manual_seed(params.seed)
178
+ prompt = params.prompt
179
+ negative_prompt = params.negative_prompt
180
+ prompt_embeds = None
181
+ pooled_prompt_embeds = None
182
+ negative_prompt_embeds = None
183
+ negative_pooled_prompt_embeds = None
184
+ if hasattr(self.pipe, "compel_proc"):
185
+ _prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
186
+ [params.prompt, params.negative_prompt]
187
+ )
188
+ prompt = None
189
+ negative_prompt = None
190
+ prompt_embeds = _prompt_embeds[0:1]
191
+ pooled_prompt_embeds = pooled_prompt_embeds[0:1]
192
+ negative_prompt_embeds = _prompt_embeds[1:2]
193
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[1:2]
194
+
195
+ steps = params.steps
196
+ strength = params.strength
197
+ if int(steps * strength) < 1:
198
+ steps = math.ceil(1 / max(0.10, strength))
199
+
200
+ results = self.pipe(
201
+ image=params.image,
202
+ prompt=prompt,
203
+ negative_prompt=negative_prompt,
204
+ prompt_embeds=prompt_embeds,
205
+ pooled_prompt_embeds=pooled_prompt_embeds,
206
+ negative_prompt_embeds=negative_prompt_embeds,
207
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
208
+ generator=generator,
209
+ strength=strength,
210
+ num_inference_steps=steps,
211
+ guidance_scale=params.guidance_scale,
212
+ width=params.width,
213
+ height=params.height,
214
+ output_type="pil",
215
+ )
216
+
217
+ nsfw_content_detected = (
218
+ results.nsfw_content_detected[0]
219
+ if "nsfw_content_detected" in results
220
+ else False
221
+ )
222
+ if nsfw_content_detected:
223
+ return None
224
+ result_image = results.images[0]
225
+
226
+ return result_image