Merge pull request #58 from AmericanPresidentJimmyCarter/add-i2v-script
Browse files- i2v_inference.py +679 -0
i2v_inference.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import traceback
|
| 7 |
+
import gc
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
# These imports rely on your existing code structure
|
| 11 |
+
# They must match the location of your WAN code, etc.
|
| 12 |
+
import wan
|
| 13 |
+
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
|
| 14 |
+
from wan.modules.attention import get_attention_modes
|
| 15 |
+
from wan.utils.utils import cache_video
|
| 16 |
+
from mmgp import offload, safetensors2, profile_type
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import triton
|
| 20 |
+
except ImportError:
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
DATA_DIR = "ckpts"
|
| 24 |
+
|
| 25 |
+
# --------------------------------------------------
|
| 26 |
+
# HELPER FUNCTIONS
|
| 27 |
+
# --------------------------------------------------
|
| 28 |
+
|
| 29 |
+
def sanitize_file_name(file_name):
|
| 30 |
+
"""Clean up file name from special chars."""
|
| 31 |
+
return (
|
| 32 |
+
file_name.replace("/", "")
|
| 33 |
+
.replace("\\", "")
|
| 34 |
+
.replace(":", "")
|
| 35 |
+
.replace("|", "")
|
| 36 |
+
.replace("?", "")
|
| 37 |
+
.replace("<", "")
|
| 38 |
+
.replace(">", "")
|
| 39 |
+
.replace('"', "")
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def extract_preset(lset_name, lora_dir, loras):
|
| 43 |
+
"""
|
| 44 |
+
Load a .lset JSON that lists the LoRA files to apply, plus multipliers
|
| 45 |
+
and possibly a suggested prompt prefix.
|
| 46 |
+
"""
|
| 47 |
+
lset_name = sanitize_file_name(lset_name)
|
| 48 |
+
if not lset_name.endswith(".lset"):
|
| 49 |
+
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
|
| 50 |
+
else:
|
| 51 |
+
lset_name_filename = os.path.join(lora_dir, lset_name)
|
| 52 |
+
|
| 53 |
+
if not os.path.isfile(lset_name_filename):
|
| 54 |
+
raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")
|
| 55 |
+
|
| 56 |
+
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
| 57 |
+
text = reader.read()
|
| 58 |
+
lset = json.loads(text)
|
| 59 |
+
|
| 60 |
+
loras_choices_files = lset["loras"]
|
| 61 |
+
loras_choices = []
|
| 62 |
+
missing_loras = []
|
| 63 |
+
for lora_file in loras_choices_files:
|
| 64 |
+
# Build absolute path and see if it is in loras
|
| 65 |
+
full_lora_path = os.path.join(lora_dir, lora_file)
|
| 66 |
+
if full_lora_path in loras:
|
| 67 |
+
idx = loras.index(full_lora_path)
|
| 68 |
+
loras_choices.append(str(idx))
|
| 69 |
+
else:
|
| 70 |
+
missing_loras.append(lora_file)
|
| 71 |
+
|
| 72 |
+
if len(missing_loras) > 0:
|
| 73 |
+
missing_list = ", ".join(missing_loras)
|
| 74 |
+
raise ValueError(f"Missing LoRA files for preset: {missing_list}")
|
| 75 |
+
|
| 76 |
+
loras_mult_choices = lset["loras_mult"]
|
| 77 |
+
prompt_prefix = lset.get("prompt", "")
|
| 78 |
+
full_prompt = lset.get("full_prompt", False)
|
| 79 |
+
return loras_choices, loras_mult_choices, prompt_prefix, full_prompt
|
| 80 |
+
|
| 81 |
+
def get_attention_mode(args_attention, installed_modes):
|
| 82 |
+
"""
|
| 83 |
+
Decide which attention mode to use: either the user choice or auto fallback.
|
| 84 |
+
"""
|
| 85 |
+
if args_attention == "auto":
|
| 86 |
+
for candidate in ["sage2", "sage", "sdpa"]:
|
| 87 |
+
if candidate in installed_modes:
|
| 88 |
+
return candidate
|
| 89 |
+
return "sdpa" # last fallback
|
| 90 |
+
elif args_attention in installed_modes:
|
| 91 |
+
return args_attention
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(
|
| 94 |
+
f"Requested attention mode '{args_attention}' not installed. "
|
| 95 |
+
f"Installed modes: {installed_modes}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def load_i2v_model(model_filename, text_encoder_filename, is_720p):
|
| 99 |
+
"""
|
| 100 |
+
Load the i2v model with a specific size config and text encoder.
|
| 101 |
+
"""
|
| 102 |
+
if is_720p:
|
| 103 |
+
print("Loading 14B-720p i2v model ...")
|
| 104 |
+
cfg = WAN_CONFIGS['i2v-14B']
|
| 105 |
+
wan_model = wan.WanI2V(
|
| 106 |
+
config=cfg,
|
| 107 |
+
checkpoint_dir=DATA_DIR,
|
| 108 |
+
device_id=0,
|
| 109 |
+
rank=0,
|
| 110 |
+
t5_fsdp=False,
|
| 111 |
+
dit_fsdp=False,
|
| 112 |
+
use_usp=False,
|
| 113 |
+
i2v720p=True,
|
| 114 |
+
model_filename=model_filename,
|
| 115 |
+
text_encoder_filename=text_encoder_filename
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
print("Loading 14B-480p i2v model ...")
|
| 119 |
+
cfg = WAN_CONFIGS['i2v-14B']
|
| 120 |
+
wan_model = wan.WanI2V(
|
| 121 |
+
config=cfg,
|
| 122 |
+
checkpoint_dir=DATA_DIR,
|
| 123 |
+
device_id=0,
|
| 124 |
+
rank=0,
|
| 125 |
+
t5_fsdp=False,
|
| 126 |
+
dit_fsdp=False,
|
| 127 |
+
use_usp=False,
|
| 128 |
+
i2v720p=False,
|
| 129 |
+
model_filename=model_filename,
|
| 130 |
+
text_encoder_filename=text_encoder_filename
|
| 131 |
+
)
|
| 132 |
+
# Pipe structure
|
| 133 |
+
pipe = {
|
| 134 |
+
"transformer": wan_model.model,
|
| 135 |
+
"text_encoder": wan_model.text_encoder.model,
|
| 136 |
+
"text_encoder_2": wan_model.clip.model,
|
| 137 |
+
"vae": wan_model.vae.model
|
| 138 |
+
}
|
| 139 |
+
return wan_model, pipe
|
| 140 |
+
|
| 141 |
+
def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
|
| 142 |
+
"""
|
| 143 |
+
Load loras from a directory, optionally apply a preset.
|
| 144 |
+
"""
|
| 145 |
+
from pathlib import Path
|
| 146 |
+
import glob
|
| 147 |
+
|
| 148 |
+
if not lora_dir or not Path(lora_dir).is_dir():
|
| 149 |
+
print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
|
| 150 |
+
return [], [], [], "", "", False
|
| 151 |
+
|
| 152 |
+
# Gather LoRA files
|
| 153 |
+
loras = sorted(
|
| 154 |
+
glob.glob(os.path.join(lora_dir, "*.sft"))
|
| 155 |
+
+ glob.glob(os.path.join(lora_dir, "*.safetensors"))
|
| 156 |
+
)
|
| 157 |
+
loras_names = [Path(x).stem for x in loras]
|
| 158 |
+
|
| 159 |
+
# Offload them with no activation
|
| 160 |
+
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)
|
| 161 |
+
|
| 162 |
+
# If user gave a preset, apply it
|
| 163 |
+
default_loras_choices = []
|
| 164 |
+
default_loras_multis_str = ""
|
| 165 |
+
default_prompt_prefix = ""
|
| 166 |
+
preset_applied_full_prompt = False
|
| 167 |
+
if lora_preset:
|
| 168 |
+
loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
|
| 169 |
+
default_loras_choices = loras_choices
|
| 170 |
+
# If user stored loras_mult as a list or string in JSON, unify that to str
|
| 171 |
+
if isinstance(loras_mult, list):
|
| 172 |
+
# Just store them in a single line
|
| 173 |
+
default_loras_multis_str = " ".join([str(x) for x in loras_mult])
|
| 174 |
+
else:
|
| 175 |
+
default_loras_multis_str = str(loras_mult)
|
| 176 |
+
default_prompt_prefix = prefix
|
| 177 |
+
preset_applied_full_prompt = full_prompt
|
| 178 |
+
|
| 179 |
+
return (
|
| 180 |
+
loras,
|
| 181 |
+
loras_names,
|
| 182 |
+
default_loras_choices,
|
| 183 |
+
default_loras_multis_str,
|
| 184 |
+
default_prompt_prefix,
|
| 185 |
+
preset_applied_full_prompt
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def parse_loras_and_activate(
|
| 189 |
+
transformer,
|
| 190 |
+
loras,
|
| 191 |
+
loras_choices,
|
| 192 |
+
loras_mult_str,
|
| 193 |
+
num_inference_steps
|
| 194 |
+
):
|
| 195 |
+
"""
|
| 196 |
+
Activate the chosen LoRAs with multipliers over the pipeline's transformer.
|
| 197 |
+
Supports stepwise expansions (like "0.5,0.8" for partial steps).
|
| 198 |
+
"""
|
| 199 |
+
if not loras or not loras_choices:
|
| 200 |
+
# no LoRAs selected
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
# Handle multipliers
|
| 204 |
+
def is_float_or_comma_list(x):
|
| 205 |
+
"""
|
| 206 |
+
Example: "0.5", or "0.8,1.0", etc. is valid.
|
| 207 |
+
"""
|
| 208 |
+
if not x:
|
| 209 |
+
return False
|
| 210 |
+
for chunk in x.split(","):
|
| 211 |
+
try:
|
| 212 |
+
float(chunk.strip())
|
| 213 |
+
except ValueError:
|
| 214 |
+
return False
|
| 215 |
+
return True
|
| 216 |
+
|
| 217 |
+
# Convert multiline or spaced lines to a single list
|
| 218 |
+
lines = [
|
| 219 |
+
line.strip()
|
| 220 |
+
for line in loras_mult_str.replace("\r", "\n").split("\n")
|
| 221 |
+
if line.strip() and not line.strip().startswith("#")
|
| 222 |
+
]
|
| 223 |
+
# Now combine them by space
|
| 224 |
+
joined_line = " ".join(lines) # "1.0 2.0,3.0"
|
| 225 |
+
if not joined_line.strip():
|
| 226 |
+
multipliers = []
|
| 227 |
+
else:
|
| 228 |
+
multipliers = joined_line.split(" ")
|
| 229 |
+
|
| 230 |
+
# Expand each item
|
| 231 |
+
final_multipliers = []
|
| 232 |
+
for mult in multipliers:
|
| 233 |
+
mult = mult.strip()
|
| 234 |
+
if not mult:
|
| 235 |
+
continue
|
| 236 |
+
if is_float_or_comma_list(mult):
|
| 237 |
+
# Could be "0.7" or "0.5,0.6"
|
| 238 |
+
if "," in mult:
|
| 239 |
+
# expand over steps
|
| 240 |
+
chunk_vals = [float(x.strip()) for x in mult.split(",")]
|
| 241 |
+
expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
|
| 242 |
+
final_multipliers.append(expanded)
|
| 243 |
+
else:
|
| 244 |
+
final_multipliers.append(float(mult))
|
| 245 |
+
else:
|
| 246 |
+
raise ValueError(f"Invalid LoRA multiplier: '{mult}'")
|
| 247 |
+
|
| 248 |
+
# If fewer multipliers than chosen LoRAs => pad with 1.0
|
| 249 |
+
needed = len(loras_choices) - len(final_multipliers)
|
| 250 |
+
if needed > 0:
|
| 251 |
+
final_multipliers += [1.0]*needed
|
| 252 |
+
|
| 253 |
+
# Actually activate them
|
| 254 |
+
offload.activate_loras(transformer, loras_choices, final_multipliers)
|
| 255 |
+
|
| 256 |
+
def expand_list_over_steps(short_list, num_steps):
|
| 257 |
+
"""
|
| 258 |
+
If user gave (0.5, 0.8) for example, expand them over `num_steps`.
|
| 259 |
+
The expansion is simply linear slice across steps.
|
| 260 |
+
"""
|
| 261 |
+
result = []
|
| 262 |
+
inc = len(short_list) / float(num_steps)
|
| 263 |
+
idxf = 0.0
|
| 264 |
+
for _ in range(num_steps):
|
| 265 |
+
value = short_list[int(idxf)]
|
| 266 |
+
result.append(value)
|
| 267 |
+
idxf += inc
|
| 268 |
+
return result
|
| 269 |
+
|
| 270 |
+
def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
|
| 271 |
+
"""
|
| 272 |
+
Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.
|
| 273 |
+
If not, downloads them from a Hugging Face Hub repo.
|
| 274 |
+
Adjust the 'repo_id' and needed files as appropriate.
|
| 275 |
+
"""
|
| 276 |
+
import os
|
| 277 |
+
from pathlib import Path
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 281 |
+
except ImportError as e:
|
| 282 |
+
raise ImportError(
|
| 283 |
+
"huggingface_hub is required for automatic model download. "
|
| 284 |
+
"Please install it via `pip install huggingface_hub`."
|
| 285 |
+
) from e
|
| 286 |
+
|
| 287 |
+
# Identify just the filename portion for each path
|
| 288 |
+
def basename(path_str):
|
| 289 |
+
return os.path.basename(path_str)
|
| 290 |
+
|
| 291 |
+
repo_id = "DeepBeepMeep/Wan2.1"
|
| 292 |
+
target_root = local_folder
|
| 293 |
+
|
| 294 |
+
# You can customize this list as needed for i2v usage.
|
| 295 |
+
# At minimum you need:
|
| 296 |
+
# 1) The requested i2v transformer file
|
| 297 |
+
# 2) The requested text encoder file
|
| 298 |
+
# 3) VAE file
|
| 299 |
+
# 4) The open-clip xlm-roberta-large weights
|
| 300 |
+
#
|
| 301 |
+
# If your i2v config references additional files, add them here.
|
| 302 |
+
needed_files = [
|
| 303 |
+
"Wan2.1_VAE.pth",
|
| 304 |
+
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
|
| 305 |
+
basename(text_encoder_filename),
|
| 306 |
+
basename(transformer_filename_i2v),
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
# The original script also downloads an entire "xlm-roberta-large" folder
|
| 310 |
+
# via snapshot_download. If you require that for your pipeline,
|
| 311 |
+
# you can add it here, for example:
|
| 312 |
+
subfolder_name = "xlm-roberta-large"
|
| 313 |
+
if not Path(os.path.join(target_root, subfolder_name)).exists():
|
| 314 |
+
snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)
|
| 315 |
+
|
| 316 |
+
for filename in needed_files:
|
| 317 |
+
local_path = os.path.join(target_root, filename)
|
| 318 |
+
if not os.path.isfile(local_path):
|
| 319 |
+
print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
|
| 320 |
+
hf_hub_download(
|
| 321 |
+
repo_id=repo_id,
|
| 322 |
+
filename=filename,
|
| 323 |
+
local_dir=target_root
|
| 324 |
+
)
|
| 325 |
+
else:
|
| 326 |
+
# Already present
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
print("All required i2v files are present.")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# --------------------------------------------------
|
| 333 |
+
# ARGUMENT PARSER
|
| 334 |
+
# --------------------------------------------------
|
| 335 |
+
|
| 336 |
+
def parse_args():
|
| 337 |
+
parser = argparse.ArgumentParser(
|
| 338 |
+
description="Image-to-Video inference using WAN 2.1 i2v"
|
| 339 |
+
)
|
| 340 |
+
# Model + Tools
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--quantize-transformer",
|
| 343 |
+
action="store_true",
|
| 344 |
+
help="Use on-the-fly transformer quantization"
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--compile",
|
| 348 |
+
action="store_true",
|
| 349 |
+
help="Enable PyTorch 2.0 compile for the transformer"
|
| 350 |
+
)
|
| 351 |
+
parser.add_argument(
|
| 352 |
+
"--attention",
|
| 353 |
+
type=str,
|
| 354 |
+
default="auto",
|
| 355 |
+
help="Which attention to use: auto, sdpa, sage, sage2, flash"
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--profile",
|
| 359 |
+
type=int,
|
| 360 |
+
default=4,
|
| 361 |
+
help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--preload",
|
| 365 |
+
type=int,
|
| 366 |
+
default=0,
|
| 367 |
+
help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
|
| 368 |
+
)
|
| 369 |
+
parser.add_argument(
|
| 370 |
+
"--verbose",
|
| 371 |
+
type=int,
|
| 372 |
+
default=1,
|
| 373 |
+
help="Verbosity level [0..5]"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
# i2v Model
|
| 377 |
+
parser.add_argument(
|
| 378 |
+
"--transformer-file",
|
| 379 |
+
type=str,
|
| 380 |
+
default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
|
| 381 |
+
help="Which i2v model to load"
|
| 382 |
+
)
|
| 383 |
+
parser.add_argument(
|
| 384 |
+
"--text-encoder-file",
|
| 385 |
+
type=str,
|
| 386 |
+
default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
|
| 387 |
+
help="Which text encoder to use"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# LoRA
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--lora-dir",
|
| 393 |
+
type=str,
|
| 394 |
+
default="",
|
| 395 |
+
help="Path to a directory containing i2v LoRAs"
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument(
|
| 398 |
+
"--lora-preset",
|
| 399 |
+
type=str,
|
| 400 |
+
default="",
|
| 401 |
+
help="A .lset preset name in the lora_dir to auto-apply"
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
# Generation Options
|
| 405 |
+
parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
|
| 406 |
+
parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
|
| 407 |
+
parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
|
| 408 |
+
parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
|
| 409 |
+
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
|
| 410 |
+
parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
|
| 411 |
+
parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
|
| 412 |
+
parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
|
| 413 |
+
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
| 414 |
+
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
| 415 |
+
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
| 416 |
+
|
| 417 |
+
# LoRA usage
|
| 418 |
+
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
| 419 |
+
parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")
|
| 420 |
+
|
| 421 |
+
# Input
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--input-image",
|
| 424 |
+
type=str,
|
| 425 |
+
default=None,
|
| 426 |
+
required=True,
|
| 427 |
+
help="Path to an input image (or multiple)."
|
| 428 |
+
)
|
| 429 |
+
parser.add_argument(
|
| 430 |
+
"--output-file",
|
| 431 |
+
type=str,
|
| 432 |
+
default="output.mp4",
|
| 433 |
+
help="Where to save the resulting video."
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
return parser.parse_args()
|
| 437 |
+
|
| 438 |
+
# --------------------------------------------------
|
| 439 |
+
# MAIN
|
| 440 |
+
# --------------------------------------------------
|
| 441 |
+
|
| 442 |
+
def main():
|
| 443 |
+
args = parse_args()
|
| 444 |
+
|
| 445 |
+
# Setup environment
|
| 446 |
+
offload.default_verboseLevel = args.verbose
|
| 447 |
+
installed_attn_modes = get_attention_modes()
|
| 448 |
+
|
| 449 |
+
# Decide attention
|
| 450 |
+
chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
|
| 451 |
+
offload.shared_state["_attention"] = chosen_attention
|
| 452 |
+
|
| 453 |
+
# Determine i2v resolution format
|
| 454 |
+
if "720" in args.transformer_file:
|
| 455 |
+
is_720p = True
|
| 456 |
+
else:
|
| 457 |
+
is_720p = False
|
| 458 |
+
|
| 459 |
+
# Make sure we have the needed models locally
|
| 460 |
+
download_models_if_needed(args.transformer_file, args.text_encoder_file)
|
| 461 |
+
|
| 462 |
+
# Load i2v
|
| 463 |
+
wan_model, pipe = load_i2v_model(
|
| 464 |
+
model_filename=args.transformer_file,
|
| 465 |
+
text_encoder_filename=args.text_encoder_file,
|
| 466 |
+
is_720p=is_720p
|
| 467 |
+
)
|
| 468 |
+
wan_model._interrupt = False
|
| 469 |
+
|
| 470 |
+
# Offload / profile
|
| 471 |
+
# e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
|
| 472 |
+
# pass the budgets if you want, etc.
|
| 473 |
+
kwargs = {}
|
| 474 |
+
if args.profile == 2 or args.profile == 4:
|
| 475 |
+
# preload is in MB
|
| 476 |
+
if args.preload == 0:
|
| 477 |
+
budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
|
| 478 |
+
else:
|
| 479 |
+
budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
|
| 480 |
+
kwargs["budgets"] = budgets
|
| 481 |
+
elif args.profile == 3:
|
| 482 |
+
kwargs["budgets"] = {"*": "70%"}
|
| 483 |
+
|
| 484 |
+
compile_choice = "transformer" if args.compile else ""
|
| 485 |
+
# Create the offload object
|
| 486 |
+
offloadobj = offload.profile(
|
| 487 |
+
pipe,
|
| 488 |
+
profile_no=args.profile,
|
| 489 |
+
compile=compile_choice,
|
| 490 |
+
quantizeTransformer=args.quantize_transformer,
|
| 491 |
+
**kwargs
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# If user wants to use LoRAs
|
| 495 |
+
(
|
| 496 |
+
loras,
|
| 497 |
+
loras_names,
|
| 498 |
+
default_loras_choices,
|
| 499 |
+
default_loras_multis_str,
|
| 500 |
+
preset_prompt_prefix,
|
| 501 |
+
preset_full_prompt
|
| 502 |
+
) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)
|
| 503 |
+
|
| 504 |
+
# Combine user prompt with preset prompt if the preset indicates so
|
| 505 |
+
if preset_prompt_prefix:
|
| 506 |
+
if preset_full_prompt:
|
| 507 |
+
# Full override
|
| 508 |
+
user_prompt = preset_prompt_prefix
|
| 509 |
+
else:
|
| 510 |
+
# Just prefix
|
| 511 |
+
user_prompt = preset_prompt_prefix + "\n" + args.prompt
|
| 512 |
+
else:
|
| 513 |
+
user_prompt = args.prompt
|
| 514 |
+
|
| 515 |
+
# Actually parse user LoRA choices if they did not rely purely on the preset
|
| 516 |
+
if args.loras_choices:
|
| 517 |
+
# If user gave e.g. "0,1", we treat that as new additions
|
| 518 |
+
lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
|
| 519 |
+
else:
|
| 520 |
+
# Use the defaults from the preset
|
| 521 |
+
lora_choice_list = default_loras_choices
|
| 522 |
+
|
| 523 |
+
# Activate them
|
| 524 |
+
parse_loras_and_activate(
|
| 525 |
+
pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Negative prompt
|
| 529 |
+
negative_prompt = args.negative_prompt or ""
|
| 530 |
+
|
| 531 |
+
# Sanity check resolution
|
| 532 |
+
if "*" in args.resolution.lower():
|
| 533 |
+
print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
|
| 534 |
+
resolution_str = args.resolution.lower().replace("*", "x")
|
| 535 |
+
else:
|
| 536 |
+
resolution_str = args.resolution
|
| 537 |
+
|
| 538 |
+
try:
|
| 539 |
+
width, height = [int(x) for x in resolution_str.split("x")]
|
| 540 |
+
except:
|
| 541 |
+
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
| 542 |
+
|
| 543 |
+
# Additional checks (from your original code).
|
| 544 |
+
if "480p" in args.transformer_file:
|
| 545 |
+
# Then we cannot exceed certain area for 480p model
|
| 546 |
+
if width * height > 832*480:
|
| 547 |
+
raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
|
| 548 |
+
# etc.
|
| 549 |
+
|
| 550 |
+
# Handle random seed
|
| 551 |
+
if args.seed < 0:
|
| 552 |
+
args.seed = random.randint(0, 999999999)
|
| 553 |
+
print(f"Using seed={args.seed}")
|
| 554 |
+
|
| 555 |
+
# Setup tea cache if needed
|
| 556 |
+
trans = wan_model.model
|
| 557 |
+
trans.enable_teacache = (args.teacache > 0)
|
| 558 |
+
if trans.enable_teacache:
|
| 559 |
+
if "480p" in args.transformer_file:
|
| 560 |
+
# example from your code
|
| 561 |
+
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
| 562 |
+
elif "720p" in args.transformer_file:
|
| 563 |
+
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
| 564 |
+
else:
|
| 565 |
+
raise ValueError("Teacache not supported for this model variant")
|
| 566 |
+
|
| 567 |
+
# Attempt generation
|
| 568 |
+
print("Starting generation ...")
|
| 569 |
+
start_time = time.time()
|
| 570 |
+
|
| 571 |
+
# Read the input image
|
| 572 |
+
if not os.path.isfile(args.input_image):
|
| 573 |
+
raise ValueError(f"Input image does not exist: {args.input_image}")
|
| 574 |
+
|
| 575 |
+
from PIL import Image
|
| 576 |
+
input_img = Image.open(args.input_image).convert("RGB")
|
| 577 |
+
|
| 578 |
+
# Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration
|
| 579 |
+
|
| 580 |
+
# Define the generation call
|
| 581 |
+
# - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
|
| 582 |
+
# You can correct to that if needed:
|
| 583 |
+
frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1
|
| 584 |
+
# RIFLEx
|
| 585 |
+
enable_riflex = args.riflex
|
| 586 |
+
|
| 587 |
+
# If teacache => reset counters
|
| 588 |
+
if trans.enable_teacache:
|
| 589 |
+
trans.teacache_counter = 0
|
| 590 |
+
trans.teacache_multiplier = args.teacache
|
| 591 |
+
trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
|
| 592 |
+
trans.num_steps = args.steps
|
| 593 |
+
trans.teacache_skipped_steps = 0
|
| 594 |
+
trans.previous_residual_uncond = None
|
| 595 |
+
trans.previous_residual_cond = None
|
| 596 |
+
|
| 597 |
+
# VAE Tiling
|
| 598 |
+
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
| 599 |
+
if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM
|
| 600 |
+
use_vae_config = 1
|
| 601 |
+
elif device_mem_capacity >= 8000:
|
| 602 |
+
use_vae_config = 2
|
| 603 |
+
else:
|
| 604 |
+
use_vae_config = 3
|
| 605 |
+
|
| 606 |
+
if use_vae_config == 1:
|
| 607 |
+
VAE_tile_size = 0
|
| 608 |
+
elif use_vae_config == 2:
|
| 609 |
+
VAE_tile_size = 256
|
| 610 |
+
else:
|
| 611 |
+
VAE_tile_size = 128
|
| 612 |
+
|
| 613 |
+
print('Using VAE tile size of', VAE_tile_size)
|
| 614 |
+
|
| 615 |
+
# Actually run the i2v generation
|
| 616 |
+
try:
|
| 617 |
+
sample_frames = wan_model.generate(
|
| 618 |
+
user_prompt,
|
| 619 |
+
input_img,
|
| 620 |
+
frame_num=frame_count,
|
| 621 |
+
max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom
|
| 622 |
+
shift=args.flow_shift,
|
| 623 |
+
sampling_steps=args.steps,
|
| 624 |
+
guide_scale=args.guidance_scale,
|
| 625 |
+
n_prompt=negative_prompt,
|
| 626 |
+
seed=args.seed,
|
| 627 |
+
offload_model=False,
|
| 628 |
+
callback=None, # or define your own callback if you want
|
| 629 |
+
enable_RIFLEx=enable_riflex,
|
| 630 |
+
VAE_tile_size=VAE_tile_size,
|
| 631 |
+
)
|
| 632 |
+
except Exception as e:
|
| 633 |
+
offloadobj.unload_all()
|
| 634 |
+
gc.collect()
|
| 635 |
+
torch.cuda.empty_cache()
|
| 636 |
+
|
| 637 |
+
err_str = f"Generation failed with error: {e}"
|
| 638 |
+
# Attempt to detect OOM errors
|
| 639 |
+
s = str(e).lower()
|
| 640 |
+
if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
|
| 641 |
+
raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
|
| 642 |
+
else:
|
| 643 |
+
traceback.print_exc()
|
| 644 |
+
raise RuntimeError(err_str)
|
| 645 |
+
|
| 646 |
+
# After generation
|
| 647 |
+
offloadobj.unload_all()
|
| 648 |
+
gc.collect()
|
| 649 |
+
torch.cuda.empty_cache()
|
| 650 |
+
|
| 651 |
+
if sample_frames is None:
|
| 652 |
+
raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
|
| 653 |
+
|
| 654 |
+
# If teacache was used, we can see how many steps were skipped
|
| 655 |
+
if trans.enable_teacache:
|
| 656 |
+
print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
|
| 657 |
+
|
| 658 |
+
# Save result
|
| 659 |
+
sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W]
|
| 660 |
+
os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
|
| 661 |
+
|
| 662 |
+
# Use the provided helper from your code to store the MP4
|
| 663 |
+
# By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
|
| 664 |
+
# or you can do your own. We'll do the same for consistency:
|
| 665 |
+
cache_video(
|
| 666 |
+
tensor=sample_frames[None], # shape => [1, c, T, H, W]
|
| 667 |
+
save_file=args.output_file,
|
| 668 |
+
fps=16,
|
| 669 |
+
nrow=1,
|
| 670 |
+
normalize=True,
|
| 671 |
+
value_range=(-1, 1)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
end_time = time.time()
|
| 675 |
+
elapsed_s = end_time - start_time
|
| 676 |
+
print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")
|
| 677 |
+
|
| 678 |
+
if __name__ == "__main__":
|
| 679 |
+
main()
|