radames commited on
Commit
dbe1557
1 Parent(s): 547a086

add SDXL Turbo examples

Browse files
pipelines/controlnetSDXLTurbo.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ StableDiffusionXLControlNetImg2ImgPipeline,
3
+ ControlNetModel,
4
+ AutoencoderKL,
5
+ )
6
+ from compel import Compel, ReturnedEmbeddingsType
7
+ import torch
8
+ from pipelines.utils.canny_gpu import SobelOperator
9
+
10
+ try:
11
+ import intel_extension_for_pytorch as ipex # type: ignore
12
+ except:
13
+ pass
14
+
15
+ import psutil
16
+ from config import Args
17
+ from pydantic import BaseModel, Field
18
+ from PIL import Image
19
+
20
+ controlnet_model = "diffusers/controlnet-canny-sdxl-1.0"
21
+ model_id = "stabilityai/sdxl-turbo"
22
+
23
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
24
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
25
+ page_content = """
26
+ <h1 class="text-3xl font-bold">Real-Time SDXL Turbo</h1>
27
+ <h3 class="text-xl font-bold">Image-to-Image ControlNet</h3>
28
+ <p class="text-sm">
29
+ This demo showcases
30
+ <a
31
+ href="https://huggingface.co/stabilityai/sdxl-turbo"
32
+ target="_blank"
33
+ class="text-blue-500 underline hover:no-underline">SDXL Turbo</a>
34
+ Image to Image pipeline using
35
+ <a
36
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl_turbo"
37
+ target="_blank"
38
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
39
+ > with a MJPEG stream server.
40
+ </p>
41
+ <p class="text-sm text-gray-500">
42
+ Change the prompt to generate different images, accepts <a
43
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
44
+ target="_blank"
45
+ class="text-blue-500 underline hover:no-underline">Compel</a
46
+ > syntax.
47
+ </p>
48
+ """
49
+
50
+
51
+ class Pipeline:
52
+ class Info(BaseModel):
53
+ name: str = "controlnet+SDXL+Turbo"
54
+ title: str = "SDXL Turbo + Controlnet"
55
+ description: str = "Generates an image from a text prompt"
56
+ input_mode: str = "image"
57
+ page_content: str = page_content
58
+
59
+ class InputParams(BaseModel):
60
+ prompt: str = Field(
61
+ default_prompt,
62
+ title="Prompt",
63
+ field="textarea",
64
+ id="prompt",
65
+ )
66
+ negative_prompt: str = Field(
67
+ default_negative_prompt,
68
+ title="Negative Prompt",
69
+ field="textarea",
70
+ id="negative_prompt",
71
+ hide=True,
72
+ )
73
+ seed: int = Field(
74
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
75
+ )
76
+ steps: int = Field(
77
+ 4, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
78
+ )
79
+ width: int = Field(
80
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
81
+ )
82
+ height: int = Field(
83
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
84
+ )
85
+ guidance_scale: float = Field(
86
+ 1.0,
87
+ min=0,
88
+ max=20,
89
+ step=0.001,
90
+ title="Guidance Scale",
91
+ field="range",
92
+ hide=True,
93
+ id="guidance_scale",
94
+ )
95
+ strength: float = Field(
96
+ 0.5,
97
+ min=0.25,
98
+ max=1.0,
99
+ step=0.001,
100
+ title="Strength",
101
+ field="range",
102
+ hide=True,
103
+ id="strength",
104
+ )
105
+ controlnet_scale: float = Field(
106
+ 0.5,
107
+ min=0,
108
+ max=1.0,
109
+ step=0.001,
110
+ title="Controlnet Scale",
111
+ field="range",
112
+ hide=True,
113
+ id="controlnet_scale",
114
+ )
115
+ controlnet_start: float = Field(
116
+ 0.0,
117
+ min=0,
118
+ max=1.0,
119
+ step=0.001,
120
+ title="Controlnet Start",
121
+ field="range",
122
+ hide=True,
123
+ id="controlnet_start",
124
+ )
125
+ controlnet_end: float = Field(
126
+ 1.0,
127
+ min=0,
128
+ max=1.0,
129
+ step=0.001,
130
+ title="Controlnet End",
131
+ field="range",
132
+ hide=True,
133
+ id="controlnet_end",
134
+ )
135
+ canny_low_threshold: float = Field(
136
+ 0.31,
137
+ min=0,
138
+ max=1.0,
139
+ step=0.001,
140
+ title="Canny Low Threshold",
141
+ field="range",
142
+ hide=True,
143
+ id="canny_low_threshold",
144
+ )
145
+ canny_high_threshold: float = Field(
146
+ 0.125,
147
+ min=0,
148
+ max=1.0,
149
+ step=0.001,
150
+ title="Canny High Threshold",
151
+ field="range",
152
+ hide=True,
153
+ id="canny_high_threshold",
154
+ )
155
+ debug_canny: bool = Field(
156
+ False,
157
+ title="Debug Canny",
158
+ field="checkbox",
159
+ hide=True,
160
+ id="debug_canny",
161
+ )
162
+
163
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
164
+ controlnet_canny = ControlNetModel.from_pretrained(
165
+ controlnet_model, torch_dtype=torch_dtype
166
+ ).to(device)
167
+ vae = AutoencoderKL.from_pretrained(
168
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype
169
+ )
170
+ if args.safety_checker:
171
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
172
+ model_id,
173
+ controlnet=controlnet_canny,
174
+ vae=vae,
175
+ )
176
+ else:
177
+ self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
178
+ model_id,
179
+ safety_checker=None,
180
+ controlnet=controlnet_canny,
181
+ vae=vae,
182
+ )
183
+ self.canny_torch = SobelOperator(device=device)
184
+
185
+ self.pipe.set_progress_bar_config(disable=True)
186
+ self.pipe.to(device=device, dtype=torch_dtype).to(device)
187
+ if device.type != "mps":
188
+ self.pipe.unet.to(memory_format=torch.channels_last)
189
+
190
+ if psutil.virtual_memory().total < 64 * 1024**3:
191
+ self.pipe.enable_attention_slicing()
192
+
193
+ self.pipe.compel_proc = Compel(
194
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
195
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
196
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
197
+ requires_pooled=[False, True],
198
+ )
199
+
200
+ if args.torch_compile:
201
+ self.pipe.unet = torch.compile(
202
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
203
+ )
204
+ self.pipe.vae = torch.compile(
205
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
206
+ )
207
+ self.pipe(
208
+ prompt="warmup",
209
+ image=[Image.new("RGB", (768, 768))],
210
+ control_image=[Image.new("RGB", (768, 768))],
211
+ )
212
+
213
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
214
+ generator = torch.manual_seed(params.seed)
215
+
216
+ prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
217
+ [params.prompt, params.negative_prompt]
218
+ )
219
+ control_image = self.canny_torch(
220
+ params.image, params.canny_low_threshold, params.canny_high_threshold
221
+ )
222
+ steps = params.steps
223
+ strength = params.strength
224
+ if steps == 1:
225
+ strength = 1
226
+
227
+ results = self.pipe(
228
+ image=params.image,
229
+ control_image=control_image,
230
+ prompt_embeds=prompt_embeds[0:1],
231
+ pooled_prompt_embeds=pooled_prompt_embeds[0:1],
232
+ negative_prompt_embeds=prompt_embeds[1:2],
233
+ negative_pooled_prompt_embeds=pooled_prompt_embeds[1:2],
234
+ generator=generator,
235
+ strength=strength,
236
+ num_inference_steps=steps,
237
+ guidance_scale=params.guidance_scale,
238
+ width=params.width,
239
+ height=params.height,
240
+ output_type="pil",
241
+ controlnet_conditioning_scale=params.controlnet_scale,
242
+ control_guidance_start=params.controlnet_start,
243
+ control_guidance_end=params.controlnet_end,
244
+ )
245
+
246
+ nsfw_content_detected = (
247
+ results.nsfw_content_detected[0]
248
+ if "nsfw_content_detected" in results
249
+ else False
250
+ )
251
+ if nsfw_content_detected:
252
+ return None
253
+ result_image = results.images[0]
254
+ if params.debug_canny:
255
+ # paste control_image on top of result_image
256
+ w0, h0 = (200, 200)
257
+ control_image = control_image.resize((w0, h0))
258
+ w1, h1 = result_image.size
259
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
260
+
261
+ return result_image
pipelines/img2imgSDXLTurbo.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ AutoPipelineForImage2Image,
3
+ AutoencoderTiny,
4
+ )
5
+ from compel import Compel, ReturnedEmbeddingsType
6
+ import torch
7
+
8
+ try:
9
+ import intel_extension_for_pytorch as ipex # type: ignore
10
+ except:
11
+ pass
12
+
13
+ import psutil
14
+ from config import Args
15
+ from pydantic import BaseModel, Field
16
+ from PIL import Image
17
+
18
+ base_model = "stabilityai/sdxl-turbo"
19
+ taesd_model = "madebyollin/taesd"
20
+
21
+ default_prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
22
+ default_negative_prompt = "blurry, low quality, render, 3D, oversaturated"
23
+ page_content = """
24
+ <h1 class="text-3xl font-bold">Real-Time SDXL Turbo</h1>
25
+ <h3 class="text-xl font-bold">Image-to-Image</h3>
26
+ <p class="text-sm">
27
+ This demo showcases
28
+ <a
29
+ href="https://huggingface.co/stabilityai/sdxl-turbo"
30
+ target="_blank"
31
+ class="text-blue-500 underline hover:no-underline">SDXL Turbo</a>
32
+ Image to Image pipeline using
33
+ <a
34
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/sdxl_turbo"
35
+ target="_blank"
36
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
37
+ > with a MJPEG stream server.
38
+ </p>
39
+ <p class="text-sm text-gray-500">
40
+ Change the prompt to generate different images, accepts <a
41
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
42
+ target="_blank"
43
+ class="text-blue-500 underline hover:no-underline">Compel</a
44
+ > syntax.
45
+ </p>
46
+ """
47
+
48
+
49
+ class Pipeline:
50
+ class Info(BaseModel):
51
+ name: str = "img2img"
52
+ title: str = "Image-to-Image SDXL"
53
+ description: str = "Generates an image from a text prompt"
54
+ input_mode: str = "image"
55
+ page_content: str = page_content
56
+
57
+ class InputParams(BaseModel):
58
+ prompt: str = Field(
59
+ default_prompt,
60
+ title="Prompt",
61
+ field="textarea",
62
+ id="prompt",
63
+ )
64
+ negative_prompt: str = Field(
65
+ default_negative_prompt,
66
+ title="Negative Prompt",
67
+ field="textarea",
68
+ id="negative_prompt",
69
+ hide=True,
70
+ )
71
+ seed: int = Field(
72
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
73
+ )
74
+ steps: int = Field(
75
+ 4, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
76
+ )
77
+ width: int = Field(
78
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
79
+ )
80
+ height: int = Field(
81
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
82
+ )
83
+ guidance_scale: float = Field(
84
+ 0.2,
85
+ min=0,
86
+ max=20,
87
+ step=0.001,
88
+ title="Guidance Scale",
89
+ field="range",
90
+ hide=True,
91
+ id="guidance_scale",
92
+ )
93
+ strength: float = Field(
94
+ 0.5,
95
+ min=0.25,
96
+ max=1.0,
97
+ step=0.001,
98
+ title="Strength",
99
+ field="range",
100
+ hide=True,
101
+ id="strength",
102
+ )
103
+
104
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
105
+ if args.safety_checker:
106
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(base_model)
107
+ else:
108
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
109
+ base_model,
110
+ safety_checker=None,
111
+ )
112
+ if args.use_taesd:
113
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
114
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
115
+ )
116
+
117
+ self.pipe.set_progress_bar_config(disable=True)
118
+ self.pipe.to(device=device, dtype=torch_dtype)
119
+ if device.type != "mps":
120
+ self.pipe.unet.to(memory_format=torch.channels_last)
121
+
122
+ # check if computer has less than 64GB of RAM using sys or os
123
+ if psutil.virtual_memory().total < 64 * 1024**3:
124
+ self.pipe.enable_attention_slicing()
125
+
126
+ if args.torch_compile:
127
+ print("Running torch compile")
128
+ self.pipe.unet = torch.compile(
129
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
130
+ )
131
+ self.pipe.vae = torch.compile(
132
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
133
+ )
134
+
135
+ self.pipe(
136
+ prompt="warmup",
137
+ image=[Image.new("RGB", (768, 768))],
138
+ )
139
+
140
+ self.pipe.compel_proc = Compel(
141
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
142
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
143
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
144
+ requires_pooled=[False, True],
145
+ )
146
+
147
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
148
+ generator = torch.manual_seed(params.seed)
149
+ prompt_embeds, pooled_prompt_embeds = self.pipe.compel_proc(
150
+ [params.prompt, params.negative_prompt]
151
+ )
152
+ steps = params.steps
153
+ strength = params.strength
154
+ if steps <= 1:
155
+ strength = 1
156
+ else:
157
+ strength = 1 / steps
158
+
159
+ results = self.pipe(
160
+ image=params.image,
161
+ prompt_embeds=prompt_embeds[0:1],
162
+ pooled_prompt_embeds=pooled_prompt_embeds[0:1],
163
+ negative_prompt_embeds=prompt_embeds[1:2],
164
+ negative_pooled_prompt_embeds=pooled_prompt_embeds[1:2],
165
+ generator=generator,
166
+ strength=strength,
167
+ num_inference_steps=steps,
168
+ guidance_scale=params.guidance_scale,
169
+ width=params.width,
170
+ height=params.height,
171
+ output_type="pil",
172
+ )
173
+
174
+ nsfw_content_detected = (
175
+ results.nsfw_content_detected[0]
176
+ if "nsfw_content_detected" in results
177
+ else False
178
+ )
179
+ if nsfw_content_detected:
180
+ return None
181
+ result_image = results.images[0]
182
+
183
+ return result_image