Phr00t commited on
Commit
1faaad3
·
verified ·
1 Parent(s): 7a703ee

Update Custom-Advanced-VACE-Node/nodes_utility.py

Browse files

* use empty_frame_level
* support end frame easing

Custom-Advanced-VACE-Node/nodes_utility.py CHANGED
@@ -1,703 +1,708 @@
1
- import torch
2
- import numpy as np
3
- from comfy.utils import common_upscale
4
- from .utils import log
5
- from einops import rearrange
6
-
7
- try:
8
- from server import PromptServer
9
- except:
10
- PromptServer = None
11
-
12
- VAE_STRIDE = (4, 8, 8)
13
- PATCH_SIZE = (1, 2, 2)
14
-
15
- class WanVideoImageResizeToClosest:
16
- @classmethod
17
- def INPUT_TYPES(s):
18
- return {"required": {
19
- "image": ("IMAGE", {"tooltip": "Image to resize"}),
20
- "generation_width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
21
- "generation_height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
22
- "aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],),
23
- },
24
- }
25
-
26
- RETURN_TYPES = ("IMAGE", "INT", "INT", )
27
- RETURN_NAMES = ("image","width","height",)
28
- FUNCTION = "process"
29
- CATEGORY = "WanVideoWrapper"
30
- DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code"
31
-
32
- def process(self, image, generation_width, generation_height, aspect_ratio_preservation ):
33
-
34
- H, W = image.shape[1], image.shape[2]
35
- max_area = generation_width * generation_height
36
-
37
- crop = "disabled"
38
-
39
- if aspect_ratio_preservation == "keep_input":
40
- aspect_ratio = H / W
41
- elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new":
42
- aspect_ratio = generation_height / generation_width
43
- if aspect_ratio_preservation == "crop_to_new":
44
- crop = "center"
45
-
46
- lat_h = round(
47
- np.sqrt(max_area * aspect_ratio) // VAE_STRIDE[1] //
48
- PATCH_SIZE[1] * PATCH_SIZE[1])
49
- lat_w = round(
50
- np.sqrt(max_area / aspect_ratio) // VAE_STRIDE[2] //
51
- PATCH_SIZE[2] * PATCH_SIZE[2])
52
- h = lat_h * VAE_STRIDE[1]
53
- w = lat_w * VAE_STRIDE[2]
54
-
55
- resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1)
56
-
57
- return (resized_image, w, h)
58
-
59
- class ExtractStartFramesForContinuations:
60
- @classmethod
61
- def INPUT_TYPES(s):
62
- return {
63
- "required": {
64
- "input_video_frames": ("IMAGE", {"tooltip": "Input video frames to extract the start frames from."}),
65
- "num_frames": ("INT", {"default": 10, "min": 1, "max": 1024, "step": 1, "tooltip": "Number of frames to get from the start of the video."}),
66
- },
67
- }
68
-
69
- RETURN_TYPES = ("IMAGE",)
70
- RETURN_NAMES = ("start_frames",)
71
- FUNCTION = "get_start_frames"
72
- CATEGORY = "WanVideoWrapper"
73
- DESCRIPTION = "Extracts the first N frames from a video sequence for continuations."
74
-
75
- def get_start_frames(self, input_video_frames, num_frames):
76
- if input_video_frames is None or input_video_frames.shape[0] == 0:
77
- log.warning("Input video frames are empty. Returning an empty tensor.")
78
- if input_video_frames is not None:
79
- return (torch.empty((0,) + input_video_frames.shape[1:], dtype=input_video_frames.dtype),)
80
- else:
81
- # Return a tensor with 4 dimensions, as expected for an IMAGE type.
82
- return (torch.empty((0, 64, 64, 3), dtype=torch.float32),)
83
-
84
- total_frames = input_video_frames.shape[0]
85
- num_to_get = min(num_frames, total_frames)
86
-
87
- if num_to_get < num_frames:
88
- log.warning(f"Requested {num_frames} frames, but input video only has {total_frames} frames. Returning first {num_to_get} frames.")
89
-
90
- start_frames = input_video_frames[:num_to_get]
91
-
92
- return (start_frames.cpu().float(),)
93
-
94
- class WanVideoVACEStartToEndFrame:
95
- @classmethod
96
- def INPUT_TYPES(s):
97
- return {"required": {
98
- "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
99
- "empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}),
100
- },
101
- "optional": {
102
- "start_image": ("IMAGE",),
103
- "end_image": ("IMAGE",),
104
- "control_images": ("IMAGE",),
105
- "inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}),
106
- "start_index": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Index to start from"}),
107
- "end_index": ("INT", {"default": -1, "min": -10000, "max": 10000, "step": 1, "tooltip": "Index to end at"}),
108
- "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01, "tooltip": "How much does the control images apply?"}),
109
- "control_ease": ("INT", {"default": 8.0, "min": 0.0, "max": 81.0, "step": 1, "tooltip": "How many frames to ease in the control video?"}),
110
- },
111
- }
112
-
113
- RETURN_TYPES = ("IMAGE", "MASK", )
114
- RETURN_NAMES = ("images", "masks",)
115
- FUNCTION = "process"
116
- CATEGORY = "WanVideoWrapper"
117
- DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE"
118
-
119
- def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None, start_index=0, end_index=-1, control_strength=1.0, control_ease=8):
120
-
121
- device = start_image.device if start_image is not None else end_image.device
122
- B, H, W, C = start_image.shape if start_image is not None else end_image.shape
123
-
124
- if control_strength < 1.0 and control_images is not None:
125
- # strength happens at much smaller number
126
- control_strength *= 2.0
127
- control_strength = control_strength * control_strength / 8.0
128
- control_images = torch.lerp(torch.ones((control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3])) * 0.5, control_images, control_strength)
129
- # if we have a start image, don't immediately apply control strength.. ease it in
130
- if start_image is not None and num_frames > control_ease and control_ease > 0:
131
- empty_frame = torch.ones((1, control_images.shape[1], control_images.shape[2], control_images.shape[3])) * 0.5
132
- for i in range(1, control_ease + 1):
133
- control_images[i] = torch.lerp(control_images[i], empty_frame, (control_ease - i) / (1 + control_ease))
134
-
135
- if start_image is None and end_image is None and control_images is not None:
136
- if control_images.shape[0] >= num_frames:
137
- control_images = control_images[:num_frames]
138
- elif control_images.shape[0] < num_frames:
139
- # padd with empty_frame_level frames
140
- padding = torch.ones((num_frames - control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3]), device=control_images.device) * empty_frame_level
141
- control_images = torch.cat([control_images, padding], dim=0)
142
- return (control_images.cpu().float(), torch.zeros_like(control_images[:, :, :, 0]).cpu().float())
143
-
144
- # Convert negative end_index to positive
145
- if end_index < 0:
146
- end_index = num_frames + end_index
147
-
148
- # Create output batch with empty frames
149
- out_batch = torch.ones((num_frames, H, W, 3), device=device) * empty_frame_level
150
-
151
- # Create mask tensor with proper dimensions
152
- masks = torch.ones((num_frames, H, W), device=device)
153
-
154
- # Pre-process all images at once to avoid redundant work
155
- if end_image is not None and (end_image.shape[1] != H or end_image.shape[2] != W):
156
- end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
157
-
158
- if control_images is not None and (control_images.shape[1] != H or control_images.shape[2] != W):
159
- control_images = common_upscale(control_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
160
-
161
- # Place start image at start_index
162
- if start_image is not None:
163
- frames_to_copy = min(start_image.shape[0], num_frames - start_index)
164
- if frames_to_copy > 0:
165
- out_batch[start_index:start_index + frames_to_copy] = start_image[:frames_to_copy]
166
- masks[start_index:start_index + frames_to_copy] = 0
167
-
168
- # Place end image at end_index
169
- if end_image is not None:
170
- # Calculate where to start placing end images
171
- end_start = end_index - end_image.shape[0] + 1
172
- if end_start < 0: # Handle case where end images won't all fit
173
- end_image = end_image[abs(end_start):]
174
- end_start = 0
175
-
176
- frames_to_copy = min(end_image.shape[0], num_frames - end_start)
177
- if frames_to_copy > 0:
178
- out_batch[end_start:end_start + frames_to_copy] = end_image[:frames_to_copy]
179
- masks[end_start:end_start + frames_to_copy] = 0
180
-
181
- # Apply control images to remaining frames that don't have start or end images
182
- if control_images is not None:
183
- # Create a mask of frames that are still empty (mask == 1)
184
- empty_frames = masks.sum(dim=(1, 2)) > 0.5 * H * W
185
-
186
- if empty_frames.any():
187
- # Only apply control images where they exist
188
- control_length = control_images.shape[0]
189
- for frame_idx in range(num_frames):
190
- if empty_frames[frame_idx] and frame_idx < control_length:
191
- out_batch[frame_idx] = control_images[frame_idx]
192
-
193
- # Apply inpaint mask if provided
194
- if inpaint_mask is not None:
195
- inpaint_mask = common_upscale(inpaint_mask.unsqueeze(1), W, H, "nearest-exact", "disabled").squeeze(1).to(device)
196
-
197
- # Handle different mask lengths efficiently
198
- if inpaint_mask.shape[0] > num_frames:
199
- inpaint_mask = inpaint_mask[:num_frames]
200
- elif inpaint_mask.shape[0] < num_frames:
201
- repeat_factor = (num_frames + inpaint_mask.shape[0] - 1) // inpaint_mask.shape[0] # Ceiling division
202
- inpaint_mask = inpaint_mask.repeat(repeat_factor, 1, 1)[:num_frames]
203
-
204
- # Apply mask in one operation
205
- masks = inpaint_mask * masks
206
-
207
- return (out_batch.cpu().float(), masks.cpu().float())
208
-
209
-
210
- class CreateCFGScheduleFloatList:
211
- @classmethod
212
- def INPUT_TYPES(s):
213
- return {"required": {
214
- "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
215
- "cfg_scale_start": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
216
- "cfg_scale_end": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
217
- "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
218
- "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
219
- "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
220
- },
221
- "hidden": {
222
- "unique_id": "UNIQUE_ID",
223
- },
224
- }
225
-
226
- RETURN_TYPES = ("FLOAT", )
227
- RETURN_NAMES = ("float_list",)
228
- FUNCTION = "process"
229
- CATEGORY = "WanVideoWrapper"
230
- DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule cfg scale for the steps, outside the set range cfg is set to 1.0"
231
-
232
- def process(self, steps, cfg_scale_start, cfg_scale_end, interpolation, start_percent, end_percent, unique_id):
233
-
234
- # Create a list of floats for the cfg schedule
235
- cfg_list = [1.0] * steps
236
- start_idx = min(int(steps * start_percent), steps - 1)
237
- end_idx = min(int(steps * end_percent), steps - 1)
238
-
239
- for i in range(start_idx, end_idx + 1):
240
- if i >= steps:
241
- break
242
-
243
- if end_idx == start_idx:
244
- t = 0
245
- else:
246
- t = (i - start_idx) / (end_idx - start_idx)
247
-
248
- if interpolation == "linear":
249
- factor = t
250
- elif interpolation == "ease_in":
251
- factor = t * t
252
- elif interpolation == "ease_out":
253
- factor = t * (2 - t)
254
-
255
- cfg_list[i] = round(cfg_scale_start + factor * (cfg_scale_end - cfg_scale_start), 2)
256
-
257
- # If start_percent > 0, always include the first step
258
- if start_percent > 0:
259
- cfg_list[0] = 1.0
260
-
261
- if unique_id and PromptServer is not None:
262
- try:
263
- PromptServer.instance.send_progress_text(
264
- f"{cfg_list}",
265
- unique_id
266
- )
267
- except:
268
- pass
269
-
270
- return (cfg_list,)
271
-
272
- class CreateScheduleFloatList:
273
- @classmethod
274
- def INPUT_TYPES(s):
275
- return {"required": {
276
- "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
277
- "start_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
278
- "end_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
279
- "default_value": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "round": 0.01, "tooltip": "Default value to use for the steps"}),
280
- "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
281
- "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
282
- "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
283
- },
284
- "hidden": {
285
- "unique_id": "UNIQUE_ID",
286
- },
287
- }
288
-
289
- RETURN_TYPES = ("FLOAT", )
290
- RETURN_NAMES = ("float_list",)
291
- FUNCTION = "process"
292
- CATEGORY = "WanVideoWrapper"
293
- DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule things like cfg and lora scale per step"
294
-
295
- def process(self, steps, start_value, end_value, default_value,interpolation, start_percent, end_percent, unique_id):
296
-
297
- # Create a list of floats for the cfg schedule
298
- cfg_list = [default_value] * steps
299
- start_idx = min(int(steps * start_percent), steps - 1)
300
- end_idx = min(int(steps * end_percent), steps - 1)
301
-
302
- for i in range(start_idx, end_idx + 1):
303
- if i >= steps:
304
- break
305
-
306
- if end_idx == start_idx:
307
- t = 0
308
- else:
309
- t = (i - start_idx) / (end_idx - start_idx)
310
-
311
- if interpolation == "linear":
312
- factor = t
313
- elif interpolation == "ease_in":
314
- factor = t * t
315
- elif interpolation == "ease_out":
316
- factor = t * (2 - t)
317
-
318
- cfg_list[i] = round(start_value + factor * (end_value - start_value), 2)
319
-
320
- # If start_percent > 0, always include the first step
321
- if start_percent > 0:
322
- cfg_list[0] = default_value
323
-
324
- if unique_id and PromptServer is not None:
325
- try:
326
- PromptServer.instance.send_progress_text(
327
- f"{cfg_list}",
328
- unique_id
329
- )
330
- except:
331
- pass
332
-
333
- return (cfg_list,)
334
-
335
-
336
- class DummyComfyWanModelObject:
337
- @classmethod
338
- def INPUT_TYPES(s):
339
- return {"required": {
340
- "shift": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma shift value"}),
341
- }
342
- }
343
-
344
- RETURN_TYPES = ("MODEL", )
345
- RETURN_NAMES = ("model",)
346
- FUNCTION = "create"
347
- CATEGORY = "WanVideoWrapper"
348
- DESCRIPTION = "Helper node to create empty Wan model to use with BasicScheduler -node to get sigmas"
349
-
350
- def create(self, shift):
351
- from comfy.model_sampling import ModelSamplingDiscreteFlow
352
- class DummyModel:
353
- def get_model_object(self, name):
354
- if name == "model_sampling":
355
- model_sampling = ModelSamplingDiscreteFlow()
356
- model_sampling.set_parameters(shift=shift)
357
- return model_sampling
358
- return None
359
- return (DummyModel(),)
360
-
361
- class WanVideoLatentReScale:
362
- @classmethod
363
- def INPUT_TYPES(s):
364
- return {"required": {
365
- "samples": ("LATENT",),
366
- "direction": (["comfy_to_wrapper", "wrapper_to_comfy"], {"tooltip": "Direction to rescale latents, from comfy to wrapper or vice versa"}),
367
- }
368
- }
369
-
370
- RETURN_TYPES = ("LATENT",)
371
- RETURN_NAMES = ("samples",)
372
- FUNCTION = "encode"
373
- CATEGORY = "WanVideoWrapper"
374
- DESCRIPTION = "Rescale latents to match the expected range for encoding or decoding between native ComfyUI VAE and the WanVideoWrapper VAE."
375
-
376
- def encode(self, samples, direction):
377
- samples = samples.copy()
378
- latents = samples["samples"]
379
-
380
- if latents.shape[1] == 48:
381
- mean = [
382
- -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
383
- -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
384
- -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
385
- -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
386
- -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
387
- 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
388
- ]
389
- std = [
390
- 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
391
- 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
392
- 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
393
- 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
394
- 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
395
- 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
396
- ]
397
- else:
398
- mean = [
399
- -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
400
- 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
401
- ]
402
- std = [
403
- 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
404
- 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
405
- ]
406
- mean = torch.tensor(mean).view(1, latents.shape[1], 1, 1, 1)
407
- std = torch.tensor(std).view(1, latents.shape[1], 1, 1, 1)
408
- inv_std = (1.0 / std).view(1, latents.shape[1], 1, 1, 1)
409
- if direction == "comfy_to_wrapper":
410
- latents = (latents - mean.to(latents)) * inv_std.to(latents)
411
- elif direction == "wrapper_to_comfy":
412
- latents = latents / inv_std.to(latents) + mean.to(latents)
413
-
414
- samples["samples"] = latents
415
-
416
- return (samples,)
417
-
418
- class WanVideoSigmaToStep:
419
- @classmethod
420
- def INPUT_TYPES(s):
421
- return {"required": {
422
- "sigma": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.001}),
423
- },
424
- }
425
-
426
- RETURN_TYPES = ("INT", )
427
- RETURN_NAMES = ("step",)
428
- FUNCTION = "convert"
429
- CATEGORY = "WanVideoWrapper"
430
- DESCRIPTION = "Simply passes a float value as an integer, used to set start/end steps with sigma threshold"
431
-
432
- def convert(self, sigma):
433
- return (sigma,)
434
-
435
- class NormalizeAudioLoudness:
436
- @classmethod
437
- def INPUT_TYPES(s):
438
- return {"required": {
439
- "audio": ("AUDIO",),
440
- "lufs": ("FLOAT", {"default": -23.0, "min": -100.0, "max": 0.0, "step": 0.1, "tool": "Loudness Units relative to Full Scale, higher LUFS values (closer to 0) mean louder audio. Lower LUFS values (more negative) mean quieter audio."}),
441
- },
442
- }
443
-
444
- RETURN_TYPES = ("AUDIO", )
445
- RETURN_NAMES = ("audio", )
446
- FUNCTION = "normalize"
447
- CATEGORY = "WanVideoWrapper"
448
-
449
- def normalize(self, audio, lufs):
450
- audio_input = audio["waveform"]
451
- sample_rate = audio["sample_rate"]
452
- if audio_input.dim() == 3:
453
- audio_input = audio_input.squeeze(0)
454
- audio_input_np = audio_input.detach().transpose(0, 1).numpy().astype(np.float32)
455
- audio_input_np = np.ascontiguousarray(audio_input_np)
456
- normalized_audio = self.loudness_norm(audio_input_np, sr=sample_rate, lufs=lufs)
457
-
458
- out_audio = {"waveform": torch.from_numpy(normalized_audio).transpose(0, 1).unsqueeze(0).float(), "sample_rate": sample_rate}
459
-
460
- return (out_audio, )
461
-
462
- def loudness_norm(self, audio_array, sr=16000, lufs=-23):
463
- try:
464
- import pyloudnorm
465
- except:
466
- raise ImportError("pyloudnorm package is not installed")
467
- meter = pyloudnorm.Meter(sr)
468
- loudness = meter.integrated_loudness(audio_array)
469
- if abs(loudness) > 100:
470
- return audio_array
471
- normalized_audio = pyloudnorm.normalize.loudness(audio_array, loudness, lufs)
472
- return normalized_audio
473
-
474
- class WanVideoPassImagesFromSamples:
475
- @classmethod
476
- def INPUT_TYPES(s):
477
- return {"required": {
478
- "samples": ("LATENT",),
479
- }
480
- }
481
-
482
- RETURN_TYPES = ("IMAGE", "STRING",)
483
- RETURN_NAMES = ("images", "output_path",)
484
- OUTPUT_TOOLTIPS = ("Decoded images from the samples dictionary", "Output path if provided in the samples dictionary",)
485
- FUNCTION = "decode"
486
- CATEGORY = "WanVideoWrapper"
487
- DESCRIPTION = "Gets possible already decoded images from the samples dictionary, used with Multi/InfiniteTalk sampling"
488
-
489
- def decode(self, samples):
490
- video = samples.get("video", None)
491
- video.clamp_(-1.0, 1.0)
492
- video.add_(1.0).div_(2.0)
493
- return video.cpu().float(), samples.get("output_path", "")
494
-
495
-
496
- class FaceMaskFromPoseKeypoints:
497
- @classmethod
498
- def INPUT_TYPES(s):
499
- input_types = {
500
- "required": {
501
- "pose_kps": ("POSE_KEYPOINT",),
502
- "person_index": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Index of the person to start with"}),
503
- }
504
- }
505
- return input_types
506
- RETURN_TYPES = ("MASK",)
507
- FUNCTION = "createmask"
508
- CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
509
-
510
- def createmask(self, pose_kps, person_index):
511
- pose_frames = pose_kps
512
- prev_center = None
513
- np_frames = []
514
- for i, pose_frame in enumerate(pose_frames):
515
- selected_idx, prev_center = self.select_closest_person(pose_frame, person_index if i == 0 else prev_center)
516
- np_frames.append(self.draw_kps(pose_frame, selected_idx))
517
-
518
- if not np_frames:
519
- # Handle case where no frames were processed
520
- log.warning("No valid pose frames found, returning empty mask")
521
- return (torch.zeros((1, 64, 64), dtype=torch.float32),)
522
-
523
- np_frames = np.stack(np_frames, axis=0)
524
- tensor = torch.from_numpy(np_frames).float() / 255.
525
- print("tensor.shape:", tensor.shape)
526
- tensor = tensor[:, :, :, 0]
527
- return (tensor,)
528
-
529
- def select_closest_person(self, pose_frame, prev_center_or_index):
530
- people = pose_frame["people"]
531
- if not people:
532
- return -1, None
533
-
534
- centers = []
535
- valid_people_indices = []
536
-
537
- for idx, person in enumerate(people):
538
- # Check if face keypoints exist and are valid
539
- if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
540
- continue
541
-
542
- kps = np.array(person["face_keypoints_2d"])
543
- if len(kps) == 0:
544
- continue
545
-
546
- n = len(kps) // 3
547
- if n == 0:
548
- continue
549
-
550
- facial_kps = rearrange(kps, "(n c) -> n c", n=n, c=3)[:, :2]
551
-
552
- # Check if we have valid coordinates (not all zeros)
553
- if np.all(facial_kps == 0):
554
- continue
555
-
556
- center = facial_kps.mean(axis=0)
557
-
558
- # Check if center is valid (not NaN or infinite)
559
- if np.isnan(center).any() or np.isinf(center).any():
560
- continue
561
-
562
- centers.append(center)
563
- valid_people_indices.append(idx)
564
-
565
- if not centers:
566
- return -1, None
567
-
568
- if isinstance(prev_center_or_index, (int, np.integer)):
569
- # First frame: use person_index, but map to valid people
570
- if 0 <= prev_center_or_index < len(valid_people_indices):
571
- idx = valid_people_indices[prev_center_or_index]
572
- return idx, centers[prev_center_or_index]
573
- elif valid_people_indices:
574
- # Fallback to first valid person
575
- idx = valid_people_indices[0]
576
- return idx, centers[0]
577
- else:
578
- return -1, None
579
- elif prev_center_or_index is not None:
580
- # Find closest to previous center
581
- prev_center = np.array(prev_center_or_index)
582
- dists = [np.linalg.norm(center - prev_center) for center in centers]
583
- min_idx = int(np.argmin(dists))
584
- actual_idx = valid_people_indices[min_idx]
585
- return actual_idx, centers[min_idx]
586
- else:
587
- # prev_center_or_index is None, fallback to first valid person
588
- if valid_people_indices:
589
- idx = valid_people_indices[0]
590
- return idx, centers[0]
591
- else:
592
- return -1, None
593
-
594
- def draw_kps(self, pose_frame, person_index):
595
- import cv2
596
- width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
597
- canvas = np.zeros((height, width, 3), dtype=np.uint8)
598
- people = pose_frame["people"]
599
-
600
- if person_index < 0 or person_index >= len(people):
601
- return canvas # Out of bounds, return blank
602
-
603
- person = people[person_index]
604
-
605
- # Check if face keypoints exist and are valid
606
- if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
607
- return canvas # No face keypoints, return blank
608
-
609
- face_kps_data = person["face_keypoints_2d"]
610
- if len(face_kps_data) == 0:
611
- return canvas # Empty keypoints, return blank
612
-
613
- n = len(face_kps_data) // 3
614
- if n < 17: # Need at least 17 points for outer contour
615
- return canvas # Not enough keypoints, return blank
616
-
617
- facial_kps = rearrange(np.array(face_kps_data), "(n c) -> n c", n=n, c=3)[:, :2]
618
-
619
- # Check if we have valid coordinates (not all zeros)
620
- if np.all(facial_kps == 0):
621
- return canvas # All keypoints are zero, return blank
622
-
623
- # Check for NaN or infinite values
624
- if np.isnan(facial_kps).any() or np.isinf(facial_kps).any():
625
- return canvas # Invalid coordinates, return blank
626
-
627
- # Check for negative coordinates or coordinates that would create streaks
628
- if np.any(facial_kps < 0):
629
- return canvas # Negative coordinates, likely bad detection
630
-
631
- # Check if coordinates are reasonable (not too close to edges which might indicate bad detection)
632
- min_margin = 5 # Minimum distance from edges
633
- if (np.any(facial_kps[:, 0] < min_margin) or
634
- np.any(facial_kps[:, 1] < min_margin) or
635
- np.any(facial_kps[:, 0] > width - min_margin) or
636
- np.any(facial_kps[:, 1] > height - min_margin)):
637
- # Check if this looks like a streak to corner (many points near 0,0)
638
- corner_points = np.sum((facial_kps[:, 0] < min_margin) & (facial_kps[:, 1] < min_margin))
639
- if corner_points > 3: # Too many points near corner, likely bad detection
640
- return canvas
641
-
642
- facial_kps = facial_kps.astype(np.int32)
643
-
644
- # Ensure coordinates are within canvas bounds
645
- facial_kps[:, 0] = np.clip(facial_kps[:, 0], 0, width - 1)
646
- facial_kps[:, 1] = np.clip(facial_kps[:, 1], 0, height - 1)
647
-
648
- part_color = (255, 255, 255)
649
- outer_contour = facial_kps[:17]
650
-
651
- # Additional validation for the contour before drawing
652
- # Check if contour points are too spread out (indicating bad detection)
653
- if len(outer_contour) >= 3:
654
- # Calculate bounding box of the contour
655
- min_x, min_y = np.min(outer_contour, axis=0)
656
- max_x, max_y = np.max(outer_contour, axis=0)
657
- contour_width = max_x - min_x
658
- contour_height = max_y - min_y
659
-
660
- # If contour spans more than 80% of canvas, likely bad detection
661
- if (contour_width > 0.8 * width or contour_height > 0.8 * height):
662
- return canvas
663
-
664
- # Check if we have a valid contour (at least 3 unique points)
665
- unique_points = np.unique(outer_contour, axis=0)
666
- if len(unique_points) >= 3:
667
- # Final check: ensure the contour is reasonable
668
- # Calculate area to see if it's too large or too small
669
- contour_area = cv2.contourArea(outer_contour)
670
- canvas_area = width * height
671
-
672
- # If contour is less than 0.1% or more than 50% of canvas, skip
673
- if 0.001 * canvas_area <= contour_area <= 0.5 * canvas_area:
674
- cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
675
-
676
- return canvas
677
-
678
- NODE_CLASS_MAPPINGS = {
679
- "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
680
- "WanVideoVACEStartToEndFrame": WanVideoVACEStartToEndFrame,
681
- "ExtractStartFramesForContinuations": ExtractStartFramesForContinuations,
682
- "CreateCFGScheduleFloatList": CreateCFGScheduleFloatList,
683
- "DummyComfyWanModelObject": DummyComfyWanModelObject,
684
- "WanVideoLatentReScale": WanVideoLatentReScale,
685
- "CreateScheduleFloatList": CreateScheduleFloatList,
686
- "WanVideoSigmaToStep": WanVideoSigmaToStep,
687
- "NormalizeAudioLoudness": NormalizeAudioLoudness,
688
- "WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
689
- "FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
690
- }
691
- NODE_DISPLAY_NAME_MAPPINGS = {
692
- "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
693
- "WanVideoVACEStartToEndFrame": "WanVideo VACE Start To End Frame",
694
- "ExtractStartFramesForContinuations": "Extract Start Frames For Continuations",
695
- "CreateCFGScheduleFloatList": "Create CFG Schedule Float List",
696
- "DummyComfyWanModelObject": "Dummy Comfy Wan Model Object",
697
- "WanVideoLatentReScale": "WanVideo Latent ReScale",
698
- "CreateScheduleFloatList": "Create Schedule Float List",
699
- "WanVideoSigmaToStep": "WanVideo Sigma To Step",
700
- "NormalizeAudioLoudness": "Normalize Audio Loudness",
701
- "WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
702
- "FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
 
 
 
 
 
703
  }
 
