radames HF staff commited on
Commit
06d3d8c
1 Parent(s): e0273d5

Upload 3 files

Browse files
server/utils/__init__.py ADDED
File without changes
server/utils/viewer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import threading
4
+ import time
5
+ import tkinter as tk
6
+ from multiprocessing import Queue
7
+ from typing import List
8
+ from PIL import Image, ImageTk
9
+ from streamdiffusion.image_utils import postprocess_image
10
+
11
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
12
+
13
+
14
+ def update_image(image_data: Image.Image, label: tk.Label) -> None:
15
+ """
16
+ Update the image displayed on a Tkinter label.
17
+
18
+ Parameters
19
+ ----------
20
+ image_data : Image.Image
21
+ The image to be displayed.
22
+ label : tk.Label
23
+ The labels where the image will be updated.
24
+ """
25
+ width = 512
26
+ height = 512
27
+ tk_image = ImageTk.PhotoImage(image_data, size=width)
28
+ label.configure(image=tk_image, width=width, height=height)
29
+ label.image = tk_image # keep a reference
30
+
31
+ def _receive_images(
32
+ queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label
33
+ ) -> None:
34
+ """
35
+ Continuously receive images from a queue and update the labels.
36
+
37
+ Parameters
38
+ ----------
39
+ queue : Queue
40
+ The queue to receive images from.
41
+ fps_queue : Queue
42
+ The queue to put the calculated fps.
43
+ label : tk.Label
44
+ The label to update with images.
45
+ fps_label : tk.Label
46
+ The label to show fps.
47
+ """
48
+ while True:
49
+ try:
50
+ if not queue.empty():
51
+ label.after(
52
+ 0,
53
+ update_image,
54
+ postprocess_image(queue.get(block=False), output_type="pil")[0],
55
+ label,
56
+ )
57
+ if not fps_queue.empty():
58
+ fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}")
59
+
60
+ time.sleep(0.0005)
61
+ except KeyboardInterrupt:
62
+ return
63
+
64
+
65
+ def receive_images(queue: Queue, fps_queue: Queue) -> None:
66
+ """
67
+ Setup the Tkinter window and start the thread to receive images.
68
+
69
+ Parameters
70
+ ----------
71
+ queue : Queue
72
+ The queue to receive images from.
73
+ fps_queue : Queue
74
+ The queue to put the calculated fps.
75
+ """
76
+ root = tk.Tk()
77
+ root.title("Image Viewer")
78
+ label = tk.Label(root)
79
+ fps_label = tk.Label(root, text="FPS: 0")
80
+ label.grid(column=0)
81
+ fps_label.grid(column=1)
82
+
83
+ def on_closing():
84
+ print("window closed")
85
+ root.quit() # stop event loop
86
+ return
87
+
88
+ thread = threading.Thread(
89
+ target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True
90
+ )
91
+ thread.start()
92
+
93
+ try:
94
+ root.protocol("WM_DELETE_WINDOW", on_closing)
95
+ root.mainloop()
96
+ except KeyboardInterrupt:
97
+ return
98
+
server/utils/wrapper.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from pathlib import Path
4
+ import traceback
5
+ from typing import List, Literal, Optional, Union, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import AutoencoderTiny, StableDiffusionPipeline
10
+ from PIL import Image
11
+ from polygraphy import cuda
12
+
13
+ from streamdiffusion import StreamDiffusion
14
+ from streamdiffusion.image_utils import postprocess_image
15
+
16
+
17
+ torch.set_grad_enabled(False)
18
+ torch.backends.cuda.matmul.allow_tf32 = True
19
+ torch.backends.cudnn.allow_tf32 = True
20
+
21
+
22
+ class StreamDiffusionWrapper:
23
+ def __init__(
24
+ self,
25
+ model_id_or_path: str,
26
+ t_index_list: List[int],
27
+ lora_dict: Optional[Dict[str, float]] = None,
28
+ mode: Literal["img2img", "txt2img"] = "img2img",
29
+ output_type: Literal["pil", "pt", "np", "latent"] = "pil",
30
+ lcm_lora_id: Optional[str] = None,
31
+ vae_id: Optional[str] = None,
32
+ device: Literal["cpu", "cuda"] = "cuda",
33
+ dtype: torch.dtype = torch.float16,
34
+ frame_buffer_size: int = 1,
35
+ width: int = 512,
36
+ height: int = 512,
37
+ warmup: int = 10,
38
+ acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt",
39
+ do_add_noise: bool = True,
40
+ device_ids: Optional[List[int]] = None,
41
+ use_lcm_lora: bool = True,
42
+ use_tiny_vae: bool = True,
43
+ enable_similar_image_filter: bool = False,
44
+ similar_image_filter_threshold: float = 0.98,
45
+ similar_image_filter_max_skip_frame: int = 10,
46
+ use_denoising_batch: bool = True,
47
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
48
+ seed: int = 2,
49
+ use_safety_checker: bool = False,
50
+ ):
51
+ """
52
+ Initializes the StreamDiffusionWrapper.
53
+
54
+ Parameters
55
+ ----------
56
+ model_id_or_path : str
57
+ The model id or path to load.
58
+ t_index_list : List[int]
59
+ The t_index_list to use for inference.
60
+ lora_dict : Optional[Dict[str, float]], optional
61
+ The lora_dict to load, by default None.
62
+ Keys are the LoRA names and values are the LoRA scales.
63
+ Example: {"LoRA_1" : 0.5 , "LoRA_2" : 0.7 ,...}
64
+ mode : Literal["img2img", "txt2img"], optional
65
+ txt2img or img2img, by default "img2img".
66
+ output_type : Literal["pil", "pt", "np", "latent"], optional
67
+ The output type of image, by default "pil".
68
+ lcm_lora_id : Optional[str], optional
69
+ The lcm_lora_id to load, by default None.
70
+ If None, the default LCM-LoRA
71
+ ("latent-consistency/lcm-lora-sdv1-5") will be used.
72
+ vae_id : Optional[str], optional
73
+ The vae_id to load, by default None.
74
+ If None, the default TinyVAE
75
+ ("madebyollin/taesd") will be used.
76
+ device : Literal["cpu", "cuda"], optional
77
+ The device to use for inference, by default "cuda".
78
+ dtype : torch.dtype, optional
79
+ The dtype for inference, by default torch.float16.
80
+ frame_buffer_size : int, optional
81
+ The frame buffer size for denoising batch, by default 1.
82
+ width : int, optional
83
+ The width of the image, by default 512.
84
+ height : int, optional
85
+ The height of the image, by default 512.
86
+ warmup : int, optional
87
+ The number of warmup steps to perform, by default 10.
88
+ acceleration : Literal["none", "xformers", "tensorrt"], optional
89
+ The acceleration method, by default "tensorrt".
90
+ do_add_noise : bool, optional
91
+ Whether to add noise for following denoising steps or not,
92
+ by default True.
93
+ device_ids : Optional[List[int]], optional
94
+ The device ids to use for DataParallel, by default None.
95
+ use_lcm_lora : bool, optional
96
+ Whether to use LCM-LoRA or not, by default True.
97
+ use_tiny_vae : bool, optional
98
+ Whether to use TinyVAE or not, by default True.
99
+ enable_similar_image_filter : bool, optional
100
+ Whether to enable similar image filter or not,
101
+ by default False.
102
+ similar_image_filter_threshold : float, optional
103
+ The threshold for similar image filter, by default 0.98.
104
+ similar_image_filter_max_skip_frame : int, optional
105
+ The max skip frame for similar image filter, by default 10.
106
+ use_denoising_batch : bool, optional
107
+ Whether to use denoising batch or not, by default True.
108
+ cfg_type : Literal["none", "full", "self", "initialize"],
109
+ optional
110
+ The cfg_type for img2img mode, by default "self".
111
+ You cannot use anything other than "none" for txt2img mode.
112
+ seed : int, optional
113
+ The seed, by default 2.
114
+ use_safety_checker : bool, optional
115
+ Whether to use safety checker or not, by default False.
116
+ """
117
+ self.sd_turbo = "turbo" in model_id_or_path
118
+
119
+ if mode == "txt2img":
120
+ if cfg_type != "none":
121
+ raise ValueError(
122
+ f"txt2img mode accepts only cfg_type = 'none', but got {cfg_type}"
123
+ )
124
+ if use_denoising_batch and frame_buffer_size > 1:
125
+ if not self.sd_turbo:
126
+ raise ValueError(
127
+ "txt2img mode cannot use denoising batch with frame_buffer_size > 1."
128
+ )
129
+
130
+ if mode == "img2img":
131
+ if not use_denoising_batch:
132
+ raise NotImplementedError(
133
+ "img2img mode must use denoising batch for now."
134
+ )
135
+
136
+ self.device = device
137
+ self.dtype = dtype
138
+ self.width = width
139
+ self.height = height
140
+ self.mode = mode
141
+ self.output_type = output_type
142
+ self.frame_buffer_size = frame_buffer_size
143
+ self.batch_size = (
144
+ len(t_index_list) * frame_buffer_size
145
+ if use_denoising_batch
146
+ else frame_buffer_size
147
+ )
148
+
149
+ self.use_denoising_batch = use_denoising_batch
150
+ self.use_safety_checker = use_safety_checker
151
+
152
+ self.stream: StreamDiffusion = self._load_model(
153
+ model_id_or_path=model_id_or_path,
154
+ lora_dict=lora_dict,
155
+ lcm_lora_id=lcm_lora_id,
156
+ vae_id=vae_id,
157
+ t_index_list=t_index_list,
158
+ acceleration=acceleration,
159
+ warmup=warmup,
160
+ do_add_noise=do_add_noise,
161
+ use_lcm_lora=use_lcm_lora,
162
+ use_tiny_vae=use_tiny_vae,
163
+ cfg_type=cfg_type,
164
+ seed=seed,
165
+ )
166
+
167
+ if device_ids is not None:
168
+ self.stream.unet = torch.nn.DataParallel(
169
+ self.stream.unet, device_ids=device_ids
170
+ )
171
+
172
+ if enable_similar_image_filter:
173
+ self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame)
174
+
175
+ def prepare(
176
+ self,
177
+ prompt: str,
178
+ negative_prompt: str = "",
179
+ num_inference_steps: int = 50,
180
+ guidance_scale: float = 1.2,
181
+ delta: float = 1.0,
182
+ ) -> None:
183
+ """
184
+ Prepares the model for inference.
185
+
186
+ Parameters
187
+ ----------
188
+ prompt : str
189
+ The prompt to generate images from.
190
+ num_inference_steps : int, optional
191
+ The number of inference steps to perform, by default 50.
192
+ guidance_scale : float, optional
193
+ The guidance scale to use, by default 1.2.
194
+ delta : float, optional
195
+ The delta multiplier of virtual residual noise,
196
+ by default 1.0.
197
+ """
198
+ self.stream.prepare(
199
+ prompt,
200
+ negative_prompt,
201
+ num_inference_steps=num_inference_steps,
202
+ guidance_scale=guidance_scale,
203
+ delta=delta,
204
+ )
205
+
206
+ def __call__(
207
+ self,
208
+ image: Optional[Union[str, Image.Image, torch.Tensor]] = None,
209
+ prompt: Optional[str] = None,
210
+ ) -> Union[Image.Image, List[Image.Image]]:
211
+ """
212
+ Performs img2img or txt2img based on the mode.
213
+
214
+ Parameters
215
+ ----------
216
+ image : Optional[Union[str, Image.Image, torch.Tensor]]
217
+ The image to generate from.
218
+ prompt : Optional[str]
219
+ The prompt to generate images from.
220
+
221
+ Returns
222
+ -------
223
+ Union[Image.Image, List[Image.Image]]
224
+ The generated image.
225
+ """
226
+ if self.mode == "img2img":
227
+ return self.img2img(image)
228
+ else:
229
+ return self.txt2img(prompt)
230
+
231
+ def txt2img(
232
+ self, prompt: Optional[str] = None
233
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
234
+ """
235
+ Performs txt2img.
236
+
237
+ Parameters
238
+ ----------
239
+ prompt : Optional[str]
240
+ The prompt to generate images from.
241
+
242
+ Returns
243
+ -------
244
+ Union[Image.Image, List[Image.Image]]
245
+ The generated image.
246
+ """
247
+ if prompt is not None:
248
+ self.stream.update_prompt(prompt)
249
+
250
+ if self.sd_turbo:
251
+ image_tensor = self.stream.txt2img_sd_turbo(self.batch_size)
252
+ else:
253
+ image_tensor = self.stream.txt2img(self.frame_buffer_size)
254
+ image = self.postprocess_image(image_tensor, output_type=self.output_type)
255
+
256
+ if self.use_safety_checker:
257
+ safety_checker_input = self.feature_extractor(
258
+ image, return_tensors="pt"
259
+ ).to(self.device)
260
+ _, has_nsfw_concept = self.safety_checker(
261
+ images=image_tensor.to(self.dtype),
262
+ clip_input=safety_checker_input.pixel_values.to(self.dtype),
263
+ )
264
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
265
+
266
+ return image
267
+
268
+ def img2img(
269
+ self, image: Union[str, Image.Image, torch.Tensor]
270
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
271
+ """
272
+ Performs img2img.
273
+
274
+ Parameters
275
+ ----------
276
+ image : Union[str, Image.Image, torch.Tensor]
277
+ The image to generate from.
278
+
279
+ Returns
280
+ -------
281
+ Image.Image
282
+ The generated image.
283
+ """
284
+ if isinstance(image, str) or isinstance(image, Image.Image):
285
+ image = self.preprocess_image(image)
286
+
287
+ image_tensor = self.stream(image)
288
+ image = self.postprocess_image(image_tensor, output_type=self.output_type)
289
+
290
+ if self.use_safety_checker:
291
+ safety_checker_input = self.feature_extractor(
292
+ image, return_tensors="pt"
293
+ ).to(self.device)
294
+ _, has_nsfw_concept = self.safety_checker(
295
+ images=image_tensor.to(self.dtype),
296
+ clip_input=safety_checker_input.pixel_values.to(self.dtype),
297
+ )
298
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
299
+
300
+ return image
301
+
302
+ def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor:
303
+ """
304
+ Preprocesses the image.
305
+
306
+ Parameters
307
+ ----------
308
+ image : Union[str, Image.Image, torch.Tensor]
309
+ The image to preprocess.
310
+
311
+ Returns
312
+ -------
313
+ torch.Tensor
314
+ The preprocessed image.
315
+ """
316
+ if isinstance(image, str):
317
+ image = Image.open(image).convert("RGB").resize((self.width, self.height))
318
+ if isinstance(image, Image.Image):
319
+ image = image.convert("RGB").resize((self.width, self.height))
320
+
321
+ return self.stream.image_processor.preprocess(
322
+ image, self.height, self.width
323
+ ).to(device=self.device, dtype=self.dtype)
324
+
325
+ def postprocess_image(
326
+ self, image_tensor: torch.Tensor, output_type: str = "pil"
327
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
328
+ """
329
+ Postprocesses the image.
330
+
331
+ Parameters
332
+ ----------
333
+ image_tensor : torch.Tensor
334
+ The image tensor to postprocess.
335
+
336
+ Returns
337
+ -------
338
+ Union[Image.Image, List[Image.Image]]
339
+ The postprocessed image.
340
+ """
341
+ if self.frame_buffer_size > 1:
342
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)
343
+ else:
344
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)[0]
345
+
346
+ def _load_model(
347
+ self,
348
+ model_id_or_path: str,
349
+ t_index_list: List[int],
350
+ lora_dict: Optional[Dict[str, float]] = None,
351
+ lcm_lora_id: Optional[str] = None,
352
+ vae_id: Optional[str] = None,
353
+ acceleration: Literal["none", "sfast", "tensorrt"] = "tensorrt",
354
+ warmup: int = 10,
355
+ do_add_noise: bool = True,
356
+ use_lcm_lora: bool = True,
357
+ use_tiny_vae: bool = True,
358
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
359
+ seed: int = 2,
360
+ ) -> StreamDiffusion:
361
+ """
362
+ Loads the model.
363
+
364
+ This method does the following:
365
+
366
+ 1. Loads the model from the model_id_or_path.
367
+ 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed.
368
+ 3. Loads the VAE model from the vae_id if needed.
369
+ 4. Enables acceleration if needed.
370
+ 5. Prepares the model for inference.
371
+ 6. Load the safety checker if needed.
372
+
373
+ Parameters
374
+ ----------
375
+ model_id_or_path : str
376
+ The model id or path to load.
377
+ t_index_list : List[int]
378
+ The t_index_list to use for inference.
379
+ lora_dict : Optional[Dict[str, float]], optional
380
+ The lora_dict to load, by default None.
381
+ Keys are the LoRA names and values are the LoRA scales.
382
+ Example: {"LoRA_1" : 0.5 , "LoRA_2" : 0.7 ,...}
383
+ lcm_lora_id : Optional[str], optional
384
+ The lcm_lora_id to load, by default None.
385
+ vae_id : Optional[str], optional
386
+ The vae_id to load, by default None.
387
+ acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional
388
+ The acceleration method, by default "tensorrt".
389
+ warmup : int, optional
390
+ The number of warmup steps to perform, by default 10.
391
+ do_add_noise : bool, optional
392
+ Whether to add noise for following denoising steps or not,
393
+ by default True.
394
+ use_lcm_lora : bool, optional
395
+ Whether to use LCM-LoRA or not, by default True.
396
+ use_tiny_vae : bool, optional
397
+ Whether to use TinyVAE or not, by default True.
398
+ cfg_type : Literal["none", "full", "self", "initialize"],
399
+ optional
400
+ The cfg_type for img2img mode, by default "self".
401
+ You cannot use anything other than "none" for txt2img mode.
402
+ seed : int, optional
403
+ The seed, by default 2.
404
+
405
+ Returns
406
+ -------
407
+ StreamDiffusion
408
+ The loaded model.
409
+ """
410
+
411
+ try: # Load from local directory
412
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
413
+ model_id_or_path,
414
+ ).to(device=self.device, dtype=self.dtype)
415
+
416
+ except ValueError: # Load from huggingface
417
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
418
+ model_id_or_path,
419
+ ).to(device=self.device, dtype=self.dtype)
420
+ except Exception: # No model found
421
+ traceback.print_exc()
422
+ print("Model load has failed. Doesn't exist.")
423
+ exit()
424
+
425
+ stream = StreamDiffusion(
426
+ pipe=pipe,
427
+ t_index_list=t_index_list,
428
+ torch_dtype=self.dtype,
429
+ width=self.width,
430
+ height=self.height,
431
+ do_add_noise=do_add_noise,
432
+ frame_buffer_size=self.frame_buffer_size,
433
+ use_denoising_batch=self.use_denoising_batch,
434
+ cfg_type=cfg_type,
435
+ )
436
+ if not self.sd_turbo:
437
+ if use_lcm_lora:
438
+ if lcm_lora_id is not None:
439
+ stream.load_lcm_lora(
440
+ pretrained_model_name_or_path_or_dict=lcm_lora_id
441
+ )
442
+ else:
443
+ stream.load_lcm_lora()
444
+ stream.fuse_lora()
445
+
446
+ if lora_dict is not None:
447
+ for lora_name, lora_scale in lora_dict.items():
448
+ stream.load_lora(lora_name)
449
+ stream.fuse_lora(lora_scale=lora_scale)
450
+ print(f"Use LoRA: {lora_name} in weights {lora_scale}")
451
+
452
+ if use_tiny_vae:
453
+ if vae_id is not None:
454
+ stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
455
+ device=pipe.device, dtype=pipe.dtype
456
+ )
457
+ else:
458
+ stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(
459
+ device=pipe.device, dtype=pipe.dtype
460
+ )
461
+
462
+ try:
463
+ if acceleration == "xformers":
464
+ stream.pipe.enable_xformers_memory_efficient_attention()
465
+ if acceleration == "tensorrt":
466
+ from streamdiffusion.acceleration.tensorrt import (
467
+ TorchVAEEncoder,
468
+ compile_unet,
469
+ compile_vae_decoder,
470
+ compile_vae_encoder,
471
+ )
472
+ from streamdiffusion.acceleration.tensorrt.engine import (
473
+ AutoencoderKLEngine,
474
+ UNet2DConditionModelEngine,
475
+ )
476
+ from streamdiffusion.acceleration.tensorrt.models import (
477
+ VAE,
478
+ UNet,
479
+ VAEEncoder,
480
+ )
481
+
482
+ def create_prefix(
483
+ model_id_or_path: str,
484
+ max_batch_size: int,
485
+ min_batch_size: int,
486
+ ):
487
+ maybe_path = Path(model_id_or_path)
488
+ if maybe_path.exists():
489
+ return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}"
490
+ else:
491
+ return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}"
492
+
493
+ engine_dir = os.path.join("engines")
494
+ unet_path = os.path.join(
495
+ engine_dir,
496
+ create_prefix(
497
+ model_id_or_path=model_id_or_path,
498
+ max_batch_size=stream.trt_unet_batch_size,
499
+ min_batch_size=stream.trt_unet_batch_size,
500
+ ),
501
+ "unet.engine",
502
+ )
503
+ vae_encoder_path = os.path.join(
504
+ engine_dir,
505
+ create_prefix(
506
+ model_id_or_path=model_id_or_path,
507
+ max_batch_size=self.batch_size
508
+ if self.mode == "txt2img"
509
+ else stream.frame_bff_size,
510
+ min_batch_size=self.batch_size
511
+ if self.mode == "txt2img"
512
+ else stream.frame_bff_size,
513
+ ),
514
+ "vae_encoder.engine",
515
+ )
516
+ vae_decoder_path = os.path.join(
517
+ engine_dir,
518
+ create_prefix(
519
+ model_id_or_path=model_id_or_path,
520
+ max_batch_size=self.batch_size
521
+ if self.mode == "txt2img"
522
+ else stream.frame_bff_size,
523
+ min_batch_size=self.batch_size
524
+ if self.mode == "txt2img"
525
+ else stream.frame_bff_size,
526
+ ),
527
+ "vae_decoder.engine",
528
+ )
529
+
530
+ if not os.path.exists(unet_path):
531
+ os.makedirs(os.path.dirname(unet_path), exist_ok=True)
532
+ unet_model = UNet(
533
+ fp16=True,
534
+ device=stream.device,
535
+ max_batch_size=stream.trt_unet_batch_size,
536
+ min_batch_size=stream.trt_unet_batch_size,
537
+ embedding_dim=stream.text_encoder.config.hidden_size,
538
+ unet_dim=stream.unet.config.in_channels,
539
+ )
540
+ compile_unet(
541
+ stream.unet,
542
+ unet_model,
543
+ unet_path + ".onnx",
544
+ unet_path + ".opt.onnx",
545
+ unet_path,
546
+ opt_batch_size=stream.trt_unet_batch_size,
547
+ )
548
+
549
+ if not os.path.exists(vae_decoder_path):
550
+ os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True)
551
+ stream.vae.forward = stream.vae.decode
552
+ vae_decoder_model = VAE(
553
+ device=stream.device,
554
+ max_batch_size=self.batch_size
555
+ if self.mode == "txt2img"
556
+ else stream.frame_bff_size,
557
+ min_batch_size=self.batch_size
558
+ if self.mode == "txt2img"
559
+ else stream.frame_bff_size,
560
+ )
561
+ compile_vae_decoder(
562
+ stream.vae,
563
+ vae_decoder_model,
564
+ vae_decoder_path + ".onnx",
565
+ vae_decoder_path + ".opt.onnx",
566
+ vae_decoder_path,
567
+ opt_batch_size=self.batch_size
568
+ if self.mode == "txt2img"
569
+ else stream.frame_bff_size,
570
+ )
571
+ delattr(stream.vae, "forward")
572
+
573
+ if not os.path.exists(vae_encoder_path):
574
+ os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True)
575
+ vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda"))
576
+ vae_encoder_model = VAEEncoder(
577
+ device=stream.device,
578
+ max_batch_size=self.batch_size
579
+ if self.mode == "txt2img"
580
+ else stream.frame_bff_size,
581
+ min_batch_size=self.batch_size
582
+ if self.mode == "txt2img"
583
+ else stream.frame_bff_size,
584
+ )
585
+ compile_vae_encoder(
586
+ vae_encoder,
587
+ vae_encoder_model,
588
+ vae_encoder_path + ".onnx",
589
+ vae_encoder_path + ".opt.onnx",
590
+ vae_encoder_path,
591
+ opt_batch_size=self.batch_size
592
+ if self.mode == "txt2img"
593
+ else stream.frame_bff_size,
594
+ )
595
+
596
+ cuda_steram = cuda.Stream()
597
+
598
+ vae_config = stream.vae.config
599
+ vae_dtype = stream.vae.dtype
600
+
601
+ stream.unet = UNet2DConditionModelEngine(
602
+ unet_path, cuda_steram, use_cuda_graph=False
603
+ )
604
+ stream.vae = AutoencoderKLEngine(
605
+ vae_encoder_path,
606
+ vae_decoder_path,
607
+ cuda_steram,
608
+ stream.pipe.vae_scale_factor,
609
+ use_cuda_graph=False,
610
+ )
611
+ setattr(stream.vae, "config", vae_config)
612
+ setattr(stream.vae, "dtype", vae_dtype)
613
+
614
+ gc.collect()
615
+ torch.cuda.empty_cache()
616
+
617
+ print("TensorRT acceleration enabled.")
618
+ if acceleration == "sfast":
619
+ from streamdiffusion.acceleration.sfast import (
620
+ accelerate_with_stable_fast,
621
+ )
622
+
623
+ stream = accelerate_with_stable_fast(stream)
624
+ print("StableFast acceleration enabled.")
625
+ except Exception:
626
+ traceback.print_exc()
627
+ print("Acceleration has failed. Falling back to normal mode.")
628
+
629
+ stream.prepare(
630
+ "",
631
+ "",
632
+ num_inference_steps=50,
633
+ guidance_scale=1.1
634
+ if stream.cfg_type in ["full", "self", "initialize"]
635
+ else 1.0,
636
+ generator=torch.manual_seed(seed),
637
+ seed=seed,
638
+ )
639
+
640
+ if self.use_safety_checker:
641
+ from transformers import CLIPFeatureExtractor
642
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
643
+ StableDiffusionSafetyChecker,
644
+ )
645
+
646
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
647
+ "CompVis/stable-diffusion-safety-checker"
648
+ ).to(pipe.device)
649
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
650
+ "openai/clip-vit-base-patch32"
651
+ )
652
+ self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0))
653
+
654
+ return stream