Update Custom-Advanced-VACE-Node/nodes_utility.py
Browse files* use empty_frame_level
* support end frame easing
Custom-Advanced-VACE-Node/nodes_utility.py
CHANGED
|
@@ -1,703 +1,708 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import numpy as np
|
| 3 |
-
from comfy.utils import common_upscale
|
| 4 |
-
from .utils import log
|
| 5 |
-
from einops import rearrange
|
| 6 |
-
|
| 7 |
-
try:
|
| 8 |
-
from server import PromptServer
|
| 9 |
-
except:
|
| 10 |
-
PromptServer = None
|
| 11 |
-
|
| 12 |
-
VAE_STRIDE = (4, 8, 8)
|
| 13 |
-
PATCH_SIZE = (1, 2, 2)
|
| 14 |
-
|
| 15 |
-
class WanVideoImageResizeToClosest:
|
| 16 |
-
@classmethod
|
| 17 |
-
def INPUT_TYPES(s):
|
| 18 |
-
return {"required": {
|
| 19 |
-
"image": ("IMAGE", {"tooltip": "Image to resize"}),
|
| 20 |
-
"generation_width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
|
| 21 |
-
"generation_height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
|
| 22 |
-
"aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],),
|
| 23 |
-
},
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
RETURN_TYPES = ("IMAGE", "INT", "INT", )
|
| 27 |
-
RETURN_NAMES = ("image","width","height",)
|
| 28 |
-
FUNCTION = "process"
|
| 29 |
-
CATEGORY = "WanVideoWrapper"
|
| 30 |
-
DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code"
|
| 31 |
-
|
| 32 |
-
def process(self, image, generation_width, generation_height, aspect_ratio_preservation ):
|
| 33 |
-
|
| 34 |
-
H, W = image.shape[1], image.shape[2]
|
| 35 |
-
max_area = generation_width * generation_height
|
| 36 |
-
|
| 37 |
-
crop = "disabled"
|
| 38 |
-
|
| 39 |
-
if aspect_ratio_preservation == "keep_input":
|
| 40 |
-
aspect_ratio = H / W
|
| 41 |
-
elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new":
|
| 42 |
-
aspect_ratio = generation_height / generation_width
|
| 43 |
-
if aspect_ratio_preservation == "crop_to_new":
|
| 44 |
-
crop = "center"
|
| 45 |
-
|
| 46 |
-
lat_h = round(
|
| 47 |
-
np.sqrt(max_area * aspect_ratio) // VAE_STRIDE[1] //
|
| 48 |
-
PATCH_SIZE[1] * PATCH_SIZE[1])
|
| 49 |
-
lat_w = round(
|
| 50 |
-
np.sqrt(max_area / aspect_ratio) // VAE_STRIDE[2] //
|
| 51 |
-
PATCH_SIZE[2] * PATCH_SIZE[2])
|
| 52 |
-
h = lat_h * VAE_STRIDE[1]
|
| 53 |
-
w = lat_w * VAE_STRIDE[2]
|
| 54 |
-
|
| 55 |
-
resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1)
|
| 56 |
-
|
| 57 |
-
return (resized_image, w, h)
|
| 58 |
-
|
| 59 |
-
class ExtractStartFramesForContinuations:
|
| 60 |
-
@classmethod
|
| 61 |
-
def INPUT_TYPES(s):
|
| 62 |
-
return {
|
| 63 |
-
"required": {
|
| 64 |
-
"input_video_frames": ("IMAGE", {"tooltip": "Input video frames to extract the start frames from."}),
|
| 65 |
-
"num_frames": ("INT", {"default": 10, "min": 1, "max": 1024, "step": 1, "tooltip": "Number of frames to get from the start of the video."}),
|
| 66 |
-
},
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
RETURN_TYPES = ("IMAGE",)
|
| 70 |
-
RETURN_NAMES = ("start_frames",)
|
| 71 |
-
FUNCTION = "get_start_frames"
|
| 72 |
-
CATEGORY = "WanVideoWrapper"
|
| 73 |
-
DESCRIPTION = "Extracts the first N frames from a video sequence for continuations."
|
| 74 |
-
|
| 75 |
-
def get_start_frames(self, input_video_frames, num_frames):
|
| 76 |
-
if input_video_frames is None or input_video_frames.shape[0] == 0:
|
| 77 |
-
log.warning("Input video frames are empty. Returning an empty tensor.")
|
| 78 |
-
if input_video_frames is not None:
|
| 79 |
-
return (torch.empty((0,) + input_video_frames.shape[1:], dtype=input_video_frames.dtype),)
|
| 80 |
-
else:
|
| 81 |
-
# Return a tensor with 4 dimensions, as expected for an IMAGE type.
|
| 82 |
-
return (torch.empty((0, 64, 64, 3), dtype=torch.float32),)
|
| 83 |
-
|
| 84 |
-
total_frames = input_video_frames.shape[0]
|
| 85 |
-
num_to_get = min(num_frames, total_frames)
|
| 86 |
-
|
| 87 |
-
if num_to_get < num_frames:
|
| 88 |
-
log.warning(f"Requested {num_frames} frames, but input video only has {total_frames} frames. Returning first {num_to_get} frames.")
|
| 89 |
-
|
| 90 |
-
start_frames = input_video_frames[:num_to_get]
|
| 91 |
-
|
| 92 |
-
return (start_frames.cpu().float(),)
|
| 93 |
-
|
| 94 |
-
class WanVideoVACEStartToEndFrame:
|
| 95 |
-
@classmethod
|
| 96 |
-
def INPUT_TYPES(s):
|
| 97 |
-
return {"required": {
|
| 98 |
-
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
|
| 99 |
-
"empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}),
|
| 100 |
-
},
|
| 101 |
-
"optional": {
|
| 102 |
-
"start_image": ("IMAGE",),
|
| 103 |
-
"end_image": ("IMAGE",),
|
| 104 |
-
"control_images": ("IMAGE",),
|
| 105 |
-
"inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}),
|
| 106 |
-
"start_index": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Index to start from"}),
|
| 107 |
-
"end_index": ("INT", {"default": -1, "min": -10000, "max": 10000, "step": 1, "tooltip": "Index to end at"}),
|
| 108 |
-
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01, "tooltip": "How much does the control images apply?"}),
|
| 109 |
-
"control_ease": ("INT", {"default":
|
| 110 |
-
},
|
| 111 |
-
}
|
| 112 |
-
|
| 113 |
-
RETURN_TYPES = ("IMAGE", "MASK", )
|
| 114 |
-
RETURN_NAMES = ("images", "masks",)
|
| 115 |
-
FUNCTION = "process"
|
| 116 |
-
CATEGORY = "WanVideoWrapper"
|
| 117 |
-
DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE"
|
| 118 |
-
|
| 119 |
-
def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None, start_index=0, end_index=-1, control_strength=1.0, control_ease=
|
| 120 |
-
|
| 121 |
-
device = start_image.device if start_image is not None else end_image.device
|
| 122 |
-
B, H, W, C = start_image.shape if start_image is not None else end_image.shape
|
| 123 |
-
|
| 124 |
-
if control_strength < 1.0 and control_images is not None:
|
| 125 |
-
# strength happens at much smaller number
|
| 126 |
-
control_strength *= 2.0
|
| 127 |
-
control_strength = control_strength * control_strength / 8.0
|
| 128 |
-
control_images = torch.lerp(torch.ones((control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3])) *
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
if
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
"
|
| 220 |
-
},
|
| 221 |
-
"
|
| 222 |
-
|
| 223 |
-
},
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
if
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
return
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
},
|
| 284 |
-
"
|
| 285 |
-
|
| 286 |
-
},
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
if
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
return
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
0.
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
0.
|
| 391 |
-
0.
|
| 392 |
-
0.
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
0.
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
return
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
def
|
| 438 |
-
return
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
return
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
if
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
if
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
#
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
if
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
if
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
if
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
#
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
"
|
| 685 |
-
"
|
| 686 |
-
"
|
| 687 |
-
"
|
| 688 |
-
"
|
| 689 |
-
"
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
"
|
| 693 |
-
"
|
| 694 |
-
"
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
"
|
| 698 |
-
"
|
| 699 |
-
"
|
| 700 |
-
"
|
| 701 |
-
"
|
| 702 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
}
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from comfy.utils import common_upscale
|
| 4 |
+
from .utils import log
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from server import PromptServer
|
| 9 |
+
except:
|
| 10 |
+
PromptServer = None
|
| 11 |
+
|
| 12 |
+
VAE_STRIDE = (4, 8, 8)
|
| 13 |
+
PATCH_SIZE = (1, 2, 2)
|
| 14 |
+
|
| 15 |
+
class WanVideoImageResizeToClosest:
|
| 16 |
+
@classmethod
|
| 17 |
+
def INPUT_TYPES(s):
|
| 18 |
+
return {"required": {
|
| 19 |
+
"image": ("IMAGE", {"tooltip": "Image to resize"}),
|
| 20 |
+
"generation_width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
|
| 21 |
+
"generation_height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
|
| 22 |
+
"aspect_ratio_preservation": (["keep_input", "stretch_to_new", "crop_to_new"],),
|
| 23 |
+
},
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
RETURN_TYPES = ("IMAGE", "INT", "INT", )
|
| 27 |
+
RETURN_NAMES = ("image","width","height",)
|
| 28 |
+
FUNCTION = "process"
|
| 29 |
+
CATEGORY = "WanVideoWrapper"
|
| 30 |
+
DESCRIPTION = "Resizes image to the closest supported resolution based on aspect ratio and max pixels, according to the original code"
|
| 31 |
+
|
| 32 |
+
def process(self, image, generation_width, generation_height, aspect_ratio_preservation ):
|
| 33 |
+
|
| 34 |
+
H, W = image.shape[1], image.shape[2]
|
| 35 |
+
max_area = generation_width * generation_height
|
| 36 |
+
|
| 37 |
+
crop = "disabled"
|
| 38 |
+
|
| 39 |
+
if aspect_ratio_preservation == "keep_input":
|
| 40 |
+
aspect_ratio = H / W
|
| 41 |
+
elif aspect_ratio_preservation == "stretch_to_new" or aspect_ratio_preservation == "crop_to_new":
|
| 42 |
+
aspect_ratio = generation_height / generation_width
|
| 43 |
+
if aspect_ratio_preservation == "crop_to_new":
|
| 44 |
+
crop = "center"
|
| 45 |
+
|
| 46 |
+
lat_h = round(
|
| 47 |
+
np.sqrt(max_area * aspect_ratio) // VAE_STRIDE[1] //
|
| 48 |
+
PATCH_SIZE[1] * PATCH_SIZE[1])
|
| 49 |
+
lat_w = round(
|
| 50 |
+
np.sqrt(max_area / aspect_ratio) // VAE_STRIDE[2] //
|
| 51 |
+
PATCH_SIZE[2] * PATCH_SIZE[2])
|
| 52 |
+
h = lat_h * VAE_STRIDE[1]
|
| 53 |
+
w = lat_w * VAE_STRIDE[2]
|
| 54 |
+
|
| 55 |
+
resized_image = common_upscale(image.movedim(-1, 1), w, h, "lanczos", crop).movedim(1, -1)
|
| 56 |
+
|
| 57 |
+
return (resized_image, w, h)
|
| 58 |
+
|
| 59 |
+
class ExtractStartFramesForContinuations:
|
| 60 |
+
@classmethod
|
| 61 |
+
def INPUT_TYPES(s):
|
| 62 |
+
return {
|
| 63 |
+
"required": {
|
| 64 |
+
"input_video_frames": ("IMAGE", {"tooltip": "Input video frames to extract the start frames from."}),
|
| 65 |
+
"num_frames": ("INT", {"default": 10, "min": 1, "max": 1024, "step": 1, "tooltip": "Number of frames to get from the start of the video."}),
|
| 66 |
+
},
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
RETURN_TYPES = ("IMAGE",)
|
| 70 |
+
RETURN_NAMES = ("start_frames",)
|
| 71 |
+
FUNCTION = "get_start_frames"
|
| 72 |
+
CATEGORY = "WanVideoWrapper"
|
| 73 |
+
DESCRIPTION = "Extracts the first N frames from a video sequence for continuations."
|
| 74 |
+
|
| 75 |
+
def get_start_frames(self, input_video_frames, num_frames):
|
| 76 |
+
if input_video_frames is None or input_video_frames.shape[0] == 0:
|
| 77 |
+
log.warning("Input video frames are empty. Returning an empty tensor.")
|
| 78 |
+
if input_video_frames is not None:
|
| 79 |
+
return (torch.empty((0,) + input_video_frames.shape[1:], dtype=input_video_frames.dtype),)
|
| 80 |
+
else:
|
| 81 |
+
# Return a tensor with 4 dimensions, as expected for an IMAGE type.
|
| 82 |
+
return (torch.empty((0, 64, 64, 3), dtype=torch.float32),)
|
| 83 |
+
|
| 84 |
+
total_frames = input_video_frames.shape[0]
|
| 85 |
+
num_to_get = min(num_frames, total_frames)
|
| 86 |
+
|
| 87 |
+
if num_to_get < num_frames:
|
| 88 |
+
log.warning(f"Requested {num_frames} frames, but input video only has {total_frames} frames. Returning first {num_to_get} frames.")
|
| 89 |
+
|
| 90 |
+
start_frames = input_video_frames[:num_to_get]
|
| 91 |
+
|
| 92 |
+
return (start_frames.cpu().float(),)
|
| 93 |
+
|
| 94 |
+
class WanVideoVACEStartToEndFrame:
|
| 95 |
+
@classmethod
|
| 96 |
+
def INPUT_TYPES(s):
|
| 97 |
+
return {"required": {
|
| 98 |
+
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
|
| 99 |
+
"empty_frame_level": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "White level of empty frame to use"}),
|
| 100 |
+
},
|
| 101 |
+
"optional": {
|
| 102 |
+
"start_image": ("IMAGE",),
|
| 103 |
+
"end_image": ("IMAGE",),
|
| 104 |
+
"control_images": ("IMAGE",),
|
| 105 |
+
"inpaint_mask": ("MASK", {"tooltip": "Inpaint mask to use for the empty frames"}),
|
| 106 |
+
"start_index": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Index to start from"}),
|
| 107 |
+
"end_index": ("INT", {"default": -1, "min": -10000, "max": 10000, "step": 1, "tooltip": "Index to end at"}),
|
| 108 |
+
"control_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01, "tooltip": "How much does the control images apply?"}),
|
| 109 |
+
"control_ease": ("INT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 1, "tooltip": "How many frames to ease in the control video?"}),
|
| 110 |
+
},
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
RETURN_TYPES = ("IMAGE", "MASK", )
|
| 114 |
+
RETURN_NAMES = ("images", "masks",)
|
| 115 |
+
FUNCTION = "process"
|
| 116 |
+
CATEGORY = "WanVideoWrapper"
|
| 117 |
+
DESCRIPTION = "Helper node to create start/end frame batch and masks for VACE"
|
| 118 |
+
|
| 119 |
+
def process(self, num_frames, empty_frame_level, start_image=None, end_image=None, control_images=None, inpaint_mask=None, start_index=0, end_index=-1, control_strength=1.0, control_ease=0):
|
| 120 |
+
|
| 121 |
+
device = start_image.device if start_image is not None else end_image.device
|
| 122 |
+
B, H, W, C = start_image.shape if start_image is not None else end_image.shape
|
| 123 |
+
|
| 124 |
+
if control_strength < 1.0 and control_images is not None:
|
| 125 |
+
# strength happens at much smaller number
|
| 126 |
+
control_strength *= 2.0
|
| 127 |
+
control_strength = control_strength * control_strength / 8.0
|
| 128 |
+
control_images = torch.lerp(torch.ones((control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3])) * empty_frame_level, control_images, control_strength)
|
| 129 |
+
|
| 130 |
+
# ease in control stuff?
|
| 131 |
+
if num_frames > control_ease and control_ease > 0:
|
| 132 |
+
empty_frame = torch.ones((1, control_images.shape[1], control_images.shape[2], control_images.shape[3])) * empty_frame_level
|
| 133 |
+
if start_image is not None:
|
| 134 |
+
for i in range(1, control_ease + 1):
|
| 135 |
+
control_images[i] = torch.lerp(control_images[i], empty_frame, (control_ease - i) / (1 + control_ease))
|
| 136 |
+
else:
|
| 137 |
+
for i in range(num_frames - control_ease - 1, num_frames - 1):
|
| 138 |
+
control_images[i] = torch.lerp(control_images[i], empty_frame, i / (1 + control_ease))
|
| 139 |
+
|
| 140 |
+
if start_image is None and end_image is None and control_images is not None:
|
| 141 |
+
if control_images.shape[0] >= num_frames:
|
| 142 |
+
control_images = control_images[:num_frames]
|
| 143 |
+
elif control_images.shape[0] < num_frames:
|
| 144 |
+
# padd with empty_frame_level frames
|
| 145 |
+
padding = torch.ones((num_frames - control_images.shape[0], control_images.shape[1], control_images.shape[2], control_images.shape[3]), device=control_images.device) * empty_frame_level
|
| 146 |
+
control_images = torch.cat([control_images, padding], dim=0)
|
| 147 |
+
return (control_images.cpu().float(), torch.zeros_like(control_images[:, :, :, 0]).cpu().float())
|
| 148 |
+
|
| 149 |
+
# Convert negative end_index to positive
|
| 150 |
+
if end_index < 0:
|
| 151 |
+
end_index = num_frames + end_index
|
| 152 |
+
|
| 153 |
+
# Create output batch with empty frames
|
| 154 |
+
out_batch = torch.ones((num_frames, H, W, 3), device=device) * empty_frame_level
|
| 155 |
+
|
| 156 |
+
# Create mask tensor with proper dimensions
|
| 157 |
+
masks = torch.ones((num_frames, H, W), device=device)
|
| 158 |
+
|
| 159 |
+
# Pre-process all images at once to avoid redundant work
|
| 160 |
+
if end_image is not None and (end_image.shape[1] != H or end_image.shape[2] != W):
|
| 161 |
+
end_image = common_upscale(end_image.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
|
| 162 |
+
|
| 163 |
+
if control_images is not None and (control_images.shape[1] != H or control_images.shape[2] != W):
|
| 164 |
+
control_images = common_upscale(control_images.movedim(-1, 1), W, H, "lanczos", "disabled").movedim(1, -1)
|
| 165 |
+
|
| 166 |
+
# Place start image at start_index
|
| 167 |
+
if start_image is not None:
|
| 168 |
+
frames_to_copy = min(start_image.shape[0], num_frames - start_index)
|
| 169 |
+
if frames_to_copy > 0:
|
| 170 |
+
out_batch[start_index:start_index + frames_to_copy] = start_image[:frames_to_copy]
|
| 171 |
+
masks[start_index:start_index + frames_to_copy] = 0
|
| 172 |
+
|
| 173 |
+
# Place end image at end_index
|
| 174 |
+
if end_image is not None:
|
| 175 |
+
# Calculate where to start placing end images
|
| 176 |
+
end_start = end_index - end_image.shape[0] + 1
|
| 177 |
+
if end_start < 0: # Handle case where end images won't all fit
|
| 178 |
+
end_image = end_image[abs(end_start):]
|
| 179 |
+
end_start = 0
|
| 180 |
+
|
| 181 |
+
frames_to_copy = min(end_image.shape[0], num_frames - end_start)
|
| 182 |
+
if frames_to_copy > 0:
|
| 183 |
+
out_batch[end_start:end_start + frames_to_copy] = end_image[:frames_to_copy]
|
| 184 |
+
masks[end_start:end_start + frames_to_copy] = 0
|
| 185 |
+
|
| 186 |
+
# Apply control images to remaining frames that don't have start or end images
|
| 187 |
+
if control_images is not None:
|
| 188 |
+
# Create a mask of frames that are still empty (mask == 1)
|
| 189 |
+
empty_frames = masks.sum(dim=(1, 2)) > 0.5 * H * W
|
| 190 |
+
|
| 191 |
+
if empty_frames.any():
|
| 192 |
+
# Only apply control images where they exist
|
| 193 |
+
control_length = control_images.shape[0]
|
| 194 |
+
for frame_idx in range(num_frames):
|
| 195 |
+
if empty_frames[frame_idx] and frame_idx < control_length:
|
| 196 |
+
out_batch[frame_idx] = control_images[frame_idx]
|
| 197 |
+
|
| 198 |
+
# Apply inpaint mask if provided
|
| 199 |
+
if inpaint_mask is not None:
|
| 200 |
+
inpaint_mask = common_upscale(inpaint_mask.unsqueeze(1), W, H, "nearest-exact", "disabled").squeeze(1).to(device)
|
| 201 |
+
|
| 202 |
+
# Handle different mask lengths efficiently
|
| 203 |
+
if inpaint_mask.shape[0] > num_frames:
|
| 204 |
+
inpaint_mask = inpaint_mask[:num_frames]
|
| 205 |
+
elif inpaint_mask.shape[0] < num_frames:
|
| 206 |
+
repeat_factor = (num_frames + inpaint_mask.shape[0] - 1) // inpaint_mask.shape[0] # Ceiling division
|
| 207 |
+
inpaint_mask = inpaint_mask.repeat(repeat_factor, 1, 1)[:num_frames]
|
| 208 |
+
|
| 209 |
+
# Apply mask in one operation
|
| 210 |
+
masks = inpaint_mask * masks
|
| 211 |
+
|
| 212 |
+
return (out_batch.cpu().float(), masks.cpu().float())
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class CreateCFGScheduleFloatList:
|
| 216 |
+
@classmethod
|
| 217 |
+
def INPUT_TYPES(s):
|
| 218 |
+
return {"required": {
|
| 219 |
+
"steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
|
| 220 |
+
"cfg_scale_start": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
|
| 221 |
+
"cfg_scale_end": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 30.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
|
| 222 |
+
"interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
|
| 223 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
|
| 224 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
|
| 225 |
+
},
|
| 226 |
+
"hidden": {
|
| 227 |
+
"unique_id": "UNIQUE_ID",
|
| 228 |
+
},
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
RETURN_TYPES = ("FLOAT", )
|
| 232 |
+
RETURN_NAMES = ("float_list",)
|
| 233 |
+
FUNCTION = "process"
|
| 234 |
+
CATEGORY = "WanVideoWrapper"
|
| 235 |
+
DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule cfg scale for the steps, outside the set range cfg is set to 1.0"
|
| 236 |
+
|
| 237 |
+
def process(self, steps, cfg_scale_start, cfg_scale_end, interpolation, start_percent, end_percent, unique_id):
|
| 238 |
+
|
| 239 |
+
# Create a list of floats for the cfg schedule
|
| 240 |
+
cfg_list = [1.0] * steps
|
| 241 |
+
start_idx = min(int(steps * start_percent), steps - 1)
|
| 242 |
+
end_idx = min(int(steps * end_percent), steps - 1)
|
| 243 |
+
|
| 244 |
+
for i in range(start_idx, end_idx + 1):
|
| 245 |
+
if i >= steps:
|
| 246 |
+
break
|
| 247 |
+
|
| 248 |
+
if end_idx == start_idx:
|
| 249 |
+
t = 0
|
| 250 |
+
else:
|
| 251 |
+
t = (i - start_idx) / (end_idx - start_idx)
|
| 252 |
+
|
| 253 |
+
if interpolation == "linear":
|
| 254 |
+
factor = t
|
| 255 |
+
elif interpolation == "ease_in":
|
| 256 |
+
factor = t * t
|
| 257 |
+
elif interpolation == "ease_out":
|
| 258 |
+
factor = t * (2 - t)
|
| 259 |
+
|
| 260 |
+
cfg_list[i] = round(cfg_scale_start + factor * (cfg_scale_end - cfg_scale_start), 2)
|
| 261 |
+
|
| 262 |
+
# If start_percent > 0, always include the first step
|
| 263 |
+
if start_percent > 0:
|
| 264 |
+
cfg_list[0] = 1.0
|
| 265 |
+
|
| 266 |
+
if unique_id and PromptServer is not None:
|
| 267 |
+
try:
|
| 268 |
+
PromptServer.instance.send_progress_text(
|
| 269 |
+
f"{cfg_list}",
|
| 270 |
+
unique_id
|
| 271 |
+
)
|
| 272 |
+
except:
|
| 273 |
+
pass
|
| 274 |
+
|
| 275 |
+
return (cfg_list,)
|
| 276 |
+
|
| 277 |
+
class CreateScheduleFloatList:
|
| 278 |
+
@classmethod
|
| 279 |
+
def INPUT_TYPES(s):
|
| 280 |
+
return {"required": {
|
| 281 |
+
"steps": ("INT", {"default": 30, "min": 2, "max": 1000, "step": 1, "tooltip": "Number of steps to schedule cfg for"} ),
|
| 282 |
+
"start_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
|
| 283 |
+
"end_value": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01, "tooltip": "CFG scale to use for the steps"}),
|
| 284 |
+
"default_value": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.01, "round": 0.01, "tooltip": "Default value to use for the steps"}),
|
| 285 |
+
"interpolation": (["linear", "ease_in", "ease_out"], {"default": "linear", "tooltip": "Interpolation method to use for the cfg scale"}),
|
| 286 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "Start percent of the steps to apply cfg"}),
|
| 287 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "round": 0.01,"tooltip": "End percent of the steps to apply cfg"}),
|
| 288 |
+
},
|
| 289 |
+
"hidden": {
|
| 290 |
+
"unique_id": "UNIQUE_ID",
|
| 291 |
+
},
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
RETURN_TYPES = ("FLOAT", )
|
| 295 |
+
RETURN_NAMES = ("float_list",)
|
| 296 |
+
FUNCTION = "process"
|
| 297 |
+
CATEGORY = "WanVideoWrapper"
|
| 298 |
+
DESCRIPTION = "Helper node to generate a list of floats that can be used to schedule things like cfg and lora scale per step"
|
| 299 |
+
|
| 300 |
+
def process(self, steps, start_value, end_value, default_value,interpolation, start_percent, end_percent, unique_id):
|
| 301 |
+
|
| 302 |
+
# Create a list of floats for the cfg schedule
|
| 303 |
+
cfg_list = [default_value] * steps
|
| 304 |
+
start_idx = min(int(steps * start_percent), steps - 1)
|
| 305 |
+
end_idx = min(int(steps * end_percent), steps - 1)
|
| 306 |
+
|
| 307 |
+
for i in range(start_idx, end_idx + 1):
|
| 308 |
+
if i >= steps:
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
if end_idx == start_idx:
|
| 312 |
+
t = 0
|
| 313 |
+
else:
|
| 314 |
+
t = (i - start_idx) / (end_idx - start_idx)
|
| 315 |
+
|
| 316 |
+
if interpolation == "linear":
|
| 317 |
+
factor = t
|
| 318 |
+
elif interpolation == "ease_in":
|
| 319 |
+
factor = t * t
|
| 320 |
+
elif interpolation == "ease_out":
|
| 321 |
+
factor = t * (2 - t)
|
| 322 |
+
|
| 323 |
+
cfg_list[i] = round(start_value + factor * (end_value - start_value), 2)
|
| 324 |
+
|
| 325 |
+
# If start_percent > 0, always include the first step
|
| 326 |
+
if start_percent > 0:
|
| 327 |
+
cfg_list[0] = default_value
|
| 328 |
+
|
| 329 |
+
if unique_id and PromptServer is not None:
|
| 330 |
+
try:
|
| 331 |
+
PromptServer.instance.send_progress_text(
|
| 332 |
+
f"{cfg_list}",
|
| 333 |
+
unique_id
|
| 334 |
+
)
|
| 335 |
+
except:
|
| 336 |
+
pass
|
| 337 |
+
|
| 338 |
+
return (cfg_list,)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class DummyComfyWanModelObject:
|
| 342 |
+
@classmethod
|
| 343 |
+
def INPUT_TYPES(s):
|
| 344 |
+
return {"required": {
|
| 345 |
+
"shift": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "Sigma shift value"}),
|
| 346 |
+
}
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
RETURN_TYPES = ("MODEL", )
|
| 350 |
+
RETURN_NAMES = ("model",)
|
| 351 |
+
FUNCTION = "create"
|
| 352 |
+
CATEGORY = "WanVideoWrapper"
|
| 353 |
+
DESCRIPTION = "Helper node to create empty Wan model to use with BasicScheduler -node to get sigmas"
|
| 354 |
+
|
| 355 |
+
def create(self, shift):
|
| 356 |
+
from comfy.model_sampling import ModelSamplingDiscreteFlow
|
| 357 |
+
class DummyModel:
|
| 358 |
+
def get_model_object(self, name):
|
| 359 |
+
if name == "model_sampling":
|
| 360 |
+
model_sampling = ModelSamplingDiscreteFlow()
|
| 361 |
+
model_sampling.set_parameters(shift=shift)
|
| 362 |
+
return model_sampling
|
| 363 |
+
return None
|
| 364 |
+
return (DummyModel(),)
|
| 365 |
+
|
| 366 |
+
class WanVideoLatentReScale:
|
| 367 |
+
@classmethod
|
| 368 |
+
def INPUT_TYPES(s):
|
| 369 |
+
return {"required": {
|
| 370 |
+
"samples": ("LATENT",),
|
| 371 |
+
"direction": (["comfy_to_wrapper", "wrapper_to_comfy"], {"tooltip": "Direction to rescale latents, from comfy to wrapper or vice versa"}),
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
RETURN_TYPES = ("LATENT",)
|
| 376 |
+
RETURN_NAMES = ("samples",)
|
| 377 |
+
FUNCTION = "encode"
|
| 378 |
+
CATEGORY = "WanVideoWrapper"
|
| 379 |
+
DESCRIPTION = "Rescale latents to match the expected range for encoding or decoding between native ComfyUI VAE and the WanVideoWrapper VAE."
|
| 380 |
+
|
| 381 |
+
def encode(self, samples, direction):
|
| 382 |
+
samples = samples.copy()
|
| 383 |
+
latents = samples["samples"]
|
| 384 |
+
|
| 385 |
+
if latents.shape[1] == 48:
|
| 386 |
+
mean = [
|
| 387 |
+
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
| 388 |
+
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
| 389 |
+
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
|
| 390 |
+
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
|
| 391 |
+
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
|
| 392 |
+
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
|
| 393 |
+
]
|
| 394 |
+
std = [
|
| 395 |
+
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
|
| 396 |
+
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
|
| 397 |
+
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
|
| 398 |
+
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
|
| 399 |
+
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
|
| 400 |
+
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
|
| 401 |
+
]
|
| 402 |
+
else:
|
| 403 |
+
mean = [
|
| 404 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 405 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 406 |
+
]
|
| 407 |
+
std = [
|
| 408 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 409 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 410 |
+
]
|
| 411 |
+
mean = torch.tensor(mean).view(1, latents.shape[1], 1, 1, 1)
|
| 412 |
+
std = torch.tensor(std).view(1, latents.shape[1], 1, 1, 1)
|
| 413 |
+
inv_std = (1.0 / std).view(1, latents.shape[1], 1, 1, 1)
|
| 414 |
+
if direction == "comfy_to_wrapper":
|
| 415 |
+
latents = (latents - mean.to(latents)) * inv_std.to(latents)
|
| 416 |
+
elif direction == "wrapper_to_comfy":
|
| 417 |
+
latents = latents / inv_std.to(latents) + mean.to(latents)
|
| 418 |
+
|
| 419 |
+
samples["samples"] = latents
|
| 420 |
+
|
| 421 |
+
return (samples,)
|
| 422 |
+
|
| 423 |
+
class WanVideoSigmaToStep:
|
| 424 |
+
@classmethod
|
| 425 |
+
def INPUT_TYPES(s):
|
| 426 |
+
return {"required": {
|
| 427 |
+
"sigma": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 428 |
+
},
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
RETURN_TYPES = ("INT", )
|
| 432 |
+
RETURN_NAMES = ("step",)
|
| 433 |
+
FUNCTION = "convert"
|
| 434 |
+
CATEGORY = "WanVideoWrapper"
|
| 435 |
+
DESCRIPTION = "Simply passes a float value as an integer, used to set start/end steps with sigma threshold"
|
| 436 |
+
|
| 437 |
+
def convert(self, sigma):
|
| 438 |
+
return (sigma,)
|
| 439 |
+
|
| 440 |
+
class NormalizeAudioLoudness:
|
| 441 |
+
@classmethod
|
| 442 |
+
def INPUT_TYPES(s):
|
| 443 |
+
return {"required": {
|
| 444 |
+
"audio": ("AUDIO",),
|
| 445 |
+
"lufs": ("FLOAT", {"default": -23.0, "min": -100.0, "max": 0.0, "step": 0.1, "tool": "Loudness Units relative to Full Scale, higher LUFS values (closer to 0) mean louder audio. Lower LUFS values (more negative) mean quieter audio."}),
|
| 446 |
+
},
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
RETURN_TYPES = ("AUDIO", )
|
| 450 |
+
RETURN_NAMES = ("audio", )
|
| 451 |
+
FUNCTION = "normalize"
|
| 452 |
+
CATEGORY = "WanVideoWrapper"
|
| 453 |
+
|
| 454 |
+
def normalize(self, audio, lufs):
|
| 455 |
+
audio_input = audio["waveform"]
|
| 456 |
+
sample_rate = audio["sample_rate"]
|
| 457 |
+
if audio_input.dim() == 3:
|
| 458 |
+
audio_input = audio_input.squeeze(0)
|
| 459 |
+
audio_input_np = audio_input.detach().transpose(0, 1).numpy().astype(np.float32)
|
| 460 |
+
audio_input_np = np.ascontiguousarray(audio_input_np)
|
| 461 |
+
normalized_audio = self.loudness_norm(audio_input_np, sr=sample_rate, lufs=lufs)
|
| 462 |
+
|
| 463 |
+
out_audio = {"waveform": torch.from_numpy(normalized_audio).transpose(0, 1).unsqueeze(0).float(), "sample_rate": sample_rate}
|
| 464 |
+
|
| 465 |
+
return (out_audio, )
|
| 466 |
+
|
| 467 |
+
def loudness_norm(self, audio_array, sr=16000, lufs=-23):
|
| 468 |
+
try:
|
| 469 |
+
import pyloudnorm
|
| 470 |
+
except:
|
| 471 |
+
raise ImportError("pyloudnorm package is not installed")
|
| 472 |
+
meter = pyloudnorm.Meter(sr)
|
| 473 |
+
loudness = meter.integrated_loudness(audio_array)
|
| 474 |
+
if abs(loudness) > 100:
|
| 475 |
+
return audio_array
|
| 476 |
+
normalized_audio = pyloudnorm.normalize.loudness(audio_array, loudness, lufs)
|
| 477 |
+
return normalized_audio
|
| 478 |
+
|
| 479 |
+
class WanVideoPassImagesFromSamples:
|
| 480 |
+
@classmethod
|
| 481 |
+
def INPUT_TYPES(s):
|
| 482 |
+
return {"required": {
|
| 483 |
+
"samples": ("LATENT",),
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
RETURN_TYPES = ("IMAGE", "STRING",)
|
| 488 |
+
RETURN_NAMES = ("images", "output_path",)
|
| 489 |
+
OUTPUT_TOOLTIPS = ("Decoded images from the samples dictionary", "Output path if provided in the samples dictionary",)
|
| 490 |
+
FUNCTION = "decode"
|
| 491 |
+
CATEGORY = "WanVideoWrapper"
|
| 492 |
+
DESCRIPTION = "Gets possible already decoded images from the samples dictionary, used with Multi/InfiniteTalk sampling"
|
| 493 |
+
|
| 494 |
+
def decode(self, samples):
|
| 495 |
+
video = samples.get("video", None)
|
| 496 |
+
video.clamp_(-1.0, 1.0)
|
| 497 |
+
video.add_(1.0).div_(2.0)
|
| 498 |
+
return video.cpu().float(), samples.get("output_path", "")
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class FaceMaskFromPoseKeypoints:
|
| 502 |
+
@classmethod
|
| 503 |
+
def INPUT_TYPES(s):
|
| 504 |
+
input_types = {
|
| 505 |
+
"required": {
|
| 506 |
+
"pose_kps": ("POSE_KEYPOINT",),
|
| 507 |
+
"person_index": ("INT", {"default": 0, "min": 0, "max": 100, "step": 1, "tooltip": "Index of the person to start with"}),
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
return input_types
|
| 511 |
+
RETURN_TYPES = ("MASK",)
|
| 512 |
+
FUNCTION = "createmask"
|
| 513 |
+
CATEGORY = "ControlNet Preprocessors/Pose Keypoint Postprocess"
|
| 514 |
+
|
| 515 |
+
def createmask(self, pose_kps, person_index):
|
| 516 |
+
pose_frames = pose_kps
|
| 517 |
+
prev_center = None
|
| 518 |
+
np_frames = []
|
| 519 |
+
for i, pose_frame in enumerate(pose_frames):
|
| 520 |
+
selected_idx, prev_center = self.select_closest_person(pose_frame, person_index if i == 0 else prev_center)
|
| 521 |
+
np_frames.append(self.draw_kps(pose_frame, selected_idx))
|
| 522 |
+
|
| 523 |
+
if not np_frames:
|
| 524 |
+
# Handle case where no frames were processed
|
| 525 |
+
log.warning("No valid pose frames found, returning empty mask")
|
| 526 |
+
return (torch.zeros((1, 64, 64), dtype=torch.float32),)
|
| 527 |
+
|
| 528 |
+
np_frames = np.stack(np_frames, axis=0)
|
| 529 |
+
tensor = torch.from_numpy(np_frames).float() / 255.
|
| 530 |
+
print("tensor.shape:", tensor.shape)
|
| 531 |
+
tensor = tensor[:, :, :, 0]
|
| 532 |
+
return (tensor,)
|
| 533 |
+
|
| 534 |
+
def select_closest_person(self, pose_frame, prev_center_or_index):
|
| 535 |
+
people = pose_frame["people"]
|
| 536 |
+
if not people:
|
| 537 |
+
return -1, None
|
| 538 |
+
|
| 539 |
+
centers = []
|
| 540 |
+
valid_people_indices = []
|
| 541 |
+
|
| 542 |
+
for idx, person in enumerate(people):
|
| 543 |
+
# Check if face keypoints exist and are valid
|
| 544 |
+
if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
|
| 545 |
+
continue
|
| 546 |
+
|
| 547 |
+
kps = np.array(person["face_keypoints_2d"])
|
| 548 |
+
if len(kps) == 0:
|
| 549 |
+
continue
|
| 550 |
+
|
| 551 |
+
n = len(kps) // 3
|
| 552 |
+
if n == 0:
|
| 553 |
+
continue
|
| 554 |
+
|
| 555 |
+
facial_kps = rearrange(kps, "(n c) -> n c", n=n, c=3)[:, :2]
|
| 556 |
+
|
| 557 |
+
# Check if we have valid coordinates (not all zeros)
|
| 558 |
+
if np.all(facial_kps == 0):
|
| 559 |
+
continue
|
| 560 |
+
|
| 561 |
+
center = facial_kps.mean(axis=0)
|
| 562 |
+
|
| 563 |
+
# Check if center is valid (not NaN or infinite)
|
| 564 |
+
if np.isnan(center).any() or np.isinf(center).any():
|
| 565 |
+
continue
|
| 566 |
+
|
| 567 |
+
centers.append(center)
|
| 568 |
+
valid_people_indices.append(idx)
|
| 569 |
+
|
| 570 |
+
if not centers:
|
| 571 |
+
return -1, None
|
| 572 |
+
|
| 573 |
+
if isinstance(prev_center_or_index, (int, np.integer)):
|
| 574 |
+
# First frame: use person_index, but map to valid people
|
| 575 |
+
if 0 <= prev_center_or_index < len(valid_people_indices):
|
| 576 |
+
idx = valid_people_indices[prev_center_or_index]
|
| 577 |
+
return idx, centers[prev_center_or_index]
|
| 578 |
+
elif valid_people_indices:
|
| 579 |
+
# Fallback to first valid person
|
| 580 |
+
idx = valid_people_indices[0]
|
| 581 |
+
return idx, centers[0]
|
| 582 |
+
else:
|
| 583 |
+
return -1, None
|
| 584 |
+
elif prev_center_or_index is not None:
|
| 585 |
+
# Find closest to previous center
|
| 586 |
+
prev_center = np.array(prev_center_or_index)
|
| 587 |
+
dists = [np.linalg.norm(center - prev_center) for center in centers]
|
| 588 |
+
min_idx = int(np.argmin(dists))
|
| 589 |
+
actual_idx = valid_people_indices[min_idx]
|
| 590 |
+
return actual_idx, centers[min_idx]
|
| 591 |
+
else:
|
| 592 |
+
# prev_center_or_index is None, fallback to first valid person
|
| 593 |
+
if valid_people_indices:
|
| 594 |
+
idx = valid_people_indices[0]
|
| 595 |
+
return idx, centers[0]
|
| 596 |
+
else:
|
| 597 |
+
return -1, None
|
| 598 |
+
|
| 599 |
+
def draw_kps(self, pose_frame, person_index):
|
| 600 |
+
import cv2
|
| 601 |
+
width, height = pose_frame["canvas_width"], pose_frame["canvas_height"]
|
| 602 |
+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
|
| 603 |
+
people = pose_frame["people"]
|
| 604 |
+
|
| 605 |
+
if person_index < 0 or person_index >= len(people):
|
| 606 |
+
return canvas # Out of bounds, return blank
|
| 607 |
+
|
| 608 |
+
person = people[person_index]
|
| 609 |
+
|
| 610 |
+
# Check if face keypoints exist and are valid
|
| 611 |
+
if "face_keypoints_2d" not in person or not person["face_keypoints_2d"]:
|
| 612 |
+
return canvas # No face keypoints, return blank
|
| 613 |
+
|
| 614 |
+
face_kps_data = person["face_keypoints_2d"]
|
| 615 |
+
if len(face_kps_data) == 0:
|
| 616 |
+
return canvas # Empty keypoints, return blank
|
| 617 |
+
|
| 618 |
+
n = len(face_kps_data) // 3
|
| 619 |
+
if n < 17: # Need at least 17 points for outer contour
|
| 620 |
+
return canvas # Not enough keypoints, return blank
|
| 621 |
+
|
| 622 |
+
facial_kps = rearrange(np.array(face_kps_data), "(n c) -> n c", n=n, c=3)[:, :2]
|
| 623 |
+
|
| 624 |
+
# Check if we have valid coordinates (not all zeros)
|
| 625 |
+
if np.all(facial_kps == 0):
|
| 626 |
+
return canvas # All keypoints are zero, return blank
|
| 627 |
+
|
| 628 |
+
# Check for NaN or infinite values
|
| 629 |
+
if np.isnan(facial_kps).any() or np.isinf(facial_kps).any():
|
| 630 |
+
return canvas # Invalid coordinates, return blank
|
| 631 |
+
|
| 632 |
+
# Check for negative coordinates or coordinates that would create streaks
|
| 633 |
+
if np.any(facial_kps < 0):
|
| 634 |
+
return canvas # Negative coordinates, likely bad detection
|
| 635 |
+
|
| 636 |
+
# Check if coordinates are reasonable (not too close to edges which might indicate bad detection)
|
| 637 |
+
min_margin = 5 # Minimum distance from edges
|
| 638 |
+
if (np.any(facial_kps[:, 0] < min_margin) or
|
| 639 |
+
np.any(facial_kps[:, 1] < min_margin) or
|
| 640 |
+
np.any(facial_kps[:, 0] > width - min_margin) or
|
| 641 |
+
np.any(facial_kps[:, 1] > height - min_margin)):
|
| 642 |
+
# Check if this looks like a streak to corner (many points near 0,0)
|
| 643 |
+
corner_points = np.sum((facial_kps[:, 0] < min_margin) & (facial_kps[:, 1] < min_margin))
|
| 644 |
+
if corner_points > 3: # Too many points near corner, likely bad detection
|
| 645 |
+
return canvas
|
| 646 |
+
|
| 647 |
+
facial_kps = facial_kps.astype(np.int32)
|
| 648 |
+
|
| 649 |
+
# Ensure coordinates are within canvas bounds
|
| 650 |
+
facial_kps[:, 0] = np.clip(facial_kps[:, 0], 0, width - 1)
|
| 651 |
+
facial_kps[:, 1] = np.clip(facial_kps[:, 1], 0, height - 1)
|
| 652 |
+
|
| 653 |
+
part_color = (255, 255, 255)
|
| 654 |
+
outer_contour = facial_kps[:17]
|
| 655 |
+
|
| 656 |
+
# Additional validation for the contour before drawing
|
| 657 |
+
# Check if contour points are too spread out (indicating bad detection)
|
| 658 |
+
if len(outer_contour) >= 3:
|
| 659 |
+
# Calculate bounding box of the contour
|
| 660 |
+
min_x, min_y = np.min(outer_contour, axis=0)
|
| 661 |
+
max_x, max_y = np.max(outer_contour, axis=0)
|
| 662 |
+
contour_width = max_x - min_x
|
| 663 |
+
contour_height = max_y - min_y
|
| 664 |
+
|
| 665 |
+
# If contour spans more than 80% of canvas, likely bad detection
|
| 666 |
+
if (contour_width > 0.8 * width or contour_height > 0.8 * height):
|
| 667 |
+
return canvas
|
| 668 |
+
|
| 669 |
+
# Check if we have a valid contour (at least 3 unique points)
|
| 670 |
+
unique_points = np.unique(outer_contour, axis=0)
|
| 671 |
+
if len(unique_points) >= 3:
|
| 672 |
+
# Final check: ensure the contour is reasonable
|
| 673 |
+
# Calculate area to see if it's too large or too small
|
| 674 |
+
contour_area = cv2.contourArea(outer_contour)
|
| 675 |
+
canvas_area = width * height
|
| 676 |
+
|
| 677 |
+
# If contour is less than 0.1% or more than 50% of canvas, skip
|
| 678 |
+
if 0.001 * canvas_area <= contour_area <= 0.5 * canvas_area:
|
| 679 |
+
cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
|
| 680 |
+
|
| 681 |
+
return canvas
|
| 682 |
+
|
| 683 |
+
NODE_CLASS_MAPPINGS = {
|
| 684 |
+
"WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
|
| 685 |
+
"WanVideoVACEStartToEndFrame": WanVideoVACEStartToEndFrame,
|
| 686 |
+
"ExtractStartFramesForContinuations": ExtractStartFramesForContinuations,
|
| 687 |
+
"CreateCFGScheduleFloatList": CreateCFGScheduleFloatList,
|
| 688 |
+
"DummyComfyWanModelObject": DummyComfyWanModelObject,
|
| 689 |
+
"WanVideoLatentReScale": WanVideoLatentReScale,
|
| 690 |
+
"CreateScheduleFloatList": CreateScheduleFloatList,
|
| 691 |
+
"WanVideoSigmaToStep": WanVideoSigmaToStep,
|
| 692 |
+
"NormalizeAudioLoudness": NormalizeAudioLoudness,
|
| 693 |
+
"WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
|
| 694 |
+
"FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
|
| 695 |
+
}
|
| 696 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 697 |
+
"WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
|
| 698 |
+
"WanVideoVACEStartToEndFrame": "WanVideo VACE Start To End Frame",
|
| 699 |
+
"ExtractStartFramesForContinuations": "Extract Start Frames For Continuations",
|
| 700 |
+
"CreateCFGScheduleFloatList": "Create CFG Schedule Float List",
|
| 701 |
+
"DummyComfyWanModelObject": "Dummy Comfy Wan Model Object",
|
| 702 |
+
"WanVideoLatentReScale": "WanVideo Latent ReScale",
|
| 703 |
+
"CreateScheduleFloatList": "Create Schedule Float List",
|
| 704 |
+
"WanVideoSigmaToStep": "WanVideo Sigma To Step",
|
| 705 |
+
"NormalizeAudioLoudness": "Normalize Audio Loudness",
|
| 706 |
+
"WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
|
| 707 |
+
"FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
|
| 708 |
}
|