1
+ import torch
2
+ import numpy as np
3
+ from comfy.utils import common_upscale
4
+ from .utils import log
5
+ from einops import rearrange
6
+
7
+ try:
8
+ from server import PromptServer
9
+ except:
10
+ PromptServer = None
11
+
12
+ VAE_STRIDE = (4, 8, 8)
13
+ PATCH_SIZE = (1, 2, 2)
14
+
15
+ class WanVideoImageResizeToClosest:
16
+ @classmethod
17
+ def INPUT_TYPES(s):
18
+ return {"required": {
19
+ "image": ("IMAGE", {"tooltip": "Image to resize"}),
20
+ "generation_width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
21
+ "generation_height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
22
+ "aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],),
23
+ },
24
+ }
25
+
26
+ RETURN_TYPES = ("IMAGE", "INT", "INT", )
27
+ RETURN_NAMES = ("image","width","height",)
28
+ FUNCTION = "process"
29
+ CATEGORY = "WanVideoWrapper"
30
+ DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code"
31
+
32
+ def process(self, image, generation_width, generation_height, aspect_ratio_preservation ):
33
+
34
+ H, W = image.shape[1], image.shape[2]
35
+ max_area = generation_width * generation_height
36
+
37
+ crop = "disabled"
38
+
39
+ if aspect_ratio_preservation == "keep_input":
40
+ aspect_ratio = H / W
41
+ elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new":
42
+ aspect_ratio = generation_height / generation_width
43
+ if aspect_ratio_preservation == "crop_to_new":
44
+ crop = "center"
45
+
46
+ lat_h = round(
47
+ np.sqrt(max_area * aspect_ratio) // VAE_STRIDE[1] //
48
+ PATCH_SIZE[1] * PATCH_SIZE[1])
49
+ lat_w = round(
50
+ np.sqrt(max_area / aspect_ratio) // VAE_STRIDE[2] //
51
+ PATCH_SIZE[2] * PATCH_SIZE[2])
52
+ h = lat_h * VAE_STRIDE[1]
53
+ w = lat_w * VAE_STRIDE[2]
54
+
55
+ resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1)
56
+
57
+ return (resized_image, w, h)
58
+
59
+ class ExtractStartFramesForContinuations:
60
+ @classmethod
61
+ def INPUT_TYPES(s):
62
+ return {
63
+ "required": {
64
+ "input_video_frames": ("IMAGE", {"tooltip": "Input video frames to extract the start frames from."}),
65
+ "num_frames": ("INT", {"default": 10, "min": 1, "max": 1024, "step": 1, "tooltip": "Number of frames to get from the start of the video."}),
66
+ },
67
+ }
68
+
69
+ RETURN_TYPES = ("IMAGE",)
70
+ RETURN_NAMES = ("start_frames",)
71
+ FUNCTION = "get_start_frames"
72
+ CATEGORY = "WanVideoWrapper"
73
+ DESCRIPTION = "Extracts the first N frames from a video sequence for continuations."
74
+
75
+ def get_start_frames(self, input_video_frames, num_frames):
76
+ if input_video_frames is None or input_video_frames.shape[0] == 0:
77
+ log.warning("Input video frames are empty. Returning an empty tensor.")
78
+ if input_video_frames is not None:
79
+ return (torch.empty((0,) + input_video_frames.shape[1:], dtype=input_video_frames.dtype),)
80
+ else:
81
+ # Return a tensor with 4 dimensions, as expected for an IMAGE type.
82
+ return (torch.empty((0, 64, 64, 3), dtype=torch.float32),)
83
+
84
+ total_frames = input_video_frames.shape[0]
85
+ num_to_get = min(num_frames, total_frames)
86
+
87
+ if num_to_get < num_frames:
88
+ log.warning(f"Requested {num_frames} frames, but input video only has {total_frames} frames. Returning first {num_to_get} frames.")
89
+
90
+ start_frames = input_video_frames[:num_to_get]
91
+
92
+ return (start_frames.cpu().float(),)
93
+
94
+ class WanVideoVACEStartToEndFrame:
95
+ @classmethod
96
+ def INPUT_TYPES(s):
97
+ return {"required": {
98
+ "num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
99
+ "empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}),
100
+ },
101
+ "optional": {
102
+ "start_image": ("IMAGE",),
103
+ "end_image": ("IMAGE",),
104
+ "control_images": ("IMAGE",),
105
+ "inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}),
106
+ "start_index": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Index to start from"}),
107
+ "end_index": ("INT", {"default": -1, "min": -10000, "max": 10000, "step": 1, "tooltip": "Index to end at"}),
108
+ "control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01, "tooltip": "How much does the control images apply?"}),
109
+ "control_ease": ("INT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 1, "tooltip": "How many frames to ease in the control video?"}),
110
+ },
111
+ }
112
+
113
+ RETURN_TYPES = ("IMAGE", "MASK", )
114
+ RETURN_NAMES = ("images", "masks",)
115
+ FUNCTION = "process"
116
+ CATEGORY = "WanVideoWrapper"
117
+ DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE"
118
+
119
+ def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None, start_index=0, end_index=-1, control_strength=1.0, control_ease=0):
120
+
121
+ device = start_image.device if start_image is not None else end_image.device
122
+ B, H, W, C = start_image.shape if start_image is not None else end_image.shape
123
+
124
+ if control_strength < 1.0 and control_images is not None:
125
+ # strength happens at much smaller number
126
+ control_strength *= 2.0
127
+ control_strength = control_strength * control_strength / 8.0
128
+ control_images = torch.lerp(torch.ones((control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3])) * empty_frame_level, control_images, control_strength)
129
+
130
+ # ease in control stuff?
131
+ if num_frames > control_ease and control_ease > 0:
132
+ empty_frame = torch.ones((1, control_images.shape[1], control_images.shape[2], control_images.shape[3])) * empty_frame_level
133
+ if start_image is not None:
134
+ for i in range(1, control_ease + 1):
135
+ control_images[i] = torch.lerp(control_images[i], empty_frame, (control_ease - i) / (1 + control_ease))
136
+ else:
137
+ for i in range(num_frames - control_ease - 1, num_frames - 1):
138
+ control_images[i] = torch.lerp(control_images[i], empty_frame, i / (1 + control_ease))
139
+
140
+ if start_image is None and end_image is None and control_images is not None:
141
+ if control_images.shape[0] >= num_frames:
142
+ control_images = control_images[:num_frames]
143
+ elif control_images.shape[0] < num_frames:
144
+ # padd with empty_frame_level frames
145
+ padding = torch.ones((num_frames - control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3]), device=control_images.device) * empty_frame_level
146
+ control_images = torch.cat([control_images, padding], dim=0)
147
+ return (control_images.cpu().float(), torch.zeros_like(control_images[:, :, :, 0]).cpu().float())
148
+
149
+ # Convert negative end_index to positive
150
+ if end_index < 0:
151
+ end_index = num_frames + end_index
152
+
153
+ # Create output batch with empty frames
154
+ out_batch = torch.ones((num_frames, H, W, 3), device=device) * empty_frame_level
155
+
156
+ # Create mask tensor with proper dimensions
157
+ masks = torch.ones((num_frames, H, W), device=device)
158
+
159
+ # Pre-process all images at once to avoid redundant work
160
+ if end_image is not None and (end_image.shape[1] != H or end_image.shape[2] != W):
161
+ end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
162
+
163
+ if control_images is not None and (control_images.shape[1] != H or control_images.shape[2] != W):
164
+ control_images = common_upscale(control_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
165
+
166
+ # Place start image at start_index
167
+ if start_image is not None:
168
+ frames_to_copy = min(start_image.shape[0], num_frames - start_index)
169
+ if frames_to_copy > 0:
170
+ out_batch[start_index:start_index + frames_to_copy] = start_image[:frames_to_copy]
171
+ masks[start_index:start_index + frames_to_copy] = 0
172
+
173
+ # Place end image at end_index
174
+ if end_image is not None:
175
+ # Calculate where to start placing end images
176
+ end_start = end_index - end_image.shape[0] + 1
177
+ if end_start < 0: # Handle case where end images won't all fit
178
+ end_image = end_image[abs(end_start):]
179
+ end_start = 0
180
+
181
+ frames_to_copy = min(end_image.shape[0], num_frames - end_start)
182
+ if frames_to_copy > 0:
183
+ out_batch[end_start:end_start + frames_to_copy] = end_image[:frames_to_copy]
184
+ masks[end_start:end_start + frames_to_copy] = 0
185
+
186
+ # Apply control images to remaining frames that don't have start or end images
187
+ if control_images is not None:
188
+ # Create a mask of frames that are still empty (mask == 1)
189
+ empty_frames = masks.sum(dim=(1, 2)) > 0.5 * H * W
190
+
191
+ if empty_frames.any():
192
+ # Only apply control images where they exist
193
+ control_length = control_images.shape[0]
194
+ for frame_idx in range(num_frames):
195
+ if empty_frames[frame_idx] and frame_idx < control_length:
196
+ out_batch[frame_idx] = control_images[frame_idx]
197
+
198
+ # Apply inpaint mask if provided
199
+ if inpaint_mask is not None:
200
+ inpaint_mask = common_upscale(inpaint_mask.unsqueeze(1), W, H, "nearest-exact", "disabled").squeeze(1).to(device)
201
+
202
+ # Handle different mask lengths efficiently
203
+ if inpaint_mask.shape[0] > num_frames:
204
+ inpaint_mask = inpaint_mask[:num_frames]
205
+ elif inpaint_mask.shape[0] < num_frames:
206
+ repeat_factor = (num_frames + inpaint_mask.shape[0] - 1) // inpaint_mask.shape[0] # Ceiling division
207
+ inpaint_mask = inpaint_mask.repeat(repeat_factor, 1, 1)[:num_frames]
208
+
209
+ # Apply mask in one operation
210
+ masks = inpaint_mask * masks
211
+
212
+ return (out_batch.cpu().float(), masks.cpu().float())
213
+
214
+
215
+ class CreateCFGScheduleFloatList:
216
+ @classmethod
217
+ def INPUT_TYPES(s):
218
+ return {"required": {
219
+ "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
220
+ "cfg_scale_start": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
221
+ "cfg_scale_end": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
222
+ "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
223
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
224
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
225
+ },
226
+ "hidden": {
227
+ "unique_id": "UNIQUE_ID",
228
+ },
229
+ }
230
+
231
+ RETURN_TYPES = ("FLOAT", )
232
+ RETURN_NAMES = ("float_list",)
233
+ FUNCTION = "process"
234
+ CATEGORY = "WanVideoWrapper"
235
+ DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule cfg scale for the steps, outside the set range cfg is set to 1.0"
236
+
237
+ def process(self, steps, cfg_scale_start, cfg_scale_end, interpolation, start_percent, end_percent, unique_id):
238
+
239
+ # Create a list of floats for the cfg schedule
240
+ cfg_list = [1.0] * steps
241
+ start_idx = min(int(steps * start_percent), steps - 1)
242
+ end_idx = min(int(steps * end_percent), steps - 1)
243
+
244
+ for i in range(start_idx, end_idx + 1):
245
+ if i >= steps:
246
+ break
247
+
248
+ if end_idx == start_idx:
249
+ t = 0
250
+ else:
251
+ t = (i - start_idx) / (end_idx - start_idx)
252
+
253
+ if interpolation == "linear":
254
+ factor = t
255
+ elif interpolation == "ease_in":
256
+ factor = t * t
257
+ elif interpolation == "ease_out":
258
+ factor = t * (2 - t)
259
+
260
+ cfg_list[i] = round(cfg_scale_start + factor * (cfg_scale_end - cfg_scale_start), 2)
261
+
262
+ # If start_percent > 0, always include the first step
263
+ if start_percent > 0:
264
+ cfg_list[0] = 1.0
265
+
266
+ if unique_id and PromptServer is not None:
267
+ try:
268
+ PromptServer.instance.send_progress_text(
269
+ f"{cfg_list}",
270
+ unique_id
271
+ )
272
+ except:
273
+ pass
274
+
275
+ return (cfg_list,)
276
+
277
+ class CreateScheduleFloatList:
278
+ @classmethod
279
+ def INPUT_TYPES(s):
280
+ return {"required": {
281
+ "steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
282
+ "start_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
283
+ "end_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
284
+ "default_value": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "round": 0.01, "tooltip": "Default value to use for the steps"}),
285
+ "interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
286
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
287
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
288
+ },
289
+ "hidden": {
290
+ "unique_id": "UNIQUE_ID",
291
+ },
292
+ }
293
+
294
+ RETURN_TYPES = ("FLOAT", )
295
+ RETURN_NAMES = ("float_list",)
296
+ FUNCTION = "process"
297
+ CATEGORY = "WanVideoWrapper"
298
+ DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule things like cfg and lora scale per step"
299
+
300
+ def process(self, steps, start_value, end_value, default_value,interpolation, start_percent, end_percent, unique_id):
301
+
302
+ # Create a list of floats for the cfg schedule
303
+ cfg_list = [default_value] * steps
304
+ start_idx = min(int(steps * start_percent), steps - 1)
305
+ end_idx = min(int(steps * end_percent), steps - 1)
306
+
307
+ for i in range(start_idx, end_idx + 1):
308
+ if i >= steps:
309
+ break
310
+
311
+ if end_idx == start_idx:
312
+ t = 0
313
+ else:
314
+ t = (i - start_idx) / (end_idx - start_idx)
315
+
316
+ if interpolation == "linear":
317
+ factor = t
318
+ elif interpolation == "ease_in":
319
+ factor = t * t
320
+ elif interpolation == "ease_out":
321
+ factor = t * (2 - t)
322
+
323
+ cfg_list[i] = round(start_value + factor * (end_value - start_value), 2)
324
+
325
+ # If start_percent > 0, always include the first step
326
+ if start_percent > 0:
327
+ cfg_list[0] = default_value
328
+
329
+ if unique_id and PromptServer is not None:
330
+ try:
331
+ PromptServer.instance.send_progress_text(
332
+ f"{cfg_list}",
333
+ unique_id
334
+ )
335
+ except:
336
+ pass
337
+
338
+ return (cfg_list,)
339
+
340
+
341
+ class DummyComfyWanModelObject:
342
+ @classmethod
343
+ def INPUT_TYPES(s):
344
+ return {"required": {
345
+ "shift": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma shift value"}),
346
+ }
347
+ }
348
+
349
+ RETURN_TYPES = ("MODEL", )
350
+ RETURN_NAMES = ("model",)
351
+ FUNCTION = "create"
352
+ CATEGORY = "WanVideoWrapper"
353
+ DESCRIPTION = "Helper node to create empty Wan model to use with BasicScheduler -node to get sigmas"
354
+
355
+ def create(self, shift):
356
+ from comfy.model_sampling import ModelSamplingDiscreteFlow
357
+ class DummyModel:
358
+ def get_model_object(self, name):
359
+ if name == "model_sampling":
360
+ model_sampling = ModelSamplingDiscreteFlow()
361
+ model_sampling.set_parameters(shift=shift)
362
+ return model_sampling
363
+ return None
364
+ return (DummyModel(),)
365
+
366
+ class WanVideoLatentReScale:
367
+ @classmethod
368
+ def INPUT_TYPES(s):
369
+ return {"required": {
370
+ "samples": ("LATENT",),
371
+ "direction": (["comfy_to_wrapper", "wrapper_to_comfy"], {"tooltip": "Direction to rescale latents, from comfy to wrapper or vice versa"}),
372
+ }
373
+ }
374
+
375
+ RETURN_TYPES = ("LATENT",)
376
+ RETURN_NAMES = ("samples",)
377
+ FUNCTION = "encode"
378
+ CATEGORY = "WanVideoWrapper"
379
+ DESCRIPTION = "Rescale latents to match the expected range for encoding or decoding between native ComfyUI VAE and the WanVideoWrapper VAE."
380
+
381
+ def encode(self, samples, direction):
382
+ samples = samples.copy()
383
+ latents = samples["samples"]
384
+
385
+ if latents.shape[1] == 48:
386
+ mean = [
387
+ -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
388
+ -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
389
+ -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
390
+ -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
391
+ -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
392
+ 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
393
+ ]
394
+ std = [
395
+ 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
396
+ 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
397
+ 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
398
+ 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
399
+ 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
400
+ 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
401
+ ]
402
+ else:
403
+ mean = [
404
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
405
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
406
+ ]
407
+ std = [
408
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
409
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
410
+ ]
411
+ mean = torch.tensor(mean).view(1, latents.shape[1], 1, 1, 1)
412
+ std = torch.tensor(std).view(1, latents.shape[1], 1, 1, 1)
413
+ inv_std = (1.0 / std).view(1, latents.shape[1], 1, 1, 1)
414
+ if direction == "comfy_to_wrapper":
415
+ latents = (latents - mean.to(latents)) * inv_std.to(latents)
416
+ elif direction == "wrapper_to_comfy":
417
+ latents = latents / inv_std.to(latents) + mean.to(latents)
418
+
419
+ samples["samples"] = latents
420
+
421
+ return (samples,)
422
+
423
+ class WanVideoSigmaToStep:
424
+ @classmethod
425
+ def INPUT_TYPES(s):
426
+ return {"required": {
427
+ "sigma": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.001}),
428
+ },
429
+ }
430
+
431
+ RETURN_TYPES = ("INT", )
432
+ RETURN_NAMES = ("step",)
433
+ FUNCTION = "convert"
434
+ CATEGORY = "WanVideoWrapper"
435
+ DESCRIPTION = "Simply passes a float value as an integer, used to set start/end steps with sigma threshold"
436
+
437
+ def convert(self, sigma):
438
+ return (sigma,)
439
+
440
+ class NormalizeAudioLoudness:
441
+ @classmethod
442
+ def INPUT_TYPES(s):
443
+ return {"required": {
444
+ "audio": ("AUDIO",),
445
+ "lufs": ("FLOAT", {"default": -23.0, "min": -100.0, "max": 0.0, "step": 0.1, "tool": "Loudness Units relative to Full Scale, higher LUFS values (closer to 0) mean louder audio. Lower LUFS values (more negative) mean quieter audio."}),
446
+ },
447
+ }
448
+
449
+ RETURN_TYPES = ("AUDIO", )
450
+ RETURN_NAMES = ("audio", )
451
+ FUNCTION = "normalize"
452
+ CATEGORY = "WanVideoWrapper"
453
+
454
+ def normalize(self, audio, lufs):
455
+ audio_input = audio["waveform"]
456
+ sample_rate = audio["sample_rate"]
457
+ if audio_input.dim() == 3:
458
+ audio_input = audio_input.squeeze(0)
459
+ audio_input_np = audio_input.detach().transpose(0, 1).numpy().astype(np.float32)
460
+ audio_input_np = np.ascontiguousarray(audio_input_np)
461
+ normalized_audio = self.loudness_norm(audio_input_np, sr=sample_rate, lufs=lufs)
462
+
463
+ out_audio = {"waveform": torch.from_numpy(normalized_audio).transpose(0, 1).unsqueeze(0).float(), "sample_rate": sample_rate}
464
+
465
+ return (out_audio, )
466
+
467
+ def loudness_norm(self, audio_array, sr=16000, lufs=-23):
468
+ try:
469
+ import pyloudnorm
470
+ except:
471
+ raise ImportError("pyloudnorm package is not installed")
472
+ meter = pyloudnorm.Meter(sr)
473
+ loudness = meter.integrated_loudness(audio_array)
474
+ if abs(loudness) > 100:
475
+ return audio_array
476
+ normalized_audio = pyloudnorm.normalize.loudness(audio_array, loudness, lufs)
477
+ return normalized_audio
478
+
479
+ class WanVideoPassImagesFromSamples:
480
+ @classmethod
481
+ def INPUT_TYPES(s):
482
+ return {"required": {
483
+ "samples": ("LATENT",),
484
+ }
485
+ }
486
+
487
+ RETURN_TYPES = ("IMAGE", "STRING",)
488
+ RETURN_NAMES = ("images", "output_path",)
489
+ OUTPUT_TOOLTIPS = ("Decoded images from the samples dictionary", "Output path if provided in the samples dictionary",)
490
+ FUNCTION = "decode"
491
+ CATEGORY = "WanVideoWrapper"
492
+ DESCRIPTION = "Gets possible already decoded images from the samples dictionary, used with Multi/InfiniteTalk sampling"
493
+
494
+ def decode(self, samples):
495
+ video = samples.get("video", None)
496
+ video.clamp_(-1.0, 1.0)
497
+ video.add_(1.0).div_(2.0)
498
+ return video.cpu().float(), samples.get("output_path", "")
499
+
500
+
501
+ class FaceMaskFromPoseKeypoints:
502
+ @classmethod
503
+ def INPUT_TYPES(s):
504
+ input_types = {
505
+ "required": {
506
+ "pose_kps": ("POSE_KEYPOINT",),
507
+ "person_index": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Index of the person to start with"}),
508
+ }
509
+ }
510
+ return input_types
511
+ RETURN_TYPES = ("MASK",)
512
+ FUNCTION = "createmask"
513
+ CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
514
+
515
+ def createmask(self, pose_kps, person_index):
516
+ pose_frames = pose_kps
517
+ prev_center = None
518
+ np_frames = []
519
+ for i, pose_frame in enumerate(pose_frames):
520
+ selected_idx, prev_center = self.select_closest_person(pose_frame, person_index if i == 0 else prev_center)
521
+ np_frames.append(self.draw_kps(pose_frame, selected_idx))
522
+
523
+ if not np_frames:
524
+ # Handle case where no frames were processed
525
+ log.warning("No valid pose frames found, returning empty mask")
526
+ return (torch.zeros((1, 64, 64), dtype=torch.float32),)
527
+
528
+ np_frames = np.stack(np_frames, axis=0)
529
+ tensor = torch.from_numpy(np_frames).float() / 255.
530
+ print("tensor.shape:", tensor.shape)
531
+ tensor = tensor[:, :, :, 0]
532
+ return (tensor,)
533
+
534
+ def select_closest_person(self, pose_frame, prev_center_or_index):
535
+ people = pose_frame["people"]
536
+ if not people:
537
+ return -1, None
538
+
539
+ centers = []
540
+ valid_people_indices = []
541
+
542
+ for idx, person in enumerate(people):
543
+ # Check if face keypoints exist and are valid
544
+ if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
545
+ continue
546
+
547
+ kps = np.array(person["face_keypoints_2d"])
548
+ if len(kps) == 0:
549
+ continue
550
+
551
+ n = len(kps) // 3
552
+ if n == 0:
553
+ continue
554
+
555
+ facial_kps = rearrange(kps, "(n c) -> n c", n=n, c=3)[:, :2]
556
+
557
+ # Check if we have valid coordinates (not all zeros)
558
+ if np.all(facial_kps == 0):
559
+ continue
560
+
561
+ center = facial_kps.mean(axis=0)
562
+
563
+ # Check if center is valid (not NaN or infinite)
564
+ if np.isnan(center).any() or np.isinf(center).any():
565
+ continue
566
+
567
+ centers.append(center)
568
+ valid_people_indices.append(idx)
569
+
570
+ if not centers:
571
+ return -1, None
572
+
573
+ if isinstance(prev_center_or_index, (int, np.integer)):
574
+ # First frame: use person_index, but map to valid people
575
+ if 0 <= prev_center_or_index < len(valid_people_indices):
576
+ idx = valid_people_indices[prev_center_or_index]
577
+ return idx, centers[prev_center_or_index]
578
+ elif valid_people_indices:
579
+ # Fallback to first valid person
580
+ idx = valid_people_indices[0]
581
+ return idx, centers[0]
582
+ else:
583
+ return -1, None
584
+ elif prev_center_or_index is not None:
585
+ # Find closest to previous center
586
+ prev_center = np.array(prev_center_or_index)
587
+ dists = [np.linalg.norm(center - prev_center) for center in centers]
588
+ min_idx = int(np.argmin(dists))
589
+ actual_idx = valid_people_indices[min_idx]
590
+ return actual_idx, centers[min_idx]
591
+ else:
592
+ # prev_center_or_index is None, fallback to first valid person
593
+ if valid_people_indices:
594
+ idx = valid_people_indices[0]
595
+ return idx, centers[0]
596
+ else:
597
+ return -1, None
598
+
599
+ def draw_kps(self, pose_frame, person_index):
600
+ import cv2
601
+ width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
602
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
603
+ people = pose_frame["people"]
604
+
605
+ if person_index < 0 or person_index >= len(people):
606
+ return canvas # Out of bounds, return blank
607
+
608
+ person = people[person_index]
609
+
610
+ # Check if face keypoints exist and are valid
611
+ if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
612
+ return canvas # No face keypoints, return blank
613
+
614
+ face_kps_data = person["face_keypoints_2d"]
615
+ if len(face_kps_data) == 0:
616
+ return canvas # Empty keypoints, return blank
617
+
618
+ n = len(face_kps_data) // 3
619
+ if n < 17: # Need at least 17 points for outer contour
620
+ return canvas # Not enough keypoints, return blank
621
+
622
+ facial_kps = rearrange(np.array(face_kps_data), "(n c) -> n c", n=n, c=3)[:, :2]
623
+
624
+ # Check if we have valid coordinates (not all zeros)
625
+ if np.all(facial_kps == 0):
626
+ return canvas # All keypoints are zero, return blank
627
+
628
+ # Check for NaN or infinite values
629
+ if np.isnan(facial_kps).any() or np.isinf(facial_kps).any():
630
+ return canvas # Invalid coordinates, return blank
631
+
632
+ # Check for negative coordinates or coordinates that would create streaks
633
+ if np.any(facial_kps < 0):
634
+ return canvas # Negative coordinates, likely bad detection
635
+
636
+ # Check if coordinates are reasonable (not too close to edges which might indicate bad detection)
637
+ min_margin = 5 # Minimum distance from edges
638
+ if (np.any(facial_kps[:, 0] < min_margin) or
639
+ np.any(facial_kps[:, 1] < min_margin) or
640
+ np.any(facial_kps[:, 0] > width - min_margin) or
641
+ np.any(facial_kps[:, 1] > height - min_margin)):
642
+ # Check if this looks like a streak to corner (many points near 0,0)
643
+ corner_points = np.sum((facial_kps[:, 0] < min_margin) & (facial_kps[:, 1] < min_margin))
644
+ if corner_points > 3: # Too many points near corner, likely bad detection
645
+ return canvas
646
+
647
+ facial_kps = facial_kps.astype(np.int32)
648
+
649
+ # Ensure coordinates are within canvas bounds
650
+ facial_kps[:, 0] = np.clip(facial_kps[:, 0], 0, width - 1)
651
+ facial_kps[:, 1] = np.clip(facial_kps[:, 1], 0, height - 1)
652
+
653
+ part_color = (255, 255, 255)
654
+ outer_contour = facial_kps[:17]
655
+
656
+ # Additional validation for the contour before drawing
657
+ # Check if contour points are too spread out (indicating bad detection)
658
+ if len(outer_contour) >= 3:
659
+ # Calculate bounding box of the contour
660
+ min_x, min_y = np.min(outer_contour, axis=0)
661
+ max_x, max_y = np.max(outer_contour, axis=0)
662
+ contour_width = max_x - min_x
663
+ contour_height = max_y - min_y
664
+
665
+ # If contour spans more than 80% of canvas, likely bad detection
666
+ if (contour_width > 0.8 * width or contour_height > 0.8 * height):
667
+ return canvas
668
+
669
+ # Check if we have a valid contour (at least 3 unique points)
670
+ unique_points = np.unique(outer_contour, axis=0)
671
+ if len(unique_points) >= 3:
672
+ # Final check: ensure the contour is reasonable
673
+ # Calculate area to see if it's too large or too small
674
+ contour_area = cv2.contourArea(outer_contour)
675
+ canvas_area = width * height
676
+
677
+ # If contour is less than 0.1% or more than 50% of canvas, skip
678
+ if 0.001 * canvas_area <= contour_area <= 0.5 * canvas_area:
679
+ cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
680
+
681
+ return canvas
682
+
683
+ NODE_CLASS_MAPPINGS = {
684
+ "WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
685
+ "WanVideoVACEStartToEndFrame": WanVideoVACEStartToEndFrame,
686
+ "ExtractStartFramesForContinuations": ExtractStartFramesForContinuations,
687
+ "CreateCFGScheduleFloatList": CreateCFGScheduleFloatList,
688
+ "DummyComfyWanModelObject": DummyComfyWanModelObject,
689
+ "WanVideoLatentReScale": WanVideoLatentReScale,
690
+ "CreateScheduleFloatList": CreateScheduleFloatList,
691
+ "WanVideoSigmaToStep": WanVideoSigmaToStep,
692
+ "NormalizeAudioLoudness": NormalizeAudioLoudness,
693
+ "WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
694
+ "FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
695
+ }
696
+ NODE_DISPLAY_NAME_MAPPINGS = {
697
+ "WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
698
+ "WanVideoVACEStartToEndFrame": "WanVideo VACE Start To End Frame",
699
+ "ExtractStartFramesForContinuations": "Extract Start Frames For Continuations",
700
+ "CreateCFGScheduleFloatList": "Create CFG Schedule Float List",
701
+ "DummyComfyWanModelObject": "Dummy Comfy Wan Model Object",
702
+ "WanVideoLatentReScale": "WanVideo Latent ReScale",
703
+ "CreateScheduleFloatList": "Create Schedule Float List",
704
+ "WanVideoSigmaToStep": "WanVideo Sigma To Step",
705
+ "NormalizeAudioLoudness": "Normalize Audio Loudness",
706
+ "WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
707
+ "FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
708
  }