Spaces:
Build error
Build error
| # WIP: Optimizations are coming! | |
| from typing import List, Optional, Tuple, Union | |
| import cv2 | |
| import numpy as np | |
| import safetensors.torch | |
| import torch | |
| import torchvision.transforms.v2 as transforms | |
| from diffusers import FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline | |
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback | |
| from diffusers.loaders import HunyuanVideoLoraLoaderMixin | |
| from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel | |
| from diffusers.models.attention import Attention | |
| from diffusers.models.embeddings import apply_rotary_emb | |
| from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoPatchEmbed, HunyuanVideoTransformer3DModel | |
| from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE, retrieve_timesteps | |
| from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import export_to_video, is_torch_xla_available, load_image, logging, replace_example_docstring | |
| from diffusers.utils.state_dict_utils import convert_state_dict_to_diffusers, convert_unet_state_dict_to_peft | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.video_processor import VideoProcessor | |
| from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict | |
| from PIL import Image | |
| from typing import Dict, List | |
| from typing import Any | |
| from typing import Callable | |
| import argparse | |
| import os | |
| import time | |
| import random | |
| import sys | |
| # 20250305 pftq load settings for customization #### | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--base_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo") | |
| parser.add_argument("--transformer_model_id", type=str, default="hunyuanvideo-community/HunyuanVideo") | |
| parser.add_argument("--lora_path", type=str, default="i2v.sft") | |
| parser.add_argument("--use_sage", action="store_true") | |
| parser.add_argument("--use_flash", action="store_true") | |
| parser.add_argument("--cfg", type=float, default=6.0) | |
| parser.add_argument("--num_frames", type=int, default=77) | |
| parser.add_argument("--steps", type=int, default=50) | |
| parser.add_argument("--seed", type=int, default=-1) | |
| parser.add_argument("--prompt", type=str, default="a woman") | |
| parser.add_argument("--height", type=int, default=1280) | |
| parser.add_argument("--width", type=int, default=720) | |
| parser.add_argument("--video_num", type=int, default=1) | |
| parser.add_argument("--image1", type=str, default="https://content.dashtoon.ai/stability-images/e524013d-55d4-483a-b80a-dfc51d639158.png") | |
| parser.add_argument("--image2", type=str, default="https://content.dashtoon.ai/stability-images/0b29c296-0a90-4b92-96b9-1ed0ae21e480.png") | |
| parser.add_argument("--image3", type=str, default="") | |
| parser.add_argument("--image4", type=str, default="") | |
| parser.add_argument("--image5", type=str, default="") | |
| parser.add_argument("--fps", type=int, default=24) | |
| parser.add_argument("--mbps", type=float, default=7) | |
| parser.add_argument("--color_match", action="store_true") | |
| args = parser.parse_args() | |
| # 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py | |
| use_sage = False | |
| use_flash = False | |
| if args.use_sage: | |
| try: | |
| from sageattention import sageattn, sageattn_varlen | |
| use_sage = True | |
| except ImportError: | |
| sageattn, sageattn_varlen = None, None | |
| elif args.use_flash: | |
| try: | |
| import flash_attn | |
| from flash_attn.flash_attn_interface import _flash_attn_forward, flash_attn_varlen_func | |
| use_flash = True | |
| except ImportError: | |
| flash_attn, _flash_attn_forward, flash_attn_varlen_func = None, None, None | |
| print("Using SageAtten: "+str(use_sage)) | |
| print("Using FlashAttn: "+str(use_flash)) | |
| video_transforms = transforms.Compose( | |
| [ | |
| transforms.Lambda(lambda x: x / 255.0), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
| ] | |
| ) | |
| def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: | |
| """ | |
| Resize the image to the bucket resolution. | |
| """ | |
| is_pil_image = isinstance(image, Image.Image) | |
| if is_pil_image: | |
| image_width, image_height = image.size | |
| else: | |
| image_height, image_width = image.shape[:2] | |
| if bucket_reso == (image_width, image_height): | |
| return np.array(image) if is_pil_image else image | |
| bucket_width, bucket_height = bucket_reso | |
| scale_width = bucket_width / image_width | |
| scale_height = bucket_height / image_height | |
| scale = max(scale_width, scale_height) | |
| image_width = int(image_width * scale + 0.5) | |
| image_height = int(image_height * scale + 0.5) | |
| if scale > 1: | |
| image = Image.fromarray(image) if not is_pil_image else image | |
| image = image.resize((image_width, image_height), Image.LANCZOS) | |
| image = np.array(image) | |
| else: | |
| image = np.array(image) if is_pil_image else image | |
| image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) | |
| # crop the image to the bucket resolution | |
| crop_left = (image_width - bucket_width) // 2 | |
| crop_top = (image_height - bucket_height) // 2 | |
| image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] | |
| return image | |
| # 20250305 pftq: from main repo at https://github.com/dashtoon/hunyuan-video-keyframe-control-lora/blob/main/hv_control_lora_inference.py | |
| def get_cu_seqlens(attention_mask): | |
| """Calculate cu_seqlens_q, cu_seqlens_kv using attention_mask""" | |
| batch_size = attention_mask.shape[0] | |
| text_len = attention_mask.sum(dim=-1, dtype=torch.int) | |
| max_len = attention_mask.shape[-1] | |
| cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") | |
| for i in range(batch_size): | |
| s = text_len[i] | |
| s1 = i * max_len + s | |
| s2 = (i + 1) * max_len | |
| cu_seqlens[2 * i + 1] = s1 | |
| cu_seqlens[2 * i + 2] = s2 | |
| return cu_seqlens | |
| class HunyuanVideoFlashAttnProcessor: | |
| def __init__(self, use_flash_attn=True, use_sageattn=False): | |
| self.use_flash_attn = use_flash_attn | |
| self.use_sageattn = use_sageattn | |
| if self.use_flash_attn: | |
| assert flash_attn is not None, "Flash attention not available" | |
| if self.use_sageattn: | |
| assert sageattn is not None, "Sage attention not available" | |
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): | |
| if attn.add_q_proj is None and encoder_hidden_states is not None: | |
| hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| if image_rotary_emb is not None: | |
| if attn.add_q_proj is None and encoder_hidden_states is not None: | |
| query = torch.cat( | |
| [ | |
| apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), | |
| query[:, :, -encoder_hidden_states.shape[1] :], | |
| ], | |
| dim=2, | |
| ) | |
| key = torch.cat( | |
| [ | |
| apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), | |
| key[:, :, -encoder_hidden_states.shape[1] :], | |
| ], | |
| dim=2, | |
| ) | |
| else: | |
| query = apply_rotary_emb(query, image_rotary_emb) | |
| key = apply_rotary_emb(key, image_rotary_emb) | |
| batch_size = hidden_states.shape[0] | |
| img_seq_len = hidden_states.shape[1] | |
| txt_seq_len = 0 | |
| if attn.add_q_proj is not None and encoder_hidden_states is not None: | |
| encoder_query = attn.add_q_proj(encoder_hidden_states) | |
| encoder_key = attn.add_k_proj(encoder_hidden_states) | |
| encoder_value = attn.add_v_proj(encoder_hidden_states) | |
| encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) | |
| if attn.norm_added_q is not None: | |
| encoder_query = attn.norm_added_q(encoder_query) | |
| if attn.norm_added_k is not None: | |
| encoder_key = attn.norm_added_k(encoder_key) | |
| query = torch.cat([query, encoder_query], dim=2) | |
| key = torch.cat([key, encoder_key], dim=2) | |
| value = torch.cat([value, encoder_value], dim=2) | |
| txt_seq_len = encoder_hidden_states.shape[1] | |
| max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len | |
| cu_seqlens_q = cu_seqlens_kv = get_cu_seqlens(attention_mask) | |
| query = query.transpose(1, 2).reshape(-1, query.shape[1], query.shape[3]) | |
| key = key.transpose(1, 2).reshape(-1, key.shape[1], key.shape[3]) | |
| value = value.transpose(1, 2).reshape(-1, value.shape[1], value.shape[3]) | |
| if self.use_flash_attn: | |
| hidden_states = flash_attn_varlen_func( | |
| query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv | |
| ) | |
| elif self.use_sageattn: | |
| hidden_states = sageattn_varlen(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) | |
| else: | |
| raise NotImplementedError("Please set use_flash_attn=True or use_sageattn=True") | |
| hidden_states = hidden_states.reshape(batch_size, max_seqlen_q, -1) | |
| hidden_states = hidden_states.to(query.dtype) | |
| if encoder_hidden_states is not None: | |
| hidden_states, encoder_hidden_states = ( | |
| hidden_states[:, : -encoder_hidden_states.shape[1]], | |
| hidden_states[:, -encoder_hidden_states.shape[1] :], | |
| ) | |
| if getattr(attn, "to_out", None) is not None: | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if getattr(attn, "to_add_out", None) is not None: | |
| encoder_hidden_states = attn.to_add_out(encoder_hidden_states) | |
| return hidden_states, encoder_hidden_states | |
| def call_pipe( | |
| pipe, | |
| prompt: Union[str, List[str]] = None, | |
| prompt_2: Union[str, List[str]] = None, | |
| height: int = 720, | |
| width: int = 1280, | |
| num_frames: int = 129, | |
| num_inference_steps: int = 50, | |
| sigmas: List[float] = None, | |
| guidance_scale: float = 6.0, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| pooled_prompt_embeds: Optional[torch.Tensor] = None, | |
| prompt_attention_mask: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, | |
| max_sequence_length: int = 256, | |
| image_latents: Optional[torch.Tensor] = None, | |
| ): | |
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
| # 1. Check inputs. Raise error if not correct | |
| pipe.check_inputs( | |
| prompt, | |
| prompt_2, | |
| height, | |
| width, | |
| prompt_embeds, | |
| callback_on_step_end_tensor_inputs, | |
| prompt_template, | |
| ) | |
| pipe._guidance_scale = guidance_scale | |
| pipe._attention_kwargs = attention_kwargs | |
| pipe._current_timestep = None | |
| pipe._interrupt = False | |
| device = pipe._execution_device | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # 3. Encode input prompt | |
| prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt_2, | |
| prompt_template=prompt_template, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| prompt_attention_mask=prompt_attention_mask, | |
| device=device, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| transformer_dtype = pipe.transformer.dtype | |
| prompt_embeds = prompt_embeds.to(transformer_dtype) | |
| prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) | |
| if pooled_prompt_embeds is not None: | |
| pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) | |
| # 4. Prepare timesteps | |
| sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| pipe.scheduler, | |
| num_inference_steps, | |
| device, | |
| sigmas=sigmas, | |
| ) | |
| # 5. Prepare latent variables | |
| num_channels_latents = pipe.transformer.config.in_channels | |
| num_latent_frames = (num_frames - 1) // pipe.vae_scale_factor_temporal + 1 | |
| latents = pipe.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| num_latent_frames, | |
| torch.float32, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Prepare guidance condition | |
| guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 | |
| # 7. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order | |
| pipe._num_timesteps = len(timesteps) | |
| # 20250305 pftq: added to properly offload to CPU, was out of memory otherwise | |
| pipe.text_encoder.to("cpu") | |
| pipe.text_encoder_2.to("cpu") | |
| torch.cuda.empty_cache() | |
| with pipe.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if pipe.interrupt: | |
| continue | |
| pipe._current_timestep = t | |
| latent_model_input = latents.to(transformer_dtype) | |
| timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
| noise_pred = pipe.transformer( | |
| hidden_states=torch.cat([latent_model_input, image_latents], dim=1), | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| encoder_attention_mask=prompt_attention_mask, | |
| pooled_projections=pooled_prompt_embeds, | |
| guidance=guidance, | |
| attention_kwargs=attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(pipe, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): | |
| progress_bar.update() | |
| pipe._current_timestep = None | |
| if not output_type == "latent": | |
| latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor | |
| video = pipe.vae.decode(latents, return_dict=False)[0] | |
| video = pipe.video_processor.postprocess_video(video, output_type=output_type) | |
| else: | |
| video = latents | |
| # Offload all models | |
| pipe.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (video,) | |
| return HunyuanVideoPipelineOutput(frames=video) | |
| #20250305 pftq: customizable bitrate | |
| # Function to check if FFmpeg is installed | |
| import subprocess # For FFmpeg functionality | |
| def is_ffmpeg_installed(): | |
| try: | |
| subprocess.run(["ffmpeg", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| return False | |
| # FFmpeg-based video saving with bitrate control | |
| def save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment=None): | |
| frames = [np.array(frame) for frame in frames] | |
| height, width, _ = frames[0].shape | |
| bitrate = f"{bitrate_mbps}M" | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-f", "rawvideo", | |
| "-vcodec", "rawvideo", | |
| "-s", f"{width}x{height}", | |
| "-pix_fmt", "rgb24", | |
| "-r", str(fps), | |
| "-i", "-", | |
| "-c:v", "libx264", | |
| "-b:v", bitrate, | |
| "-pix_fmt", "yuv420p", | |
| "-preset", "medium", | |
| ] | |
| # Add metadata comment if provided | |
| if metadata_comment: | |
| cmd.extend(["-metadata", f"comment={metadata_comment}"]) | |
| cmd.append(output_path) | |
| process = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) | |
| for frame in frames: | |
| process.stdin.write(frame.tobytes()) | |
| process.stdin.close() | |
| process.wait() | |
| stderr_output = process.stderr.read().decode() | |
| if process.returncode != 0: | |
| print(f"FFmpeg error: {stderr_output}") | |
| else: | |
| print(f"Video saved to {output_path} with FFmpeg") | |
| # Fallback OpenCV-based video saving | |
| def save_video_with_opencv(frames, output_path, fps, bitrate_mbps): | |
| frames = [np.array(frame) for frame in frames] | |
| height, width, _ = frames[0].shape | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| # Note: cv2.CAP_PROP_BITRATE is not supported, so bitrate_mbps is ignored | |
| for frame in frames: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV | |
| writer.write(frame) | |
| writer.release() | |
| print(f"Video saved to {output_path} with OpenCV (bitrate control unavailable)") | |
| # Wrapper to choose between FFmpeg and OpenCV | |
| def save_video_with_quality(frames, output_path, fps, bitrate_mbps, metadata_comment=None): | |
| if is_ffmpeg_installed(): | |
| save_video_with_ffmpeg(frames, output_path, fps, bitrate_mbps, metadata_comment) | |
| else: | |
| print("FFmpeg not found. Falling back to OpenCV (bitrate not customizable).") | |
| save_video_with_opencv(frames, output_path, fps, bitrate_mbps) | |
| # Reconstruct command-line with quotes and backslash+linebreak after argument-value pairs | |
| def reconstruct_command_line(args, argv): | |
| cmd_parts = [argv[0]] # Start with script name | |
| args_dict = vars(args) # Convert args to dict | |
| i = 1 | |
| while i < len(argv): | |
| arg = argv[i] | |
| if arg.startswith("--"): | |
| key = arg[2:] | |
| if key in args_dict: | |
| value = args_dict[key] | |
| if isinstance(value, bool): | |
| if value: | |
| cmd_parts.append(arg) # Boolean flag | |
| i += 1 | |
| else: | |
| # Combine argument and value into one part | |
| if i + 1 < len(argv) and not argv[i + 1].startswith("--"): | |
| next_val = argv[i + 1] | |
| if isinstance(value, str): | |
| cmd_parts.append(f'{arg} "{value}"') # Quote strings | |
| else: | |
| cmd_parts.append(f"{arg} {value}") # No quotes for numbers | |
| i += 2 | |
| else: | |
| # Handle missing value in argv (use parsed args) | |
| if isinstance(value, str): | |
| cmd_parts.append(f'{arg} "{value}"') | |
| else: | |
| cmd_parts.append(f"{arg} {value}") | |
| i += 1 | |
| else: | |
| i += 1 | |
| # Build multi-line string with backslash and newline except for the last part | |
| if len(cmd_parts) > 1: | |
| result = "" | |
| for j, part in enumerate(cmd_parts): | |
| if j < len(cmd_parts) - 1: | |
| result += part + " \\\n" | |
| else: | |
| result += part # No trailing backslash on last part | |
| return result | |
| return cmd_parts[0] # Single arg case | |
| # start executing here ################### | |
| print("Initializing model...") | |
| transformer_subfolder = "transformer" | |
| if args.transformer_model_id == "Skywork/SkyReels-V1-Hunyuan-I2V": | |
| transformer_subfolder = "" # 20250305 pftq: Error otherwise - Skywork/SkyReels-V1-Hunyuan-I2V does not appear to have a file named config.json. | |
| transformer = HunyuanVideoTransformer3DModel.from_pretrained(args.transformer_model_id, subfolder=transformer_subfolder, torch_dtype=torch.bfloat16) | |
| pipe = HunyuanVideoPipeline.from_pretrained(args.base_model_id, transformer=transformer, torch_dtype=torch.bfloat16) | |
| # Enable memory savings | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| pipe.enable_model_cpu_offload() | |
| # Apply flash attention to all transformer blocks | |
| if use_sage or use_flash: | |
| for block in pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks: | |
| block.attn.processor = HunyuanVideoFlashAttnProcessor(use_flash_attn=use_flash, use_sageattn=use_sage) | |
| with torch.no_grad(): # enable image inputs | |
| initial_input_channels = pipe.transformer.config.in_channels | |
| new_img_in = HunyuanVideoPatchEmbed( | |
| patch_size=(pipe.transformer.config.patch_size_t, pipe.transformer.config.patch_size, pipe.transformer.config.patch_size), | |
| in_chans=pipe.transformer.config.in_channels * 2, | |
| embed_dim=pipe.transformer.config.num_attention_heads * pipe.transformer.config.attention_head_dim, | |
| ) | |
| new_img_in = new_img_in.to(pipe.device, dtype=pipe.dtype) | |
| new_img_in.proj.weight.zero_() | |
| new_img_in.proj.weight[:, :initial_input_channels].copy_(pipe.transformer.x_embedder.proj.weight) | |
| if pipe.transformer.x_embedder.proj.bias is not None: | |
| new_img_in.proj.bias.copy_(pipe.transformer.x_embedder.proj.bias) | |
| pipe.transformer.x_embedder = new_img_in | |
| print("Loading lora...") | |
| lora_state_dict = pipe.lora_state_dict(args.lora_path) | |
| transformer_lora_state_dict = {f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") and "lora" in k} | |
| pipe.load_lora_into_transformer(transformer_lora_state_dict, transformer=pipe.transformer, adapter_name="i2v", _pipeline=pipe) | |
| pipe.set_adapters(["i2v"], adapter_weights=[1.0]) | |
| pipe.fuse_lora(components=["transformer"], lora_scale=1.0, adapter_names=["i2v"]) | |
| pipe.unload_lora_weights() | |
| print("Loading images...") | |
| cond_frame1 = load_image(args.image1) | |
| cond_frame2 = load_image(args.image2) | |
| cond_frame1 = resize_image_to_bucket(cond_frame1, bucket_reso=(args.width, args.height)) | |
| cond_frame2 = resize_image_to_bucket(cond_frame2, bucket_reso=(args.width, args.height)) | |
| cond_video = np.zeros(shape=(args.num_frames, args.height, args.width, 3)) | |
| # 20250305 pftq: Optional 3rd-5th frame, sadly doesn't work so easily, needs more code | |
| cond_frame3 = None | |
| cond_frame4 = None | |
| cond_frame5 = None | |
| if args.image3 != "": | |
| cond_frame3 = load_image(args.image3) | |
| cond_frame3 = resize_image_to_bucket(cond_frame3, bucket_reso=(args.width, args.height)) | |
| if args.image4 !="": | |
| cond_frame4 = load_image(args.image4) | |
| cond_frame4 = resize_image_to_bucket(cond_frame4, bucket_reso=(args.width, args.height)) | |
| if args.image5 !="": | |
| cond_frame5 = load_image(args.image5) | |
| cond_frame5 = resize_image_to_bucket(cond_frame5, bucket_reso=(args.width, args.height)) | |
| if args.image5 != "" and args.image4 != "" and args.image3 !="" and args.image2 !="": | |
| cond_video[0] = np.array(cond_frame1) | |
| cond_video[args.num_frames//4] = np.array(cond_frame2) | |
| cond_video[(args.num_frames * 2 )//4] = np.array(cond_frame3) | |
| cond_video[(args.num_frames * 3 )//4] = np.array(cond_frame4) | |
| cond_video[args.num_frames -1] = np.array(cond_frame5) | |
| elif args.image4 != "" and args.image3 !="" and args.image2 !="": | |
| cond_video[0] = np.array(cond_frame1) | |
| cond_video[args.num_frames//3] = np.array(cond_frame2) | |
| cond_video[(args.num_frames * 2 )//3] = np.array(cond_frame3) | |
| cond_video[args.num_frames -1] = np.array(cond_frame4) | |
| elif args.image3 != "" and args.image2 !="": | |
| cond_video[0] = np.array(cond_frame1) | |
| cond_video[args.num_frames//2] = np.array(cond_frame2) | |
| cond_video[args.num_frames -1] = np.array(cond_frame3) | |
| else: | |
| cond_video[0] = np.array(cond_frame1) | |
| cond_video[args.num_frames -1] = np.array(cond_frame2) | |
| cond_video = torch.from_numpy(cond_video.copy()).permute(0, 3, 1, 2) | |
| cond_video = torch.stack([video_transforms(x) for x in cond_video], dim=0).unsqueeze(0) | |
| with torch.no_grad(): | |
| image_or_video = cond_video.to(device="cuda", dtype=pipe.dtype) | |
| image_or_video = image_or_video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] | |
| cond_latents = pipe.vae.encode(image_or_video).latent_dist.sample() | |
| cond_latents = cond_latents * pipe.vae.config.scaling_factor | |
| cond_latents = cond_latents.to(dtype=pipe.dtype) | |
| for idx in range(args.video_num): # 20250305 pftq: for loop for multiple videos per batch with varying seeds | |
| if args.seed == -1 or idx > 0: # 20250305 pftq: seed argument ignored if asking for more than one video | |
| random.seed(time.time()) | |
| args.seed = int(random.randrange(4294967294)) | |
| #20250223 pftq: More useful filename and higher customizable bitrate | |
| from datetime import datetime | |
| now = datetime.now() | |
| formatted_time = now.strftime('%Y-%m-%d_%H-%M-%S') | |
| video_out_file = formatted_time+f"_hunyuankeyframe_{args.width}-{args.num_frames}f_cfg-{args.cfg}_steps-{args.steps}_seed-{args.seed}_{args.prompt[:40].replace('/','')}_{idx}" | |
| command_line = reconstruct_command_line(args, sys.argv) # 20250307: Store the full command-line used in the mp4 comment with quotes | |
| #print(f"Command-line received:\n{command_line}") | |
| print("Starting video generation #"+str(idx)+" for "+video_out_file) | |
| video = call_pipe( | |
| pipe, | |
| prompt=args.prompt, | |
| num_frames=args.num_frames, | |
| num_inference_steps=args.steps, | |
| image_latents=cond_latents, | |
| width=args.width, | |
| height=args.height, | |
| guidance_scale=args.cfg, | |
| generator=torch.Generator(device="cuda").manual_seed(args.seed), | |
| ).frames[0] | |
| # 20250305 pftq: Color match with direct MKL and temporal smoothing | |
| if args.color_match: | |
| #save_video_with_quality(video, f"{video_out_file}_raw.mp4", args.fps, args.mbps) | |
| print("Applying color matching to video...") | |
| from color_matcher import ColorMatcher | |
| from color_matcher.io_handler import load_img_file | |
| from color_matcher.normalizer import Normalizer | |
| # Load the reference image (image1) | |
| ref_img = load_img_file(args.image1) # Original load | |
| cm = ColorMatcher() | |
| matched_video = [] | |
| for frame in video: | |
| frame_rgb = np.array(frame) # Direct PIL to numpy | |
| matched_frame = cm.transfer(src=frame_rgb, ref=ref_img, method='mkl') | |
| matched_frame = Normalizer(matched_frame).uint8_norm() | |
| matched_video.append(matched_frame) | |
| video = matched_video | |
| # END OF COLOR MATCHING | |
| print("Saving "+video_out_file) | |
| #export_to_video(final_video, "output.mp4", fps=24) | |
| save_video_with_quality(video, f"{video_out_file}.mp4", args.fps, args.mbps, command_line) |