jbilcke-hf HF staff commited on
Commit
e08f02e
1 Parent(s): 0c09322
Files changed (4) hide show
  1. utils/__init__.py +0 -0
  2. utils/viewer.py +98 -0
  3. utils/wrapper.py +664 -0
  4. vid2vid.py +0 -7
utils/__init__.py ADDED
File without changes
utils/viewer.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import threading
4
+ import time
5
+ import tkinter as tk
6
+ from multiprocessing import Queue
7
+ from typing import List
8
+ from PIL import Image, ImageTk
9
+ from streamv2v.image_utils import postprocess_image
10
+
11
+ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
12
+
13
+
14
+ def update_image(image_data: Image.Image, label: tk.Label) -> None:
15
+ """
16
+ Update the image displayed on a Tkinter label.
17
+
18
+ Parameters
19
+ ----------
20
+ image_data : Image.Image
21
+ The image to be displayed.
22
+ label : tk.Label
23
+ The labels where the image will be updated.
24
+ """
25
+ width = 512
26
+ height = 512
27
+ tk_image = ImageTk.PhotoImage(image_data, size=width)
28
+ label.configure(image=tk_image, width=width, height=height)
29
+ label.image = tk_image # keep a reference
30
+
31
+ def _receive_images(
32
+ queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label
33
+ ) -> None:
34
+ """
35
+ Continuously receive images from a queue and update the labels.
36
+
37
+ Parameters
38
+ ----------
39
+ queue : Queue
40
+ The queue to receive images from.
41
+ fps_queue : Queue
42
+ The queue to put the calculated fps.
43
+ label : tk.Label
44
+ The label to update with images.
45
+ fps_label : tk.Label
46
+ The label to show fps.
47
+ """
48
+ while True:
49
+ try:
50
+ if not queue.empty():
51
+ label.after(
52
+ 0,
53
+ update_image,
54
+ postprocess_image(queue.get(block=False), output_type="pil")[0],
55
+ label,
56
+ )
57
+ if not fps_queue.empty():
58
+ fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}")
59
+
60
+ time.sleep(0.0005)
61
+ except KeyboardInterrupt:
62
+ return
63
+
64
+
65
+ def receive_images(queue: Queue, fps_queue: Queue) -> None:
66
+ """
67
+ Setup the Tkinter window and start the thread to receive images.
68
+
69
+ Parameters
70
+ ----------
71
+ queue : Queue
72
+ The queue to receive images from.
73
+ fps_queue : Queue
74
+ The queue to put the calculated fps.
75
+ """
76
+ root = tk.Tk()
77
+ root.title("Image Viewer")
78
+ label = tk.Label(root)
79
+ fps_label = tk.Label(root, text="FPS: 0")
80
+ label.grid(column=0)
81
+ fps_label.grid(column=1)
82
+
83
+ def on_closing():
84
+ print("window closed")
85
+ root.quit() # stop event loop
86
+ return
87
+
88
+ thread = threading.Thread(
89
+ target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True
90
+ )
91
+ thread.start()
92
+
93
+ try:
94
+ root.protocol("WM_DELETE_WINDOW", on_closing)
95
+ root.mainloop()
96
+ except KeyboardInterrupt:
97
+ return
98
+
utils/wrapper.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from pathlib import Path
4
+ import traceback
5
+ from typing import List, Literal, Optional, Union, Dict
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import AutoencoderTiny, StableDiffusionPipeline
10
+ from diffusers.models.attention_processor import XFormersAttnProcessor, AttnProcessor2_0
11
+ from PIL import Image
12
+
13
+ from streamv2v import StreamV2V
14
+ from streamv2v.image_utils import postprocess_image
15
+ from streamv2v.models.attention_processor import CachedSTXFormersAttnProcessor, CachedSTAttnProcessor2_0
16
+
17
+
18
+ torch.set_grad_enabled(False)
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+
23
+ class StreamV2VWrapper:
24
+ def __init__(
25
+ self,
26
+ model_id_or_path: str,
27
+ t_index_list: List[int],
28
+ lora_dict: Optional[Dict[str, float]] = None,
29
+ output_type: Literal["pil", "pt", "np", "latent"] = "pil",
30
+ lcm_lora_id: Optional[str] = None,
31
+ vae_id: Optional[str] = None,
32
+ device: Literal["cpu", "cuda"] = "cuda",
33
+ dtype: torch.dtype = torch.float16,
34
+ frame_buffer_size: int = 1,
35
+ width: int = 512,
36
+ height: int = 512,
37
+ warmup: int = 10,
38
+ acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
39
+ do_add_noise: bool = True,
40
+ device_ids: Optional[List[int]] = None,
41
+ use_lcm_lora: bool = True,
42
+ use_tiny_vae: bool = True,
43
+ enable_similar_image_filter: bool = False,
44
+ similar_image_filter_threshold: float = 0.98,
45
+ similar_image_filter_max_skip_frame: int = 10,
46
+ use_denoising_batch: bool = True,
47
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
48
+ use_cached_attn: bool = True,
49
+ use_feature_injection: bool = True,
50
+ feature_injection_strength: float = 0.8,
51
+ feature_similarity_threshold: float = 0.98,
52
+ cache_interval: int = 4,
53
+ cache_maxframes: int = 1,
54
+ use_tome_cache: bool = True,
55
+ tome_metric: str = "keys",
56
+ tome_ratio: float = 0.5,
57
+ use_grid: bool = False,
58
+ seed: int = 2,
59
+ use_safety_checker: bool = False,
60
+ engine_dir: Optional[Union[str, Path]] = "engines",
61
+ ):
62
+ """
63
+ Initializes the StreamV2VWrapper.
64
+
65
+ Parameters
66
+ ----------
67
+ model_id_or_path : str
68
+ The model identifier or path to load.
69
+ t_index_list : List[int]
70
+ The list of indices to use for inference.
71
+ lora_dict : Optional[Dict[str, float]], optional
72
+ Dictionary of LoRA names and their corresponding scales,
73
+ by default None. Example: {'LoRA_1': 0.5, 'LoRA_2': 0.7, ...}
74
+ output_type : Literal["pil", "pt", "np", "latent"], optional
75
+ The type of output image, by default "pil".
76
+ lcm_lora_id : Optional[str], optional
77
+ The identifier for the LCM-LoRA to load, by default None.
78
+ If None, the default LCM-LoRA ("latent-consistency/lcm-lora-sdv1-5") is used.
79
+ vae_id : Optional[str], optional
80
+ The identifier for the VAE to load, by default None.
81
+ If None, the default TinyVAE ("madebyollin/taesd") is used.
82
+ device : Literal["cpu", "cuda"], optional
83
+ The device to use for inference, by default "cuda".
84
+ dtype : torch.dtype, optional
85
+ The data type for inference, by default torch.float16.
86
+ frame_buffer_size : int, optional
87
+ The size of the frame buffer for denoising batch, by default 1.
88
+ width : int, optional
89
+ The width of the image, by default 512.
90
+ height : int, optional
91
+ The height of the image, by default 512.
92
+ warmup : int, optional
93
+ The number of warmup steps to perform, by default 10.
94
+ acceleration : Literal["none", "xformers", "tensorrt"], optional
95
+ The acceleration method, by default "xformers".
96
+ do_add_noise : bool, optional
97
+ Whether to add noise during denoising steps, by default True.
98
+ device_ids : Optional[List[int]], optional
99
+ List of device IDs to use for DataParallel, by default None.
100
+ use_lcm_lora : bool, optional
101
+ Whether to use LCM-LoRA, by default True.
102
+ use_tiny_vae : bool, optional
103
+ Whether to use TinyVAE, by default True.
104
+ enable_similar_image_filter : bool, optional
105
+ Whether to enable similar image filtering, by default False.
106
+ similar_image_filter_threshold : float, optional
107
+ The threshold for the similar image filter, by default 0.98.
108
+ similar_image_filter_max_skip_frame : int, optional
109
+ The maximum number of frames to skip for similar image filter, by default 10.
110
+ use_denoising_batch : bool, optional
111
+ Whether to use denoising batch, by default True.
112
+ cfg_type : Literal["none", "full", "self", "initialize"], optional
113
+ The CFG type for img2img mode, by default "self".
114
+ use_cached_attn : bool, optional
115
+ Whether to cache self-attention maps from previous frames to improve temporal consistency, by default True.
116
+ use_feature_injection : bool, optional
117
+ Whether to use feature maps from previous frames to improve temporal consistency, by default True.
118
+ feature_injection_strength : float, optional
119
+ The strength of feature injection, by default 0.8.
120
+ feature_similarity_threshold : float, optional
121
+ The similarity threshold for feature injection, by default 0.98.
122
+ cache_interval : int, optional
123
+ The interval at which to cache attention maps, by default 4.
124
+ cache_maxframes : int, optional
125
+ The maximum number of frames to cache attention maps, by default 1.
126
+ use_tome_cache : bool, optional
127
+ Whether to use Tome caching, by default True.
128
+ tome_metric : str, optional
129
+ The metric to use for Tome, by default "keys".
130
+ tome_ratio : float, optional
131
+ The ratio for Tome, by default 0.5.
132
+ use_grid : bool, optional
133
+ Whether to use grid, by default False.
134
+ seed : int, optional
135
+ The seed for random number generation, by default 2.
136
+ use_safety_checker : bool, optional
137
+ Whether to use a safety checker, by default False.
138
+ engine_dir : Optional[Union[str, Path]], optional
139
+ The directory for the engine, by default "engines".
140
+ """
141
+ # TODO: Test SD turbo
142
+ self.sd_turbo = "turbo" in model_id_or_path
143
+
144
+ assert use_denoising_batch, "vid2vid mode must use denoising batch for now."
145
+
146
+ self.device = device
147
+ self.dtype = dtype
148
+ self.width = width
149
+ self.height = height
150
+ self.output_type = output_type
151
+ self.frame_buffer_size = frame_buffer_size
152
+ self.batch_size = (
153
+ len(t_index_list) * frame_buffer_size
154
+ if use_denoising_batch
155
+ else frame_buffer_size
156
+ )
157
+
158
+ self.use_denoising_batch = use_denoising_batch
159
+ self.use_cached_attn = use_cached_attn
160
+ self.use_feature_injection = use_feature_injection
161
+ self.feature_injection_strength = feature_injection_strength
162
+ self.feature_similarity_threshold = feature_similarity_threshold
163
+ self.cache_interval = cache_interval
164
+ self.cache_maxframes = cache_maxframes
165
+ self.use_tome_cache = use_tome_cache
166
+ self.tome_metric = tome_metric
167
+ self.tome_ratio = tome_ratio
168
+ self.use_grid = use_grid
169
+ self.use_safety_checker = use_safety_checker
170
+
171
+ self.stream: StreamV2V = self._load_model(
172
+ model_id_or_path=model_id_or_path,
173
+ lora_dict=lora_dict,
174
+ lcm_lora_id=lcm_lora_id,
175
+ vae_id=vae_id,
176
+ t_index_list=t_index_list,
177
+ acceleration=acceleration,
178
+ warmup=warmup,
179
+ do_add_noise=do_add_noise,
180
+ use_lcm_lora=use_lcm_lora,
181
+ use_tiny_vae=use_tiny_vae,
182
+ cfg_type=cfg_type,
183
+ seed=seed,
184
+ engine_dir=engine_dir,
185
+ )
186
+
187
+ if device_ids is not None:
188
+ self.stream.unet = torch.nn.DataParallel(
189
+ self.stream.unet, device_ids=device_ids
190
+ )
191
+
192
+ if enable_similar_image_filter:
193
+ self.stream.enable_similar_image_filter(similar_image_filter_threshold, similar_image_filter_max_skip_frame)
194
+
195
+ def prepare(
196
+ self,
197
+ prompt: str,
198
+ negative_prompt: str = "",
199
+ num_inference_steps: int = 50,
200
+ guidance_scale: float = 1.2,
201
+ delta: float = 1.0,
202
+ ) -> None:
203
+ """
204
+ Prepares the model for inference.
205
+
206
+ Parameters
207
+ ----------
208
+ prompt : str
209
+ The prompt to generate images from.
210
+ num_inference_steps : int, optional
211
+ The number of inference steps to perform, by default 50.
212
+ guidance_scale : float, optional
213
+ The guidance scale to use, by default 1.2.
214
+ delta : float, optional
215
+ The delta multiplier of virtual residual noise,
216
+ by default 1.0.
217
+ """
218
+ self.stream.prepare(
219
+ prompt,
220
+ negative_prompt,
221
+ num_inference_steps=num_inference_steps,
222
+ guidance_scale=guidance_scale,
223
+ delta=delta,
224
+ )
225
+
226
+ def __call__(
227
+ self,
228
+ image: Union[str, Image.Image, torch.Tensor],
229
+ prompt: Optional[str] = None,
230
+ ) -> Union[Image.Image, List[Image.Image]]:
231
+ """
232
+ Performs img2img
233
+
234
+ Parameters
235
+ ----------
236
+ image : Optional[Union[str, Image.Image, torch.Tensor]]
237
+ The image to generate from.
238
+ prompt : Optional[str]
239
+ The prompt to generate images from.
240
+
241
+ Returns
242
+ -------
243
+ Union[Image.Image, List[Image.Image]]
244
+ The generated image.
245
+ """
246
+ return self.img2img(image, prompt)
247
+
248
+ def img2img(
249
+ self, image: Union[str, Image.Image, torch.Tensor], prompt: Optional[str] = None
250
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
251
+ """
252
+ Performs img2img.
253
+
254
+ Parameters
255
+ ----------
256
+ image : Union[str, Image.Image, torch.Tensor]
257
+ The image to generate from.
258
+
259
+ Returns
260
+ -------
261
+ Image.Image
262
+ The generated image.
263
+ """
264
+ if prompt is not None:
265
+ self.stream.update_prompt(prompt)
266
+
267
+ if isinstance(image, str) or isinstance(image, Image.Image):
268
+ image = self.preprocess_image(image)
269
+
270
+ image_tensor = self.stream(image)
271
+ image = self.postprocess_image(image_tensor, output_type=self.output_type)
272
+
273
+ if self.use_safety_checker:
274
+ safety_checker_input = self.feature_extractor(
275
+ image, return_tensors="pt"
276
+ ).to(self.device)
277
+ _, has_nsfw_concept = self.safety_checker(
278
+ images=image_tensor.to(self.dtype),
279
+ clip_input=safety_checker_input.pixel_values.to(self.dtype),
280
+ )
281
+ image = self.nsfw_fallback_img if has_nsfw_concept[0] else image
282
+
283
+ return image
284
+
285
+ def preprocess_image(self, image: Union[str, Image.Image]) -> torch.Tensor:
286
+ """
287
+ Preprocesses the image.
288
+
289
+ Parameters
290
+ ----------
291
+ image : Union[str, Image.Image, torch.Tensor]
292
+ The image to preprocess.
293
+
294
+ Returns
295
+ -------
296
+ torch.Tensor
297
+ The preprocessed image.
298
+ """
299
+ if isinstance(image, str):
300
+ image = Image.open(image).convert("RGB").resize((self.width, self.height))
301
+ if isinstance(image, Image.Image):
302
+ image = image.convert("RGB").resize((self.width, self.height))
303
+
304
+ return self.stream.image_processor.preprocess(
305
+ image, self.height, self.width
306
+ ).to(device=self.device, dtype=self.dtype)
307
+
308
+ def postprocess_image(
309
+ self, image_tensor: torch.Tensor, output_type: str = "pil"
310
+ ) -> Union[Image.Image, List[Image.Image], torch.Tensor, np.ndarray]:
311
+ """
312
+ Postprocesses the image.
313
+
314
+ Parameters
315
+ ----------
316
+ image_tensor : torch.Tensor
317
+ The image tensor to postprocess.
318
+
319
+ Returns
320
+ -------
321
+ Union[Image.Image, List[Image.Image]]
322
+ The postprocessed image.
323
+ """
324
+ if self.frame_buffer_size > 1:
325
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)
326
+ else:
327
+ return postprocess_image(image_tensor.cpu(), output_type=output_type)[0]
328
+
329
+ def _load_model(
330
+ self,
331
+ model_id_or_path: str,
332
+ t_index_list: List[int],
333
+ lora_dict: Optional[Dict[str, float]] = None,
334
+ lcm_lora_id: Optional[str] = None,
335
+ vae_id: Optional[str] = None,
336
+ acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
337
+ warmup: int = 10,
338
+ do_add_noise: bool = True,
339
+ use_lcm_lora: bool = True,
340
+ use_tiny_vae: bool = True,
341
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
342
+ seed: int = 2,
343
+ engine_dir: Optional[Union[str, Path]] = "engines",
344
+ ) -> StreamV2V:
345
+ """
346
+ Loads the model.
347
+
348
+ This method does the following:
349
+
350
+ 1. Loads the model from the model_id_or_path.
351
+ 2. Loads and fuses the LCM-LoRA model from the lcm_lora_id if needed.
352
+ 3. Loads the VAE model from the vae_id if needed.
353
+ 4. Enables acceleration if needed.
354
+ 5. Prepares the model for inference.
355
+ 6. Load the safety checker if needed.
356
+
357
+ Parameters
358
+ ----------
359
+ model_id_or_path : str
360
+ The model id or path to load.
361
+ t_index_list : List[int]
362
+ The t_index_list to use for inference.
363
+ lora_dict : Optional[Dict[str, float]], optional
364
+ The lora_dict to load, by default None.
365
+ Keys are the LoRA names and values are the LoRA scales.
366
+ Example: {'LoRA_1' : 0.5 , 'LoRA_2' : 0.7 ,...}
367
+ lcm_lora_id : Optional[str], optional
368
+ The lcm_lora_id to load, by default None.
369
+ vae_id : Optional[str], optional
370
+ The vae_id to load, by default None.
371
+ acceleration : Literal["none", "xfomers", "sfast", "tensorrt"], optional
372
+ The acceleration method, by default "tensorrt".
373
+ warmup : int, optional
374
+ The number of warmup steps to perform, by default 10.
375
+ do_add_noise : bool, optional
376
+ Whether to add noise for following denoising steps or not,
377
+ by default True.
378
+ use_lcm_lora : bool, optional
379
+ Whether to use LCM-LoRA or not, by default True.
380
+ use_tiny_vae : bool, optional
381
+ Whether to use TinyVAE or not, by default True.
382
+ cfg_type : Literal["none", "full", "self", "initialize"],
383
+ optional
384
+ The cfg_type for img2img mode, by default " seed : int, optional
385
+ ".
386
+ seed : int, optional
387
+ The seed, by default 2.
388
+
389
+ Returns
390
+ -------
391
+ StreamV2V
392
+ The loaded model.
393
+ """
394
+
395
+ try: # Load from local directory
396
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
397
+ model_id_or_path,
398
+ ).to(device=self.device, dtype=self.dtype)
399
+
400
+ except ValueError: # Load from huggingface
401
+ pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(
402
+ model_id_or_path,
403
+ ).to(device=self.device, dtype=self.dtype)
404
+ except Exception: # No model found
405
+ traceback.print_exc()
406
+ print("Model load has failed. Doesn't exist.")
407
+ exit()
408
+
409
+ stream = StreamV2V(
410
+ pipe=pipe,
411
+ t_index_list=t_index_list,
412
+ torch_dtype=self.dtype,
413
+ width=self.width,
414
+ height=self.height,
415
+ do_add_noise=do_add_noise,
416
+ frame_buffer_size=self.frame_buffer_size,
417
+ use_denoising_batch=self.use_denoising_batch,
418
+ cfg_type=cfg_type,
419
+ )
420
+ if not self.sd_turbo:
421
+ if use_lcm_lora:
422
+ if lcm_lora_id is not None:
423
+ stream.load_lcm_lora(
424
+ pretrained_model_name_or_path_or_dict=lcm_lora_id,
425
+ adapter_name="lcm")
426
+ else:
427
+ stream.load_lcm_lora(
428
+ pretrained_model_name_or_path_or_dict="latent-consistency/lcm-lora-sdv1-5",
429
+ adapter_name="lcm"
430
+ )
431
+
432
+ if lora_dict is not None:
433
+ for lora_name, lora_scale in lora_dict.items():
434
+ stream.load_lora(lora_name)
435
+
436
+ if use_tiny_vae:
437
+ if vae_id is not None:
438
+ stream.vae = AutoencoderTiny.from_pretrained(vae_id).to(
439
+ device=pipe.device, dtype=pipe.dtype
440
+ )
441
+ else:
442
+ stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to(
443
+ device=pipe.device, dtype=pipe.dtype
444
+ )
445
+
446
+ try:
447
+ if acceleration == "xformers":
448
+ stream.pipe.enable_xformers_memory_efficient_attention()
449
+ if self.use_cached_attn:
450
+ attn_processors = stream.pipe.unet.attn_processors
451
+ new_attn_processors = {}
452
+ for key, attn_processor in attn_processors.items():
453
+ assert isinstance(attn_processor, XFormersAttnProcessor), \
454
+ "We only replace 'XFormersAttnProcessor' to 'CachedSTXFormersAttnProcessor'"
455
+ new_attn_processors[key] = CachedSTXFormersAttnProcessor(name=key,
456
+ use_feature_injection=self.use_feature_injection,
457
+ feature_injection_strength=self.feature_injection_strength,
458
+ feature_similarity_threshold=self.feature_similarity_threshold,
459
+ interval=self.cache_interval,
460
+ max_frames=self.cache_maxframes,
461
+ use_tome_cache=self.use_tome_cache,
462
+ tome_metric=self.tome_metric,
463
+ tome_ratio=self.tome_ratio,
464
+ use_grid=self.use_grid)
465
+ stream.pipe.unet.set_attn_processor(new_attn_processors)
466
+
467
+ if acceleration == "tensorrt":
468
+ if self.use_cached_attn:
469
+ raise NotImplementedError("TensorRT seems not support the costom attention_processor")
470
+ else:
471
+ stream.pipe.enable_xformers_memory_efficient_attention()
472
+ if self.use_cached_attn:
473
+ attn_processors = stream.pipe.unet.attn_processors
474
+ new_attn_processors = {}
475
+ for key, attn_processor in attn_processors.items():
476
+ assert isinstance(attn_processor, XFormersAttnProcessor), \
477
+ "We only replace 'XFormersAttnProcessor' to 'CachedSTXFormersAttnProcessor'"
478
+ new_attn_processors[key] = CachedSTXFormersAttnProcessor(name=key,
479
+ use_feature_injection=self.use_feature_injection,
480
+ feature_injection_strength=self.feature_injection_strength,
481
+ feature_similarity_threshold=self.feature_similarity_threshold,
482
+ interval=self.cache_interval,
483
+ max_frames=self.cache_maxframes,
484
+ use_tome_cache=self.use_tome_cache,
485
+ tome_metric=self.tome_metric,
486
+ tome_ratio=self.tome_ratio,
487
+ use_grid=self.use_grid)
488
+ stream.pipe.unet.set_attn_processor(new_attn_processors)
489
+
490
+ from polygraphy import cuda
491
+ from streamv2v.acceleration.tensorrt import (
492
+ TorchVAEEncoder,
493
+ compile_unet,
494
+ compile_vae_decoder,
495
+ compile_vae_encoder,
496
+ )
497
+ from streamv2v.acceleration.tensorrt.engine import (
498
+ AutoencoderKLEngine,
499
+ UNet2DConditionModelEngine,
500
+ )
501
+ from streamv2v.acceleration.tensorrt.models import (
502
+ VAE,
503
+ UNet,
504
+ VAEEncoder,
505
+ )
506
+
507
+ def create_prefix(
508
+ model_id_or_path: str,
509
+ max_batch_size: int,
510
+ min_batch_size: int,
511
+ ):
512
+ maybe_path = Path(model_id_or_path)
513
+ if maybe_path.exists():
514
+ return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--cache--{self.use_cached_attn}"
515
+ else:
516
+ return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--cache--{self.use_cached_attn}"
517
+
518
+ engine_dir = Path(engine_dir)
519
+ unet_path = os.path.join(
520
+ engine_dir,
521
+ create_prefix(
522
+ model_id_or_path=model_id_or_path,
523
+ max_batch_size=stream.trt_unet_batch_size,
524
+ min_batch_size=stream.trt_unet_batch_size,
525
+ ),
526
+ "unet.engine",
527
+ )
528
+ vae_encoder_path = os.path.join(
529
+ engine_dir,
530
+ create_prefix(
531
+ model_id_or_path=model_id_or_path,
532
+ max_batch_size=stream.frame_bff_size,
533
+ min_batch_size=stream.frame_bff_size,
534
+ ),
535
+ "vae_encoder.engine",
536
+ )
537
+ vae_decoder_path = os.path.join(
538
+ engine_dir,
539
+ create_prefix(
540
+ model_id_or_path=model_id_or_path,
541
+ max_batch_size=stream.frame_bff_size,
542
+ min_batch_size=stream.frame_bff_size,
543
+ ),
544
+ "vae_decoder.engine",
545
+ )
546
+
547
+ if not os.path.exists(unet_path):
548
+ os.makedirs(os.path.dirname(unet_path), exist_ok=True)
549
+ unet_model = UNet(
550
+ fp16=True,
551
+ device=stream.device,
552
+ max_batch_size=stream.trt_unet_batch_size,
553
+ min_batch_size=stream.trt_unet_batch_size,
554
+ embedding_dim=stream.text_encoder.config.hidden_size,
555
+ unet_dim=stream.unet.config.in_channels,
556
+ )
557
+ compile_unet(
558
+ stream.unet,
559
+ unet_model,
560
+ unet_path + ".onnx",
561
+ unet_path + ".opt.onnx",
562
+ unet_path,
563
+ opt_batch_size=stream.trt_unet_batch_size,
564
+ )
565
+
566
+ if not os.path.exists(vae_decoder_path):
567
+ os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True)
568
+ stream.vae.forward = stream.vae.decode
569
+ vae_decoder_model = VAE(
570
+ device=stream.device,
571
+ max_batch_size=stream.frame_bff_size,
572
+ min_batch_size=stream.frame_bff_size,
573
+ )
574
+ compile_vae_decoder(
575
+ stream.vae,
576
+ vae_decoder_model,
577
+ vae_decoder_path + ".onnx",
578
+ vae_decoder_path + ".opt.onnx",
579
+ vae_decoder_path,
580
+ opt_batch_size=stream.frame_bff_size,
581
+ )
582
+ delattr(stream.vae, "forward")
583
+
584
+ if not os.path.exists(vae_encoder_path):
585
+ os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True)
586
+ vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda"))
587
+ vae_encoder_model = VAEEncoder(
588
+ device=stream.device,
589
+ max_batch_size=stream.frame_bff_size,
590
+ min_batch_size=stream.frame_bff_size,
591
+ )
592
+ compile_vae_encoder(
593
+ vae_encoder,
594
+ vae_encoder_model,
595
+ vae_encoder_path + ".onnx",
596
+ vae_encoder_path + ".opt.onnx",
597
+ vae_encoder_path,
598
+ opt_batch_size=stream.frame_bff_size,
599
+ )
600
+
601
+ cuda_steram = cuda.Stream()
602
+
603
+ vae_config = stream.vae.config
604
+ vae_dtype = stream.vae.dtype
605
+
606
+ stream.unet = UNet2DConditionModelEngine(
607
+ unet_path, cuda_steram, use_cuda_graph=False
608
+ )
609
+ stream.vae = AutoencoderKLEngine(
610
+ vae_encoder_path,
611
+ vae_decoder_path,
612
+ cuda_steram,
613
+ stream.pipe.vae_scale_factor,
614
+ use_cuda_graph=False,
615
+ )
616
+ setattr(stream.vae, "config", vae_config)
617
+ setattr(stream.vae, "dtype", vae_dtype)
618
+
619
+ gc.collect()
620
+ torch.cuda.empty_cache()
621
+
622
+ print("TensorRT acceleration enabled.")
623
+ if acceleration == "sfast":
624
+ if self.use_cached_attn:
625
+ raise NotImplementedError
626
+ from streamv2v.acceleration.sfast import (
627
+ accelerate_with_stable_fast,
628
+ )
629
+
630
+ stream = accelerate_with_stable_fast(stream)
631
+ print("StableFast acceleration enabled.")
632
+ except Exception:
633
+ traceback.print_exc()
634
+ print("Acceleration has failed. Falling back to normal mode.")
635
+
636
+ if seed < 0: # Random seed
637
+ seed = np.random.randint(0, 1000000)
638
+
639
+ stream.prepare(
640
+ "",
641
+ "",
642
+ num_inference_steps=50,
643
+ guidance_scale=1.1
644
+ if stream.cfg_type in ["full", "self", "initialize"]
645
+ else 1.0,
646
+ generator=torch.manual_seed(seed),
647
+ seed=seed,
648
+ )
649
+
650
+ if self.use_safety_checker:
651
+ from transformers import CLIPFeatureExtractor
652
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
653
+ StableDiffusionSafetyChecker,
654
+ )
655
+
656
+ self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
657
+ "CompVis/stable-diffusion-safety-checker"
658
+ ).to(pipe.device)
659
+ self.feature_extractor = CLIPFeatureExtractor.from_pretrained(
660
+ "openai/clip-vit-base-patch32"
661
+ )
662
+ self.nsfw_fallback_img = Image.new("RGB", (512, 512), (0, 0, 0))
663
+
664
+ return stream
vid2vid.py CHANGED
@@ -1,13 +1,6 @@
1
  import sys
2
  import os
3
 
4
- sys.path.append(
5
- os.path.join(
6
- os.path.dirname(__file__),
7
- "..",
8
- )
9
- )
10
-
11
  from utils.wrapper import StreamV2VWrapper
12
 
13
  import torch
 
1
  import sys
2
  import os
3
 
 
 
 
 
 
 
 
4
  from utils.wrapper import StreamV2VWrapper
5
 
6
  import torch