deepbeepmeep commited on
Commit
92f2b6e
·
unverified ·
2 Parent(s): 01cce3d 828d140

Merge pull request #58 from AmericanPresidentJimmyCarter/add-i2v-script

Browse files
Files changed (1) hide show
  1. i2v_inference.py +679 -0
i2v_inference.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import argparse
4
+ import json
5
+ import torch
6
+ import traceback
7
+ import gc
8
+ import random
9
+
10
+ # These imports rely on your existing code structure
11
+ # They must match the location of your WAN code, etc.
12
+ import wan
13
+ from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
14
+ from wan.modules.attention import get_attention_modes
15
+ from wan.utils.utils import cache_video
16
+ from mmgp import offload, safetensors2, profile_type
17
+
18
+ try:
19
+ import triton
20
+ except ImportError:
21
+ pass
22
+
23
+ DATA_DIR = "ckpts"
24
+
25
+ # --------------------------------------------------
26
+ # HELPER FUNCTIONS
27
+ # --------------------------------------------------
28
+
29
+ def sanitize_file_name(file_name):
30
+ """Clean up file name from special chars."""
31
+ return (
32
+ file_name.replace("/", "")
33
+ .replace("\\", "")
34
+ .replace(":", "")
35
+ .replace("|", "")
36
+ .replace("?", "")
37
+ .replace("<", "")
38
+ .replace(">", "")
39
+ .replace('"', "")
40
+ )
41
+
42
+ def extract_preset(lset_name, lora_dir, loras):
43
+ """
44
+ Load a .lset JSON that lists the LoRA files to apply, plus multipliers
45
+ and possibly a suggested prompt prefix.
46
+ """
47
+ lset_name = sanitize_file_name(lset_name)
48
+ if not lset_name.endswith(".lset"):
49
+ lset_name_filename = os.path.join(lora_dir, lset_name + ".lset")
50
+ else:
51
+ lset_name_filename = os.path.join(lora_dir, lset_name)
52
+
53
+ if not os.path.isfile(lset_name_filename):
54
+ raise ValueError(f"Preset '{lset_name}' not found in {lora_dir}")
55
+
56
+ with open(lset_name_filename, "r", encoding="utf-8") as reader:
57
+ text = reader.read()
58
+ lset = json.loads(text)
59
+
60
+ loras_choices_files = lset["loras"]
61
+ loras_choices = []
62
+ missing_loras = []
63
+ for lora_file in loras_choices_files:
64
+ # Build absolute path and see if it is in loras
65
+ full_lora_path = os.path.join(lora_dir, lora_file)
66
+ if full_lora_path in loras:
67
+ idx = loras.index(full_lora_path)
68
+ loras_choices.append(str(idx))
69
+ else:
70
+ missing_loras.append(lora_file)
71
+
72
+ if len(missing_loras) > 0:
73
+ missing_list = ", ".join(missing_loras)
74
+ raise ValueError(f"Missing LoRA files for preset: {missing_list}")
75
+
76
+ loras_mult_choices = lset["loras_mult"]
77
+ prompt_prefix = lset.get("prompt", "")
78
+ full_prompt = lset.get("full_prompt", False)
79
+ return loras_choices, loras_mult_choices, prompt_prefix, full_prompt
80
+
81
+ def get_attention_mode(args_attention, installed_modes):
82
+ """
83
+ Decide which attention mode to use: either the user choice or auto fallback.
84
+ """
85
+ if args_attention == "auto":
86
+ for candidate in ["sage2", "sage", "sdpa"]:
87
+ if candidate in installed_modes:
88
+ return candidate
89
+ return "sdpa" # last fallback
90
+ elif args_attention in installed_modes:
91
+ return args_attention
92
+ else:
93
+ raise ValueError(
94
+ f"Requested attention mode '{args_attention}' not installed. "
95
+ f"Installed modes: {installed_modes}"
96
+ )
97
+
98
+ def load_i2v_model(model_filename, text_encoder_filename, is_720p):
99
+ """
100
+ Load the i2v model with a specific size config and text encoder.
101
+ """
102
+ if is_720p:
103
+ print("Loading 14B-720p i2v model ...")
104
+ cfg = WAN_CONFIGS['i2v-14B']
105
+ wan_model = wan.WanI2V(
106
+ config=cfg,
107
+ checkpoint_dir=DATA_DIR,
108
+ device_id=0,
109
+ rank=0,
110
+ t5_fsdp=False,
111
+ dit_fsdp=False,
112
+ use_usp=False,
113
+ i2v720p=True,
114
+ model_filename=model_filename,
115
+ text_encoder_filename=text_encoder_filename
116
+ )
117
+ else:
118
+ print("Loading 14B-480p i2v model ...")
119
+ cfg = WAN_CONFIGS['i2v-14B']
120
+ wan_model = wan.WanI2V(
121
+ config=cfg,
122
+ checkpoint_dir=DATA_DIR,
123
+ device_id=0,
124
+ rank=0,
125
+ t5_fsdp=False,
126
+ dit_fsdp=False,
127
+ use_usp=False,
128
+ i2v720p=False,
129
+ model_filename=model_filename,
130
+ text_encoder_filename=text_encoder_filename
131
+ )
132
+ # Pipe structure
133
+ pipe = {
134
+ "transformer": wan_model.model,
135
+ "text_encoder": wan_model.text_encoder.model,
136
+ "text_encoder_2": wan_model.clip.model,
137
+ "vae": wan_model.vae.model
138
+ }
139
+ return wan_model, pipe
140
+
141
+ def setup_loras(pipe, lora_dir, lora_preset, num_inference_steps):
142
+ """
143
+ Load loras from a directory, optionally apply a preset.
144
+ """
145
+ from pathlib import Path
146
+ import glob
147
+
148
+ if not lora_dir or not Path(lora_dir).is_dir():
149
+ print("No valid --lora-dir provided or directory doesn't exist, skipping LoRA setup.")
150
+ return [], [], [], "", "", False
151
+
152
+ # Gather LoRA files
153
+ loras = sorted(
154
+ glob.glob(os.path.join(lora_dir, "*.sft"))
155
+ + glob.glob(os.path.join(lora_dir, "*.safetensors"))
156
+ )
157
+ loras_names = [Path(x).stem for x in loras]
158
+
159
+ # Offload them with no activation
160
+ offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False)
161
+
162
+ # If user gave a preset, apply it
163
+ default_loras_choices = []
164
+ default_loras_multis_str = ""
165
+ default_prompt_prefix = ""
166
+ preset_applied_full_prompt = False
167
+ if lora_preset:
168
+ loras_choices, loras_mult, prefix, full_prompt = extract_preset(lora_preset, lora_dir, loras)
169
+ default_loras_choices = loras_choices
170
+ # If user stored loras_mult as a list or string in JSON, unify that to str
171
+ if isinstance(loras_mult, list):
172
+ # Just store them in a single line
173
+ default_loras_multis_str = " ".join([str(x) for x in loras_mult])
174
+ else:
175
+ default_loras_multis_str = str(loras_mult)
176
+ default_prompt_prefix = prefix
177
+ preset_applied_full_prompt = full_prompt
178
+
179
+ return (
180
+ loras,
181
+ loras_names,
182
+ default_loras_choices,
183
+ default_loras_multis_str,
184
+ default_prompt_prefix,
185
+ preset_applied_full_prompt
186
+ )
187
+
188
+ def parse_loras_and_activate(
189
+ transformer,
190
+ loras,
191
+ loras_choices,
192
+ loras_mult_str,
193
+ num_inference_steps
194
+ ):
195
+ """
196
+ Activate the chosen LoRAs with multipliers over the pipeline's transformer.
197
+ Supports stepwise expansions (like "0.5,0.8" for partial steps).
198
+ """
199
+ if not loras or not loras_choices:
200
+ # no LoRAs selected
201
+ return
202
+
203
+ # Handle multipliers
204
+ def is_float_or_comma_list(x):
205
+ """
206
+ Example: "0.5", or "0.8,1.0", etc. is valid.
207
+ """
208
+ if not x:
209
+ return False
210
+ for chunk in x.split(","):
211
+ try:
212
+ float(chunk.strip())
213
+ except ValueError:
214
+ return False
215
+ return True
216
+
217
+ # Convert multiline or spaced lines to a single list
218
+ lines = [
219
+ line.strip()
220
+ for line in loras_mult_str.replace("\r", "\n").split("\n")
221
+ if line.strip() and not line.strip().startswith("#")
222
+ ]
223
+ # Now combine them by space
224
+ joined_line = " ".join(lines) # "1.0 2.0,3.0"
225
+ if not joined_line.strip():
226
+ multipliers = []
227
+ else:
228
+ multipliers = joined_line.split(" ")
229
+
230
+ # Expand each item
231
+ final_multipliers = []
232
+ for mult in multipliers:
233
+ mult = mult.strip()
234
+ if not mult:
235
+ continue
236
+ if is_float_or_comma_list(mult):
237
+ # Could be "0.7" or "0.5,0.6"
238
+ if "," in mult:
239
+ # expand over steps
240
+ chunk_vals = [float(x.strip()) for x in mult.split(",")]
241
+ expanded = expand_list_over_steps(chunk_vals, num_inference_steps)
242
+ final_multipliers.append(expanded)
243
+ else:
244
+ final_multipliers.append(float(mult))
245
+ else:
246
+ raise ValueError(f"Invalid LoRA multiplier: '{mult}'")
247
+
248
+ # If fewer multipliers than chosen LoRAs => pad with 1.0
249
+ needed = len(loras_choices) - len(final_multipliers)
250
+ if needed > 0:
251
+ final_multipliers += [1.0]*needed
252
+
253
+ # Actually activate them
254
+ offload.activate_loras(transformer, loras_choices, final_multipliers)
255
+
256
+ def expand_list_over_steps(short_list, num_steps):
257
+ """
258
+ If user gave (0.5, 0.8) for example, expand them over `num_steps`.
259
+ The expansion is simply linear slice across steps.
260
+ """
261
+ result = []
262
+ inc = len(short_list) / float(num_steps)
263
+ idxf = 0.0
264
+ for _ in range(num_steps):
265
+ value = short_list[int(idxf)]
266
+ result.append(value)
267
+ idxf += inc
268
+ return result
269
+
270
+ def download_models_if_needed(transformer_filename_i2v, text_encoder_filename, local_folder=DATA_DIR):
271
+ """
272
+ Checks if all required WAN 2.1 i2v files exist locally under 'ckpts/'.
273
+ If not, downloads them from a Hugging Face Hub repo.
274
+ Adjust the 'repo_id' and needed files as appropriate.
275
+ """
276
+ import os
277
+ from pathlib import Path
278
+
279
+ try:
280
+ from huggingface_hub import hf_hub_download, snapshot_download
281
+ except ImportError as e:
282
+ raise ImportError(
283
+ "huggingface_hub is required for automatic model download. "
284
+ "Please install it via `pip install huggingface_hub`."
285
+ ) from e
286
+
287
+ # Identify just the filename portion for each path
288
+ def basename(path_str):
289
+ return os.path.basename(path_str)
290
+
291
+ repo_id = "DeepBeepMeep/Wan2.1"
292
+ target_root = local_folder
293
+
294
+ # You can customize this list as needed for i2v usage.
295
+ # At minimum you need:
296
+ # 1) The requested i2v transformer file
297
+ # 2) The requested text encoder file
298
+ # 3) VAE file
299
+ # 4) The open-clip xlm-roberta-large weights
300
+ #
301
+ # If your i2v config references additional files, add them here.
302
+ needed_files = [
303
+ "Wan2.1_VAE.pth",
304
+ "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
305
+ basename(text_encoder_filename),
306
+ basename(transformer_filename_i2v),
307
+ ]
308
+
309
+ # The original script also downloads an entire "xlm-roberta-large" folder
310
+ # via snapshot_download. If you require that for your pipeline,
311
+ # you can add it here, for example:
312
+ subfolder_name = "xlm-roberta-large"
313
+ if not Path(os.path.join(target_root, subfolder_name)).exists():
314
+ snapshot_download(repo_id=repo_id, allow_patterns=subfolder_name + "/*", local_dir=target_root)
315
+
316
+ for filename in needed_files:
317
+ local_path = os.path.join(target_root, filename)
318
+ if not os.path.isfile(local_path):
319
+ print(f"File '{filename}' not found locally. Downloading from {repo_id} ...")
320
+ hf_hub_download(
321
+ repo_id=repo_id,
322
+ filename=filename,
323
+ local_dir=target_root
324
+ )
325
+ else:
326
+ # Already present
327
+ pass
328
+
329
+ print("All required i2v files are present.")
330
+
331
+
332
+ # --------------------------------------------------
333
+ # ARGUMENT PARSER
334
+ # --------------------------------------------------
335
+
336
+ def parse_args():
337
+ parser = argparse.ArgumentParser(
338
+ description="Image-to-Video inference using WAN 2.1 i2v"
339
+ )
340
+ # Model + Tools
341
+ parser.add_argument(
342
+ "--quantize-transformer",
343
+ action="store_true",
344
+ help="Use on-the-fly transformer quantization"
345
+ )
346
+ parser.add_argument(
347
+ "--compile",
348
+ action="store_true",
349
+ help="Enable PyTorch 2.0 compile for the transformer"
350
+ )
351
+ parser.add_argument(
352
+ "--attention",
353
+ type=str,
354
+ default="auto",
355
+ help="Which attention to use: auto, sdpa, sage, sage2, flash"
356
+ )
357
+ parser.add_argument(
358
+ "--profile",
359
+ type=int,
360
+ default=4,
361
+ help="Memory usage profile number [1..5]; see original script or use 2 if you have low VRAM"
362
+ )
363
+ parser.add_argument(
364
+ "--preload",
365
+ type=int,
366
+ default=0,
367
+ help="Megabytes of the diffusion model to preload in VRAM (only used in some profiles)"
368
+ )
369
+ parser.add_argument(
370
+ "--verbose",
371
+ type=int,
372
+ default=1,
373
+ help="Verbosity level [0..5]"
374
+ )
375
+
376
+ # i2v Model
377
+ parser.add_argument(
378
+ "--transformer-file",
379
+ type=str,
380
+ default=f"{DATA_DIR}/wan2.1_image2video_480p_14B_quanto_int8.safetensors",
381
+ help="Which i2v model to load"
382
+ )
383
+ parser.add_argument(
384
+ "--text-encoder-file",
385
+ type=str,
386
+ default=f"{DATA_DIR}/models_t5_umt5-xxl-enc-quanto_int8.safetensors",
387
+ help="Which text encoder to use"
388
+ )
389
+
390
+ # LoRA
391
+ parser.add_argument(
392
+ "--lora-dir",
393
+ type=str,
394
+ default="",
395
+ help="Path to a directory containing i2v LoRAs"
396
+ )
397
+ parser.add_argument(
398
+ "--lora-preset",
399
+ type=str,
400
+ default="",
401
+ help="A .lset preset name in the lora_dir to auto-apply"
402
+ )
403
+
404
+ # Generation Options
405
+ parser.add_argument("--prompt", type=str, default=None, required=True, help="Prompt for generation")
406
+ parser.add_argument("--negative-prompt", type=str, default="", help="Negative prompt")
407
+ parser.add_argument("--resolution", type=str, default="832x480", help="WxH")
408
+ parser.add_argument("--frames", type=int, default=64, help="Number of frames (16=1s if fps=16). Must be multiple of 4 +/- 1 in WAN.")
409
+ parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps.")
410
+ parser.add_argument("--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale")
411
+ parser.add_argument("--flow-shift", type=float, default=3.0, help="Flow shift parameter. Generally 3.0 for 480p, 5.0 for 720p.")
412
+ parser.add_argument("--riflex", action="store_true", help="Enable RIFLEx for longer videos")
413
+ parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
414
+ parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
415
+ parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
416
+
417
+ # LoRA usage
418
+ parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
419
+ parser.add_argument("--loras-mult", type=str, default="", help="Multipliers for each chosen LoRA. Example: '1.0 1.2,1.3' etc.")
420
+
421
+ # Input
422
+ parser.add_argument(
423
+ "--input-image",
424
+ type=str,
425
+ default=None,
426
+ required=True,
427
+ help="Path to an input image (or multiple)."
428
+ )
429
+ parser.add_argument(
430
+ "--output-file",
431
+ type=str,
432
+ default="output.mp4",
433
+ help="Where to save the resulting video."
434
+ )
435
+
436
+ return parser.parse_args()
437
+
438
+ # --------------------------------------------------
439
+ # MAIN
440
+ # --------------------------------------------------
441
+
442
+ def main():
443
+ args = parse_args()
444
+
445
+ # Setup environment
446
+ offload.default_verboseLevel = args.verbose
447
+ installed_attn_modes = get_attention_modes()
448
+
449
+ # Decide attention
450
+ chosen_attention = get_attention_mode(args.attention, installed_attn_modes)
451
+ offload.shared_state["_attention"] = chosen_attention
452
+
453
+ # Determine i2v resolution format
454
+ if "720" in args.transformer_file:
455
+ is_720p = True
456
+ else:
457
+ is_720p = False
458
+
459
+ # Make sure we have the needed models locally
460
+ download_models_if_needed(args.transformer_file, args.text_encoder_file)
461
+
462
+ # Load i2v
463
+ wan_model, pipe = load_i2v_model(
464
+ model_filename=args.transformer_file,
465
+ text_encoder_filename=args.text_encoder_file,
466
+ is_720p=is_720p
467
+ )
468
+ wan_model._interrupt = False
469
+
470
+ # Offload / profile
471
+ # e.g. for your script: offload.profile(pipe, profile_no=args.profile, compile=..., quantizeTransformer=...)
472
+ # pass the budgets if you want, etc.
473
+ kwargs = {}
474
+ if args.profile == 2 or args.profile == 4:
475
+ # preload is in MB
476
+ if args.preload == 0:
477
+ budgets = {"transformer": 100, "text_encoder": 100, "*": 1000}
478
+ else:
479
+ budgets = {"transformer": args.preload, "text_encoder": 100, "*": 1000}
480
+ kwargs["budgets"] = budgets
481
+ elif args.profile == 3:
482
+ kwargs["budgets"] = {"*": "70%"}
483
+
484
+ compile_choice = "transformer" if args.compile else ""
485
+ # Create the offload object
486
+ offloadobj = offload.profile(
487
+ pipe,
488
+ profile_no=args.profile,
489
+ compile=compile_choice,
490
+ quantizeTransformer=args.quantize_transformer,
491
+ **kwargs
492
+ )
493
+
494
+ # If user wants to use LoRAs
495
+ (
496
+ loras,
497
+ loras_names,
498
+ default_loras_choices,
499
+ default_loras_multis_str,
500
+ preset_prompt_prefix,
501
+ preset_full_prompt
502
+ ) = setup_loras(pipe, args.lora_dir, args.lora_preset, args.steps)
503
+
504
+ # Combine user prompt with preset prompt if the preset indicates so
505
+ if preset_prompt_prefix:
506
+ if preset_full_prompt:
507
+ # Full override
508
+ user_prompt = preset_prompt_prefix
509
+ else:
510
+ # Just prefix
511
+ user_prompt = preset_prompt_prefix + "\n" + args.prompt
512
+ else:
513
+ user_prompt = args.prompt
514
+
515
+ # Actually parse user LoRA choices if they did not rely purely on the preset
516
+ if args.loras_choices:
517
+ # If user gave e.g. "0,1", we treat that as new additions
518
+ lora_choice_list = [x.strip() for x in args.loras_choices.split(",")]
519
+ else:
520
+ # Use the defaults from the preset
521
+ lora_choice_list = default_loras_choices
522
+
523
+ # Activate them
524
+ parse_loras_and_activate(
525
+ pipe["transformer"], loras, lora_choice_list, args.loras_mult or default_loras_multis_str, args.steps
526
+ )
527
+
528
+ # Negative prompt
529
+ negative_prompt = args.negative_prompt or ""
530
+
531
+ # Sanity check resolution
532
+ if "*" in args.resolution.lower():
533
+ print("ERROR: resolution must be e.g. 832x480 not '832*480'. Fixing it.")
534
+ resolution_str = args.resolution.lower().replace("*", "x")
535
+ else:
536
+ resolution_str = args.resolution
537
+
538
+ try:
539
+ width, height = [int(x) for x in resolution_str.split("x")]
540
+ except:
541
+ raise ValueError(f"Invalid resolution: '{resolution_str}'")
542
+
543
+ # Additional checks (from your original code).
544
+ if "480p" in args.transformer_file:
545
+ # Then we cannot exceed certain area for 480p model
546
+ if width * height > 832*480:
547
+ raise ValueError("You must use the 720p i2v model to generate bigger than 832x480.")
548
+ # etc.
549
+
550
+ # Handle random seed
551
+ if args.seed < 0:
552
+ args.seed = random.randint(0, 999999999)
553
+ print(f"Using seed={args.seed}")
554
+
555
+ # Setup tea cache if needed
556
+ trans = wan_model.model
557
+ trans.enable_teacache = (args.teacache > 0)
558
+ if trans.enable_teacache:
559
+ if "480p" in args.transformer_file:
560
+ # example from your code
561
+ trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
562
+ elif "720p" in args.transformer_file:
563
+ trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
564
+ else:
565
+ raise ValueError("Teacache not supported for this model variant")
566
+
567
+ # Attempt generation
568
+ print("Starting generation ...")
569
+ start_time = time.time()
570
+
571
+ # Read the input image
572
+ if not os.path.isfile(args.input_image):
573
+ raise ValueError(f"Input image does not exist: {args.input_image}")
574
+
575
+ from PIL import Image
576
+ input_img = Image.open(args.input_image).convert("RGB")
577
+
578
+ # Possibly load more than one image if you want "multiple images" – but here we'll just do single for demonstration
579
+
580
+ # Define the generation call
581
+ # - frames => must be multiple of 4 plus 1 as per original script's note, e.g. 81, 65, ...
582
+ # You can correct to that if needed:
583
+ frame_count = (args.frames // 4)*4 + 1 # ensures it's 4*N+1
584
+ # RIFLEx
585
+ enable_riflex = args.riflex
586
+
587
+ # If teacache => reset counters
588
+ if trans.enable_teacache:
589
+ trans.teacache_counter = 0
590
+ trans.teacache_multiplier = args.teacache
591
+ trans.teacache_start_step = int(args.teacache_start * args.steps / 100.0)
592
+ trans.num_steps = args.steps
593
+ trans.teacache_skipped_steps = 0
594
+ trans.previous_residual_uncond = None
595
+ trans.previous_residual_cond = None
596
+
597
+ # VAE Tiling
598
+ device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
599
+ if device_mem_capacity >= 28000: # 81 frames 720p requires about 28 GB VRAM
600
+ use_vae_config = 1
601
+ elif device_mem_capacity >= 8000:
602
+ use_vae_config = 2
603
+ else:
604
+ use_vae_config = 3
605
+
606
+ if use_vae_config == 1:
607
+ VAE_tile_size = 0
608
+ elif use_vae_config == 2:
609
+ VAE_tile_size = 256
610
+ else:
611
+ VAE_tile_size = 128
612
+
613
+ print('Using VAE tile size of', VAE_tile_size)
614
+
615
+ # Actually run the i2v generation
616
+ try:
617
+ sample_frames = wan_model.generate(
618
+ user_prompt,
619
+ input_img,
620
+ frame_num=frame_count,
621
+ max_area=MAX_AREA_CONFIGS[f"{width}*{height}"], # or you can pass your custom
622
+ shift=args.flow_shift,
623
+ sampling_steps=args.steps,
624
+ guide_scale=args.guidance_scale,
625
+ n_prompt=negative_prompt,
626
+ seed=args.seed,
627
+ offload_model=False,
628
+ callback=None, # or define your own callback if you want
629
+ enable_RIFLEx=enable_riflex,
630
+ VAE_tile_size=VAE_tile_size,
631
+ )
632
+ except Exception as e:
633
+ offloadobj.unload_all()
634
+ gc.collect()
635
+ torch.cuda.empty_cache()
636
+
637
+ err_str = f"Generation failed with error: {e}"
638
+ # Attempt to detect OOM errors
639
+ s = str(e).lower()
640
+ if any(keyword in s for keyword in ["memory", "cuda", "alloc"]):
641
+ raise RuntimeError("Likely out-of-VRAM or out-of-RAM error. " + err_str)
642
+ else:
643
+ traceback.print_exc()
644
+ raise RuntimeError(err_str)
645
+
646
+ # After generation
647
+ offloadobj.unload_all()
648
+ gc.collect()
649
+ torch.cuda.empty_cache()
650
+
651
+ if sample_frames is None:
652
+ raise RuntimeError("No frames were returned (maybe generation was aborted or failed).")
653
+
654
+ # If teacache was used, we can see how many steps were skipped
655
+ if trans.enable_teacache:
656
+ print(f"TeaCache skipped steps: {trans.teacache_skipped_steps} / {args.steps}")
657
+
658
+ # Save result
659
+ sample_frames = sample_frames.cpu() # shape = c, t, h, w => [3, T, H, W]
660
+ os.makedirs(os.path.dirname(args.output_file) or ".", exist_ok=True)
661
+
662
+ # Use the provided helper from your code to store the MP4
663
+ # By default, you used cache_video(tensor=..., save_file=..., fps=16, ...)
664
+ # or you can do your own. We'll do the same for consistency:
665
+ cache_video(
666
+ tensor=sample_frames[None], # shape => [1, c, T, H, W]
667
+ save_file=args.output_file,
668
+ fps=16,
669
+ nrow=1,
670
+ normalize=True,
671
+ value_range=(-1, 1)
672
+ )
673
+
674
+ end_time = time.time()
675
+ elapsed_s = end_time - start_time
676
+ print(f"Done! Output written to {args.output_file}. Generation time: {elapsed_s:.1f} seconds.")
677
+
678
+ if __name__ == "__main__":
679
+ main()