Spaces:
Sleeping
Sleeping
| import nodes | |
| import node_helpers | |
| import torch | |
| import comfy.model_management | |
| import comfy.model_sampling | |
| import comfy.utils | |
| import math | |
| import numpy as np | |
| import av | |
| from io import BytesIO | |
| from typing_extensions import override | |
| from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords | |
| from comfy_api.latest import ComfyExtension, io | |
| class EmptyLTXVLatentVideo(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="EmptyLTXVLatentVideo", | |
| category="latent/video/ltxv", | |
| inputs=[ | |
| io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), | |
| io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), | |
| io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=8), | |
| io.Int.Input("batch_size", default=1, min=1, max=4096), | |
| ], | |
| outputs=[ | |
| io.Latent.Output(), | |
| ], | |
| ) | |
| def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: | |
| latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) | |
| return io.NodeOutput({"samples": latent}) | |
| generate = execute # TODO: remove | |
| class LTXVImgToVideo(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVImgToVideo", | |
| category="conditioning/video_models", | |
| inputs=[ | |
| io.Conditioning.Input("positive"), | |
| io.Conditioning.Input("negative"), | |
| io.Vae.Input("vae"), | |
| io.Image.Input("image"), | |
| io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), | |
| io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), | |
| io.Int.Input("length", default=97, min=9, max=nodes.MAX_RESOLUTION, step=8), | |
| io.Int.Input("batch_size", default=1, min=1, max=4096), | |
| io.Float.Input("strength", default=1.0, min=0.0, max=1.0), | |
| ], | |
| outputs=[ | |
| io.Conditioning.Output(display_name="positive"), | |
| io.Conditioning.Output(display_name="negative"), | |
| io.Latent.Output(display_name="latent"), | |
| ], | |
| ) | |
| def execute(cls, positive, negative, image, vae, width, height, length, batch_size, strength) -> io.NodeOutput: | |
| pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) | |
| encode_pixels = pixels[:, :, :, :3] | |
| t = vae.encode(encode_pixels) | |
| latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) | |
| latent[:, :, :t.shape[2]] = t | |
| conditioning_latent_frames_mask = torch.ones( | |
| (batch_size, 1, latent.shape[2], 1, 1), | |
| dtype=torch.float32, | |
| device=latent.device, | |
| ) | |
| conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength | |
| return io.NodeOutput(positive, negative, {"samples": latent, "noise_mask": conditioning_latent_frames_mask}) | |
| generate = execute # TODO: remove | |
| def conditioning_get_any_value(conditioning, key, default=None): | |
| for t in conditioning: | |
| if key in t[1]: | |
| return t[1][key] | |
| return default | |
| def get_noise_mask(latent): | |
| noise_mask = latent.get("noise_mask", None) | |
| latent_image = latent["samples"] | |
| if noise_mask is None: | |
| batch_size, _, latent_length, _, _ = latent_image.shape | |
| noise_mask = torch.ones( | |
| (batch_size, 1, latent_length, 1, 1), | |
| dtype=torch.float32, | |
| device=latent_image.device, | |
| ) | |
| else: | |
| noise_mask = noise_mask.clone() | |
| return noise_mask | |
| def get_keyframe_idxs(cond): | |
| keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) | |
| if keyframe_idxs is None: | |
| return None, 0 | |
| num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] | |
| return keyframe_idxs, num_keyframes | |
| class LTXVAddGuide(io.ComfyNode): | |
| NUM_PREFIX_FRAMES = 2 | |
| PATCHIFIER = SymmetricPatchifier(1) | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVAddGuide", | |
| category="conditioning/video_models", | |
| inputs=[ | |
| io.Conditioning.Input("positive"), | |
| io.Conditioning.Input("negative"), | |
| io.Vae.Input("vae"), | |
| io.Latent.Input("latent"), | |
| io.Image.Input( | |
| "image", | |
| tooltip="Image or video to condition the latent video on. Must be 8*n + 1 frames. " | |
| "If the video is not 8*n + 1 frames, it will be cropped to the nearest 8*n + 1 frames.", | |
| ), | |
| io.Int.Input( | |
| "frame_idx", | |
| default=0, | |
| min=-9999, | |
| max=9999, | |
| tooltip="Frame index to start the conditioning at. " | |
| "For single-frame images or videos with 1-8 frames, any frame_idx value is acceptable. " | |
| "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " | |
| "down to the nearest multiple of 8. Negative values are counted from the end of the video.", | |
| ), | |
| io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), | |
| ], | |
| outputs=[ | |
| io.Conditioning.Output(display_name="positive"), | |
| io.Conditioning.Output(display_name="negative"), | |
| io.Latent.Output(display_name="latent"), | |
| ], | |
| ) | |
| def encode(cls, vae, latent_width, latent_height, images, scale_factors): | |
| time_scale_factor, width_scale_factor, height_scale_factor = scale_factors | |
| images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] | |
| pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) | |
| encode_pixels = pixels[:, :, :, :3] | |
| t = vae.encode(encode_pixels) | |
| return encode_pixels, t | |
| def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): | |
| time_scale_factor, _, _ = scale_factors | |
| _, num_keyframes = get_keyframe_idxs(cond) | |
| latent_count = latent_length - num_keyframes | |
| frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) | |
| if guide_length > 1 and frame_idx != 0: | |
| frame_idx = (frame_idx - 1) // time_scale_factor * time_scale_factor + 1 # frame index - 1 must be divisible by 8 or frame_idx == 0 | |
| latent_idx = (frame_idx + time_scale_factor - 1) // time_scale_factor | |
| return frame_idx, latent_idx | |
| def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): | |
| keyframe_idxs, _ = get_keyframe_idxs(cond) | |
| _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) | |
| pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 | |
| pixel_coords[:, 0] += frame_idx | |
| if keyframe_idxs is None: | |
| keyframe_idxs = pixel_coords | |
| else: | |
| keyframe_idxs = torch.cat([keyframe_idxs, pixel_coords], dim=2) | |
| return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) | |
| def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): | |
| _, latent_idx = cls.get_latent_index( | |
| cond=positive, | |
| latent_length=latent_image.shape[2], | |
| guide_length=guiding_latent.shape[2], | |
| frame_idx=frame_idx, | |
| scale_factors=scale_factors, | |
| ) | |
| noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 | |
| positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) | |
| negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) | |
| mask = torch.full( | |
| (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), | |
| 1.0 - strength, | |
| dtype=noise_mask.dtype, | |
| device=noise_mask.device, | |
| ) | |
| latent_image = torch.cat([latent_image, guiding_latent], dim=2) | |
| noise_mask = torch.cat([noise_mask, mask], dim=2) | |
| return positive, negative, latent_image, noise_mask | |
| def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_idx, strength): | |
| cond_length = guiding_latent.shape[2] | |
| assert latent_image.shape[2] >= latent_idx + cond_length, "Conditioning frames exceed the length of the latent sequence." | |
| mask = torch.full( | |
| (noise_mask.shape[0], 1, cond_length, 1, 1), | |
| 1.0 - strength, | |
| dtype=noise_mask.dtype, | |
| device=noise_mask.device, | |
| ) | |
| latent_image = latent_image.clone() | |
| noise_mask = noise_mask.clone() | |
| latent_image[:, :, latent_idx : latent_idx + cond_length] = guiding_latent | |
| noise_mask[:, :, latent_idx : latent_idx + cond_length] = mask | |
| return latent_image, noise_mask | |
| def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput: | |
| scale_factors = vae.downscale_index_formula | |
| latent_image = latent["samples"] | |
| noise_mask = get_noise_mask(latent) | |
| _, _, latent_length, latent_height, latent_width = latent_image.shape | |
| image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors) | |
| frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) | |
| assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." | |
| num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) | |
| positive, negative, latent_image, noise_mask = cls.append_keyframe( | |
| positive, | |
| negative, | |
| frame_idx, | |
| latent_image, | |
| noise_mask, | |
| t[:, :, :num_prefix_frames], | |
| strength, | |
| scale_factors, | |
| ) | |
| latent_idx += num_prefix_frames | |
| t = t[:, :, num_prefix_frames:] | |
| if t.shape[2] == 0: | |
| return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) | |
| latent_image, noise_mask = cls.replace_latent_frames( | |
| latent_image, | |
| noise_mask, | |
| t, | |
| latent_idx, | |
| strength, | |
| ) | |
| return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) | |
| generate = execute # TODO: remove | |
| class LTXVCropGuides(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVCropGuides", | |
| category="conditioning/video_models", | |
| inputs=[ | |
| io.Conditioning.Input("positive"), | |
| io.Conditioning.Input("negative"), | |
| io.Latent.Input("latent"), | |
| ], | |
| outputs=[ | |
| io.Conditioning.Output(display_name="positive"), | |
| io.Conditioning.Output(display_name="negative"), | |
| io.Latent.Output(display_name="latent"), | |
| ], | |
| ) | |
| def execute(cls, positive, negative, latent) -> io.NodeOutput: | |
| latent_image = latent["samples"].clone() | |
| noise_mask = get_noise_mask(latent) | |
| _, num_keyframes = get_keyframe_idxs(positive) | |
| if num_keyframes == 0: | |
| return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) | |
| latent_image = latent_image[:, :, :-num_keyframes] | |
| noise_mask = noise_mask[:, :, :-num_keyframes] | |
| positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) | |
| negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) | |
| return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) | |
| crop = execute # TODO: remove | |
| class LTXVConditioning(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVConditioning", | |
| category="conditioning/video_models", | |
| inputs=[ | |
| io.Conditioning.Input("positive"), | |
| io.Conditioning.Input("negative"), | |
| io.Float.Input("frame_rate", default=25.0, min=0.0, max=1000.0, step=0.01), | |
| ], | |
| outputs=[ | |
| io.Conditioning.Output(display_name="positive"), | |
| io.Conditioning.Output(display_name="negative"), | |
| ], | |
| ) | |
| def execute(cls, positive, negative, frame_rate) -> io.NodeOutput: | |
| positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate}) | |
| negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate}) | |
| return io.NodeOutput(positive, negative) | |
| class ModelSamplingLTXV(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="ModelSamplingLTXV", | |
| category="advanced/model", | |
| inputs=[ | |
| io.Model.Input("model"), | |
| io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), | |
| io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), | |
| io.Latent.Input("latent", optional=True), | |
| ], | |
| outputs=[ | |
| io.Model.Output(), | |
| ], | |
| ) | |
| def execute(cls, model, max_shift, base_shift, latent=None) -> io.NodeOutput: | |
| m = model.clone() | |
| if latent is None: | |
| tokens = 4096 | |
| else: | |
| tokens = math.prod(latent["samples"].shape[2:]) | |
| x1 = 1024 | |
| x2 = 4096 | |
| mm = (max_shift - base_shift) / (x2 - x1) | |
| b = base_shift - mm * x1 | |
| shift = (tokens) * mm + b | |
| sampling_base = comfy.model_sampling.ModelSamplingFlux | |
| sampling_type = comfy.model_sampling.CONST | |
| class ModelSamplingAdvanced(sampling_base, sampling_type): | |
| pass | |
| model_sampling = ModelSamplingAdvanced(model.model.model_config) | |
| model_sampling.set_parameters(shift=shift) | |
| m.add_object_patch("model_sampling", model_sampling) | |
| return io.NodeOutput(m) | |
| class LTXVScheduler(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVScheduler", | |
| category="sampling/custom_sampling/schedulers", | |
| inputs=[ | |
| io.Int.Input("steps", default=20, min=1, max=10000), | |
| io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), | |
| io.Float.Input("base_shift", default=0.95, min=0.0, max=100.0, step=0.01), | |
| io.Boolean.Input( | |
| id="stretch", | |
| default=True, | |
| tooltip="Stretch the sigmas to be in the range [terminal, 1].", | |
| ), | |
| io.Float.Input( | |
| id="terminal", | |
| default=0.1, | |
| min=0.0, | |
| max=0.99, | |
| step=0.01, | |
| tooltip="The terminal value of the sigmas after stretching.", | |
| ), | |
| io.Latent.Input("latent", optional=True), | |
| ], | |
| outputs=[ | |
| io.Sigmas.Output(), | |
| ], | |
| ) | |
| def execute(cls, steps, max_shift, base_shift, stretch, terminal, latent=None) -> io.NodeOutput: | |
| if latent is None: | |
| tokens = 4096 | |
| else: | |
| tokens = math.prod(latent["samples"].shape[2:]) | |
| sigmas = torch.linspace(1.0, 0.0, steps + 1) | |
| x1 = 1024 | |
| x2 = 4096 | |
| mm = (max_shift - base_shift) / (x2 - x1) | |
| b = base_shift - mm * x1 | |
| sigma_shift = (tokens) * mm + b | |
| power = 1 | |
| sigmas = torch.where( | |
| sigmas != 0, | |
| math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power), | |
| 0, | |
| ) | |
| # Stretch sigmas so that its final value matches the given terminal value. | |
| if stretch: | |
| non_zero_mask = sigmas != 0 | |
| non_zero_sigmas = sigmas[non_zero_mask] | |
| one_minus_z = 1.0 - non_zero_sigmas | |
| scale_factor = one_minus_z[-1] / (1.0 - terminal) | |
| stretched = 1.0 - (one_minus_z / scale_factor) | |
| sigmas[non_zero_mask] = stretched | |
| return io.NodeOutput(sigmas) | |
| def encode_single_frame(output_file, image_array: np.ndarray, crf): | |
| container = av.open(output_file, "w", format="mp4") | |
| try: | |
| stream = container.add_stream( | |
| "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"} | |
| ) | |
| stream.height = image_array.shape[0] | |
| stream.width = image_array.shape[1] | |
| av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat( | |
| format="yuv420p" | |
| ) | |
| container.mux(stream.encode(av_frame)) | |
| container.mux(stream.encode()) | |
| finally: | |
| container.close() | |
| def decode_single_frame(video_file): | |
| container = av.open(video_file) | |
| try: | |
| stream = next(s for s in container.streams if s.type == "video") | |
| frame = next(container.decode(stream)) | |
| finally: | |
| container.close() | |
| return frame.to_ndarray(format="rgb24") | |
| def preprocess(image: torch.Tensor, crf=29): | |
| if crf == 0: | |
| return image | |
| image_array = (image[:(image.shape[0] // 2) * 2, :(image.shape[1] // 2) * 2] * 255.0).byte().cpu().numpy() | |
| with BytesIO() as output_file: | |
| encode_single_frame(output_file, image_array, crf) | |
| video_bytes = output_file.getvalue() | |
| with BytesIO(video_bytes) as video_file: | |
| image_array = decode_single_frame(video_file) | |
| tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0 | |
| return tensor | |
| class LTXVPreprocess(io.ComfyNode): | |
| def define_schema(cls): | |
| return io.Schema( | |
| node_id="LTXVPreprocess", | |
| category="image", | |
| inputs=[ | |
| io.Image.Input("image"), | |
| io.Int.Input( | |
| id="img_compression", default=35, min=0, max=100, tooltip="Amount of compression to apply on image." | |
| ), | |
| ], | |
| outputs=[ | |
| io.Image.Output(display_name="output_image"), | |
| ], | |
| ) | |
| def execute(cls, image, img_compression) -> io.NodeOutput: | |
| output_images = [] | |
| for i in range(image.shape[0]): | |
| output_images.append(preprocess(image[i], img_compression)) | |
| return io.NodeOutput(torch.stack(output_images)) | |
| preprocess = execute # TODO: remove | |
| class LtxvExtension(ComfyExtension): | |
| async def get_node_list(self) -> list[type[io.ComfyNode]]: | |
| return [ | |
| EmptyLTXVLatentVideo, | |
| LTXVImgToVideo, | |
| ModelSamplingLTXV, | |
| LTXVConditioning, | |
| LTXVScheduler, | |
| LTXVAddGuide, | |
| LTXVPreprocess, | |
| LTXVCropGuides, | |
| ] | |
| async def comfy_entrypoint() -> LtxvExtension: | |
| return LtxvExtension() | |