Add inference code and attention settings for sfp4 checkpoint-750
Browse files- backend_snapshot/README.md +30 -0
- backend_snapshot/examples/inference/basic/basic.py +47 -0
- backend_snapshot/fastvideo/api/compat.py +503 -0
- backend_snapshot/fastvideo/configs/pipelines/wan.py +203 -0
- backend_snapshot/fastvideo/configs/sample/base.py +292 -0
- backend_snapshot/fastvideo/configs/sample/wan.py +154 -0
- backend_snapshot/fastvideo/configs/wan_1.3B_t2v_pipeline.json +40 -0
- backend_snapshot/fastvideo/entrypoints/cli/generate.py +115 -0
- backend_snapshot/fastvideo/entrypoints/video_generator.py +797 -0
- backend_snapshot/fastvideo/fastvideo_args.py +1188 -0
- backend_snapshot/fastvideo/pipelines/basic/wan/__init__.py +0 -0
- backend_snapshot/fastvideo/pipelines/basic/wan/wan_pipeline.py +60 -0
- backend_snapshot/fastvideo/pipelines/composed_pipeline_base.py +474 -0
- backend_snapshot/manifest.sha256 +17 -1
- backend_snapshot/scripts/inference/run_sfp4_ours_p_checkpoint_750.sh +54 -0
- backend_snapshot/scripts/inference/run_sfp4_single.sh +38 -0
- backend_snapshot/scripts/inference/run_validate_and_gen.sh +91 -0
- backend_snapshot/training_attention_settings.json +62 -0
backend_snapshot/README.md
CHANGED
|
@@ -13,6 +13,17 @@ Key runtime settings:
|
|
| 13 |
- `VSA_WARMUP_STEPS=0`
|
| 14 |
- tile size: `4 x 4 x 4 = 64` video tokens
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
Important files:
|
| 17 |
|
| 18 |
- `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
|
|
@@ -25,6 +36,25 @@ Important files:
|
|
| 25 |
- `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
|
| 26 |
- `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
|
| 27 |
- `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
Source repo HEAD when staged:
|
| 30 |
|
|
|
|
| 13 |
- `VSA_WARMUP_STEPS=0`
|
| 14 |
- tile size: `4 x 4 x 4 = 64` video tokens
|
| 15 |
|
| 16 |
+
Training attention semantics:
|
| 17 |
+
|
| 18 |
+
- Video self-attention uses `SPARSE_FP4_OURS_P_ATTN`.
|
| 19 |
+
- Cross-attention is not quantized/sparse in this backend. It falls back to
|
| 20 |
+
dense SDPA when `query_length != key_length`.
|
| 21 |
+
- `force_dense` paths also use dense SDPA.
|
| 22 |
+
- Q/K/V fake quantization uses FP4 with STE and no q/k mean subtraction.
|
| 23 |
+
- Selected sparse tiles use group-local P quantization in the Triton kernel.
|
| 24 |
+
- Dropped VSA tiles use tile-level q_mean/k_mean score plus mean_v
|
| 25 |
+
compensation.
|
| 26 |
+
|
| 27 |
Important files:
|
| 28 |
|
| 29 |
- `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
|
|
|
|
| 36 |
- `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
|
| 37 |
- `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
|
| 38 |
- `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
|
| 39 |
+
- `training_attention_settings.json`: structured attention/training settings
|
| 40 |
+
for this checkpoint.
|
| 41 |
+
- `scripts/inference/run_sfp4_ours_p_checkpoint_750.sh`: inference example
|
| 42 |
+
for the uploaded transformer checkpoint.
|
| 43 |
+
- `fastvideo/entrypoints/cli/generate.py`, `fastvideo/entrypoints/video_generator.py`,
|
| 44 |
+
`fastvideo/pipelines/basic/wan/wan_pipeline.py`, and
|
| 45 |
+
`fastvideo/pipelines/stages/denoising.py`: `fastvideo generate` inference
|
| 46 |
+
path used by the example script.
|
| 47 |
+
|
| 48 |
+
Example inference flow:
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
hf download yitongl/sparse_quant_exp \
|
| 52 |
+
--repo-type model \
|
| 53 |
+
--local-dir checkpoints/hf_download/sparse_quant_exp \
|
| 54 |
+
--include 'transformer/*'
|
| 55 |
+
|
| 56 |
+
bash backend_snapshot/scripts/inference/run_sfp4_ours_p_checkpoint_750.sh
|
| 57 |
+
```
|
| 58 |
|
| 59 |
Source repo HEAD when staged:
|
| 60 |
|
backend_snapshot/examples/inference/basic/basic.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastvideo import VideoGenerator
|
| 2 |
+
|
| 3 |
+
# from fastvideo.configs.sample import SamplingParam
|
| 4 |
+
|
| 5 |
+
OUTPUT_PATH = "video_samples"
|
| 6 |
+
def main():
|
| 7 |
+
# FastVideo will automatically use the optimal default arguments for the
|
| 8 |
+
# model.
|
| 9 |
+
# If a local path is provided, FastVideo will make a best effort
|
| 10 |
+
# attempt to identify the optimal arguments.
|
| 11 |
+
generator = VideoGenerator.from_pretrained(
|
| 12 |
+
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
| 13 |
+
# FastVideo will automatically handle distributed setup
|
| 14 |
+
num_gpus=1,
|
| 15 |
+
use_fsdp_inference=False, # set to True if GPU is out of memory
|
| 16 |
+
dit_cpu_offload=False,
|
| 17 |
+
vae_cpu_offload=False,
|
| 18 |
+
text_encoder_cpu_offload=True,
|
| 19 |
+
pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
|
| 20 |
+
# image_encoder_cpu_offload=False,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# sampling_param = SamplingParam.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
|
| 24 |
+
# sampling_param.num_frames = 45
|
| 25 |
+
# sampling_param.image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 26 |
+
# Generate videos with the same simple API, regardless of GPU count
|
| 27 |
+
prompt = (
|
| 28 |
+
"A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes "
|
| 29 |
+
"wide with interest. The playful yet serene atmosphere is complemented by soft "
|
| 30 |
+
"natural light filtering through the petals. Mid-shot, warm and cheerful tones."
|
| 31 |
+
)
|
| 32 |
+
video = generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True)
|
| 33 |
+
# video = generator.generate_video(prompt, sampling_param=sampling_param, output_path="wan_t2v_videos/")
|
| 34 |
+
|
| 35 |
+
# Generate another video with a different prompt, without reloading the
|
| 36 |
+
# model!
|
| 37 |
+
prompt2 = (
|
| 38 |
+
"A majestic lion strides across the golden savanna, its powerful frame "
|
| 39 |
+
"glistening under the warm afternoon sun. The tall grass ripples gently in "
|
| 40 |
+
"the breeze, enhancing the lion's commanding presence. The tone is vibrant, "
|
| 41 |
+
"embodying the raw energy of the wild. Low angle, steady tracking shot, "
|
| 42 |
+
"cinematic.")
|
| 43 |
+
video2 = generator.generate_video(prompt2, output_path=OUTPUT_PATH, save_video=True)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if __name__ == "__main__":
|
| 47 |
+
main()
|
backend_snapshot/fastvideo/api/compat.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from collections.abc import Mapping
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
from dataclasses import fields, is_dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from fastvideo.api.overrides import apply_overrides, parse_cli_overrides
|
| 11 |
+
from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config
|
| 12 |
+
from fastvideo.api.schema import (
|
| 13 |
+
GenerationRequest,
|
| 14 |
+
GeneratorConfig,
|
| 15 |
+
InputConfig,
|
| 16 |
+
OutputConfig,
|
| 17 |
+
RequestRuntimeConfig,
|
| 18 |
+
SamplingConfig,
|
| 19 |
+
)
|
| 20 |
+
from fastvideo.configs.sample import SamplingParam
|
| 21 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 22 |
+
from fastvideo.utils import shallow_asdict
|
| 23 |
+
|
| 24 |
+
_EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request"
|
| 25 |
+
_INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)}
|
| 26 |
+
_SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)}
|
| 27 |
+
_RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)}
|
| 28 |
+
_OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)}
|
| 29 |
+
_MISSING = object()
|
| 30 |
+
_LEGACY_REQUEST_ALIASES = {
|
| 31 |
+
"neg_prompt": "negative_prompt",
|
| 32 |
+
}
|
| 33 |
+
_REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({
|
| 34 |
+
"embedded_cfg_scale",
|
| 35 |
+
})
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig:
|
| 39 |
+
if isinstance(config, GeneratorConfig):
|
| 40 |
+
return config
|
| 41 |
+
return parse_config(GeneratorConfig, config)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_generator_config_from_file(
|
| 45 |
+
path: str | Path,
|
| 46 |
+
overrides: list[str] | Mapping[str, Any] | None = None,
|
| 47 |
+
) -> GeneratorConfig:
|
| 48 |
+
raw = load_raw_config(path)
|
| 49 |
+
normalized_overrides = _normalize_overrides(overrides)
|
| 50 |
+
|
| 51 |
+
if _looks_like_run_or_serve_config(raw):
|
| 52 |
+
if normalized_overrides:
|
| 53 |
+
raw = apply_overrides(raw, normalized_overrides)
|
| 54 |
+
return parse_config(GeneratorConfig, raw["generator"])
|
| 55 |
+
|
| 56 |
+
if normalized_overrides:
|
| 57 |
+
adjusted = normalized_overrides
|
| 58 |
+
if all(key.startswith("generator.") for key in adjusted):
|
| 59 |
+
adjusted = {key[len("generator."):]: value for key, value in adjusted.items()}
|
| 60 |
+
raw = apply_overrides(raw, adjusted)
|
| 61 |
+
|
| 62 |
+
return parse_config(GeneratorConfig, raw)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def legacy_from_pretrained_to_config(
|
| 66 |
+
model_path: str,
|
| 67 |
+
kwargs: Mapping[str, Any],
|
| 68 |
+
) -> GeneratorConfig:
|
| 69 |
+
raw: dict[str, Any] = {"model_path": model_path}
|
| 70 |
+
engine: dict[str, Any] = {}
|
| 71 |
+
parallelism: dict[str, Any] = {}
|
| 72 |
+
offload: dict[str, Any] = {}
|
| 73 |
+
compile_config: dict[str, Any] = {}
|
| 74 |
+
pipeline: dict[str, Any] = {}
|
| 75 |
+
components: dict[str, Any] = {}
|
| 76 |
+
quantization: dict[str, Any] = {}
|
| 77 |
+
experimental: dict[str, Any] = {}
|
| 78 |
+
|
| 79 |
+
for key, value in kwargs.items():
|
| 80 |
+
if key == "revision":
|
| 81 |
+
raw["revision"] = value
|
| 82 |
+
elif key == "trust_remote_code":
|
| 83 |
+
raw["trust_remote_code"] = value
|
| 84 |
+
elif key == "num_gpus":
|
| 85 |
+
engine["num_gpus"] = value
|
| 86 |
+
elif key == "distributed_executor_backend":
|
| 87 |
+
engine["execution_backend"] = value
|
| 88 |
+
elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}:
|
| 89 |
+
parallelism[key] = value
|
| 90 |
+
elif key == "dit_cpu_offload":
|
| 91 |
+
offload["dit"] = value
|
| 92 |
+
elif key == "dit_layerwise_offload":
|
| 93 |
+
offload["dit_layerwise"] = value
|
| 94 |
+
elif key == "text_encoder_cpu_offload":
|
| 95 |
+
offload["text_encoder"] = value
|
| 96 |
+
elif key == "image_encoder_cpu_offload":
|
| 97 |
+
offload["image_encoder"] = value
|
| 98 |
+
elif key == "vae_cpu_offload":
|
| 99 |
+
offload["vae"] = value
|
| 100 |
+
elif key == "pin_cpu_memory":
|
| 101 |
+
offload["pin_cpu_memory"] = value
|
| 102 |
+
elif key == "enable_torch_compile":
|
| 103 |
+
compile_config["enabled"] = value
|
| 104 |
+
elif key == "torch_compile_kwargs":
|
| 105 |
+
compile_config["kwargs"] = deepcopy(value)
|
| 106 |
+
elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}:
|
| 107 |
+
engine[key] = value
|
| 108 |
+
elif key == "override_text_encoder_quant":
|
| 109 |
+
quantization["text_encoder_quant"] = value
|
| 110 |
+
elif key == "transformer_quant":
|
| 111 |
+
quantization["transformer_quant"] = value
|
| 112 |
+
elif key == "workload_type":
|
| 113 |
+
pipeline["workload_type"] = value
|
| 114 |
+
elif key == "lora_path":
|
| 115 |
+
components["lora_path"] = value
|
| 116 |
+
elif key == "override_pipeline_cls_name":
|
| 117 |
+
components["override_pipeline_cls_name"] = value
|
| 118 |
+
elif key == "override_transformer_cls_name":
|
| 119 |
+
components["override_transformer_cls_name"] = value
|
| 120 |
+
elif key == "pipeline_config":
|
| 121 |
+
if isinstance(value, str):
|
| 122 |
+
components["pipeline_config_path"] = value
|
| 123 |
+
else:
|
| 124 |
+
experimental[key] = deepcopy(value)
|
| 125 |
+
elif key == "override_text_encoder_safetensors":
|
| 126 |
+
components["text_encoder_weights"] = value
|
| 127 |
+
elif key == "init_weights_from_safetensors":
|
| 128 |
+
components["transformer_weights"] = value
|
| 129 |
+
elif key == "init_weights_from_safetensors_2":
|
| 130 |
+
components["transformer_2_weights"] = value
|
| 131 |
+
else:
|
| 132 |
+
experimental[key] = deepcopy(value)
|
| 133 |
+
|
| 134 |
+
if parallelism:
|
| 135 |
+
engine["parallelism"] = parallelism
|
| 136 |
+
if offload:
|
| 137 |
+
engine["offload"] = offload
|
| 138 |
+
if compile_config:
|
| 139 |
+
engine["compile"] = compile_config
|
| 140 |
+
if quantization:
|
| 141 |
+
engine["quantization"] = quantization
|
| 142 |
+
if engine:
|
| 143 |
+
raw["engine"] = engine
|
| 144 |
+
|
| 145 |
+
if components:
|
| 146 |
+
pipeline["components"] = components
|
| 147 |
+
if experimental:
|
| 148 |
+
pipeline["experimental"] = experimental
|
| 149 |
+
if pipeline:
|
| 150 |
+
raw["pipeline"] = pipeline
|
| 151 |
+
|
| 152 |
+
return parse_config(GeneratorConfig, raw)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs:
|
| 156 |
+
normalized = normalize_generator_config(config)
|
| 157 |
+
unsupported = []
|
| 158 |
+
if normalized.pipeline.profile is not None:
|
| 159 |
+
unsupported.append("pipeline.profile")
|
| 160 |
+
if normalized.pipeline.profile_version is not None:
|
| 161 |
+
unsupported.append("pipeline.profile_version")
|
| 162 |
+
if normalized.pipeline.components.config_root is not None:
|
| 163 |
+
unsupported.append("pipeline.components.config_root")
|
| 164 |
+
if normalized.pipeline.components.vae_weights is not None:
|
| 165 |
+
unsupported.append("pipeline.components.vae_weights")
|
| 166 |
+
if normalized.pipeline.components.upsampler_weights is not None:
|
| 167 |
+
unsupported.append("pipeline.components.upsampler_weights")
|
| 168 |
+
if unsupported:
|
| 169 |
+
joined = ", ".join(unsupported)
|
| 170 |
+
raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet")
|
| 171 |
+
|
| 172 |
+
engine = normalized.engine
|
| 173 |
+
kwargs: dict[str, Any] = {
|
| 174 |
+
"model_path": normalized.model_path,
|
| 175 |
+
"revision": normalized.revision,
|
| 176 |
+
"trust_remote_code": normalized.trust_remote_code,
|
| 177 |
+
"num_gpus": engine.num_gpus,
|
| 178 |
+
"distributed_executor_backend": engine.execution_backend,
|
| 179 |
+
"tp_size": engine.parallelism.tp_size,
|
| 180 |
+
"sp_size": engine.parallelism.sp_size,
|
| 181 |
+
"hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim,
|
| 182 |
+
"hsdp_shard_dim": engine.parallelism.hsdp_shard_dim,
|
| 183 |
+
"dist_timeout": engine.parallelism.dist_timeout,
|
| 184 |
+
"dit_cpu_offload": engine.offload.dit,
|
| 185 |
+
"dit_layerwise_offload": engine.offload.dit_layerwise,
|
| 186 |
+
"text_encoder_cpu_offload": engine.offload.text_encoder,
|
| 187 |
+
"image_encoder_cpu_offload": engine.offload.image_encoder,
|
| 188 |
+
"vae_cpu_offload": engine.offload.vae,
|
| 189 |
+
"pin_cpu_memory": engine.offload.pin_cpu_memory,
|
| 190 |
+
"enable_torch_compile": engine.compile.enabled,
|
| 191 |
+
"torch_compile_kwargs": deepcopy(engine.compile.kwargs),
|
| 192 |
+
"enable_stage_verification": engine.enable_stage_verification,
|
| 193 |
+
"use_fsdp_inference": engine.use_fsdp_inference,
|
| 194 |
+
"disable_autocast": engine.disable_autocast,
|
| 195 |
+
}
|
| 196 |
+
if normalized.pipeline.workload_type is not None:
|
| 197 |
+
kwargs["workload_type"] = normalized.pipeline.workload_type
|
| 198 |
+
|
| 199 |
+
quantization = engine.quantization
|
| 200 |
+
if quantization is not None and quantization.text_encoder_quant is not None:
|
| 201 |
+
kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant
|
| 202 |
+
if quantization is not None and quantization.transformer_quant is not None:
|
| 203 |
+
kwargs["transformer_quant"] = quantization.transformer_quant
|
| 204 |
+
|
| 205 |
+
components = normalized.pipeline.components
|
| 206 |
+
if components.pipeline_config_path is not None:
|
| 207 |
+
kwargs["pipeline_config"] = components.pipeline_config_path
|
| 208 |
+
if components.lora_path is not None:
|
| 209 |
+
kwargs["lora_path"] = components.lora_path
|
| 210 |
+
if components.override_pipeline_cls_name is not None:
|
| 211 |
+
kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name
|
| 212 |
+
if components.override_transformer_cls_name is not None:
|
| 213 |
+
kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name
|
| 214 |
+
if components.text_encoder_weights is not None:
|
| 215 |
+
kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights
|
| 216 |
+
if components.transformer_weights is not None:
|
| 217 |
+
kwargs["init_weights_from_safetensors"] = components.transformer_weights
|
| 218 |
+
if components.transformer_2_weights is not None:
|
| 219 |
+
kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights
|
| 220 |
+
|
| 221 |
+
kwargs.update(deepcopy(normalized.pipeline.profile_overrides))
|
| 222 |
+
kwargs.update(deepcopy(normalized.pipeline.experimental))
|
| 223 |
+
return FastVideoArgs.from_kwargs(**kwargs)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest:
|
| 227 |
+
normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request))
|
| 228 |
+
|
| 229 |
+
if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR):
|
| 230 |
+
setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized))
|
| 231 |
+
return normalized
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def legacy_generate_call_to_request(
|
| 235 |
+
prompt: str | None,
|
| 236 |
+
sampling_param: SamplingParam | None,
|
| 237 |
+
*,
|
| 238 |
+
mouse_cond: Any | None = None,
|
| 239 |
+
keyboard_cond: Any | None = None,
|
| 240 |
+
grid_sizes: Any | None = None,
|
| 241 |
+
legacy_kwargs: Mapping[str, Any] | None = None,
|
| 242 |
+
) -> GenerationRequest:
|
| 243 |
+
raw = _sampling_param_to_request_raw(sampling_param)
|
| 244 |
+
if prompt is not None:
|
| 245 |
+
raw["prompt"] = prompt
|
| 246 |
+
|
| 247 |
+
for key, value in (legacy_kwargs or {}).items():
|
| 248 |
+
_apply_request_field(raw, key, value)
|
| 249 |
+
|
| 250 |
+
if mouse_cond is not None:
|
| 251 |
+
raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond
|
| 252 |
+
if keyboard_cond is not None:
|
| 253 |
+
raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond
|
| 254 |
+
if grid_sizes is not None:
|
| 255 |
+
raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes
|
| 256 |
+
|
| 257 |
+
normalized = parse_config(GenerationRequest, raw)
|
| 258 |
+
setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw))
|
| 259 |
+
return normalized
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def request_to_sampling_param(
|
| 263 |
+
request: GenerationRequest,
|
| 264 |
+
*,
|
| 265 |
+
model_path: str,
|
| 266 |
+
) -> SamplingParam:
|
| 267 |
+
if request.plan is not None:
|
| 268 |
+
raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet")
|
| 269 |
+
if request.state is not None:
|
| 270 |
+
raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet")
|
| 271 |
+
|
| 272 |
+
sampling_param = SamplingParam.from_pretrained(model_path)
|
| 273 |
+
updates = _explicit_request_updates(request)
|
| 274 |
+
|
| 275 |
+
for key, value in updates.items():
|
| 276 |
+
if hasattr(sampling_param, key):
|
| 277 |
+
setattr(sampling_param, key, deepcopy(value))
|
| 278 |
+
elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value):
|
| 279 |
+
continue
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}")
|
| 282 |
+
|
| 283 |
+
sampling_param.__post_init__()
|
| 284 |
+
sampling_param.check_sampling_param()
|
| 285 |
+
return sampling_param
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]:
|
| 289 |
+
if not isinstance(request.prompt, list):
|
| 290 |
+
return [request]
|
| 291 |
+
|
| 292 |
+
requests: list[GenerationRequest] = []
|
| 293 |
+
for index, prompt in enumerate(request.prompt):
|
| 294 |
+
single_request = deepcopy(request)
|
| 295 |
+
single_request.prompt = prompt
|
| 296 |
+
_fan_out_batched_input_value(request, single_request, "image_path", index)
|
| 297 |
+
_fan_out_batched_input_value(request, single_request, "video_path", index)
|
| 298 |
+
_fan_out_explicit_request_metadata(request, single_request, index, prompt)
|
| 299 |
+
requests.append(single_request)
|
| 300 |
+
return requests
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool:
|
| 304 |
+
return isinstance(raw.get("generator"), Mapping)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None:
|
| 308 |
+
if not overrides:
|
| 309 |
+
return None
|
| 310 |
+
if isinstance(overrides, list):
|
| 311 |
+
return parse_cli_overrides(overrides)
|
| 312 |
+
return dict(overrides)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]:
|
| 316 |
+
if sampling_param is None:
|
| 317 |
+
return {}
|
| 318 |
+
|
| 319 |
+
raw: dict[str, Any] = {}
|
| 320 |
+
for key, value in shallow_asdict(sampling_param).items():
|
| 321 |
+
if key == "prompt":
|
| 322 |
+
continue
|
| 323 |
+
_apply_request_field(raw, key, deepcopy(value))
|
| 324 |
+
return raw
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def _apply_request_field(
|
| 328 |
+
raw: dict[str, Any],
|
| 329 |
+
key: str,
|
| 330 |
+
value: Any,
|
| 331 |
+
) -> None:
|
| 332 |
+
key = _LEGACY_REQUEST_ALIASES.get(key, key)
|
| 333 |
+
if key == "negative_prompt":
|
| 334 |
+
raw["negative_prompt"] = value
|
| 335 |
+
return
|
| 336 |
+
if key in _INPUT_FIELD_NAMES:
|
| 337 |
+
raw.setdefault("inputs", {})[key] = value
|
| 338 |
+
return
|
| 339 |
+
if key in _SAMPLING_FIELD_NAMES:
|
| 340 |
+
raw.setdefault("sampling", {})[key] = value
|
| 341 |
+
return
|
| 342 |
+
if key in _RUNTIME_FIELD_NAMES:
|
| 343 |
+
raw.setdefault("runtime", {})[key] = value
|
| 344 |
+
return
|
| 345 |
+
if key in _OUTPUT_FIELD_NAMES:
|
| 346 |
+
raw.setdefault("output", {})[key] = value
|
| 347 |
+
return
|
| 348 |
+
raw.setdefault("extensions", {})[key] = value
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]:
|
| 352 |
+
overrides: dict[str, Any] = {}
|
| 353 |
+
for key, value in _explicit_request_updates(request).items():
|
| 354 |
+
if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS:
|
| 355 |
+
overrides[key] = deepcopy(value)
|
| 356 |
+
return overrides
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]:
|
| 360 |
+
raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None)
|
| 361 |
+
if raw is None:
|
| 362 |
+
raw = _serialize_generation_request(request)
|
| 363 |
+
|
| 364 |
+
return _extract_request_updates(raw)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]:
|
| 368 |
+
updates: dict[str, Any] = {}
|
| 369 |
+
if "negative_prompt" in raw:
|
| 370 |
+
updates["negative_prompt"] = deepcopy(raw["negative_prompt"])
|
| 371 |
+
|
| 372 |
+
for section_name in ("inputs", "sampling", "runtime", "output"):
|
| 373 |
+
section = raw.get(section_name)
|
| 374 |
+
if not isinstance(section, Mapping):
|
| 375 |
+
continue
|
| 376 |
+
for key, value in section.items():
|
| 377 |
+
updates[key] = deepcopy(value)
|
| 378 |
+
|
| 379 |
+
stage_overrides = raw.get("stage_overrides")
|
| 380 |
+
if stage_overrides:
|
| 381 |
+
updates.update(_flatten_stage_overrides(stage_overrides))
|
| 382 |
+
|
| 383 |
+
extensions = raw.get("extensions")
|
| 384 |
+
if isinstance(extensions, Mapping):
|
| 385 |
+
for key, value in extensions.items():
|
| 386 |
+
updates[key] = deepcopy(value)
|
| 387 |
+
|
| 388 |
+
return updates
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]:
|
| 392 |
+
if not isinstance(stage_overrides, Mapping):
|
| 393 |
+
raise ValueError("GenerationRequest.stage_overrides must be a mapping")
|
| 394 |
+
|
| 395 |
+
flattened: dict[str, Any] = {}
|
| 396 |
+
for stage_name, overrides in stage_overrides.items():
|
| 397 |
+
if not isinstance(overrides, Mapping):
|
| 398 |
+
raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping")
|
| 399 |
+
for key, value in overrides.items():
|
| 400 |
+
if key in flattened and flattened[key] != value:
|
| 401 |
+
raise ValueError(f"Conflicting stage override for {key!r} across stages")
|
| 402 |
+
flattened[key] = deepcopy(value)
|
| 403 |
+
return flattened
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]:
|
| 407 |
+
return deepcopy(config_to_dict(request))
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _fan_out_batched_input_value(
|
| 411 |
+
source_request: GenerationRequest,
|
| 412 |
+
target_request: GenerationRequest,
|
| 413 |
+
field_name: str,
|
| 414 |
+
index: int,
|
| 415 |
+
) -> None:
|
| 416 |
+
value = getattr(source_request.inputs, field_name)
|
| 417 |
+
if not isinstance(value, list):
|
| 418 |
+
return
|
| 419 |
+
_validate_batched_input_length(source_request.prompt, value, field_name)
|
| 420 |
+
setattr(target_request.inputs, field_name, deepcopy(value[index]))
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _fan_out_explicit_request_metadata(
|
| 424 |
+
source_request: GenerationRequest,
|
| 425 |
+
target_request: GenerationRequest,
|
| 426 |
+
index: int,
|
| 427 |
+
prompt: str,
|
| 428 |
+
) -> None:
|
| 429 |
+
raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None)
|
| 430 |
+
if raw is None:
|
| 431 |
+
return
|
| 432 |
+
|
| 433 |
+
raw = deepcopy(raw)
|
| 434 |
+
raw["prompt"] = prompt
|
| 435 |
+
inputs = raw.get("inputs")
|
| 436 |
+
if isinstance(inputs, dict):
|
| 437 |
+
for field_name in ("image_path", "video_path"):
|
| 438 |
+
value = inputs.get(field_name)
|
| 439 |
+
if isinstance(value, list):
|
| 440 |
+
_validate_batched_input_length(source_request.prompt, value, field_name)
|
| 441 |
+
inputs[field_name] = deepcopy(value[index])
|
| 442 |
+
|
| 443 |
+
setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def _validate_batched_input_length(
|
| 447 |
+
prompts: str | list[str] | None,
|
| 448 |
+
values: list[Any],
|
| 449 |
+
field_name: str,
|
| 450 |
+
) -> None:
|
| 451 |
+
if not isinstance(prompts, list):
|
| 452 |
+
return
|
| 453 |
+
if len(values) != len(prompts):
|
| 454 |
+
raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _is_supported_as_default_only(key: str, value: Any) -> bool:
|
| 458 |
+
default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING)
|
| 459 |
+
return default_value is not _MISSING and _values_equal(value, default_value)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def _collect_non_default_fields(
|
| 463 |
+
value: Any,
|
| 464 |
+
default: Any,
|
| 465 |
+
) -> dict[str, Any]:
|
| 466 |
+
if not (is_dataclass(value) and is_dataclass(default)):
|
| 467 |
+
return {}
|
| 468 |
+
|
| 469 |
+
result: dict[str, Any] = {}
|
| 470 |
+
for field in fields(value):
|
| 471 |
+
current = getattr(value, field.name)
|
| 472 |
+
default_value = getattr(default, field.name)
|
| 473 |
+
if is_dataclass(current) and is_dataclass(default_value):
|
| 474 |
+
nested = _collect_non_default_fields(current, default_value)
|
| 475 |
+
if nested:
|
| 476 |
+
result[field.name] = nested
|
| 477 |
+
continue
|
| 478 |
+
if not _values_equal(current, default_value):
|
| 479 |
+
result[field.name] = deepcopy(current)
|
| 480 |
+
return result
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def _values_equal(left: Any, right: Any) -> bool:
|
| 484 |
+
if left is right:
|
| 485 |
+
return True
|
| 486 |
+
try:
|
| 487 |
+
return bool(left == right)
|
| 488 |
+
except Exception:
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
_DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest()))
|
| 493 |
+
|
| 494 |
+
__all__ = [
|
| 495 |
+
"generator_config_to_fastvideo_args",
|
| 496 |
+
"legacy_from_pretrained_to_config",
|
| 497 |
+
"legacy_generate_call_to_request",
|
| 498 |
+
"load_generator_config_from_file",
|
| 499 |
+
"normalize_generation_request",
|
| 500 |
+
"normalize_generator_config",
|
| 501 |
+
"request_to_pipeline_overrides",
|
| 502 |
+
"request_to_sampling_param",
|
| 503 |
+
]
|
backend_snapshot/fastvideo/configs/pipelines/wan.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
|
| 8 |
+
from fastvideo.configs.models.dits import WanVideoConfig
|
| 9 |
+
from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
|
| 10 |
+
from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config,
|
| 11 |
+
WAN2_1ControlCLIPVisionConfig)
|
| 12 |
+
from fastvideo.configs.models.vaes import WanVAEConfig
|
| 13 |
+
from fastvideo.configs.pipelines.base import PipelineConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def t5_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
|
| 17 |
+
mask: torch.Tensor = outputs.attention_mask
|
| 18 |
+
hidden_state: torch.Tensor = outputs.last_hidden_state
|
| 19 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 20 |
+
assert torch.isnan(hidden_state).sum() == 0
|
| 21 |
+
prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
|
| 22 |
+
prompt_embeds_tensor: torch.Tensor = torch.stack(
|
| 23 |
+
[torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0)
|
| 24 |
+
return prompt_embeds_tensor
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class WanT2V480PConfig(PipelineConfig):
|
| 29 |
+
"""Base configuration for Wan T2V 1.3B pipeline architecture."""
|
| 30 |
+
|
| 31 |
+
# WanConfig-specific parameters with defaults
|
| 32 |
+
# DiT
|
| 33 |
+
dit_config: DiTConfig = field(default_factory=WanVideoConfig)
|
| 34 |
+
# VAE
|
| 35 |
+
vae_config: VAEConfig = field(default_factory=WanVAEConfig)
|
| 36 |
+
vae_tiling: bool = False
|
| 37 |
+
vae_sp: bool = False
|
| 38 |
+
|
| 39 |
+
# Denoising stage
|
| 40 |
+
flow_shift: float | None = 3.0
|
| 41 |
+
|
| 42 |
+
# Text encoding stage
|
| 43 |
+
text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (T5Config(), ))
|
| 44 |
+
postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
|
| 45 |
+
...] = field(default_factory=lambda: (t5_postprocess_text, ))
|
| 46 |
+
|
| 47 |
+
# Precision for each component
|
| 48 |
+
precision: str = "bf16"
|
| 49 |
+
vae_precision: str = "fp32"
|
| 50 |
+
text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", ))
|
| 51 |
+
|
| 52 |
+
# self-forcing params
|
| 53 |
+
warp_denoising_step: bool = True
|
| 54 |
+
|
| 55 |
+
# WanConfig-specific added parameters
|
| 56 |
+
|
| 57 |
+
def __post_init__(self):
|
| 58 |
+
self.vae_config.load_encoder = False
|
| 59 |
+
self.vae_config.load_decoder = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class WanT2V720PConfig(WanT2V480PConfig):
|
| 64 |
+
"""Base configuration for Wan T2V 14B 720P pipeline architecture."""
|
| 65 |
+
|
| 66 |
+
# WanConfig-specific parameters with defaults
|
| 67 |
+
|
| 68 |
+
# Denoising stage
|
| 69 |
+
flow_shift: float | None = 5.0
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class WanI2V480PConfig(WanT2V480PConfig):
|
| 74 |
+
"""Base configuration for Wan I2V 14B 480P pipeline architecture."""
|
| 75 |
+
|
| 76 |
+
# WanConfig-specific parameters with defaults
|
| 77 |
+
|
| 78 |
+
# Precision for each component
|
| 79 |
+
image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)
|
| 80 |
+
image_encoder_precision: str = "fp32"
|
| 81 |
+
|
| 82 |
+
def __post_init__(self) -> None:
|
| 83 |
+
self.vae_config.load_encoder = True
|
| 84 |
+
self.vae_config.load_decoder = True
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class WanI2V720PConfig(WanI2V480PConfig):
|
| 89 |
+
"""Base configuration for Wan I2V 14B 720P pipeline architecture."""
|
| 90 |
+
|
| 91 |
+
# WanConfig-specific parameters with defaults
|
| 92 |
+
|
| 93 |
+
# Denoising stage
|
| 94 |
+
flow_shift: float | None = 5.0
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@dataclass
|
| 98 |
+
class WANV2VConfig(WanI2V480PConfig):
|
| 99 |
+
"""Configuration for WAN2.1 1.3B Control pipeline."""
|
| 100 |
+
|
| 101 |
+
image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
|
| 102 |
+
# CLIP encoder precision
|
| 103 |
+
image_encoder_precision: str = 'bf16'
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
|
| 108 |
+
"""Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
|
| 109 |
+
|
| 110 |
+
# WanConfig-specific parameters with defaults
|
| 111 |
+
|
| 112 |
+
# Denoising stage
|
| 113 |
+
flow_shift: float | None = 8.0
|
| 114 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@dataclass
|
| 118 |
+
class Wan2_2_TI2V_5B_Config(WanT2V480PConfig):
|
| 119 |
+
flow_shift: float | None = 5.0
|
| 120 |
+
ti2v_task: bool = True
|
| 121 |
+
expand_timesteps: bool = True
|
| 122 |
+
|
| 123 |
+
def __post_init__(self) -> None:
|
| 124 |
+
self.vae_config.load_encoder = True
|
| 125 |
+
self.vae_config.load_decoder = True
|
| 126 |
+
self.dit_config.expand_timesteps = self.expand_timesteps
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@dataclass
|
| 130 |
+
class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):
|
| 131 |
+
flow_shift: float | None = 5.0
|
| 132 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
|
| 137 |
+
flow_shift: float | None = 12.0
|
| 138 |
+
boundary_ratio: float | None = 0.875
|
| 139 |
+
|
| 140 |
+
# self-forcing params
|
| 141 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
|
| 142 |
+
warp_denoising_step: bool = True
|
| 143 |
+
|
| 144 |
+
def __post_init__(self) -> None:
|
| 145 |
+
self.dit_config.boundary_ratio = self.boundary_ratio
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@dataclass
|
| 149 |
+
class Wan2_2_I2V_A14B_Config(WanI2V480PConfig):
|
| 150 |
+
flow_shift: float | None = 5.0
|
| 151 |
+
boundary_ratio: float | None = 0.900
|
| 152 |
+
|
| 153 |
+
def __post_init__(self) -> None:
|
| 154 |
+
super().__post_init__()
|
| 155 |
+
self.dit_config.boundary_ratio = self.boundary_ratio
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# =============================================
|
| 159 |
+
# ============= Causal Self-Forcing =============
|
| 160 |
+
# =============================================
|
| 161 |
+
@dataclass
|
| 162 |
+
class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
|
| 163 |
+
is_causal: bool = True
|
| 164 |
+
flow_shift: float | None = 5.0
|
| 165 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
|
| 166 |
+
warp_denoising_step: bool = True
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@dataclass
|
| 170 |
+
class SelfForcingWan2_2_T2V480PConfig(Wan2_2_T2V_A14B_Config):
|
| 171 |
+
is_causal: bool = True
|
| 172 |
+
flow_shift: float | None = 12.0
|
| 173 |
+
boundary_ratio: float | None = 0.875
|
| 174 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 850, 700, 550, 350, 275, 200, 125])
|
| 175 |
+
warp_denoising_step: bool = True
|
| 176 |
+
|
| 177 |
+
def __post_init__(self) -> None:
|
| 178 |
+
self.vae_config.load_encoder = True
|
| 179 |
+
self.vae_config.load_decoder = True
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# =============================================
|
| 183 |
+
# ============= Matrix Game ===================
|
| 184 |
+
# =============================================
|
| 185 |
+
@dataclass
|
| 186 |
+
class MatrixGameBaseI2V480PConfig(WanI2V480PConfig):
|
| 187 |
+
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
|
| 188 |
+
flow_shift: float | None = 5.0
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@dataclass
|
| 192 |
+
class MatrixGameI2V480PConfig(WanI2V480PConfig):
|
| 193 |
+
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
|
| 194 |
+
|
| 195 |
+
image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
|
| 196 |
+
|
| 197 |
+
is_causal: bool = True
|
| 198 |
+
flow_shift: float | None = 5.0
|
| 199 |
+
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
|
| 200 |
+
warp_denoising_step: bool = True
|
| 201 |
+
context_noise: int = 0
|
| 202 |
+
num_frames_per_block: int = 3
|
| 203 |
+
# sliding_window_num_frames: int = 15
|
backend_snapshot/fastvideo/configs/sample/base.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastvideo.logger import init_logger
|
| 6 |
+
from fastvideo.utils import StoreBoolean
|
| 7 |
+
|
| 8 |
+
logger = init_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class SamplingParam:
|
| 13 |
+
"""
|
| 14 |
+
Sampling parameters for video generation.
|
| 15 |
+
"""
|
| 16 |
+
# All fields below are copied from ForwardBatch
|
| 17 |
+
data_type: str = "video"
|
| 18 |
+
|
| 19 |
+
# Image inputs
|
| 20 |
+
image_path: str | None = None
|
| 21 |
+
pil_image: Any | None = None
|
| 22 |
+
|
| 23 |
+
# Video inputs
|
| 24 |
+
video_path: str | None = None
|
| 25 |
+
|
| 26 |
+
# Action control inputs (Matrix-Game)
|
| 27 |
+
mouse_cond: Any | None = None # Shape: (B, T, 2)
|
| 28 |
+
keyboard_cond: Any | None = None # Shape: (B, T, K)
|
| 29 |
+
grid_sizes: Any | None = None # Shape: (3,) [F,H,W]
|
| 30 |
+
|
| 31 |
+
# Camera control inputs (HYWorld)
|
| 32 |
+
pose: str | None = None # Camera trajectory: pose string (e.g., 'w-31') or JSON file path
|
| 33 |
+
|
| 34 |
+
# Camera control inputs (LingBotWorld)
|
| 35 |
+
c2ws_plucker_emb: Any | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat]
|
| 36 |
+
|
| 37 |
+
# Refine inputs (LongCat 480p->720p upscaling)
|
| 38 |
+
# Path-based refine (load stage1 video from disk, e.g. MP4)
|
| 39 |
+
refine_from: str | None = None # Path to stage1 video (480p output from distill)
|
| 40 |
+
t_thresh: float = 0.5 # Threshold for timestep scheduling in refinement
|
| 41 |
+
spatial_refine_only: bool = False # If True, only spatial (no temporal doubling)
|
| 42 |
+
num_cond_frames: int = 0 # Number of conditioning frames
|
| 43 |
+
# In-memory refine input (for two-stage pipeline where stage1 frames are already in memory)
|
| 44 |
+
# This mirrors LongCat's demo where a list of frames (e.g. np.ndarray or PIL.Image)
|
| 45 |
+
# is passed directly to the refinement pipeline instead of reloading from disk.
|
| 46 |
+
stage1_video: Any | None = None
|
| 47 |
+
|
| 48 |
+
# Text inputs
|
| 49 |
+
prompt: str | list[str] | None = None
|
| 50 |
+
negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 51 |
+
prompt_path: str | None = None
|
| 52 |
+
output_path: str = "outputs/"
|
| 53 |
+
output_video_name: str | None = None
|
| 54 |
+
|
| 55 |
+
# Batch info
|
| 56 |
+
num_videos_per_prompt: int = 1
|
| 57 |
+
seed: int = 1024
|
| 58 |
+
|
| 59 |
+
# Original dimensions (before VAE scaling)
|
| 60 |
+
num_frames: int = 125
|
| 61 |
+
height: int = 720
|
| 62 |
+
width: int = 1280
|
| 63 |
+
height_sr: int = 1072
|
| 64 |
+
width_sr: int = 1920
|
| 65 |
+
fps: int = 24
|
| 66 |
+
|
| 67 |
+
# Denoising parameters
|
| 68 |
+
num_inference_steps: int = 50
|
| 69 |
+
num_inference_steps_sr: int = 50
|
| 70 |
+
guidance_scale: float = 1.0
|
| 71 |
+
guidance_scale_2: float | None = None
|
| 72 |
+
guidance_rescale: float = 0.0
|
| 73 |
+
boundary_ratio: float | None = None
|
| 74 |
+
sigmas: list[float] | None = None
|
| 75 |
+
|
| 76 |
+
# TeaCache parameters
|
| 77 |
+
enable_teacache: bool = False
|
| 78 |
+
|
| 79 |
+
# GEN3C camera control
|
| 80 |
+
trajectory_type: str | None = None
|
| 81 |
+
movement_distance: float | None = None
|
| 82 |
+
camera_rotation: str | None = None
|
| 83 |
+
|
| 84 |
+
# Misc
|
| 85 |
+
save_video: bool = True
|
| 86 |
+
return_frames: bool = True
|
| 87 |
+
return_trajectory_latents: bool = False # returns all latents for each timestep
|
| 88 |
+
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
|
| 89 |
+
|
| 90 |
+
def __post_init__(self) -> None:
|
| 91 |
+
self.data_type = "video" if self.num_frames > 1 else "image"
|
| 92 |
+
|
| 93 |
+
def __getattr__(self, name: str) -> Any:
|
| 94 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 95 |
+
|
| 96 |
+
def check_sampling_param(self) -> None:
|
| 97 |
+
if self.prompt_path and not self.prompt_path.endswith(".txt"):
|
| 98 |
+
raise ValueError("prompt_path must be a txt file")
|
| 99 |
+
|
| 100 |
+
def update(self, source_dict: dict[str, Any]) -> None:
|
| 101 |
+
for key, value in source_dict.items():
|
| 102 |
+
if hasattr(self, key):
|
| 103 |
+
setattr(self, key, value)
|
| 104 |
+
else:
|
| 105 |
+
logger.exception("%s has no attribute %s", type(self).__name__, key)
|
| 106 |
+
|
| 107 |
+
self.__post_init__()
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_pretrained(cls, model_path: str) -> "SamplingParam":
|
| 111 |
+
from fastvideo.registry import get_sampling_param_cls_for_name
|
| 112 |
+
sampling_cls = get_sampling_param_cls_for_name(model_path)
|
| 113 |
+
if sampling_cls is not None:
|
| 114 |
+
sampling_param: SamplingParam = sampling_cls()
|
| 115 |
+
else:
|
| 116 |
+
logger.warning("Couldn't find an optimal sampling param for %s. Using the default sampling param.",
|
| 117 |
+
model_path)
|
| 118 |
+
sampling_param = cls()
|
| 119 |
+
|
| 120 |
+
return sampling_param
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def add_cli_args(parser: Any) -> Any:
|
| 124 |
+
"""Add CLI arguments for SamplingParam fields"""
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--prompt",
|
| 127 |
+
type=str,
|
| 128 |
+
default=SamplingParam.prompt,
|
| 129 |
+
help="Text prompt for video generation",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--negative-prompt",
|
| 133 |
+
type=str,
|
| 134 |
+
default=SamplingParam.negative_prompt,
|
| 135 |
+
help="Negative text prompt for video generation",
|
| 136 |
+
)
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--prompt-path",
|
| 139 |
+
type=str,
|
| 140 |
+
default=SamplingParam.prompt_path,
|
| 141 |
+
help="Path to a text file containing the prompt",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--output-path",
|
| 145 |
+
type=str,
|
| 146 |
+
default=SamplingParam.output_path,
|
| 147 |
+
help="Path to save the generated video",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--output-video-name",
|
| 151 |
+
type=str,
|
| 152 |
+
default=SamplingParam.output_video_name,
|
| 153 |
+
help="Name of the output video",
|
| 154 |
+
)
|
| 155 |
+
parser.add_argument(
|
| 156 |
+
"--num-videos-per-prompt",
|
| 157 |
+
type=int,
|
| 158 |
+
default=SamplingParam.num_videos_per_prompt,
|
| 159 |
+
help="Number of videos to generate per prompt",
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--seed",
|
| 163 |
+
type=int,
|
| 164 |
+
default=SamplingParam.seed,
|
| 165 |
+
help="Random seed for generation",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--num-frames",
|
| 169 |
+
type=int,
|
| 170 |
+
default=SamplingParam.num_frames,
|
| 171 |
+
help="Number of frames to generate",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--height",
|
| 175 |
+
type=int,
|
| 176 |
+
default=SamplingParam.height,
|
| 177 |
+
help="Height of generated video",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--width",
|
| 181 |
+
type=int,
|
| 182 |
+
default=SamplingParam.width,
|
| 183 |
+
help="Width of generated video",
|
| 184 |
+
)
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--fps",
|
| 187 |
+
type=int,
|
| 188 |
+
default=SamplingParam.fps,
|
| 189 |
+
help="Frames per second for saved video",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--num-inference-steps",
|
| 193 |
+
type=int,
|
| 194 |
+
default=SamplingParam.num_inference_steps,
|
| 195 |
+
help="Number of denoising steps",
|
| 196 |
+
)
|
| 197 |
+
parser.add_argument(
|
| 198 |
+
"--guidance-scale",
|
| 199 |
+
type=float,
|
| 200 |
+
default=SamplingParam.guidance_scale,
|
| 201 |
+
help="Classifier-free guidance scale",
|
| 202 |
+
)
|
| 203 |
+
parser.add_argument(
|
| 204 |
+
"--guidance-rescale",
|
| 205 |
+
type=float,
|
| 206 |
+
default=SamplingParam.guidance_rescale,
|
| 207 |
+
help="Guidance rescale factor",
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--boundary-ratio",
|
| 211 |
+
type=float,
|
| 212 |
+
default=SamplingParam.boundary_ratio,
|
| 213 |
+
help="Boundary timestep ratio",
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--save-video",
|
| 217 |
+
action="store_true",
|
| 218 |
+
default=SamplingParam.save_video,
|
| 219 |
+
help="Whether to save the video to disk",
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument(
|
| 222 |
+
"--no-save-video",
|
| 223 |
+
action="store_false",
|
| 224 |
+
dest="save_video",
|
| 225 |
+
help="Don't save the video to disk",
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--return-frames",
|
| 229 |
+
action="store_true",
|
| 230 |
+
default=False,
|
| 231 |
+
help="Whether to return the raw frames",
|
| 232 |
+
)
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--image-path",
|
| 235 |
+
type=str,
|
| 236 |
+
default=SamplingParam.image_path,
|
| 237 |
+
help="Path to input image for image-to-video generation",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--video-path",
|
| 241 |
+
type=str,
|
| 242 |
+
default=SamplingParam.video_path,
|
| 243 |
+
help="Path to input video for video-to-video generation",
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--refine-from",
|
| 247 |
+
type=str,
|
| 248 |
+
default=SamplingParam.refine_from,
|
| 249 |
+
help="Path to stage1 video for refinement (LongCat 480p->720p)",
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--t-thresh",
|
| 253 |
+
type=float,
|
| 254 |
+
default=SamplingParam.t_thresh,
|
| 255 |
+
help="Threshold for timestep scheduling in refinement (default: 0.5)",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--spatial-refine-only",
|
| 259 |
+
action=StoreBoolean,
|
| 260 |
+
default=SamplingParam.spatial_refine_only,
|
| 261 |
+
help="Only perform spatial super-resolution (no temporal doubling)",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--num-cond-frames",
|
| 265 |
+
type=int,
|
| 266 |
+
default=SamplingParam.num_cond_frames,
|
| 267 |
+
help="Number of conditioning frames for refinement",
|
| 268 |
+
)
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--moba-config-path",
|
| 271 |
+
type=str,
|
| 272 |
+
default=None,
|
| 273 |
+
help="Path to a JSON file containing V-MoBA specific configurations.",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--return-trajectory-latents",
|
| 277 |
+
action="store_true",
|
| 278 |
+
default=SamplingParam.return_trajectory_latents,
|
| 279 |
+
help="Whether to return the trajectory",
|
| 280 |
+
)
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--return-trajectory-decoded",
|
| 283 |
+
action="store_true",
|
| 284 |
+
default=SamplingParam.return_trajectory_decoded,
|
| 285 |
+
help="Whether to return the decoded trajectory",
|
| 286 |
+
)
|
| 287 |
+
return parser
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
@dataclass
|
| 291 |
+
class CacheParams:
|
| 292 |
+
cache_type: str = "none"
|
backend_snapshot/fastvideo/configs/sample/wan.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from fastvideo.configs.sample.base import SamplingParam
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class WanT2V_1_3B_SamplingParam(SamplingParam):
|
| 9 |
+
# Video parameters
|
| 10 |
+
height: int = 480
|
| 11 |
+
width: int = 832
|
| 12 |
+
num_frames: int = 81
|
| 13 |
+
fps: int = 16
|
| 14 |
+
|
| 15 |
+
# Denoising stage
|
| 16 |
+
guidance_scale: float = 3.0
|
| 17 |
+
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 18 |
+
num_inference_steps: int = 50
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class WanT2V_14B_SamplingParam(SamplingParam):
|
| 23 |
+
# Video parameters
|
| 24 |
+
height: int = 720
|
| 25 |
+
width: int = 1280
|
| 26 |
+
num_frames: int = 81
|
| 27 |
+
fps: int = 16
|
| 28 |
+
|
| 29 |
+
# Denoising stage
|
| 30 |
+
guidance_scale: float = 5.0
|
| 31 |
+
negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 32 |
+
num_inference_steps: int = 50
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParam):
|
| 37 |
+
# Denoising stage
|
| 38 |
+
guidance_scale: float = 5.0
|
| 39 |
+
num_inference_steps: int = 40
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParam):
|
| 44 |
+
# Denoising stage
|
| 45 |
+
guidance_scale: float = 5.0
|
| 46 |
+
num_inference_steps: int = 40
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam):
|
| 51 |
+
# DMD parameters
|
| 52 |
+
# dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
|
| 53 |
+
num_inference_steps: int = 3
|
| 54 |
+
num_frames: int = 61
|
| 55 |
+
height: int = 448
|
| 56 |
+
width: int = 832
|
| 57 |
+
fps: int = 16
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# =============================================
|
| 61 |
+
# ============= Wan2.1 Fun Models =============
|
| 62 |
+
# =============================================
|
| 63 |
+
@dataclass
|
| 64 |
+
class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
|
| 65 |
+
"""Sampling parameters for Wan2.1 Fun 1.3B InP model."""
|
| 66 |
+
height: int = 480
|
| 67 |
+
width: int = 832
|
| 68 |
+
num_frames: int = 81
|
| 69 |
+
fps: int = 16
|
| 70 |
+
negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 71 |
+
guidance_scale: float = 6.0
|
| 72 |
+
num_inference_steps: int = 50
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam):
|
| 77 |
+
fps: int = 16
|
| 78 |
+
num_frames: int = 49
|
| 79 |
+
height: int = 832
|
| 80 |
+
width: int = 480
|
| 81 |
+
guidance_scale: float = 6.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# =============================================
|
| 85 |
+
# ============= Wan2.2 TI2V Models =============
|
| 86 |
+
# =============================================
|
| 87 |
+
@dataclass
|
| 88 |
+
class Wan2_2_Base_SamplingParam(SamplingParam):
|
| 89 |
+
"""Sampling parameters for Wan2.2 TI2V 5B model."""
|
| 90 |
+
negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 95 |
+
"""Sampling parameters for Wan2.2 TI2V 5B model."""
|
| 96 |
+
height: int = 704
|
| 97 |
+
width: int = 1280
|
| 98 |
+
num_frames: int = 121
|
| 99 |
+
fps: int = 24
|
| 100 |
+
guidance_scale: float = 5.0
|
| 101 |
+
num_inference_steps: int = 50
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@dataclass
|
| 105 |
+
class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 106 |
+
guidance_scale: float = 4.0 # high_noise
|
| 107 |
+
guidance_scale_2: float = 3.0 # low_noise
|
| 108 |
+
num_inference_steps: int = 40
|
| 109 |
+
fps: int = 16
|
| 110 |
+
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
|
| 111 |
+
# can be overridden during sampling
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
|
| 116 |
+
guidance_scale: float = 3.5 # high_noise
|
| 117 |
+
guidance_scale_2: float = 3.5 # low_noise
|
| 118 |
+
num_inference_steps: int = 40
|
| 119 |
+
fps: int = 16
|
| 120 |
+
# NOTE(will): default boundary timestep is tracked by PipelineConfig, but
|
| 121 |
+
# can be overridden during sampling
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class Wan2_2_Fun_A14B_Control_SamplingParam(Wan2_1_Fun_1_3B_Control_SamplingParam):
|
| 126 |
+
num_frames: int = 81
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# =============================================
|
| 130 |
+
# ============= Causal Self-Forcing =============
|
| 131 |
+
# =============================================
|
| 132 |
+
@dataclass
|
| 133 |
+
class SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam(Wan2_1_Fun_1_3B_InP_SamplingParam):
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@dataclass
|
| 138 |
+
class SelfForcingWan2_2_T2V_A14B_480P_SamplingParam(Wan2_2_T2V_A14B_SamplingParam):
|
| 139 |
+
num_inference_steps: int = 8
|
| 140 |
+
num_frames: int = 81
|
| 141 |
+
height: int = 448
|
| 142 |
+
width: int = 832
|
| 143 |
+
fps: int = 16
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@dataclass
|
| 147 |
+
class MatrixGame2_SamplingParam(SamplingParam):
|
| 148 |
+
height: int = 352
|
| 149 |
+
width: int = 640
|
| 150 |
+
num_frames: int = 57
|
| 151 |
+
fps: int = 25
|
| 152 |
+
guidance_scale: float = 1.0
|
| 153 |
+
num_inference_steps: int = 3
|
| 154 |
+
negative_prompt: str | None = None
|
backend_snapshot/fastvideo/configs/wan_1.3B_t2v_pipeline.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"embedded_cfg_scale": 6.0,
|
| 3 |
+
"flow_shift": 3,
|
| 4 |
+
"dit_cpu_offload": true,
|
| 5 |
+
"disable_autocast": false,
|
| 6 |
+
"precision": "bf16",
|
| 7 |
+
"vae_precision": "fp32",
|
| 8 |
+
"vae_tiling": false,
|
| 9 |
+
"vae_sp": false,
|
| 10 |
+
"vae_config": {
|
| 11 |
+
"load_encoder": false,
|
| 12 |
+
"load_decoder": true,
|
| 13 |
+
"tile_sample_min_height": 256,
|
| 14 |
+
"tile_sample_min_width": 256,
|
| 15 |
+
"tile_sample_min_num_frames": 16,
|
| 16 |
+
"tile_sample_stride_height": 192,
|
| 17 |
+
"tile_sample_stride_width": 192,
|
| 18 |
+
"tile_sample_stride_num_frames": 12,
|
| 19 |
+
"blend_num_frames": 8,
|
| 20 |
+
"use_tiling": false,
|
| 21 |
+
"use_temporal_tiling": false,
|
| 22 |
+
"use_parallel_tiling": false,
|
| 23 |
+
"use_feature_cache": true
|
| 24 |
+
},
|
| 25 |
+
"dit_config": {
|
| 26 |
+
"prefix": "Wan",
|
| 27 |
+
"quant_config": null
|
| 28 |
+
},
|
| 29 |
+
"text_encoder_precisions": [
|
| 30 |
+
"fp32"
|
| 31 |
+
],
|
| 32 |
+
"text_encoder_configs": [
|
| 33 |
+
{
|
| 34 |
+
"prefix": "t5",
|
| 35 |
+
"quant_config": null,
|
| 36 |
+
"lora_config": null
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"enable_torch_compile": false
|
| 40 |
+
}
|
backend_snapshot/fastvideo/entrypoints/cli/generate.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import dataclasses
|
| 6 |
+
import os
|
| 7 |
+
from typing import cast
|
| 8 |
+
|
| 9 |
+
from fastvideo import VideoGenerator
|
| 10 |
+
from fastvideo.configs.sample.base import SamplingParam
|
| 11 |
+
from fastvideo.entrypoints.cli.cli_types import CLISubcommand
|
| 12 |
+
from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
|
| 13 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 14 |
+
from fastvideo.logger import init_logger
|
| 15 |
+
from fastvideo.utils import FlexibleArgumentParser
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GenerateSubcommand(CLISubcommand):
|
| 21 |
+
"""The `generate` subcommand for the FastVideo CLI"""
|
| 22 |
+
|
| 23 |
+
def __init__(self) -> None:
|
| 24 |
+
self.name = "generate"
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.init_arg_names = self._get_init_arg_names()
|
| 27 |
+
self.generation_arg_names = self._get_generation_arg_names()
|
| 28 |
+
|
| 29 |
+
def _get_init_arg_names(self) -> list[str]:
|
| 30 |
+
"""Get names of arguments for VideoGenerator initialization"""
|
| 31 |
+
return ["num_gpus", "tp_size", "sp_size", "model_path"]
|
| 32 |
+
|
| 33 |
+
def _get_generation_arg_names(self) -> list[str]:
|
| 34 |
+
"""Get names of arguments for generate_video method"""
|
| 35 |
+
return [field.name for field in dataclasses.fields(SamplingParam)]
|
| 36 |
+
|
| 37 |
+
def cmd(self, args: argparse.Namespace) -> None:
|
| 38 |
+
excluded_args = ['subparser', 'config', 'dispatch_function']
|
| 39 |
+
|
| 40 |
+
provided_args = {}
|
| 41 |
+
for k, v in vars(args).items():
|
| 42 |
+
if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
|
| 43 |
+
provided_args[k] = v
|
| 44 |
+
|
| 45 |
+
if 'model_path' in vars(args) and args.model_path is not None:
|
| 46 |
+
provided_args['model_path'] = args.model_path
|
| 47 |
+
|
| 48 |
+
if 'prompt' in vars(args) and args.prompt is not None:
|
| 49 |
+
provided_args['prompt'] = args.prompt
|
| 50 |
+
|
| 51 |
+
merged_args = {**provided_args}
|
| 52 |
+
|
| 53 |
+
logger.info('CLI Args: %s', merged_args)
|
| 54 |
+
|
| 55 |
+
if 'model_path' not in merged_args or not merged_args['model_path']:
|
| 56 |
+
raise ValueError("model_path must be provided either in config file or via --model-path")
|
| 57 |
+
|
| 58 |
+
# Check if either prompt or prompt_txt is provided
|
| 59 |
+
has_prompt = 'prompt' in merged_args and merged_args['prompt']
|
| 60 |
+
has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']
|
| 61 |
+
|
| 62 |
+
if not (has_prompt or has_prompt_txt):
|
| 63 |
+
raise ValueError("Either prompt or prompt_txt must be provided")
|
| 64 |
+
|
| 65 |
+
if has_prompt and has_prompt_txt:
|
| 66 |
+
raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")
|
| 67 |
+
|
| 68 |
+
init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
|
| 69 |
+
generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
|
| 70 |
+
generation_args.setdefault("return_frames", False)
|
| 71 |
+
|
| 72 |
+
model_path = init_args.pop('model_path')
|
| 73 |
+
prompt = generation_args.pop('prompt', None)
|
| 74 |
+
|
| 75 |
+
generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)
|
| 76 |
+
|
| 77 |
+
# Call generate_video - it handles both single and batch modes
|
| 78 |
+
generator.generate_video(prompt=prompt, **generation_args)
|
| 79 |
+
|
| 80 |
+
def validate(self, args: argparse.Namespace) -> None:
|
| 81 |
+
"""Validate the arguments for this command"""
|
| 82 |
+
if args.num_gpus is not None and args.num_gpus <= 0:
|
| 83 |
+
raise ValueError("Number of gpus must be positive")
|
| 84 |
+
|
| 85 |
+
if args.config and not os.path.exists(args.config):
|
| 86 |
+
raise ValueError(f"Config file not found: {args.config}")
|
| 87 |
+
|
| 88 |
+
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
| 89 |
+
generate_parser = subparsers.add_parser(
|
| 90 |
+
"generate",
|
| 91 |
+
help="Run inference on a model",
|
| 92 |
+
usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")
|
| 93 |
+
|
| 94 |
+
generate_parser.add_argument(
|
| 95 |
+
"--config",
|
| 96 |
+
type=str,
|
| 97 |
+
default='',
|
| 98 |
+
required=False,
|
| 99 |
+
help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
generate_parser = FastVideoArgs.add_cli_args(generate_parser)
|
| 103 |
+
generate_parser = SamplingParam.add_cli_args(generate_parser)
|
| 104 |
+
|
| 105 |
+
generate_parser.add_argument(
|
| 106 |
+
"--text-encoder-configs",
|
| 107 |
+
action=RaiseNotImplementedAction,
|
| 108 |
+
help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return cast(FlexibleArgumentParser, generate_parser)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def cmd_init() -> list[CLISubcommand]:
|
| 115 |
+
return [GenerateSubcommand()]
|
backend_snapshot/fastvideo/entrypoints/video_generator.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
VideoGenerator module for FastVideo.
|
| 4 |
+
|
| 5 |
+
This module provides a consolidated interface for generating videos using
|
| 6 |
+
diffusion models.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import shutil
|
| 12 |
+
import threading
|
| 13 |
+
import time
|
| 14 |
+
import tempfile
|
| 15 |
+
import warnings
|
| 16 |
+
from collections.abc import Mapping
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import imageio
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torchvision
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
|
| 26 |
+
from fastvideo.api.compat import (
|
| 27 |
+
expand_request_prompt_batch,
|
| 28 |
+
generator_config_to_fastvideo_args,
|
| 29 |
+
legacy_from_pretrained_to_config,
|
| 30 |
+
load_generator_config_from_file,
|
| 31 |
+
normalize_generation_request,
|
| 32 |
+
normalize_generator_config,
|
| 33 |
+
request_to_pipeline_overrides,
|
| 34 |
+
request_to_sampling_param,
|
| 35 |
+
)
|
| 36 |
+
from fastvideo.api.results import GenerationResult
|
| 37 |
+
from fastvideo.api.schema import GenerationRequest, GeneratorConfig
|
| 38 |
+
from fastvideo.configs.sample import SamplingParam
|
| 39 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 40 |
+
from fastvideo.logger import init_logger
|
| 41 |
+
from fastvideo.pipelines import ForwardBatch
|
| 42 |
+
from fastvideo.utils import align_to, shallow_asdict
|
| 43 |
+
from fastvideo.worker.executor import Executor
|
| 44 |
+
|
| 45 |
+
logger = init_logger(__name__)
|
| 46 |
+
|
| 47 |
+
_FROM_PRETRAINED_CONVENIENCE_KWARGS = frozenset({
|
| 48 |
+
"num_gpus",
|
| 49 |
+
"revision",
|
| 50 |
+
"trust_remote_code",
|
| 51 |
+
"distributed_executor_backend",
|
| 52 |
+
"tp_size",
|
| 53 |
+
"sp_size",
|
| 54 |
+
"hsdp_replicate_dim",
|
| 55 |
+
"hsdp_shard_dim",
|
| 56 |
+
"dist_timeout",
|
| 57 |
+
"use_fsdp_inference",
|
| 58 |
+
"disable_autocast",
|
| 59 |
+
"enable_stage_verification",
|
| 60 |
+
"dit_cpu_offload",
|
| 61 |
+
"dit_layerwise_offload",
|
| 62 |
+
"text_encoder_cpu_offload",
|
| 63 |
+
"image_encoder_cpu_offload",
|
| 64 |
+
"vae_cpu_offload",
|
| 65 |
+
"pin_cpu_memory",
|
| 66 |
+
"enable_torch_compile",
|
| 67 |
+
"torch_compile_kwargs",
|
| 68 |
+
"transformer_quant",
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _infer_latent_batch_size(batch: ForwardBatch) -> int:
|
| 73 |
+
if isinstance(batch.prompt, list):
|
| 74 |
+
latent_batch_size = len(batch.prompt)
|
| 75 |
+
elif batch.prompt is not None:
|
| 76 |
+
latent_batch_size = 1
|
| 77 |
+
elif batch.prompt_embeds is not None and len(batch.prompt_embeds) > 0:
|
| 78 |
+
latent_batch_size = batch.prompt_embeds[0].shape[0]
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError("Cannot infer batch size from batch; no prompt or prompt_embeds found")
|
| 81 |
+
latent_batch_size *= batch.num_videos_per_prompt
|
| 82 |
+
return latent_batch_size
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class VideoGenerator:
|
| 86 |
+
"""
|
| 87 |
+
A unified class for generating videos using diffusion models.
|
| 88 |
+
|
| 89 |
+
This class provides a simple interface for video generation with rich
|
| 90 |
+
customization options, similar to popular frameworks like HF Diffusers.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
fastvideo_args: FastVideoArgs,
|
| 96 |
+
executor_class: type[Executor],
|
| 97 |
+
log_stats: bool,
|
| 98 |
+
*,
|
| 99 |
+
log_queue=None,
|
| 100 |
+
):
|
| 101 |
+
"""
|
| 102 |
+
Initialize the video generator.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
fastvideo_args: The inference arguments
|
| 106 |
+
executor_class: The executor class to use for inference
|
| 107 |
+
log_stats: Whether to log statistics
|
| 108 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 109 |
+
"""
|
| 110 |
+
self.config: GeneratorConfig | None = None
|
| 111 |
+
self.fastvideo_args = fastvideo_args
|
| 112 |
+
self.executor = executor_class(fastvideo_args, log_queue=log_queue)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def from_pretrained(
|
| 116 |
+
cls,
|
| 117 |
+
model_path: str | GeneratorConfig | Mapping[str, Any] | None = None,
|
| 118 |
+
**kwargs,
|
| 119 |
+
) -> "VideoGenerator":
|
| 120 |
+
"""
|
| 121 |
+
Create a video generator from a pretrained model.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
model_path: Path or identifier for the pretrained model
|
| 125 |
+
pipeline_config: Pipeline config to use for inference
|
| 126 |
+
**kwargs: Additional arguments to customize model loading, set any FastVideoArgs or PipelineConfig attributes here.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
The created video generator
|
| 130 |
+
|
| 131 |
+
Priority level: Default pipeline config < User's pipeline config < User's kwargs
|
| 132 |
+
|
| 133 |
+
Stable convenience kwargs remain supported here for common engine and
|
| 134 |
+
offload settings. Advanced model- or pipeline-specific options should
|
| 135 |
+
move to VideoGenerator.from_config(...).
|
| 136 |
+
"""
|
| 137 |
+
log_queue = kwargs.pop("log_queue", None)
|
| 138 |
+
typed_config = kwargs.pop("config", None)
|
| 139 |
+
if typed_config is not None:
|
| 140 |
+
if model_path is not None:
|
| 141 |
+
raise TypeError("Pass either model_path or config to from_pretrained, not both")
|
| 142 |
+
if kwargs:
|
| 143 |
+
unexpected = ", ".join(sorted(kwargs))
|
| 144 |
+
raise TypeError(f"Unexpected keyword arguments with config: {unexpected}")
|
| 145 |
+
return cls.from_config(typed_config, log_queue=log_queue)
|
| 146 |
+
|
| 147 |
+
if isinstance(model_path, GeneratorConfig | Mapping):
|
| 148 |
+
if kwargs:
|
| 149 |
+
unexpected = ", ".join(sorted(kwargs))
|
| 150 |
+
raise TypeError(f"Unexpected keyword arguments with typed config: {unexpected}")
|
| 151 |
+
return cls.from_config(model_path, log_queue=log_queue)
|
| 152 |
+
|
| 153 |
+
if model_path is None:
|
| 154 |
+
raise TypeError("model_path or config is required")
|
| 155 |
+
|
| 156 |
+
legacy_only_kwargs = sorted(set(kwargs) - _FROM_PRETRAINED_CONVENIENCE_KWARGS)
|
| 157 |
+
if legacy_only_kwargs:
|
| 158 |
+
warnings.warn(
|
| 159 |
+
"VideoGenerator.from_pretrained(...) received legacy-only kwargs "
|
| 160 |
+
f"({', '.join(legacy_only_kwargs)}); prefer VideoGenerator.from_config(...) "
|
| 161 |
+
"for advanced configuration.",
|
| 162 |
+
DeprecationWarning,
|
| 163 |
+
stacklevel=2,
|
| 164 |
+
)
|
| 165 |
+
return cls.from_config(
|
| 166 |
+
legacy_from_pretrained_to_config(model_path, kwargs),
|
| 167 |
+
log_queue=log_queue,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
@classmethod
|
| 171 |
+
def from_config(
|
| 172 |
+
cls,
|
| 173 |
+
config: GeneratorConfig | Mapping[str, Any],
|
| 174 |
+
*,
|
| 175 |
+
log_queue=None,
|
| 176 |
+
) -> "VideoGenerator":
|
| 177 |
+
normalized = normalize_generator_config(config)
|
| 178 |
+
fastvideo_args = generator_config_to_fastvideo_args(normalized)
|
| 179 |
+
generator = cls.from_fastvideo_args(fastvideo_args, log_queue=log_queue)
|
| 180 |
+
generator.config = normalized
|
| 181 |
+
return generator
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def from_file(
|
| 185 |
+
cls,
|
| 186 |
+
path: str,
|
| 187 |
+
overrides: list[str] | Mapping[str, Any] | None = None,
|
| 188 |
+
*,
|
| 189 |
+
log_queue=None,
|
| 190 |
+
) -> "VideoGenerator":
|
| 191 |
+
return cls.from_config(
|
| 192 |
+
load_generator_config_from_file(path, overrides=overrides),
|
| 193 |
+
log_queue=log_queue,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def from_fastvideo_args(
|
| 198 |
+
cls,
|
| 199 |
+
fastvideo_args: FastVideoArgs,
|
| 200 |
+
*,
|
| 201 |
+
log_queue=None,
|
| 202 |
+
) -> "VideoGenerator":
|
| 203 |
+
"""
|
| 204 |
+
Create a video generator with the specified arguments.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
fastvideo_args: The inference arguments
|
| 208 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
The created video generator
|
| 212 |
+
"""
|
| 213 |
+
# Initialize distributed environment if needed
|
| 214 |
+
# initialize_distributed_and_parallelism(fastvideo_args)
|
| 215 |
+
|
| 216 |
+
executor_class = Executor.get_class(fastvideo_args)
|
| 217 |
+
return cls(
|
| 218 |
+
fastvideo_args=fastvideo_args,
|
| 219 |
+
executor_class=executor_class,
|
| 220 |
+
log_stats=False, # TODO: implement
|
| 221 |
+
log_queue=log_queue,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def generate(
|
| 225 |
+
self,
|
| 226 |
+
request: GenerationRequest | Mapping[str, Any],
|
| 227 |
+
*,
|
| 228 |
+
log_queue=None,
|
| 229 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 230 |
+
"""
|
| 231 |
+
Generate video or image outputs from a typed inference request.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
request: A `GenerationRequest` instance or a mapping that can be
|
| 235 |
+
parsed into one. This is the primary public inference
|
| 236 |
+
entrypoint for the typed API.
|
| 237 |
+
log_queue: Optional multiprocessing.Queue to forward worker logs to
|
| 238 |
+
during this request.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
A `GenerationResult` for single-request generation, or a list of
|
| 242 |
+
`GenerationResult` objects when the request expands into multiple
|
| 243 |
+
prompts.
|
| 244 |
+
"""
|
| 245 |
+
normalized_request = normalize_generation_request(request)
|
| 246 |
+
if log_queue:
|
| 247 |
+
self.executor.set_log_queue(log_queue)
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
return self._generate_request_impl(normalized_request)
|
| 251 |
+
finally:
|
| 252 |
+
if log_queue:
|
| 253 |
+
self.executor.clear_log_queue()
|
| 254 |
+
|
| 255 |
+
def generate_video(
|
| 256 |
+
self,
|
| 257 |
+
prompt: str | None = None,
|
| 258 |
+
sampling_param: SamplingParam | None = None,
|
| 259 |
+
# Action control inputs (Matrix-Game)
|
| 260 |
+
mouse_cond: torch.Tensor | None = None,
|
| 261 |
+
keyboard_cond: torch.Tensor | None = None,
|
| 262 |
+
grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
|
| 263 |
+
| None = None,
|
| 264 |
+
**kwargs,
|
| 265 |
+
) -> dict[str, Any] | list[dict[str, Any]]:
|
| 266 |
+
"""
|
| 267 |
+
Generate a video based on the given prompt.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
prompt: The prompt to use for generation (optional if prompt_txt is provided)
|
| 271 |
+
negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
|
| 272 |
+
output_path: Path to save the video (overrides the one in fastvideo_args)
|
| 273 |
+
prompt_path: Path to prompt file
|
| 274 |
+
save_video: Whether to save the video to disk
|
| 275 |
+
return_frames: Whether to include raw frames in the result dict
|
| 276 |
+
num_inference_steps: Number of denoising steps (overrides fastvideo_args)
|
| 277 |
+
guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
|
| 278 |
+
num_frames: Number of frames to generate (overrides fastvideo_args)
|
| 279 |
+
height: Height of generated video (overrides fastvideo_args)
|
| 280 |
+
width: Width of generated video (overrides fastvideo_args)
|
| 281 |
+
fps: Frames per second for saved video (overrides fastvideo_args)
|
| 282 |
+
seed: Random seed for generation (overrides fastvideo_args)
|
| 283 |
+
callback: Callback function called after each step
|
| 284 |
+
callback_steps: Number of steps between each callback
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
A metadata dictionary for single-prompt generation, or a list of
|
| 288 |
+
metadata dictionaries for prompt-file batch generation.
|
| 289 |
+
"""
|
| 290 |
+
log_queue = kwargs.pop("log_queue", None)
|
| 291 |
+
warnings.warn(
|
| 292 |
+
"VideoGenerator.generate_video(...) is deprecated; use "
|
| 293 |
+
"VideoGenerator.generate(request=...) instead.",
|
| 294 |
+
DeprecationWarning,
|
| 295 |
+
stacklevel=2,
|
| 296 |
+
)
|
| 297 |
+
if log_queue:
|
| 298 |
+
self.executor.set_log_queue(log_queue)
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
return self._generate_video_impl(
|
| 302 |
+
prompt=prompt,
|
| 303 |
+
sampling_param=sampling_param,
|
| 304 |
+
mouse_cond=mouse_cond,
|
| 305 |
+
keyboard_cond=keyboard_cond,
|
| 306 |
+
grid_sizes=grid_sizes,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
finally:
|
| 310 |
+
if log_queue:
|
| 311 |
+
self.executor.clear_log_queue()
|
| 312 |
+
|
| 313 |
+
def _generate_request_impl(
|
| 314 |
+
self,
|
| 315 |
+
request: GenerationRequest,
|
| 316 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 317 |
+
if isinstance(request.prompt, list):
|
| 318 |
+
if request.inputs.prompt_path is not None:
|
| 319 |
+
raise ValueError("request.prompt list cannot be combined with request.inputs.prompt_path")
|
| 320 |
+
results: list[GenerationResult] = []
|
| 321 |
+
for index, single_request in enumerate(expand_request_prompt_batch(request)):
|
| 322 |
+
prompt = single_request.prompt
|
| 323 |
+
wrapped = self._generate_single_request(single_request)
|
| 324 |
+
if isinstance(wrapped, list):
|
| 325 |
+
results.extend(wrapped)
|
| 326 |
+
continue
|
| 327 |
+
wrapped.prompt_index = index
|
| 328 |
+
if wrapped.prompt is None and isinstance(prompt, str):
|
| 329 |
+
wrapped.prompt = prompt
|
| 330 |
+
results.append(wrapped)
|
| 331 |
+
return results
|
| 332 |
+
|
| 333 |
+
return self._generate_single_request(request)
|
| 334 |
+
|
| 335 |
+
def _generate_single_request(
|
| 336 |
+
self,
|
| 337 |
+
request: GenerationRequest,
|
| 338 |
+
) -> GenerationResult | list[GenerationResult]:
|
| 339 |
+
fastvideo_args = self.fastvideo_args
|
| 340 |
+
pipeline_overrides = request_to_pipeline_overrides(request)
|
| 341 |
+
if pipeline_overrides:
|
| 342 |
+
fastvideo_args = deepcopy(self.fastvideo_args)
|
| 343 |
+
for key, value in pipeline_overrides.items():
|
| 344 |
+
if not hasattr(fastvideo_args.pipeline_config, key):
|
| 345 |
+
raise ValueError(f"Request field {key!r} is not supported by pipeline config overrides")
|
| 346 |
+
setattr(fastvideo_args.pipeline_config, key, deepcopy(value))
|
| 347 |
+
|
| 348 |
+
sampling_param = request_to_sampling_param(
|
| 349 |
+
request,
|
| 350 |
+
model_path=self.fastvideo_args.model_path,
|
| 351 |
+
)
|
| 352 |
+
result = self._generate_video_impl(
|
| 353 |
+
prompt=request.prompt,
|
| 354 |
+
sampling_param=sampling_param,
|
| 355 |
+
fastvideo_args=fastvideo_args,
|
| 356 |
+
)
|
| 357 |
+
return self._wrap_legacy_result(result)
|
| 358 |
+
|
| 359 |
+
def _generate_video_impl(
|
| 360 |
+
self,
|
| 361 |
+
prompt: str | list[str] | None = None,
|
| 362 |
+
sampling_param: SamplingParam | None = None,
|
| 363 |
+
mouse_cond: torch.Tensor | None = None,
|
| 364 |
+
keyboard_cond: torch.Tensor | None = None,
|
| 365 |
+
grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
|
| 366 |
+
| None = None,
|
| 367 |
+
fastvideo_args: FastVideoArgs | None = None,
|
| 368 |
+
**kwargs,
|
| 369 |
+
) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]]:
|
| 370 |
+
"""Internal implementation of generate_video."""
|
| 371 |
+
if fastvideo_args is None:
|
| 372 |
+
fastvideo_args = self.fastvideo_args
|
| 373 |
+
|
| 374 |
+
# Handle batch processing from text file
|
| 375 |
+
if sampling_param is None:
|
| 376 |
+
sampling_param = SamplingParam.from_pretrained(fastvideo_args.model_path)
|
| 377 |
+
|
| 378 |
+
# Add action control inputs to kwargs if provided
|
| 379 |
+
if mouse_cond is not None:
|
| 380 |
+
kwargs['mouse_cond'] = mouse_cond
|
| 381 |
+
if keyboard_cond is not None:
|
| 382 |
+
kwargs['keyboard_cond'] = keyboard_cond
|
| 383 |
+
if grid_sizes is not None:
|
| 384 |
+
kwargs['grid_sizes'] = grid_sizes
|
| 385 |
+
|
| 386 |
+
sampling_param.update(kwargs)
|
| 387 |
+
|
| 388 |
+
if fastvideo_args.prompt_txt is not None or sampling_param.prompt_path is not None:
|
| 389 |
+
prompt_txt_path = sampling_param.prompt_path or fastvideo_args.prompt_txt
|
| 390 |
+
if not prompt_txt_path or not os.path.exists(prompt_txt_path):
|
| 391 |
+
raise FileNotFoundError(f"Prompt text file not found: {prompt_txt_path}")
|
| 392 |
+
|
| 393 |
+
# Read prompts from file
|
| 394 |
+
with open(prompt_txt_path, encoding='utf-8') as f:
|
| 395 |
+
prompts = [line.strip() for line in f if line.strip()]
|
| 396 |
+
|
| 397 |
+
if not prompts:
|
| 398 |
+
raise ValueError(f"No prompts found in file: {prompt_txt_path}")
|
| 399 |
+
|
| 400 |
+
logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path)
|
| 401 |
+
|
| 402 |
+
results = []
|
| 403 |
+
for i, batch_prompt in enumerate(prompts):
|
| 404 |
+
logger.info("Processing prompt %d/%d: %s...", i + 1, len(prompts), batch_prompt[:100])
|
| 405 |
+
try:
|
| 406 |
+
# Generate video for this prompt using the same logic below
|
| 407 |
+
output_path = self._prepare_output_path(sampling_param.output_path, batch_prompt)
|
| 408 |
+
kwargs["output_path"] = output_path
|
| 409 |
+
result = self._generate_single_video(
|
| 410 |
+
prompt=batch_prompt,
|
| 411 |
+
sampling_param=sampling_param,
|
| 412 |
+
fastvideo_args=fastvideo_args,
|
| 413 |
+
**kwargs,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
# Add prompt info to result
|
| 417 |
+
result["prompt_index"] = i
|
| 418 |
+
result["prompt"] = batch_prompt
|
| 419 |
+
|
| 420 |
+
results.append(result)
|
| 421 |
+
logger.info("Successfully generated video for prompt %d", i + 1)
|
| 422 |
+
|
| 423 |
+
except Exception as e:
|
| 424 |
+
logger.error("Failed to generate video for prompt %d: %s", i + 1, e)
|
| 425 |
+
continue
|
| 426 |
+
|
| 427 |
+
logger.info("Completed batch processing. Generated %d videos successfully.", len(results))
|
| 428 |
+
return results
|
| 429 |
+
|
| 430 |
+
# Single prompt generation (original behavior)
|
| 431 |
+
if prompt is None:
|
| 432 |
+
raise ValueError("Either prompt or prompt_txt must be provided")
|
| 433 |
+
if not isinstance(prompt, str):
|
| 434 |
+
raise ValueError("Single-prompt generation expects a string prompt")
|
| 435 |
+
output_path = self._prepare_output_path(sampling_param.output_path, prompt)
|
| 436 |
+
kwargs["output_path"] = output_path
|
| 437 |
+
return self._generate_single_video(
|
| 438 |
+
prompt=prompt,
|
| 439 |
+
sampling_param=sampling_param,
|
| 440 |
+
fastvideo_args=fastvideo_args,
|
| 441 |
+
**kwargs,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def _is_image_workload(self) -> bool:
|
| 445 |
+
"""Return True when the workload produces a single image (t2i, i2i …)."""
|
| 446 |
+
args = getattr(self, "fastvideo_args", None)
|
| 447 |
+
if args is None:
|
| 448 |
+
return False
|
| 449 |
+
return args.workload_type.value.endswith("2i")
|
| 450 |
+
|
| 451 |
+
def _prepare_output_path(
|
| 452 |
+
self,
|
| 453 |
+
output_path: str,
|
| 454 |
+
prompt: str,
|
| 455 |
+
) -> str:
|
| 456 |
+
"""Build a unique, sanitized output file path.
|
| 457 |
+
|
| 458 |
+
The file extension is chosen automatically based on the workload type:
|
| 459 |
+
``.png`` for image workloads (``t2i``, ``i2i``, …) and ``.mp4`` for
|
| 460 |
+
video workloads.
|
| 461 |
+
|
| 462 |
+
- If ``output_path`` already carries the correct extension, treat it
|
| 463 |
+
as a file path.
|
| 464 |
+
- Otherwise, treat ``output_path`` as a directory and derive the
|
| 465 |
+
filename from the prompt.
|
| 466 |
+
- Invalid filename characters are removed; if the name changes, a
|
| 467 |
+
warning is logged.
|
| 468 |
+
- If the target path already exists, a numeric suffix is appended.
|
| 469 |
+
"""
|
| 470 |
+
target_ext = ".png" if self._is_image_workload() else ".mp4"
|
| 471 |
+
|
| 472 |
+
def _sanitize_filename_component(name: str) -> str:
|
| 473 |
+
# Remove characters invalid on common filesystems, strip spaces/dots
|
| 474 |
+
sanitized = re.sub(r'[\\/:*?"<>|]', '', name)
|
| 475 |
+
sanitized = sanitized.strip().strip('.')
|
| 476 |
+
sanitized = re.sub(r'\s+', ' ', sanitized)
|
| 477 |
+
return sanitized or "output"
|
| 478 |
+
|
| 479 |
+
base_path, extension = os.path.splitext(output_path)
|
| 480 |
+
extension_lower = extension.lower()
|
| 481 |
+
|
| 482 |
+
if extension_lower == target_ext:
|
| 483 |
+
output_dir = os.path.dirname(output_path)
|
| 484 |
+
base_name = os.path.basename(base_path) # filename without extension
|
| 485 |
+
sanitized_base = _sanitize_filename_component(base_name)
|
| 486 |
+
if sanitized_base != base_name:
|
| 487 |
+
logger.warning(
|
| 488 |
+
"The output name '%s' contained invalid characters. "
|
| 489 |
+
"It has been renamed to '%s%s'",
|
| 490 |
+
os.path.basename(output_path),
|
| 491 |
+
sanitized_base,
|
| 492 |
+
target_ext,
|
| 493 |
+
)
|
| 494 |
+
out_name = f"{sanitized_base}{target_ext}"
|
| 495 |
+
else:
|
| 496 |
+
# Treat as directory; inform if an unexpected extension was
|
| 497 |
+
# provided.
|
| 498 |
+
if extension:
|
| 499 |
+
logger.info(
|
| 500 |
+
"Output path '%s' has extension '%s' which does not "
|
| 501 |
+
"match the target '%s'; treating it as a directory",
|
| 502 |
+
output_path,
|
| 503 |
+
extension,
|
| 504 |
+
target_ext,
|
| 505 |
+
)
|
| 506 |
+
output_dir = output_path
|
| 507 |
+
prompt_component = _sanitize_filename_component(prompt[:100])
|
| 508 |
+
out_name = f"{prompt_component}{target_ext}"
|
| 509 |
+
|
| 510 |
+
if output_dir:
|
| 511 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 512 |
+
|
| 513 |
+
new_output_path = os.path.join(output_dir, out_name)
|
| 514 |
+
counter = 1
|
| 515 |
+
while os.path.exists(new_output_path):
|
| 516 |
+
name_part, ext_part = os.path.splitext(out_name)
|
| 517 |
+
new_name = f"{name_part}_{counter}{ext_part}"
|
| 518 |
+
new_output_path = os.path.join(output_dir, new_name)
|
| 519 |
+
counter += 1
|
| 520 |
+
return new_output_path
|
| 521 |
+
|
| 522 |
+
def _generate_single_video(
|
| 523 |
+
self,
|
| 524 |
+
prompt: str,
|
| 525 |
+
sampling_param: SamplingParam | None = None,
|
| 526 |
+
fastvideo_args: FastVideoArgs | None = None,
|
| 527 |
+
**kwargs,
|
| 528 |
+
) -> dict[str, Any]:
|
| 529 |
+
"""Internal method for single video generation"""
|
| 530 |
+
if fastvideo_args is None:
|
| 531 |
+
fastvideo_args = self.fastvideo_args
|
| 532 |
+
|
| 533 |
+
# Validate inputs
|
| 534 |
+
if not isinstance(prompt, str):
|
| 535 |
+
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
|
| 536 |
+
prompt = prompt.strip()
|
| 537 |
+
sampling_param = deepcopy(sampling_param)
|
| 538 |
+
output_path = kwargs["output_path"]
|
| 539 |
+
sampling_param.prompt = prompt
|
| 540 |
+
# Process negative prompt
|
| 541 |
+
if sampling_param.negative_prompt is not None:
|
| 542 |
+
sampling_param.negative_prompt = sampling_param.negative_prompt.strip()
|
| 543 |
+
|
| 544 |
+
# Validate dimensions
|
| 545 |
+
if (sampling_param.height <= 0 or sampling_param.width <= 0 or sampling_param.num_frames <= 0):
|
| 546 |
+
raise ValueError(f"Height, width, and num_frames must be positive integers, got "
|
| 547 |
+
f"height={sampling_param.height}, width={sampling_param.width}, "
|
| 548 |
+
f"num_frames={sampling_param.num_frames}")
|
| 549 |
+
|
| 550 |
+
# Calculate sizes
|
| 551 |
+
target_height = align_to(sampling_param.height, 16)
|
| 552 |
+
target_width = align_to(sampling_param.width, 16)
|
| 553 |
+
|
| 554 |
+
# Calculate latent sizes
|
| 555 |
+
latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
|
| 556 |
+
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
|
| 557 |
+
|
| 558 |
+
# Log parameters
|
| 559 |
+
debug_str = f"""
|
| 560 |
+
height: {target_height}
|
| 561 |
+
width: {target_width}
|
| 562 |
+
video_length: {sampling_param.num_frames}
|
| 563 |
+
prompt: {sampling_param.prompt}
|
| 564 |
+
image_path: {sampling_param.image_path}
|
| 565 |
+
neg_prompt: {sampling_param.negative_prompt}
|
| 566 |
+
seed: {sampling_param.seed}
|
| 567 |
+
infer_steps: {sampling_param.num_inference_steps}
|
| 568 |
+
num_videos_per_prompt: {sampling_param.num_videos_per_prompt}
|
| 569 |
+
guidance_scale: {sampling_param.guidance_scale}
|
| 570 |
+
n_tokens: {n_tokens}
|
| 571 |
+
flow_shift: {fastvideo_args.pipeline_config.flow_shift}
|
| 572 |
+
embedded_guidance_scale: {fastvideo_args.pipeline_config.embedded_cfg_scale}
|
| 573 |
+
save_video: {sampling_param.save_video}
|
| 574 |
+
output_path: {output_path}
|
| 575 |
+
""" # type: ignore[attr-defined]
|
| 576 |
+
logger.info(debug_str)
|
| 577 |
+
|
| 578 |
+
# Prepare batch
|
| 579 |
+
batch = ForwardBatch(
|
| 580 |
+
**shallow_asdict(sampling_param),
|
| 581 |
+
eta=0.0,
|
| 582 |
+
n_tokens=n_tokens,
|
| 583 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Run inference
|
| 587 |
+
start_time = time.perf_counter()
|
| 588 |
+
|
| 589 |
+
# Execute forward pass in a new thread for non-blocking tensor
|
| 590 |
+
# allocation. Capture thread exceptions so we can surface the true
|
| 591 |
+
# failure in the main thread instead of later hitting None outputs.
|
| 592 |
+
result_container = {"output_batch": ForwardBatch(data_type=batch.data_type)}
|
| 593 |
+
thread_error: dict[str, BaseException | None] = {"error": None}
|
| 594 |
+
thread_error_traceback: dict[str, str] = {"traceback": ""}
|
| 595 |
+
|
| 596 |
+
def execute_forward_thread():
|
| 597 |
+
import traceback
|
| 598 |
+
try:
|
| 599 |
+
result_container["output_batch"] = self.executor.execute_forward(batch, fastvideo_args)
|
| 600 |
+
except BaseException as error: # noqa: BLE001
|
| 601 |
+
thread_error["error"] = error
|
| 602 |
+
thread_error_traceback["traceback"] = traceback.format_exc()
|
| 603 |
+
|
| 604 |
+
thread = threading.Thread(target=execute_forward_thread)
|
| 605 |
+
thread.start()
|
| 606 |
+
latent_batch_size = _infer_latent_batch_size(batch)
|
| 607 |
+
samples = torch.empty(
|
| 608 |
+
(latent_batch_size, 3, sampling_param.num_frames, sampling_param.height, sampling_param.width),
|
| 609 |
+
device='cpu',
|
| 610 |
+
pin_memory=fastvideo_args.pin_cpu_memory)
|
| 611 |
+
thread.join()
|
| 612 |
+
|
| 613 |
+
if thread_error["error"] is not None:
|
| 614 |
+
raise RuntimeError("Forward execution thread failed.\n"
|
| 615 |
+
f"{thread_error_traceback['traceback']}") from thread_error["error"]
|
| 616 |
+
|
| 617 |
+
output_batch = result_container["output_batch"]
|
| 618 |
+
if output_batch.output is None:
|
| 619 |
+
raise RuntimeError("Forward execution returned no output tensor. "
|
| 620 |
+
"This usually means the executor/pipeline failed earlier.")
|
| 621 |
+
|
| 622 |
+
if output_batch.output.shape == samples.shape:
|
| 623 |
+
samples.copy_(output_batch.output)
|
| 624 |
+
else:
|
| 625 |
+
logger.warning("Output shape %s does not match expected shape %s; use slow path", output_batch.output.shape,
|
| 626 |
+
samples.shape)
|
| 627 |
+
samples = output_batch.output.cpu()
|
| 628 |
+
logging_info = output_batch.logging_info
|
| 629 |
+
|
| 630 |
+
gen_time = time.perf_counter() - start_time
|
| 631 |
+
logger.info("Generated successfully in %.2f seconds", gen_time)
|
| 632 |
+
|
| 633 |
+
# Process outputs
|
| 634 |
+
videos = rearrange(samples, "b c t h w -> t b c h w")
|
| 635 |
+
frames = []
|
| 636 |
+
for x in videos:
|
| 637 |
+
x = torchvision.utils.make_grid(x, nrow=6)
|
| 638 |
+
x = x.permute(1, 2, 0).squeeze(-1)
|
| 639 |
+
x = (x * 255).to(torch.uint8)
|
| 640 |
+
frames.append(x.cpu().numpy())
|
| 641 |
+
|
| 642 |
+
# Save output if requested
|
| 643 |
+
if batch.save_video:
|
| 644 |
+
if self._is_image_workload():
|
| 645 |
+
# Image workloads (t2i, i2i, …): save the first frame as PNG.
|
| 646 |
+
imageio.imwrite(output_path, frames[0])
|
| 647 |
+
logger.info("Saved image to %s", output_path)
|
| 648 |
+
else:
|
| 649 |
+
imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4")
|
| 650 |
+
logger.info("Saved video to %s", output_path)
|
| 651 |
+
audio = output_batch.extra.get("audio")
|
| 652 |
+
audio_sample_rate = output_batch.extra.get("audio_sample_rate")
|
| 653 |
+
if (audio is not None and audio_sample_rate is not None
|
| 654 |
+
and not self._mux_audio(output_path, audio, audio_sample_rate)):
|
| 655 |
+
logger.warning("Audio mux failed; saved video without audio.")
|
| 656 |
+
|
| 657 |
+
result: dict[str, Any] = {
|
| 658 |
+
"prompts": prompt,
|
| 659 |
+
"samples": samples if batch.return_frames else None,
|
| 660 |
+
"frames": frames if batch.return_frames else None,
|
| 661 |
+
"audio": output_batch.extra.get("audio") if batch.return_frames else None,
|
| 662 |
+
"size": (target_height, target_width, batch.num_frames),
|
| 663 |
+
"generation_time": gen_time,
|
| 664 |
+
"logging_info": logging_info,
|
| 665 |
+
"trajectory": output_batch.trajectory_latents,
|
| 666 |
+
"trajectory_timesteps": output_batch.trajectory_timesteps,
|
| 667 |
+
"trajectory_decoded": output_batch.trajectory_decoded,
|
| 668 |
+
"video_path": output_path if batch.save_video else None,
|
| 669 |
+
"peak_memory_mb": output_batch.extra.get("peak_memory_mb"),
|
| 670 |
+
}
|
| 671 |
+
|
| 672 |
+
return result
|
| 673 |
+
|
| 674 |
+
@staticmethod
|
| 675 |
+
def _wrap_legacy_result(
|
| 676 |
+
result: dict[str, Any] | list[dict[str, Any]], ) -> GenerationResult | list[GenerationResult]:
|
| 677 |
+
if isinstance(result, list):
|
| 678 |
+
return [GenerationResult.from_legacy_result(item) for item in result]
|
| 679 |
+
return GenerationResult.from_legacy_result(result)
|
| 680 |
+
|
| 681 |
+
@staticmethod
|
| 682 |
+
def _unwrap_typed_result(
|
| 683 |
+
result: GenerationResult | list[GenerationResult], ) -> dict[str, Any] | list[dict[str, Any]]:
|
| 684 |
+
if isinstance(result, list):
|
| 685 |
+
return [item.to_legacy_dict() for item in result]
|
| 686 |
+
return result.to_legacy_dict()
|
| 687 |
+
|
| 688 |
+
@staticmethod
|
| 689 |
+
def _mux_audio(
|
| 690 |
+
video_path: str,
|
| 691 |
+
audio: torch.Tensor | np.ndarray,
|
| 692 |
+
sample_rate: int,
|
| 693 |
+
) -> bool:
|
| 694 |
+
"""Mux audio into video using PyAV."""
|
| 695 |
+
try:
|
| 696 |
+
import av
|
| 697 |
+
except ImportError:
|
| 698 |
+
logger.warning("PyAV not installed; cannot mux audio. "
|
| 699 |
+
"Install with: pip install av")
|
| 700 |
+
return False
|
| 701 |
+
|
| 702 |
+
if torch.is_tensor(audio):
|
| 703 |
+
audio_np = audio.detach().cpu().float().numpy()
|
| 704 |
+
else:
|
| 705 |
+
audio_np = np.asarray(audio, dtype=np.float32)
|
| 706 |
+
|
| 707 |
+
if audio_np.ndim == 1:
|
| 708 |
+
audio_np = audio_np[:, None]
|
| 709 |
+
elif audio_np.ndim == 2:
|
| 710 |
+
if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
|
| 711 |
+
audio_np = audio_np.T
|
| 712 |
+
else:
|
| 713 |
+
logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
|
| 714 |
+
return False
|
| 715 |
+
|
| 716 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
| 717 |
+
audio_int16 = (audio_np * 32767.0).astype(np.int16)
|
| 718 |
+
num_channels = audio_int16.shape[1]
|
| 719 |
+
layout = "stereo" if num_channels == 2 else "mono"
|
| 720 |
+
|
| 721 |
+
try:
|
| 722 |
+
import wave
|
| 723 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 724 |
+
out_path = os.path.join(tmpdir, "muxed.mp4")
|
| 725 |
+
wav_path = os.path.join(tmpdir, "audio.wav")
|
| 726 |
+
|
| 727 |
+
# Write audio to WAV file
|
| 728 |
+
with wave.open(wav_path, "wb") as wav_file:
|
| 729 |
+
wav_file.setnchannels(num_channels)
|
| 730 |
+
wav_file.setsampwidth(2)
|
| 731 |
+
wav_file.setframerate(sample_rate)
|
| 732 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 733 |
+
|
| 734 |
+
# Open input video and audio
|
| 735 |
+
input_video = av.open(video_path)
|
| 736 |
+
input_audio = av.open(wav_path)
|
| 737 |
+
|
| 738 |
+
# Create output with both streams
|
| 739 |
+
output = av.open(out_path, mode="w")
|
| 740 |
+
|
| 741 |
+
# Add video stream (copy codec from input)
|
| 742 |
+
in_video_stream = input_video.streams.video[0]
|
| 743 |
+
out_video_stream = output.add_stream(
|
| 744 |
+
codec_name=in_video_stream.codec_context.name,
|
| 745 |
+
rate=in_video_stream.average_rate,
|
| 746 |
+
)
|
| 747 |
+
out_video_stream.width = in_video_stream.width
|
| 748 |
+
out_video_stream.height = in_video_stream.height
|
| 749 |
+
out_video_stream.pix_fmt = in_video_stream.pix_fmt
|
| 750 |
+
|
| 751 |
+
# Add audio stream (AAC)
|
| 752 |
+
out_audio_stream = output.add_stream("aac", rate=sample_rate)
|
| 753 |
+
out_audio_stream.layout = layout
|
| 754 |
+
|
| 755 |
+
# Remux video (decode and re-encode to be safe)
|
| 756 |
+
for frame in input_video.decode(video=0):
|
| 757 |
+
for packet in out_video_stream.encode(frame):
|
| 758 |
+
output.mux(packet)
|
| 759 |
+
for packet in out_video_stream.encode():
|
| 760 |
+
output.mux(packet)
|
| 761 |
+
|
| 762 |
+
# Encode audio
|
| 763 |
+
for frame in input_audio.decode(audio=0):
|
| 764 |
+
frame.pts = None # Let encoder assign PTS
|
| 765 |
+
for packet in out_audio_stream.encode(frame):
|
| 766 |
+
output.mux(packet)
|
| 767 |
+
for packet in out_audio_stream.encode():
|
| 768 |
+
output.mux(packet)
|
| 769 |
+
|
| 770 |
+
input_video.close()
|
| 771 |
+
input_audio.close()
|
| 772 |
+
output.close()
|
| 773 |
+
shutil.move(out_path, video_path)
|
| 774 |
+
return True
|
| 775 |
+
except Exception as e:
|
| 776 |
+
logger.warning("Audio mux failed: %s", e)
|
| 777 |
+
return False
|
| 778 |
+
|
| 779 |
+
def set_lora_adapter(self, lora_nickname: str, lora_path: str | None = None) -> None:
|
| 780 |
+
self.executor.set_lora_adapter(lora_nickname, lora_path)
|
| 781 |
+
|
| 782 |
+
def unmerge_lora_weights(self) -> None:
|
| 783 |
+
"""
|
| 784 |
+
Use unmerged weights for inference to produce videos that align with
|
| 785 |
+
validation videos generated during training.
|
| 786 |
+
"""
|
| 787 |
+
self.executor.unmerge_lora_weights()
|
| 788 |
+
|
| 789 |
+
def merge_lora_weights(self) -> None:
|
| 790 |
+
self.executor.merge_lora_weights()
|
| 791 |
+
|
| 792 |
+
def shutdown(self) -> None:
|
| 793 |
+
"""
|
| 794 |
+
Shutdown the video generator.
|
| 795 |
+
"""
|
| 796 |
+
self.executor.shutdown()
|
| 797 |
+
del self.executor
|
backend_snapshot/fastvideo/fastvideo_args.py
ADDED
|
@@ -0,0 +1,1188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
|
| 3 |
+
"""The arguments of FastVideo Inference."""
|
| 4 |
+
import argparse
|
| 5 |
+
import dataclasses
|
| 6 |
+
import json
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from dataclasses import field
|
| 9 |
+
from enum import Enum
|
| 10 |
+
from typing import Any, TYPE_CHECKING
|
| 11 |
+
|
| 12 |
+
from fastvideo.configs.configs import PreprocessConfig
|
| 13 |
+
from fastvideo.configs.pipelines.base import PipelineConfig
|
| 14 |
+
from fastvideo.configs.utils import clean_cli_args
|
| 15 |
+
from fastvideo.layers.quantization import QUANTIZATION_METHODS, QuantizationMethods
|
| 16 |
+
from fastvideo.logger import init_logger
|
| 17 |
+
from fastvideo.utils import FlexibleArgumentParser, StoreBoolean
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from ray.runtime_env import RuntimeEnv
|
| 21 |
+
from ray.util.placement_group import PlacementGroup
|
| 22 |
+
else:
|
| 23 |
+
RuntimeEnv = Any
|
| 24 |
+
PlacementGroup = Any
|
| 25 |
+
|
| 26 |
+
logger = init_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ExecutionMode(str, Enum):
|
| 30 |
+
"""
|
| 31 |
+
Enumeration for different pipeline modes.
|
| 32 |
+
|
| 33 |
+
Inherits from str to allow string comparison for backward compatibility.
|
| 34 |
+
"""
|
| 35 |
+
INFERENCE = "inference"
|
| 36 |
+
PREPROCESS = "preprocess"
|
| 37 |
+
FINETUNING = "finetuning"
|
| 38 |
+
DISTILLATION = "distillation"
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def from_string(cls, value: str) -> "ExecutionMode":
|
| 42 |
+
"""Convert string to ExecutionMode enum."""
|
| 43 |
+
try:
|
| 44 |
+
return cls(value.lower())
|
| 45 |
+
except ValueError:
|
| 46 |
+
raise ValueError(f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def choices(cls) -> list[str]:
|
| 50 |
+
"""Get all available choices as strings for argparse."""
|
| 51 |
+
return [mode.value for mode in cls]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class WorkloadType(str, Enum):
|
| 55 |
+
"""
|
| 56 |
+
Enumeration for different workload types.
|
| 57 |
+
|
| 58 |
+
Inherits from str to allow string comparison for backward compatibility.
|
| 59 |
+
"""
|
| 60 |
+
I2V = "i2v" # Image to Video
|
| 61 |
+
T2V = "t2v" # Text to Video
|
| 62 |
+
T2I = "t2i" # Text to Image
|
| 63 |
+
I2I = "i2i" # Image to Image
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_string(cls, value: str) -> "WorkloadType":
|
| 67 |
+
"""Convert string to WorkloadType enum."""
|
| 68 |
+
try:
|
| 69 |
+
return cls(value.lower())
|
| 70 |
+
except ValueError:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def choices(cls) -> list[str]:
|
| 76 |
+
"""Get all available choices as strings for argparse."""
|
| 77 |
+
return [workload.value for workload in cls]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# args for fastvideo framework
|
| 81 |
+
@dataclasses.dataclass
|
| 82 |
+
class FastVideoArgs:
|
| 83 |
+
# Model and path configuration (for convenience)
|
| 84 |
+
model_path: str
|
| 85 |
+
|
| 86 |
+
# Running mode
|
| 87 |
+
mode: ExecutionMode = ExecutionMode.INFERENCE
|
| 88 |
+
|
| 89 |
+
# Workload type
|
| 90 |
+
workload_type: WorkloadType = WorkloadType.T2V
|
| 91 |
+
|
| 92 |
+
# Distributed executor backend
|
| 93 |
+
distributed_executor_backend: str = "mp"
|
| 94 |
+
|
| 95 |
+
# a few attributes for ray related
|
| 96 |
+
ray_placement_group: PlacementGroup | None = None
|
| 97 |
+
ray_runtime_env: RuntimeEnv | None = None
|
| 98 |
+
|
| 99 |
+
inference_mode: bool = True # if False == training mode
|
| 100 |
+
|
| 101 |
+
# HuggingFace specific parameters
|
| 102 |
+
trust_remote_code: bool = False
|
| 103 |
+
revision: str | None = None
|
| 104 |
+
|
| 105 |
+
# Parallelism
|
| 106 |
+
num_gpus: int = 1
|
| 107 |
+
tp_size: int = -1
|
| 108 |
+
sp_size: int = -1
|
| 109 |
+
hsdp_replicate_dim: int = 1
|
| 110 |
+
hsdp_shard_dim: int = -1
|
| 111 |
+
dist_timeout: int | None = None # timeout for torch.distributed
|
| 112 |
+
|
| 113 |
+
pipeline_config: PipelineConfig = field(default_factory=PipelineConfig)
|
| 114 |
+
preprocess_config: PreprocessConfig | None = None
|
| 115 |
+
|
| 116 |
+
# LoRA parameters
|
| 117 |
+
# (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.
|
| 118 |
+
lora_path: str | None = None
|
| 119 |
+
lora_nickname: str = "default" # for swapping adapters in the pipeline
|
| 120 |
+
# can restrict layers to adapt, e.g. ["q_proj"]
|
| 121 |
+
# Will adapt only q, k, v, o by default.
|
| 122 |
+
lora_target_modules: list[str] | None = None
|
| 123 |
+
|
| 124 |
+
output_type: str = "pil"
|
| 125 |
+
|
| 126 |
+
# CPU offload parameters
|
| 127 |
+
dit_cpu_offload: bool = True
|
| 128 |
+
use_fsdp_inference: bool = False
|
| 129 |
+
dit_layerwise_offload: bool = True
|
| 130 |
+
text_encoder_cpu_offload: bool = True
|
| 131 |
+
image_encoder_cpu_offload: bool = True
|
| 132 |
+
vae_cpu_offload: bool = True
|
| 133 |
+
pin_cpu_memory: bool = True
|
| 134 |
+
|
| 135 |
+
# Compilation
|
| 136 |
+
enable_torch_compile: bool = False
|
| 137 |
+
torch_compile_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 138 |
+
|
| 139 |
+
disable_autocast: bool = False
|
| 140 |
+
|
| 141 |
+
# VSA parameters
|
| 142 |
+
VSA_sparsity: float = 0.0 # inference/validation sparsity
|
| 143 |
+
|
| 144 |
+
# V-MoBA parameters
|
| 145 |
+
moba_config_path: str | None = None
|
| 146 |
+
moba_config: dict[str, Any] = field(default_factory=dict)
|
| 147 |
+
|
| 148 |
+
# Master port for distributed training/inference
|
| 149 |
+
master_port: int | None = None
|
| 150 |
+
|
| 151 |
+
# Stage verification
|
| 152 |
+
enable_stage_verification: bool = True
|
| 153 |
+
|
| 154 |
+
# Prompt text file for batch processing
|
| 155 |
+
prompt_txt: str | None = None
|
| 156 |
+
|
| 157 |
+
# LTX-2 VAE tiling overrides
|
| 158 |
+
ltx2_vae_tiling: bool | None = None
|
| 159 |
+
ltx2_vae_spatial_tile_size_in_pixels: int | None = None
|
| 160 |
+
ltx2_vae_spatial_tile_overlap_in_pixels: int | None = None
|
| 161 |
+
ltx2_vae_temporal_tile_size_in_frames: int | None = None
|
| 162 |
+
ltx2_vae_temporal_tile_overlap_in_frames: int | None = None
|
| 163 |
+
ltx2_initial_latent_path: str | None = None
|
| 164 |
+
|
| 165 |
+
# model paths for correct deallocation
|
| 166 |
+
model_paths: dict[str, str] = field(default_factory=dict)
|
| 167 |
+
model_loaded: dict[str, bool] = field(default_factory=lambda: {
|
| 168 |
+
"transformer": True,
|
| 169 |
+
"vae": True,
|
| 170 |
+
"upsampler": True,
|
| 171 |
+
})
|
| 172 |
+
|
| 173 |
+
override_text_encoder_safetensors: str | None = None # path to safetensors file for text encoder override
|
| 174 |
+
override_text_encoder_quant: QuantizationMethods = None
|
| 175 |
+
transformer_quant: QuantizationMethods = None
|
| 176 |
+
|
| 177 |
+
override_transformer_cls_name: str | None = None
|
| 178 |
+
init_weights_from_safetensors: str = "" # path to safetensors file for initial weight loading
|
| 179 |
+
init_weights_from_safetensors_2: str = "" # path to safetensors file for initial weight loading for transformer_2
|
| 180 |
+
|
| 181 |
+
override_pipeline_cls_name: str | None = None
|
| 182 |
+
|
| 183 |
+
# # DMD parameters
|
| 184 |
+
# dmd_denoising_steps: List[int] | None = field(default=None)
|
| 185 |
+
|
| 186 |
+
# MoE parameters used by Wan2.2
|
| 187 |
+
boundary_ratio: float = 0.875
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def training_mode(self) -> bool:
|
| 191 |
+
return not self.inference_mode
|
| 192 |
+
|
| 193 |
+
def __post_init__(self):
|
| 194 |
+
if self.moba_config_path:
|
| 195 |
+
try:
|
| 196 |
+
with open(self.moba_config_path) as f:
|
| 197 |
+
self.moba_config = json.load(f)
|
| 198 |
+
logger.info("Loaded V-MoBA config from %s", self.moba_config_path)
|
| 199 |
+
except (FileNotFoundError, json.JSONDecodeError) as e:
|
| 200 |
+
logger.error("Failed to load V-MoBA config from %s: %s", self.moba_config_path, e)
|
| 201 |
+
raise
|
| 202 |
+
self._apply_ltx2_vae_overrides()
|
| 203 |
+
self.check_fastvideo_args()
|
| 204 |
+
|
| 205 |
+
def __getattr__(self, name: str) -> Any:
|
| 206 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 207 |
+
|
| 208 |
+
def _apply_ltx2_vae_overrides(self) -> None:
|
| 209 |
+
if self.pipeline_config is None:
|
| 210 |
+
return
|
| 211 |
+
vae_config = self.pipeline_config.vae_config
|
| 212 |
+
has_any = any(value is not None for value in (
|
| 213 |
+
self.ltx2_vae_spatial_tile_size_in_pixels,
|
| 214 |
+
self.ltx2_vae_spatial_tile_overlap_in_pixels,
|
| 215 |
+
self.ltx2_vae_temporal_tile_size_in_frames,
|
| 216 |
+
self.ltx2_vae_temporal_tile_overlap_in_frames,
|
| 217 |
+
))
|
| 218 |
+
if self.ltx2_vae_tiling is not None and hasattr(self.pipeline_config, "vae_tiling"):
|
| 219 |
+
self.pipeline_config.vae_tiling = self.ltx2_vae_tiling
|
| 220 |
+
elif has_any and hasattr(self.pipeline_config, "vae_tiling"):
|
| 221 |
+
self.pipeline_config.vae_tiling = True
|
| 222 |
+
|
| 223 |
+
if hasattr(vae_config,
|
| 224 |
+
"ltx2_spatial_tile_size_in_pixels") and self.ltx2_vae_spatial_tile_size_in_pixels is not None:
|
| 225 |
+
vae_config.ltx2_spatial_tile_size_in_pixels = (self.ltx2_vae_spatial_tile_size_in_pixels)
|
| 226 |
+
if hasattr(vae_config,
|
| 227 |
+
"ltx2_spatial_tile_overlap_in_pixels") and self.ltx2_vae_spatial_tile_overlap_in_pixels is not None:
|
| 228 |
+
vae_config.ltx2_spatial_tile_overlap_in_pixels = (self.ltx2_vae_spatial_tile_overlap_in_pixels)
|
| 229 |
+
if hasattr(vae_config,
|
| 230 |
+
"ltx2_temporal_tile_size_in_frames") and self.ltx2_vae_temporal_tile_size_in_frames is not None:
|
| 231 |
+
vae_config.ltx2_temporal_tile_size_in_frames = (self.ltx2_vae_temporal_tile_size_in_frames)
|
| 232 |
+
if hasattr(
|
| 233 |
+
vae_config,
|
| 234 |
+
"ltx2_temporal_tile_overlap_in_frames") and self.ltx2_vae_temporal_tile_overlap_in_frames is not None:
|
| 235 |
+
vae_config.ltx2_temporal_tile_overlap_in_frames = (self.ltx2_vae_temporal_tile_overlap_in_frames)
|
| 236 |
+
|
| 237 |
+
@staticmethod
|
| 238 |
+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
| 239 |
+
# Model and path configuration
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--model-path",
|
| 242 |
+
type=str,
|
| 243 |
+
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Running mode
|
| 247 |
+
parser.add_argument(
|
| 248 |
+
"--mode",
|
| 249 |
+
type=str,
|
| 250 |
+
choices=ExecutionMode.choices(),
|
| 251 |
+
default=FastVideoArgs.mode.value,
|
| 252 |
+
help="The mode to run FastVideo",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Workload type
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--workload-type",
|
| 258 |
+
type=str,
|
| 259 |
+
choices=WorkloadType.choices(),
|
| 260 |
+
default=FastVideoArgs.workload_type.value,
|
| 261 |
+
help="The workload type",
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# distributed_executor_backend
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--distributed-executor-backend",
|
| 267 |
+
type=str,
|
| 268 |
+
choices=["mp"],
|
| 269 |
+
default=FastVideoArgs.distributed_executor_backend,
|
| 270 |
+
help="The distributed executor backend to use",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--inference-mode",
|
| 275 |
+
action=StoreBoolean,
|
| 276 |
+
default=FastVideoArgs.inference_mode,
|
| 277 |
+
help="Whether to use inference mode",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# HuggingFace specific parameters
|
| 281 |
+
parser.add_argument(
|
| 282 |
+
"--trust-remote-code",
|
| 283 |
+
action=StoreBoolean,
|
| 284 |
+
default=FastVideoArgs.trust_remote_code,
|
| 285 |
+
help="Trust remote code when loading HuggingFace models",
|
| 286 |
+
)
|
| 287 |
+
parser.add_argument(
|
| 288 |
+
"--revision",
|
| 289 |
+
type=str,
|
| 290 |
+
default=FastVideoArgs.revision,
|
| 291 |
+
help="The specific model version to use (can be a branch name, tag name, or commit id)",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Parallelism
|
| 295 |
+
parser.add_argument(
|
| 296 |
+
"--num-gpus",
|
| 297 |
+
type=int,
|
| 298 |
+
default=FastVideoArgs.num_gpus,
|
| 299 |
+
help="The number of GPUs to use.",
|
| 300 |
+
)
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--tp-size",
|
| 303 |
+
type=int,
|
| 304 |
+
default=FastVideoArgs.tp_size,
|
| 305 |
+
help="The tensor parallelism size.",
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument(
|
| 308 |
+
"--sp-size",
|
| 309 |
+
type=int,
|
| 310 |
+
default=FastVideoArgs.sp_size,
|
| 311 |
+
help="The sequence parallelism size.",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--hsdp-replicate-dim",
|
| 315 |
+
type=int,
|
| 316 |
+
default=FastVideoArgs.hsdp_replicate_dim,
|
| 317 |
+
help="The data parallelism size.",
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--hsdp-shard-dim",
|
| 321 |
+
type=int,
|
| 322 |
+
default=FastVideoArgs.hsdp_shard_dim,
|
| 323 |
+
help="The data parallelism shards.",
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--dist-timeout",
|
| 327 |
+
type=int,
|
| 328 |
+
default=FastVideoArgs.dist_timeout,
|
| 329 |
+
help="Set timeout for torch.distributed initialization.",
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Output type
|
| 333 |
+
parser.add_argument(
|
| 334 |
+
"--output-type",
|
| 335 |
+
type=str,
|
| 336 |
+
default=FastVideoArgs.output_type,
|
| 337 |
+
choices=["pil"],
|
| 338 |
+
help="Output type for the generated video",
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Prompt text file for batch processing
|
| 342 |
+
parser.add_argument(
|
| 343 |
+
"--prompt-txt",
|
| 344 |
+
type=str,
|
| 345 |
+
default=FastVideoArgs.prompt_txt,
|
| 346 |
+
help="Path to a text file containing prompts (one per line) for batch processing",
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# LTX-2 VAE tiling overrides
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--ltx2-vae-tiling",
|
| 352 |
+
action=StoreBoolean,
|
| 353 |
+
default=FastVideoArgs.ltx2_vae_tiling,
|
| 354 |
+
help="Enable LTX-2 VAE tiling overrides.",
|
| 355 |
+
)
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--ltx2-vae-spatial-tile-size-in-pixels",
|
| 358 |
+
type=int,
|
| 359 |
+
default=FastVideoArgs.ltx2_vae_spatial_tile_size_in_pixels,
|
| 360 |
+
help="LTX-2 VAE spatial tile size in pixels.",
|
| 361 |
+
)
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--ltx2-vae-spatial-tile-overlap-in-pixels",
|
| 364 |
+
type=int,
|
| 365 |
+
default=FastVideoArgs.ltx2_vae_spatial_tile_overlap_in_pixels,
|
| 366 |
+
help="LTX-2 VAE spatial tile overlap in pixels.",
|
| 367 |
+
)
|
| 368 |
+
parser.add_argument(
|
| 369 |
+
"--ltx2-vae-temporal-tile-size-in-frames",
|
| 370 |
+
type=int,
|
| 371 |
+
default=FastVideoArgs.ltx2_vae_temporal_tile_size_in_frames,
|
| 372 |
+
help="LTX-2 VAE temporal tile size in frames.",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--ltx2-vae-temporal-tile-overlap-in-frames",
|
| 376 |
+
type=int,
|
| 377 |
+
default=FastVideoArgs.ltx2_vae_temporal_tile_overlap_in_frames,
|
| 378 |
+
help="LTX-2 VAE temporal tile overlap in frames.",
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--ltx2-initial-latent-path",
|
| 382 |
+
type=str,
|
| 383 |
+
default=FastVideoArgs.ltx2_initial_latent_path,
|
| 384 |
+
help="Path to load/save a precomputed LTX-2 initial latent.",
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# LoRA parameters (inference-time adapter loading)
|
| 388 |
+
parser.add_argument(
|
| 389 |
+
"--lora-path",
|
| 390 |
+
type=str,
|
| 391 |
+
default=FastVideoArgs.lora_path,
|
| 392 |
+
help="Path to a LoRA adapter (directory or HF repo id). If set, LoRA will be applied at inference.",
|
| 393 |
+
)
|
| 394 |
+
parser.add_argument(
|
| 395 |
+
"--lora-nickname",
|
| 396 |
+
type=str,
|
| 397 |
+
default=FastVideoArgs.lora_nickname,
|
| 398 |
+
help="Nickname to refer to the loaded LoRA adapter (useful for swapping).",
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--lora-target-modules",
|
| 402 |
+
nargs="+",
|
| 403 |
+
type=str,
|
| 404 |
+
default=FastVideoArgs.lora_target_modules,
|
| 405 |
+
help="Optional list of module name substrings to restrict LoRA injection (e.g. q_proj k_proj v_proj).",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# BSA runtime control (LongCat)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--enable-bsa",
|
| 411 |
+
action=StoreBoolean,
|
| 412 |
+
help="Enable Block Sparse Attention (BSA) at runtime (overrides config).",
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--bsa-sparsity",
|
| 416 |
+
type=float,
|
| 417 |
+
help="BSA sparsity (e.g., 0.9375).",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--bsa-cdf-threshold",
|
| 421 |
+
type=float,
|
| 422 |
+
help="BSA CDF threshold (optional).",
|
| 423 |
+
)
|
| 424 |
+
parser.add_argument(
|
| 425 |
+
"--bsa-chunk-q",
|
| 426 |
+
nargs=3,
|
| 427 |
+
type=int,
|
| 428 |
+
metavar=("T", "H", "W"),
|
| 429 |
+
help="BSA chunk_3d_shape_q as three ints, e.g., 4 4 4.",
|
| 430 |
+
)
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--bsa-chunk-k",
|
| 433 |
+
nargs=3,
|
| 434 |
+
type=int,
|
| 435 |
+
metavar=("T", "H", "W"),
|
| 436 |
+
help="BSA chunk_3d_shape_k as three ints, e.g., 4 4 4.",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
parser.add_argument(
|
| 440 |
+
"--enable-torch-compile",
|
| 441 |
+
action=StoreBoolean,
|
| 442 |
+
default=FastVideoArgs.enable_torch_compile,
|
| 443 |
+
help="Use torch.compile to speed up DiT inference." +
|
| 444 |
+
"However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)",
|
| 445 |
+
)
|
| 446 |
+
parser.add_argument(
|
| 447 |
+
"--torch-compile-kwargs",
|
| 448 |
+
type=str,
|
| 449 |
+
default=None,
|
| 450 |
+
help=
|
| 451 |
+
"JSON string of kwargs to pass to torch.compile. Example: '{\"backend\":\"inductor\",\"mode\":\"reduce-overhead\"}'",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
parser.add_argument(
|
| 455 |
+
"--dit-cpu-offload",
|
| 456 |
+
action=StoreBoolean,
|
| 457 |
+
help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
|
| 458 |
+
)
|
| 459 |
+
parser.add_argument(
|
| 460 |
+
"--dit-layerwise-offload",
|
| 461 |
+
action=StoreBoolean,
|
| 462 |
+
help="Enable layerwise CPU offload with async H2D prefetch overlap.",
|
| 463 |
+
)
|
| 464 |
+
parser.add_argument(
|
| 465 |
+
"--use-fsdp-inference",
|
| 466 |
+
action=StoreBoolean,
|
| 467 |
+
help=
|
| 468 |
+
"Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
|
| 469 |
+
)
|
| 470 |
+
parser.add_argument(
|
| 471 |
+
"--text-encoder-cpu-offload",
|
| 472 |
+
action=StoreBoolean,
|
| 473 |
+
help="Use CPU offload for text encoder. Enable if run out of memory.",
|
| 474 |
+
)
|
| 475 |
+
parser.add_argument(
|
| 476 |
+
"--image-encoder-cpu-offload",
|
| 477 |
+
action=StoreBoolean,
|
| 478 |
+
help="Use CPU offload for image encoder. Enable if run out of memory.",
|
| 479 |
+
)
|
| 480 |
+
parser.add_argument(
|
| 481 |
+
"--vae-cpu-offload",
|
| 482 |
+
action=StoreBoolean,
|
| 483 |
+
help="Use CPU offload for VAE. Enable if run out of memory.",
|
| 484 |
+
)
|
| 485 |
+
parser.add_argument(
|
| 486 |
+
"--pin-cpu-memory",
|
| 487 |
+
action=StoreBoolean,
|
| 488 |
+
help=
|
| 489 |
+
"Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
|
| 490 |
+
"Should be enabled in almost all cases",
|
| 491 |
+
)
|
| 492 |
+
parser.add_argument(
|
| 493 |
+
"--disable-autocast",
|
| 494 |
+
action=StoreBoolean,
|
| 495 |
+
help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
# VSA parameters
|
| 499 |
+
parser.add_argument(
|
| 500 |
+
"--VSA-sparsity",
|
| 501 |
+
type=float,
|
| 502 |
+
default=FastVideoArgs.VSA_sparsity,
|
| 503 |
+
help="Validation sparsity for VSA",
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Master port for distributed training/inference
|
| 507 |
+
parser.add_argument(
|
| 508 |
+
"--master-port",
|
| 509 |
+
type=int,
|
| 510 |
+
default=FastVideoArgs.master_port,
|
| 511 |
+
help="Master port for distributed training/inference",
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
# Stage verification
|
| 515 |
+
parser.add_argument(
|
| 516 |
+
"--enable-stage-verification",
|
| 517 |
+
action=StoreBoolean,
|
| 518 |
+
default=FastVideoArgs.enable_stage_verification,
|
| 519 |
+
help="Enable input/output verification for pipeline stages",
|
| 520 |
+
)
|
| 521 |
+
parser.add_argument(
|
| 522 |
+
"--override-text-encoder-safetensors",
|
| 523 |
+
type=str,
|
| 524 |
+
default=FastVideoArgs.override_text_encoder_safetensors,
|
| 525 |
+
help="Path to safetensors file for text encoder override",
|
| 526 |
+
)
|
| 527 |
+
parser.add_argument(
|
| 528 |
+
"--override-text-encoder-quant",
|
| 529 |
+
type=str,
|
| 530 |
+
choices=QUANTIZATION_METHODS,
|
| 531 |
+
default=FastVideoArgs.override_text_encoder_quant,
|
| 532 |
+
help="Quantization method for text encoder override",
|
| 533 |
+
)
|
| 534 |
+
parser.add_argument(
|
| 535 |
+
"--transformer-quant",
|
| 536 |
+
type=str,
|
| 537 |
+
choices=QUANTIZATION_METHODS,
|
| 538 |
+
default=FastVideoArgs.transformer_quant,
|
| 539 |
+
help="Quantization method for transformer loading",
|
| 540 |
+
)
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--override-transformer-cls-name",
|
| 543 |
+
type=str,
|
| 544 |
+
default=FastVideoArgs.override_transformer_cls_name,
|
| 545 |
+
help="Override transformer cls name",
|
| 546 |
+
)
|
| 547 |
+
parser.add_argument(
|
| 548 |
+
"--override-pipeline-cls-name",
|
| 549 |
+
type=str,
|
| 550 |
+
default=FastVideoArgs.override_pipeline_cls_name,
|
| 551 |
+
help="Override pipeline cls name",
|
| 552 |
+
)
|
| 553 |
+
parser.add_argument("--init-weights-from-safetensors",
|
| 554 |
+
type=str,
|
| 555 |
+
help="Path to safetensors file for initial weight loading")
|
| 556 |
+
parser.add_argument("--init-weights-from-safetensors-2",
|
| 557 |
+
type=str,
|
| 558 |
+
help="Path to safetensors file for initial weight loading")
|
| 559 |
+
|
| 560 |
+
# Add pipeline configuration arguments
|
| 561 |
+
PipelineConfig.add_cli_args(parser)
|
| 562 |
+
|
| 563 |
+
# Add preprocessing configuration arguments
|
| 564 |
+
PreprocessConfig.add_cli_args(parser)
|
| 565 |
+
|
| 566 |
+
return parser
|
| 567 |
+
|
| 568 |
+
@classmethod
|
| 569 |
+
def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
|
| 570 |
+
provided_args = clean_cli_args(args)
|
| 571 |
+
# Get all fields from the dataclass
|
| 572 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 573 |
+
|
| 574 |
+
# Create a dictionary of attribute values, with defaults for missing attributes
|
| 575 |
+
kwargs: dict[str, Any] = {}
|
| 576 |
+
for attr in attrs:
|
| 577 |
+
if attr == 'pipeline_config':
|
| 578 |
+
pipeline_config = PipelineConfig.from_kwargs(provided_args)
|
| 579 |
+
kwargs['pipeline_config'] = pipeline_config
|
| 580 |
+
elif attr == 'preprocess_config':
|
| 581 |
+
preprocess_config = PreprocessConfig.from_kwargs(provided_args)
|
| 582 |
+
kwargs['preprocess_config'] = preprocess_config
|
| 583 |
+
elif attr == 'mode':
|
| 584 |
+
# Convert string to ExecutionMode enum
|
| 585 |
+
mode_value = getattr(args, attr, FastVideoArgs.mode.value)
|
| 586 |
+
kwargs['mode'] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
|
| 587 |
+
elif attr == 'torch_compile_kwargs':
|
| 588 |
+
# Parse JSON string for torch.compile kwargs
|
| 589 |
+
torch_compile_kwargs_str = getattr(args, 'torch_compile_kwargs', None)
|
| 590 |
+
if torch_compile_kwargs_str:
|
| 591 |
+
try:
|
| 592 |
+
import json
|
| 593 |
+
kwargs['torch_compile_kwargs'] = json.loads(torch_compile_kwargs_str)
|
| 594 |
+
except json.JSONDecodeError as e:
|
| 595 |
+
raise ValueError(f"Invalid JSON for torch_compile_kwargs: {e}") from e
|
| 596 |
+
else:
|
| 597 |
+
kwargs['torch_compile_kwargs'] = {}
|
| 598 |
+
elif attr == 'workload_type':
|
| 599 |
+
# Convert string to WorkloadType enum
|
| 600 |
+
workload_type_value = getattr(args, 'workload_type', FastVideoArgs.workload_type.value)
|
| 601 |
+
kwargs['workload_type'] = WorkloadType.from_string(workload_type_value) if isinstance(
|
| 602 |
+
workload_type_value, str) else workload_type_value
|
| 603 |
+
# Use getattr with default value from the dataclass for potentially missing attributes
|
| 604 |
+
else:
|
| 605 |
+
# Get the field to check if it has a default_factory
|
| 606 |
+
field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
|
| 607 |
+
if f.name == attr)]
|
| 608 |
+
if field.default_factory is not dataclasses.MISSING:
|
| 609 |
+
# Use the default_factory to create the default value
|
| 610 |
+
default_value = field.default_factory()
|
| 611 |
+
else:
|
| 612 |
+
default_value = getattr(cls, attr, None)
|
| 613 |
+
value = getattr(args, attr, default_value)
|
| 614 |
+
kwargs[attr] = value # type: ignore
|
| 615 |
+
|
| 616 |
+
return cls(**kwargs) # type: ignore
|
| 617 |
+
|
| 618 |
+
@classmethod
|
| 619 |
+
def from_kwargs(cls, **kwargs: Any) -> "FastVideoArgs":
|
| 620 |
+
# Convert mode string to enum if necessary
|
| 621 |
+
if 'mode' in kwargs and isinstance(kwargs['mode'], str):
|
| 622 |
+
kwargs['mode'] = ExecutionMode.from_string(kwargs['mode'])
|
| 623 |
+
|
| 624 |
+
# Convert workload_type string to enum if necessary
|
| 625 |
+
if 'workload_type' in kwargs and isinstance(kwargs['workload_type'], str):
|
| 626 |
+
kwargs['workload_type'] = WorkloadType.from_string(kwargs['workload_type'])
|
| 627 |
+
|
| 628 |
+
kwargs['pipeline_config'] = PipelineConfig.from_kwargs(kwargs)
|
| 629 |
+
kwargs['preprocess_config'] = PreprocessConfig.from_kwargs(kwargs)
|
| 630 |
+
# Filter to only FastVideoArgs dataclass fields — pipeline-specific CLI
|
| 631 |
+
# args (e.g. enable_bsa, bsa_sparsity) live in PipelineConfig and must
|
| 632 |
+
# not be forwarded to the FastVideoArgs constructor.
|
| 633 |
+
valid_fields = {f.name for f in dataclasses.fields(cls)}
|
| 634 |
+
return cls(**{k: v for k, v in kwargs.items() if k in valid_fields})
|
| 635 |
+
|
| 636 |
+
def check_fastvideo_args(self) -> None:
|
| 637 |
+
"""Validate inference arguments for consistency"""
|
| 638 |
+
from fastvideo.platforms import current_platform
|
| 639 |
+
|
| 640 |
+
if current_platform.is_mps():
|
| 641 |
+
self.use_fsdp_inference = False
|
| 642 |
+
self.dit_layerwise_offload = False
|
| 643 |
+
|
| 644 |
+
if self.dit_layerwise_offload:
|
| 645 |
+
if self.use_fsdp_inference:
|
| 646 |
+
logger.warning("dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference.")
|
| 647 |
+
self.use_fsdp_inference = False
|
| 648 |
+
if self.dit_cpu_offload:
|
| 649 |
+
logger.warning("dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload.")
|
| 650 |
+
self.dit_cpu_offload = False
|
| 651 |
+
|
| 652 |
+
# Validate mode and inference_mode consistency
|
| 653 |
+
assert isinstance(self.mode, ExecutionMode), f"Mode must be an ExecutionMode enum, got {type(self.mode)}"
|
| 654 |
+
assert self.mode in ExecutionMode.choices(), f"Invalid execution mode: {self.mode}"
|
| 655 |
+
|
| 656 |
+
# Validate workload type
|
| 657 |
+
assert isinstance(self.workload_type,
|
| 658 |
+
WorkloadType), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}"
|
| 659 |
+
assert self.workload_type in WorkloadType.choices(), f"Invalid workload type: {self.workload_type}"
|
| 660 |
+
|
| 661 |
+
if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING] and self.inference_mode:
|
| 662 |
+
logger.warning("Mode is 'training' but inference_mode is True. Setting inference_mode to False.")
|
| 663 |
+
self.inference_mode = False
|
| 664 |
+
elif self.mode in [ExecutionMode.INFERENCE, ExecutionMode.PREPROCESS] and not self.inference_mode:
|
| 665 |
+
logger.warning("Mode is '%s' but inference_mode is False. Setting inference_mode to True.", self.mode)
|
| 666 |
+
self.inference_mode = True
|
| 667 |
+
|
| 668 |
+
if not self.inference_mode:
|
| 669 |
+
assert self.hsdp_replicate_dim != -1, "hsdp_replicate_dim must be set for training"
|
| 670 |
+
assert self.hsdp_shard_dim != -1, "hsdp_shard_dim must be set for training"
|
| 671 |
+
assert self.sp_size != -1, "sp_size must be set for training"
|
| 672 |
+
|
| 673 |
+
if self.tp_size == -1:
|
| 674 |
+
self.tp_size = 1
|
| 675 |
+
if self.sp_size == -1:
|
| 676 |
+
self.sp_size = self.num_gpus
|
| 677 |
+
if self.hsdp_shard_dim == -1:
|
| 678 |
+
self.hsdp_shard_dim = self.num_gpus
|
| 679 |
+
|
| 680 |
+
assert self.sp_size <= self.num_gpus and self.num_gpus % self.sp_size == 0, "num_gpus must >= and be divisible by sp_size"
|
| 681 |
+
assert self.hsdp_replicate_dim <= self.num_gpus and self.num_gpus % self.hsdp_replicate_dim == 0, "num_gpus must >= and be divisible by hsdp_replicate_dim"
|
| 682 |
+
assert self.hsdp_shard_dim <= self.num_gpus and self.num_gpus % self.hsdp_shard_dim == 0, "num_gpus must >= and be divisible by hsdp_shard_dim"
|
| 683 |
+
|
| 684 |
+
if self.num_gpus < max(self.tp_size, self.sp_size):
|
| 685 |
+
self.num_gpus = max(self.tp_size, self.sp_size)
|
| 686 |
+
|
| 687 |
+
if self.pipeline_config is None:
|
| 688 |
+
raise ValueError("pipeline_config is not set in FastVideoArgs")
|
| 689 |
+
|
| 690 |
+
self.pipeline_config.check_pipeline_config()
|
| 691 |
+
|
| 692 |
+
# Add preprocessing config validation if needed
|
| 693 |
+
if self.mode == ExecutionMode.PREPROCESS:
|
| 694 |
+
if self.preprocess_config is None:
|
| 695 |
+
raise ValueError("preprocess_config is not set in FastVideoArgs when mode is PREPROCESS")
|
| 696 |
+
if self.preprocess_config.model_path == "":
|
| 697 |
+
self.preprocess_config.model_path = self.model_path
|
| 698 |
+
if not self.pipeline_config.vae_config.load_encoder:
|
| 699 |
+
self.pipeline_config.vae_config.load_encoder = True
|
| 700 |
+
self.preprocess_config.check_preprocess_config()
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
_current_fastvideo_args = None
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs:
|
| 707 |
+
"""
|
| 708 |
+
Prepare the inference arguments from the command line arguments.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
argv: The command line arguments. Typically, it should be `sys.argv[1:]`
|
| 712 |
+
to ensure compatibility with `parse_args` when no arguments are passed.
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
The inference arguments.
|
| 716 |
+
"""
|
| 717 |
+
parser = FlexibleArgumentParser()
|
| 718 |
+
FastVideoArgs.add_cli_args(parser)
|
| 719 |
+
raw_args = parser.parse_args(argv)
|
| 720 |
+
fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
|
| 721 |
+
global _current_fastvideo_args
|
| 722 |
+
_current_fastvideo_args = fastvideo_args
|
| 723 |
+
return fastvideo_args
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
@contextmanager
|
| 727 |
+
def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
|
| 728 |
+
"""
|
| 729 |
+
Temporarily set the current fastvideo config.
|
| 730 |
+
Used during model initialization.
|
| 731 |
+
We save the current fastvideo config in a global variable,
|
| 732 |
+
so that all modules can access it, e.g. custom ops
|
| 733 |
+
can access the fastvideo config to determine how to dispatch.
|
| 734 |
+
"""
|
| 735 |
+
global _current_fastvideo_args
|
| 736 |
+
old_fastvideo_args = _current_fastvideo_args
|
| 737 |
+
try:
|
| 738 |
+
_current_fastvideo_args = fastvideo_args
|
| 739 |
+
yield
|
| 740 |
+
finally:
|
| 741 |
+
_current_fastvideo_args = old_fastvideo_args
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def get_current_fastvideo_args() -> FastVideoArgs:
|
| 745 |
+
if _current_fastvideo_args is None:
|
| 746 |
+
# in ci, usually when we test custom ops/modules directly,
|
| 747 |
+
# we don't set the fastvideo config. In that case, we set a default
|
| 748 |
+
# config.
|
| 749 |
+
# TODO(will): may need to handle this for CI.
|
| 750 |
+
raise ValueError("Current fastvideo args is not set.")
|
| 751 |
+
return _current_fastvideo_args
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
@dataclasses.dataclass
|
| 755 |
+
class TrainingArgs(FastVideoArgs):
|
| 756 |
+
"""
|
| 757 |
+
Training arguments. Inherits from FastVideoArgs and adds training-specific
|
| 758 |
+
arguments. If there are any conflicts, the training arguments will take
|
| 759 |
+
precedence.
|
| 760 |
+
"""
|
| 761 |
+
data_path: str = ""
|
| 762 |
+
dataloader_num_workers: int = 0
|
| 763 |
+
num_height: int = 0
|
| 764 |
+
num_width: int = 0
|
| 765 |
+
num_frames: int = 0
|
| 766 |
+
|
| 767 |
+
train_batch_size: int = 0
|
| 768 |
+
num_latent_t: int = 0
|
| 769 |
+
group_frame: bool = False
|
| 770 |
+
group_resolution: bool = False
|
| 771 |
+
|
| 772 |
+
# text encoder & vae & diffusion model
|
| 773 |
+
pretrained_model_name_or_path: str = ""
|
| 774 |
+
|
| 775 |
+
# DMD model paths - separate paths for each network
|
| 776 |
+
real_score_model_path: str = "" # path for real score (teacher) model
|
| 777 |
+
fake_score_model_path: str = "" # path for fake score (critic) model
|
| 778 |
+
|
| 779 |
+
# diffusion setting
|
| 780 |
+
ema_decay: float = 0.0
|
| 781 |
+
ema_start_step: int = 0
|
| 782 |
+
training_cfg_rate: float = 0.0
|
| 783 |
+
precondition_outputs: bool = False
|
| 784 |
+
|
| 785 |
+
# validation & logs
|
| 786 |
+
validation_dataset_file: str = ""
|
| 787 |
+
validation_preprocessed_path: str = ""
|
| 788 |
+
validation_sampling_steps: str = ""
|
| 789 |
+
validation_guidance_scale: str = ""
|
| 790 |
+
validation_steps: float = 0.0
|
| 791 |
+
log_validation: bool = False
|
| 792 |
+
trackers: list[str] = dataclasses.field(default_factory=list)
|
| 793 |
+
tracker_project_name: str = ""
|
| 794 |
+
wandb_run_name: str = ""
|
| 795 |
+
seed: int = 0
|
| 796 |
+
_loading_teacher_critic_model: bool = False
|
| 797 |
+
|
| 798 |
+
# output
|
| 799 |
+
output_dir: str = ""
|
| 800 |
+
checkpoints_total_limit: int = 0
|
| 801 |
+
resume_from_checkpoint: str = "" # specify the checkpoint folder to resume from
|
| 802 |
+
|
| 803 |
+
# optimizer & scheduler
|
| 804 |
+
num_train_epochs: int = 0
|
| 805 |
+
max_train_steps: int = 0
|
| 806 |
+
gradient_accumulation_steps: int = 0
|
| 807 |
+
learning_rate: float = 0.0
|
| 808 |
+
scale_lr: bool = False
|
| 809 |
+
lr_scheduler: str = "constant"
|
| 810 |
+
lr_warmup_steps: int = 0
|
| 811 |
+
max_grad_norm: float = 0.0
|
| 812 |
+
enable_gradient_checkpointing_type: str | None = None
|
| 813 |
+
selective_checkpointing: float = 0.0
|
| 814 |
+
mixed_precision: str = ""
|
| 815 |
+
train_sp_batch_size: int = 0
|
| 816 |
+
fsdp_sharding_startegy: str = ""
|
| 817 |
+
|
| 818 |
+
weighting_scheme: str = ""
|
| 819 |
+
logit_mean: float = 0.0
|
| 820 |
+
logit_std: float = 1.0
|
| 821 |
+
mode_scale: float = 0.0
|
| 822 |
+
|
| 823 |
+
num_euler_timesteps: int = 0
|
| 824 |
+
lr_num_cycles: int = 0
|
| 825 |
+
lr_power: float = 0.0
|
| 826 |
+
min_lr_ratio: float = 0.5 # minimum learning rate ratio for cosine_with_min_lr scheduler
|
| 827 |
+
not_apply_cfg_solver: bool = False
|
| 828 |
+
distill_cfg: float = 0.0
|
| 829 |
+
scheduler_type: str = ""
|
| 830 |
+
linear_quadratic_threshold: float = 0.0
|
| 831 |
+
linear_range: float = 0.0
|
| 832 |
+
weight_decay: float = 0.0
|
| 833 |
+
betas: str = "0.9,0.999" # betas for optimizer, format: "beta1,beta2"
|
| 834 |
+
use_ema: bool = False
|
| 835 |
+
multi_phased_distill_schedule: str = ""
|
| 836 |
+
pred_decay_weight: float = 0.0
|
| 837 |
+
pred_decay_type: str = ""
|
| 838 |
+
hunyuan_teacher_disable_cfg: bool = False
|
| 839 |
+
|
| 840 |
+
# master_weight_type
|
| 841 |
+
master_weight_type: str = ""
|
| 842 |
+
|
| 843 |
+
# VSA training decay parameters
|
| 844 |
+
VSA_decay_rate: float = 0.01 # decay rate -> 0.02
|
| 845 |
+
VSA_decay_interval_steps: int = 1 # decay interval steps -> 50
|
| 846 |
+
VSA_init_sparsity: float = 0.0 # initial sparsity (default 0, ramp from 0)
|
| 847 |
+
VSA_warmup_steps: int = 0 # keep init_sparsity for this many steps before ramping
|
| 848 |
+
|
| 849 |
+
# LoRA training parameters
|
| 850 |
+
lora_rank: int | None = None
|
| 851 |
+
lora_alpha: int | None = None
|
| 852 |
+
lora_training: bool = False
|
| 853 |
+
ltx2_first_frame_conditioning_p: float = 0.1
|
| 854 |
+
|
| 855 |
+
# distillation args
|
| 856 |
+
generator_update_interval: int = 5
|
| 857 |
+
dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic
|
| 858 |
+
min_timestep_ratio: float = 0.2
|
| 859 |
+
max_timestep_ratio: float = 0.98
|
| 860 |
+
real_score_guidance_scale: float = 3.5
|
| 861 |
+
fake_score_learning_rate: float = 0.0 # separate learning rate for fake_score_transformer, if 0.0, use learning_rate
|
| 862 |
+
fake_score_lr_scheduler: str = "constant" # separate lr scheduler for fake_score_transformer, if not set, use lr_scheduler
|
| 863 |
+
fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2"
|
| 864 |
+
training_state_checkpointing_steps: int = 0 # for resuming training
|
| 865 |
+
weight_only_checkpointing_steps: int = 0 # for inference
|
| 866 |
+
log_visualization: bool = False
|
| 867 |
+
visualization_steps: int = 0
|
| 868 |
+
# simulate generator forward to match inference
|
| 869 |
+
simulate_generator_forward: bool = False
|
| 870 |
+
warp_denoising_step: bool = False
|
| 871 |
+
generator_4bit_attn: bool = False
|
| 872 |
+
generator_4bit_linear: bool = False
|
| 873 |
+
|
| 874 |
+
# Self-forcing specific arguments
|
| 875 |
+
num_frame_per_block: int = 3
|
| 876 |
+
independent_first_frame: bool = False
|
| 877 |
+
enable_gradient_masking: bool = True
|
| 878 |
+
gradient_mask_last_n_frames: int = 21
|
| 879 |
+
same_step_across_blocks: bool = False # Use same exit timestep for all blocks
|
| 880 |
+
last_step_only: bool = False # Only use the last timestep for training
|
| 881 |
+
context_noise: int = 0 # Context noise level for cache updates
|
| 882 |
+
|
| 883 |
+
@classmethod
|
| 884 |
+
def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
|
| 885 |
+
provided_args = clean_cli_args(args)
|
| 886 |
+
# Get all fields from the dataclass
|
| 887 |
+
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
| 888 |
+
logger.info(provided_args)
|
| 889 |
+
# Create a dictionary of attribute values, with defaults for missing attributes
|
| 890 |
+
kwargs: dict[str, Any] = {}
|
| 891 |
+
for attr in attrs:
|
| 892 |
+
if attr == 'pipeline_config':
|
| 893 |
+
pipeline_config = PipelineConfig.from_kwargs(provided_args)
|
| 894 |
+
kwargs[attr] = pipeline_config
|
| 895 |
+
elif attr == 'mode':
|
| 896 |
+
# Convert string to ExecutionMode enum
|
| 897 |
+
mode_value = getattr(args, attr, ExecutionMode.FINETUNING.value)
|
| 898 |
+
kwargs[attr] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
|
| 899 |
+
elif attr == 'workload_type':
|
| 900 |
+
# Convert string to WorkloadType enum
|
| 901 |
+
workload_type_value = getattr(args, 'workload_type', WorkloadType.T2V.value)
|
| 902 |
+
kwargs[attr] = WorkloadType.from_string(workload_type_value) if isinstance(workload_type_value,
|
| 903 |
+
str) else workload_type_value
|
| 904 |
+
# Use getattr with default value from the dataclass for potentially missing attributes
|
| 905 |
+
else:
|
| 906 |
+
# Get the field to check its default value
|
| 907 |
+
field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
|
| 908 |
+
if f.name == attr)]
|
| 909 |
+
|
| 910 |
+
# Check if the attribute is provided in args
|
| 911 |
+
if hasattr(args, attr):
|
| 912 |
+
value = getattr(args, attr)
|
| 913 |
+
else:
|
| 914 |
+
# Use the field's default value
|
| 915 |
+
if field.default_factory is not dataclasses.MISSING:
|
| 916 |
+
value = field.default_factory()
|
| 917 |
+
elif field.default is not dataclasses.MISSING:
|
| 918 |
+
value = field.default
|
| 919 |
+
else:
|
| 920 |
+
# No default value, use None
|
| 921 |
+
value = None
|
| 922 |
+
|
| 923 |
+
kwargs[attr] = value
|
| 924 |
+
|
| 925 |
+
return cls(**kwargs) # type: ignore
|
| 926 |
+
|
| 927 |
+
@staticmethod
|
| 928 |
+
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
| 929 |
+
parser.add_argument("--data-path", type=str, required=True, help="Path to parquet files")
|
| 930 |
+
parser.add_argument("--dataloader-num-workers",
|
| 931 |
+
type=int,
|
| 932 |
+
required=True,
|
| 933 |
+
help="Number of workers for dataloader")
|
| 934 |
+
parser.add_argument("--num-height", type=int, required=True, help="Number of heights")
|
| 935 |
+
parser.add_argument("--num-width", type=int, required=True, help="Number of widths")
|
| 936 |
+
parser.add_argument("--num-frames", type=int, required=True, help="Number of frames")
|
| 937 |
+
|
| 938 |
+
# Training batch and model configuration
|
| 939 |
+
parser.add_argument("--train-batch-size", type=int, required=True, help="Training batch size")
|
| 940 |
+
parser.add_argument("--num-latent-t", type=int, required=True, help="Number of latent time steps")
|
| 941 |
+
parser.add_argument("--group-frame", action=StoreBoolean, help="Whether to group frames during training")
|
| 942 |
+
parser.add_argument("--group-resolution",
|
| 943 |
+
action=StoreBoolean,
|
| 944 |
+
help="Whether to group resolutions during training")
|
| 945 |
+
|
| 946 |
+
# Model paths
|
| 947 |
+
parser.add_argument("--pretrained-model-name-or-path",
|
| 948 |
+
type=str,
|
| 949 |
+
required=True,
|
| 950 |
+
help="Path to pretrained model or model name")
|
| 951 |
+
parser.add_argument("--dit-model-name-or-path",
|
| 952 |
+
type=str,
|
| 953 |
+
required=False,
|
| 954 |
+
help="Path to DiT model or model name")
|
| 955 |
+
parser.add_argument("--cache-dir", type=str, help="Directory to cache models")
|
| 956 |
+
|
| 957 |
+
# DMD model paths - separate paths for each network
|
| 958 |
+
parser.add_argument("--generator-model-path",
|
| 959 |
+
type=str,
|
| 960 |
+
help="Path to generator (student) model for DMD distillation")
|
| 961 |
+
parser.add_argument("--real-score-model-path",
|
| 962 |
+
type=str,
|
| 963 |
+
help="Path to real score (teacher) model for DMD distillation")
|
| 964 |
+
parser.add_argument("--fake-score-model-path",
|
| 965 |
+
type=str,
|
| 966 |
+
help="Path to fake score (critic) model for DMD distillation")
|
| 967 |
+
|
| 968 |
+
# Diffusion settings
|
| 969 |
+
parser.add_argument("--ema-decay", type=float, default=0.999, help="EMA decay rate")
|
| 970 |
+
parser.add_argument("--ema-start-step", type=int, default=0, help="Step to start EMA")
|
| 971 |
+
parser.add_argument("--training-cfg-rate", type=float, help="Classifier-free guidance scale")
|
| 972 |
+
parser.add_argument("--precondition-outputs",
|
| 973 |
+
action=StoreBoolean,
|
| 974 |
+
help="Whether to precondition the outputs of the model")
|
| 975 |
+
|
| 976 |
+
# Validation and logging
|
| 977 |
+
parser.add_argument("--validation-dataset-file", type=str, help="Path to unprocessed validation dataset")
|
| 978 |
+
parser.add_argument("--validation-preprocessed-path", type=str, help="Path to processed validation dataset")
|
| 979 |
+
parser.add_argument("--validation-sampling-steps", type=str, help="Validation sampling steps")
|
| 980 |
+
parser.add_argument("--validation-guidance-scale", type=str, help="Validation guidance scale")
|
| 981 |
+
parser.add_argument("--validation-steps", type=float, help="Number of validation steps")
|
| 982 |
+
parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results")
|
| 983 |
+
parser.add_argument("--visualization-steps", type=int, help="Number of visualization steps")
|
| 984 |
+
parser.add_argument("--tracker-project-name", type=str, help="Project name for tracking")
|
| 985 |
+
parser.add_argument("--wandb-run-name", type=str, help="Run name for wandb")
|
| 986 |
+
parser.add_argument("--seed", type=int, default=42, help="Seed for deterministic training")
|
| 987 |
+
|
| 988 |
+
# Output configuration
|
| 989 |
+
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for checkpoints and logs")
|
| 990 |
+
parser.add_argument("--checkpoints-total-limit", type=int, help="Maximum number of checkpoints to keep")
|
| 991 |
+
parser.add_argument("--training-state-checkpointing-steps",
|
| 992 |
+
type=int,
|
| 993 |
+
help="Steps between training state checkpoints (for resuming training)")
|
| 994 |
+
parser.add_argument("--weight-only-checkpointing-steps",
|
| 995 |
+
type=int,
|
| 996 |
+
help="Steps between weight-only checkpoints (for inference)")
|
| 997 |
+
parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from")
|
| 998 |
+
parser.add_argument("--logging-dir", type=str, help="Directory for logging")
|
| 999 |
+
|
| 1000 |
+
# Training configuration
|
| 1001 |
+
parser.add_argument("--num-train-epochs", type=int, help="Number of training epochs")
|
| 1002 |
+
parser.add_argument("--max-train-steps", type=int, help="Maximum number of training steps")
|
| 1003 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of steps to accumulate gradients")
|
| 1004 |
+
parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
|
| 1005 |
+
parser.add_argument("--scale-lr", action=StoreBoolean, help="Whether to scale learning rate")
|
| 1006 |
+
parser.add_argument("--lr-scheduler", type=str, default="constant", help="Learning rate scheduler type")
|
| 1007 |
+
parser.add_argument("--lr-warmup-steps", type=int, default=10, help="Number of warmup steps for learning rate")
|
| 1008 |
+
parser.add_argument("--max-grad-norm", type=float, help="Maximum gradient norm")
|
| 1009 |
+
parser.add_argument("--enable-gradient-checkpointing-type",
|
| 1010 |
+
type=str,
|
| 1011 |
+
choices=["full", "ops", "block_skip"],
|
| 1012 |
+
default=None,
|
| 1013 |
+
help="Gradient checkpointing type")
|
| 1014 |
+
parser.add_argument("--selective-checkpointing", type=float, help="Selective checkpointing threshold")
|
| 1015 |
+
parser.add_argument("--mixed-precision", type=str, help="Mixed precision training type")
|
| 1016 |
+
parser.add_argument("--train-sp-batch-size", type=int, help="Training spatial parallelism batch size")
|
| 1017 |
+
|
| 1018 |
+
parser.add_argument("--fsdp-sharding-strategy", type=str, help="FSDP sharding strategy")
|
| 1019 |
+
|
| 1020 |
+
parser.add_argument(
|
| 1021 |
+
"--weighting_scheme",
|
| 1022 |
+
type=str,
|
| 1023 |
+
default="uniform",
|
| 1024 |
+
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
|
| 1025 |
+
)
|
| 1026 |
+
parser.add_argument(
|
| 1027 |
+
"--logit_mean",
|
| 1028 |
+
type=float,
|
| 1029 |
+
default=0.0,
|
| 1030 |
+
help="mean to use when using the `'logit_normal'` weighting scheme.",
|
| 1031 |
+
)
|
| 1032 |
+
parser.add_argument(
|
| 1033 |
+
"--logit_std",
|
| 1034 |
+
type=float,
|
| 1035 |
+
default=1.0,
|
| 1036 |
+
help="std to use when using the `'logit_normal'` weighting scheme.",
|
| 1037 |
+
)
|
| 1038 |
+
parser.add_argument(
|
| 1039 |
+
"--mode_scale",
|
| 1040 |
+
type=float,
|
| 1041 |
+
default=1.29,
|
| 1042 |
+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
# Additional training parameters
|
| 1046 |
+
parser.add_argument("--num-euler-timesteps", type=int, help="Number of Euler timesteps")
|
| 1047 |
+
parser.add_argument("--lr-num-cycles", type=int, help="Number of learning rate cycles")
|
| 1048 |
+
parser.add_argument("--lr-power", type=float, help="Learning rate power")
|
| 1049 |
+
parser.add_argument("--min-lr-ratio",
|
| 1050 |
+
type=float,
|
| 1051 |
+
default=TrainingArgs.min_lr_ratio,
|
| 1052 |
+
help="Minimum learning rate ratio for cosine_with_min_lr scheduler")
|
| 1053 |
+
parser.add_argument("--not-apply-cfg-solver", action=StoreBoolean, help="Whether to not apply CFG solver")
|
| 1054 |
+
parser.add_argument("--distill-cfg", type=float, help="Distillation CFG scale")
|
| 1055 |
+
parser.add_argument("--scheduler-type", type=str, help="Scheduler type")
|
| 1056 |
+
parser.add_argument("--linear-quadratic-threshold", type=float, help="Linear quadratic threshold")
|
| 1057 |
+
parser.add_argument("--linear-range", type=float, help="Linear range")
|
| 1058 |
+
parser.add_argument("--weight-decay", type=float, help="Weight decay")
|
| 1059 |
+
parser.add_argument("--betas",
|
| 1060 |
+
type=str,
|
| 1061 |
+
default=TrainingArgs.betas,
|
| 1062 |
+
help="Betas for optimizer (format: 'beta1,beta2')")
|
| 1063 |
+
parser.add_argument("--use-ema", action=StoreBoolean, help="Whether to use EMA")
|
| 1064 |
+
parser.add_argument("--multi-phased-distill-schedule", type=str, help="Multi-phased distillation schedule")
|
| 1065 |
+
parser.add_argument("--pred-decay-weight", type=float, help="Prediction decay weight")
|
| 1066 |
+
parser.add_argument("--pred-decay-type", type=str, help="Prediction decay type")
|
| 1067 |
+
parser.add_argument("--hunyuan-teacher-disable-cfg",
|
| 1068 |
+
action=StoreBoolean,
|
| 1069 |
+
help="Whether to disable CFG for Hunyuan teacher")
|
| 1070 |
+
parser.add_argument("--master-weight-type", type=str, help="Master weight type")
|
| 1071 |
+
|
| 1072 |
+
# VSA parameters for training with dense to sparse adaption
|
| 1073 |
+
parser.add_argument(
|
| 1074 |
+
"--VSA-decay-rate", # decay rate, how much sparsity you want to decay each step
|
| 1075 |
+
type=float,
|
| 1076 |
+
default=TrainingArgs.VSA_decay_rate,
|
| 1077 |
+
help="VSA decay rate")
|
| 1078 |
+
parser.add_argument(
|
| 1079 |
+
"--VSA-decay-interval-steps", # how many steps for training with current sparsity
|
| 1080 |
+
type=int,
|
| 1081 |
+
default=TrainingArgs.VSA_decay_interval_steps,
|
| 1082 |
+
help="VSA decay interval steps")
|
| 1083 |
+
parser.add_argument(
|
| 1084 |
+
"--VSA-init-sparsity",
|
| 1085 |
+
type=float,
|
| 1086 |
+
default=TrainingArgs.VSA_init_sparsity,
|
| 1087 |
+
help="Initial sparsity to start from (default 0)")
|
| 1088 |
+
parser.add_argument(
|
| 1089 |
+
"--VSA-warmup-steps",
|
| 1090 |
+
type=int,
|
| 1091 |
+
default=TrainingArgs.VSA_warmup_steps,
|
| 1092 |
+
help="Keep init sparsity for N steps before ramping (default 0)")
|
| 1093 |
+
parser.add_argument("--lora-training", action=StoreBoolean, help="Whether to use LoRA training")
|
| 1094 |
+
parser.add_argument("--lora-rank", type=int, help="LoRA rank")
|
| 1095 |
+
parser.add_argument("--lora-alpha", type=int, help="LoRA alpha")
|
| 1096 |
+
parser.add_argument(
|
| 1097 |
+
"--ltx2-first-frame-conditioning-p",
|
| 1098 |
+
type=float,
|
| 1099 |
+
default=TrainingArgs.ltx2_first_frame_conditioning_p,
|
| 1100 |
+
help="Probability of conditioning on the first frame during LTX-2 training",
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# V-MoBA parameters
|
| 1104 |
+
parser.add_argument(
|
| 1105 |
+
"--moba-config-path",
|
| 1106 |
+
type=str,
|
| 1107 |
+
default=None,
|
| 1108 |
+
help="Path to a JSON file containing V-MoBA specific configurations.",
|
| 1109 |
+
)
|
| 1110 |
+
|
| 1111 |
+
# Distillation arguments
|
| 1112 |
+
parser.add_argument("--generator-update-interval",
|
| 1113 |
+
type=int,
|
| 1114 |
+
default=TrainingArgs.generator_update_interval,
|
| 1115 |
+
help="Ratio of student updates to critic updates.")
|
| 1116 |
+
parser.add_argument(
|
| 1117 |
+
"--dfake-gen-update-ratio",
|
| 1118 |
+
type=int,
|
| 1119 |
+
default=TrainingArgs.dfake_gen_update_ratio,
|
| 1120 |
+
help="Self-forcing: How often to train generator vs critic (train generator every N steps).")
|
| 1121 |
+
parser.add_argument("--min-timestep-ratio",
|
| 1122 |
+
type=float,
|
| 1123 |
+
default=TrainingArgs.min_timestep_ratio,
|
| 1124 |
+
help="Minimum step ratio")
|
| 1125 |
+
parser.add_argument("--max-timestep-ratio",
|
| 1126 |
+
type=float,
|
| 1127 |
+
default=TrainingArgs.max_timestep_ratio,
|
| 1128 |
+
help="Maximum step ratio")
|
| 1129 |
+
parser.add_argument("--real-score-guidance-scale",
|
| 1130 |
+
type=float,
|
| 1131 |
+
default=TrainingArgs.real_score_guidance_scale,
|
| 1132 |
+
help="Teacher guidance scale")
|
| 1133 |
+
parser.add_argument("--fake-score-learning-rate",
|
| 1134 |
+
type=float,
|
| 1135 |
+
default=TrainingArgs.fake_score_learning_rate,
|
| 1136 |
+
help="Learning rate for fake score transformer")
|
| 1137 |
+
parser.add_argument("--fake-score-betas",
|
| 1138 |
+
type=str,
|
| 1139 |
+
default=TrainingArgs.fake_score_betas,
|
| 1140 |
+
help="Betas for fake score optimizer (format: 'beta1,beta2')")
|
| 1141 |
+
parser.add_argument("--fake-score-lr-scheduler",
|
| 1142 |
+
type=str,
|
| 1143 |
+
default=TrainingArgs.fake_score_lr_scheduler,
|
| 1144 |
+
help="Learning rate scheduler for fake score transformer")
|
| 1145 |
+
parser.add_argument("--log-visualization", action=StoreBoolean, help="Whether to log visualization")
|
| 1146 |
+
parser.add_argument("--simulate-generator-forward",
|
| 1147 |
+
action=StoreBoolean,
|
| 1148 |
+
help="Whether to simulate generator forward to match inference")
|
| 1149 |
+
parser.add_argument("--warp-denoising-step",
|
| 1150 |
+
action=StoreBoolean,
|
| 1151 |
+
help="Whether to warp denoising step according to the scheduler time shift")
|
| 1152 |
+
|
| 1153 |
+
# Self-forcing specific arguments
|
| 1154 |
+
parser.add_argument("--num-frame-per-block",
|
| 1155 |
+
type=int,
|
| 1156 |
+
default=TrainingArgs.num_frame_per_block,
|
| 1157 |
+
help="Number of frames per block for causal generation")
|
| 1158 |
+
parser.add_argument("--independent-first-frame",
|
| 1159 |
+
action=StoreBoolean,
|
| 1160 |
+
help="Whether the first frame is independent in causal generation")
|
| 1161 |
+
parser.add_argument("--enable-gradient-masking",
|
| 1162 |
+
action=StoreBoolean,
|
| 1163 |
+
help="Whether to enable frame-level gradient masking")
|
| 1164 |
+
parser.add_argument("--gradient-mask-last-n-frames",
|
| 1165 |
+
type=int,
|
| 1166 |
+
default=TrainingArgs.gradient_mask_last_n_frames,
|
| 1167 |
+
help="Number of last frames to enable gradients for")
|
| 1168 |
+
parser.add_argument("--validate-cache-structure",
|
| 1169 |
+
action=StoreBoolean,
|
| 1170 |
+
help="Whether to validate KV cache structure (debug flag)")
|
| 1171 |
+
parser.add_argument("--same-step-across-blocks",
|
| 1172 |
+
action=StoreBoolean,
|
| 1173 |
+
help="Whether to use the same exit timestep for all blocks")
|
| 1174 |
+
parser.add_argument("--last-step-only",
|
| 1175 |
+
action=StoreBoolean,
|
| 1176 |
+
help="Whether to only use the last timestep for training")
|
| 1177 |
+
parser.add_argument("--context-noise",
|
| 1178 |
+
type=int,
|
| 1179 |
+
default=TrainingArgs.context_noise,
|
| 1180 |
+
help="Context noise level for cache updates")
|
| 1181 |
+
|
| 1182 |
+
return parser
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def parse_int_list(value: str) -> list[int]:
|
| 1186 |
+
if not value:
|
| 1187 |
+
return []
|
| 1188 |
+
return [int(x.strip()) for x in value.split(",")]
|
backend_snapshot/fastvideo/pipelines/basic/wan/__init__.py
ADDED
|
File without changes
|
backend_snapshot/fastvideo/pipelines/basic/wan/wan_pipeline.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Wan video diffusion pipeline implementation.
|
| 4 |
+
|
| 5 |
+
This module contains an implementation of the Wan video diffusion pipeline
|
| 6 |
+
using the modular pipeline architecture.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 10 |
+
from fastvideo.logger import init_logger
|
| 11 |
+
from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
|
| 12 |
+
from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline
|
| 13 |
+
from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage,
|
| 14 |
+
LatentPreparationStage, TextEncodingStage, TimestepPreparationStage)
|
| 15 |
+
|
| 16 |
+
logger = init_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class WanPipeline(LoRAPipeline, ComposedPipelineBase):
|
| 20 |
+
"""
|
| 21 |
+
Wan video diffusion pipeline with LoRA support.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
_required_config_modules = ["text_encoder", "tokenizer", "vae", "transformer", "scheduler"]
|
| 25 |
+
|
| 26 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 27 |
+
# We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
|
| 28 |
+
self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
|
| 29 |
+
|
| 30 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
|
| 31 |
+
"""Set up pipeline stages with proper dependency injection."""
|
| 32 |
+
|
| 33 |
+
self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())
|
| 34 |
+
|
| 35 |
+
self.add_stage(stage_name="prompt_encoding_stage",
|
| 36 |
+
stage=TextEncodingStage(
|
| 37 |
+
text_encoders=[self.get_module("text_encoder")],
|
| 38 |
+
tokenizers=[self.get_module("tokenizer")],
|
| 39 |
+
))
|
| 40 |
+
|
| 41 |
+
self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
|
| 42 |
+
|
| 43 |
+
self.add_stage(stage_name="timestep_preparation_stage",
|
| 44 |
+
stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))
|
| 45 |
+
|
| 46 |
+
self.add_stage(stage_name="latent_preparation_stage",
|
| 47 |
+
stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
|
| 48 |
+
transformer=self.get_module("transformer", None)))
|
| 49 |
+
|
| 50 |
+
self.add_stage(stage_name="denoising_stage",
|
| 51 |
+
stage=DenoisingStage(transformer=self.get_module("transformer"),
|
| 52 |
+
transformer_2=self.get_module("transformer_2", None),
|
| 53 |
+
scheduler=self.get_module("scheduler"),
|
| 54 |
+
vae=self.get_module("vae"),
|
| 55 |
+
pipeline=self))
|
| 56 |
+
|
| 57 |
+
self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
EntryClass = WanPipeline
|
backend_snapshot/fastvideo/pipelines/composed_pipeline_base.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Base class for composed pipelines.
|
| 4 |
+
|
| 5 |
+
This module defines the base class for pipelines that are composed of multiple stages.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from typing import Any, cast
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from fastvideo.configs.pipelines import PipelineConfig
|
| 16 |
+
from fastvideo.distributed import (maybe_init_distributed_environment_and_model_parallel, get_world_group)
|
| 17 |
+
from fastvideo.distributed.communication_op import (warmup_sequence_parallel_communication)
|
| 18 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 19 |
+
from fastvideo.logger import init_logger
|
| 20 |
+
from fastvideo.profiler import get_or_create_profiler
|
| 21 |
+
from fastvideo.models.loader.component_loader import PipelineComponentLoader
|
| 22 |
+
from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
|
| 23 |
+
from fastvideo.pipelines.stages import PipelineStage
|
| 24 |
+
import fastvideo.envs as envs
|
| 25 |
+
from fastvideo.utils import (maybe_download_model, verify_model_config_and_directory)
|
| 26 |
+
|
| 27 |
+
logger = init_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ComposedPipelineBase(ABC):
|
| 31 |
+
"""
|
| 32 |
+
Base class for pipelines composed of multiple stages.
|
| 33 |
+
|
| 34 |
+
This class provides the framework for creating pipelines by composing multiple
|
| 35 |
+
stages together. Each stage is responsible for a specific part of the diffusion
|
| 36 |
+
process, and the pipeline orchestrates the execution of these stages.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
is_video_pipeline: bool = False # To be overridden by video pipelines
|
| 40 |
+
_required_config_modules: list[str] = []
|
| 41 |
+
_extra_config_module_map: dict[str, str] = {}
|
| 42 |
+
training_args: Any = None
|
| 43 |
+
fastvideo_args: Any = None
|
| 44 |
+
modules: dict[str, Any] = {}
|
| 45 |
+
# do not need to include moe related transformers
|
| 46 |
+
trainable_transformer_names: list[str] = ["transformer"]
|
| 47 |
+
trainable_transformer_modules: dict[str, torch.nn.Module] = {}
|
| 48 |
+
post_init_called: bool = False
|
| 49 |
+
|
| 50 |
+
# TODO(will): args should support both inference args and training args
|
| 51 |
+
def __init__(self,
|
| 52 |
+
model_path: str,
|
| 53 |
+
fastvideo_args: FastVideoArgs | TrainingArgs,
|
| 54 |
+
required_config_modules: list[str] | None = None,
|
| 55 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None):
|
| 56 |
+
"""
|
| 57 |
+
Initialize the pipeline. After __init__, the pipeline should be ready to
|
| 58 |
+
use. The pipeline should be stateless and not hold any batch state.
|
| 59 |
+
"""
|
| 60 |
+
self.fastvideo_args = fastvideo_args
|
| 61 |
+
|
| 62 |
+
self.model_path: str = model_path
|
| 63 |
+
self._stages: list[PipelineStage] = []
|
| 64 |
+
self._stage_name_mapping: dict[str, PipelineStage] = {}
|
| 65 |
+
|
| 66 |
+
if required_config_modules is not None:
|
| 67 |
+
self._required_config_modules = required_config_modules
|
| 68 |
+
|
| 69 |
+
if self._required_config_modules is None:
|
| 70 |
+
raise NotImplementedError("Subclass must set _required_config_modules")
|
| 71 |
+
|
| 72 |
+
maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)
|
| 73 |
+
|
| 74 |
+
# Torch profiler. Enabled and configured through env vars:
|
| 75 |
+
# FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
|
| 76 |
+
trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
|
| 77 |
+
self.profiler_controller = get_or_create_profiler(trace_dir)
|
| 78 |
+
self.profiler = self.profiler_controller.profiler
|
| 79 |
+
|
| 80 |
+
self.local_rank = get_world_group().local_rank
|
| 81 |
+
|
| 82 |
+
# Load modules directly in initialization
|
| 83 |
+
logger.info("Loading pipeline modules...")
|
| 84 |
+
with self.profiler_controller.region("profiler_region_model_loading"):
|
| 85 |
+
self.modules = self.load_modules(fastvideo_args, loaded_modules)
|
| 86 |
+
|
| 87 |
+
def set_trainable(self) -> None:
|
| 88 |
+
# Only train DiT
|
| 89 |
+
if getattr(self.fastvideo_args, "training_mode", False):
|
| 90 |
+
for name, module in self.trainable_transformer_modules.items():
|
| 91 |
+
logger.info("Setting %s to requires_grad=True", name)
|
| 92 |
+
if not isinstance(module, torch.nn.Module):
|
| 93 |
+
logger.info("Skipping %s because it is not a torch.nn.Module", name)
|
| 94 |
+
continue
|
| 95 |
+
module.requires_grad_(True)
|
| 96 |
+
module.train()
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def _compile_with_conditions(
|
| 100 |
+
module: torch.nn.Module,
|
| 101 |
+
compile_kwargs: dict[str, Any],
|
| 102 |
+
) -> int:
|
| 103 |
+
"""Compile submodules that match module._compile_conditions."""
|
| 104 |
+
compile_conditions = getattr(module, "_compile_conditions", None)
|
| 105 |
+
if not compile_conditions:
|
| 106 |
+
return 0
|
| 107 |
+
|
| 108 |
+
compiled_count = 0
|
| 109 |
+
for name, submodule in module.named_modules():
|
| 110 |
+
if not name:
|
| 111 |
+
continue
|
| 112 |
+
if any(cond(name, submodule) for cond in compile_conditions):
|
| 113 |
+
submodule.forward = torch.compile(submodule.forward, **compile_kwargs)
|
| 114 |
+
compiled_count += 1
|
| 115 |
+
return compiled_count
|
| 116 |
+
|
| 117 |
+
def _maybe_compile_pipeline_module(
|
| 118 |
+
self,
|
| 119 |
+
module_name: str,
|
| 120 |
+
fsdp_module_cls: type | None,
|
| 121 |
+
compile_kwargs: dict[str, Any],
|
| 122 |
+
) -> None:
|
| 123 |
+
if module_name not in self.modules:
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
module = self.modules[module_name]
|
| 127 |
+
if fsdp_module_cls is not None and isinstance(module, fsdp_module_cls):
|
| 128 |
+
logger.info(
|
| 129 |
+
"%s is already FSDP-wrapped; skipping torch.compile in pipeline",
|
| 130 |
+
module_name.capitalize(),
|
| 131 |
+
)
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
compiled_count = self._compile_with_conditions(module, compile_kwargs)
|
| 135 |
+
if compiled_count > 0:
|
| 136 |
+
logger.info(
|
| 137 |
+
"Enabled torch.compile for %d submodules in %s via _compile_conditions with kwargs=%s",
|
| 138 |
+
compiled_count,
|
| 139 |
+
module_name,
|
| 140 |
+
compile_kwargs,
|
| 141 |
+
)
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
# Backward-compatible fallback: compile full module if no condition matched.
|
| 145 |
+
logger.info("Enabling torch.compile for %s with kwargs=%s", module_name, compile_kwargs)
|
| 146 |
+
self.modules[module_name] = torch.compile(module, **compile_kwargs)
|
| 147 |
+
|
| 148 |
+
def post_init(self) -> None:
|
| 149 |
+
assert self.fastvideo_args is not None, "fastvideo_args must be set"
|
| 150 |
+
if self.post_init_called:
|
| 151 |
+
return
|
| 152 |
+
self.post_init_called = True
|
| 153 |
+
if self.fastvideo_args.training_mode:
|
| 154 |
+
assert isinstance(self.fastvideo_args, TrainingArgs)
|
| 155 |
+
self.training_args = self.fastvideo_args
|
| 156 |
+
assert self.training_args is not None
|
| 157 |
+
self.initialize_training_pipeline(self.training_args)
|
| 158 |
+
if self.training_args.log_validation:
|
| 159 |
+
self.initialize_validation_pipeline(self.training_args)
|
| 160 |
+
|
| 161 |
+
self.initialize_pipeline(self.fastvideo_args)
|
| 162 |
+
if self.fastvideo_args.enable_torch_compile:
|
| 163 |
+
if self.fastvideo_args.training_mode:
|
| 164 |
+
logger.info("Torch Compile enabled via FSDP loader for training; skipping additional pipeline compile")
|
| 165 |
+
else:
|
| 166 |
+
fsdp_module_cls = None
|
| 167 |
+
try:
|
| 168 |
+
from torch.distributed.fsdp import FSDPModule # type: ignore
|
| 169 |
+
fsdp_module_cls = FSDPModule
|
| 170 |
+
except Exception: # pragma: no cover - FSDP not always available
|
| 171 |
+
fsdp_module_cls = None
|
| 172 |
+
|
| 173 |
+
compile_kwargs = self.fastvideo_args.torch_compile_kwargs or {}
|
| 174 |
+
self._maybe_compile_pipeline_module(
|
| 175 |
+
module_name="transformer",
|
| 176 |
+
fsdp_module_cls=fsdp_module_cls,
|
| 177 |
+
compile_kwargs=compile_kwargs,
|
| 178 |
+
)
|
| 179 |
+
self._maybe_compile_pipeline_module(
|
| 180 |
+
module_name="transformer_2",
|
| 181 |
+
fsdp_module_cls=fsdp_module_cls,
|
| 182 |
+
compile_kwargs=compile_kwargs,
|
| 183 |
+
)
|
| 184 |
+
logger.info("Torch Compile enabled for DiT")
|
| 185 |
+
|
| 186 |
+
if not self.fastvideo_args.training_mode:
|
| 187 |
+
logger.info("Creating pipeline stages...")
|
| 188 |
+
self.create_pipeline_stages(self.fastvideo_args)
|
| 189 |
+
|
| 190 |
+
# Warmup NCCL communicators for sequence parallelism to avoid
|
| 191 |
+
# slow first forward pass due to lazy initialization
|
| 192 |
+
warmup_sequence_parallel_communication()
|
| 193 |
+
|
| 194 |
+
def initialize_training_pipeline(self, training_args: TrainingArgs):
|
| 195 |
+
raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
|
| 196 |
+
|
| 197 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 198 |
+
raise NotImplementedError("if log_validation is True, the pipeline must implement this method")
|
| 199 |
+
|
| 200 |
+
@classmethod
|
| 201 |
+
def from_pretrained(cls,
|
| 202 |
+
model_path: str,
|
| 203 |
+
device: str | None = None,
|
| 204 |
+
torch_dtype: torch.dtype | None = None,
|
| 205 |
+
pipeline_config: str | PipelineConfig | None = None,
|
| 206 |
+
args: argparse.Namespace | FastVideoArgs | TrainingArgs | None = None,
|
| 207 |
+
required_config_modules: list[str] | None = None,
|
| 208 |
+
loaded_modules: dict[str, torch.nn.Module]
|
| 209 |
+
| None = None,
|
| 210 |
+
**kwargs) -> "ComposedPipelineBase":
|
| 211 |
+
"""
|
| 212 |
+
Load a pipeline from a pretrained model.
|
| 213 |
+
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
|
| 214 |
+
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
|
| 215 |
+
"""
|
| 216 |
+
if args is None or (isinstance(args, FastVideoArgs) and args.inference_mode):
|
| 217 |
+
|
| 218 |
+
kwargs['model_path'] = model_path
|
| 219 |
+
fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)
|
| 220 |
+
else:
|
| 221 |
+
if isinstance(args, TrainingArgs):
|
| 222 |
+
fastvideo_args = args
|
| 223 |
+
else:
|
| 224 |
+
assert isinstance(args, argparse.Namespace), "training mode expects argparse.Namespace args"
|
| 225 |
+
fastvideo_args = TrainingArgs.from_cli_args(args)
|
| 226 |
+
# TODO(will): fix this so that its not so ugly
|
| 227 |
+
fastvideo_args.model_path = model_path
|
| 228 |
+
for key, value in kwargs.items():
|
| 229 |
+
setattr(fastvideo_args, key, value)
|
| 230 |
+
|
| 231 |
+
fastvideo_args.dit_cpu_offload = False
|
| 232 |
+
# we hijack the precision to be the master weight type so that the
|
| 233 |
+
# model is loaded with the correct precision. Subsequently we will
|
| 234 |
+
# use FSDP2's MixedPrecisionPolicy to set the precision for the
|
| 235 |
+
# fwd, bwd, and other operations' precision.
|
| 236 |
+
assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
|
| 237 |
+
|
| 238 |
+
logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
|
| 239 |
+
|
| 240 |
+
pipe = cls(model_path,
|
| 241 |
+
fastvideo_args,
|
| 242 |
+
required_config_modules=required_config_modules,
|
| 243 |
+
loaded_modules=loaded_modules)
|
| 244 |
+
pipe.post_init()
|
| 245 |
+
return pipe
|
| 246 |
+
|
| 247 |
+
def get_module(self, module_name: str, default_value: Any = None) -> Any:
|
| 248 |
+
if module_name not in self.modules:
|
| 249 |
+
return default_value
|
| 250 |
+
return self.modules[module_name]
|
| 251 |
+
|
| 252 |
+
def add_module(self, module_name: str, module: Any):
|
| 253 |
+
self.modules[module_name] = module
|
| 254 |
+
|
| 255 |
+
def __getattr__(self, name: str) -> Any:
|
| 256 |
+
if "_stage_name_mapping" in self.__dict__ and name in self._stage_name_mapping:
|
| 257 |
+
return self._stage_name_mapping[name]
|
| 258 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 259 |
+
|
| 260 |
+
def _load_config(self, model_path: str) -> dict[str, Any]:
|
| 261 |
+
model_path = maybe_download_model(self.model_path)
|
| 262 |
+
self.model_path = model_path
|
| 263 |
+
# fastvideo_args.downloaded_model_path = model_path
|
| 264 |
+
logger.info("Model path: %s", model_path)
|
| 265 |
+
config = verify_model_config_and_directory(model_path)
|
| 266 |
+
return cast(dict[str, Any], config)
|
| 267 |
+
|
| 268 |
+
@property
|
| 269 |
+
def required_config_modules(self) -> list[str]:
|
| 270 |
+
"""
|
| 271 |
+
List of modules that are required by the pipeline. The names should match
|
| 272 |
+
the diffusers directory and model_index.json file. These modules will be
|
| 273 |
+
loaded using the PipelineComponentLoader and made available in the
|
| 274 |
+
modules dictionary. Access these modules using the get_module method.
|
| 275 |
+
|
| 276 |
+
class ConcretePipeline(ComposedPipelineBase):
|
| 277 |
+
_required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@property
|
| 281 |
+
def required_config_modules(self):
|
| 282 |
+
return self._required_config_modules
|
| 283 |
+
"""
|
| 284 |
+
return self._required_config_modules
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def stages(self) -> list[PipelineStage]:
|
| 288 |
+
"""
|
| 289 |
+
List of stages in the pipeline.
|
| 290 |
+
"""
|
| 291 |
+
return self._stages
|
| 292 |
+
|
| 293 |
+
@abstractmethod
|
| 294 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
|
| 295 |
+
"""
|
| 296 |
+
Create the inference pipeline stages.
|
| 297 |
+
"""
|
| 298 |
+
raise NotImplementedError
|
| 299 |
+
|
| 300 |
+
def create_training_stages(self, training_args: TrainingArgs):
|
| 301 |
+
"""
|
| 302 |
+
Create the training pipeline stages.
|
| 303 |
+
"""
|
| 304 |
+
raise NotImplementedError
|
| 305 |
+
|
| 306 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 307 |
+
"""
|
| 308 |
+
Initialize the pipeline.
|
| 309 |
+
"""
|
| 310 |
+
return
|
| 311 |
+
|
| 312 |
+
def load_modules(self,
|
| 313 |
+
fastvideo_args: FastVideoArgs,
|
| 314 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
|
| 315 |
+
"""
|
| 316 |
+
Load the modules from the config.
|
| 317 |
+
loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
|
| 318 |
+
If provided, loaded_modules will be used instead of loading from config/pretrained weights.
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
model_index = self._load_config(self.model_path)
|
| 322 |
+
logger.info("Loading pipeline modules from config: %s", model_index)
|
| 323 |
+
|
| 324 |
+
# remove keys that are not pipeline modules
|
| 325 |
+
model_index.pop("_class_name")
|
| 326 |
+
model_index.pop("_diffusers_version")
|
| 327 |
+
model_index.pop("_name_or_path", None)
|
| 328 |
+
model_index.pop("workload_type", None)
|
| 329 |
+
if "boundary_ratio" in model_index and model_index["boundary_ratio"] is not None:
|
| 330 |
+
logger.info("MoE pipeline detected. Adding transformer_2 to self.required_config_modules...")
|
| 331 |
+
self.required_config_modules.append("transformer_2")
|
| 332 |
+
logger.info("MoE pipeline detected. Setting boundary ratio to %s", model_index["boundary_ratio"])
|
| 333 |
+
fastvideo_args.pipeline_config.dit_config.boundary_ratio = model_index["boundary_ratio"]
|
| 334 |
+
|
| 335 |
+
model_index.pop("boundary_ratio", None)
|
| 336 |
+
# used by Wan2.2 ti2v
|
| 337 |
+
model_index.pop("expand_timesteps", None)
|
| 338 |
+
|
| 339 |
+
# some sanity checks
|
| 340 |
+
assert len(model_index) > 1, "model_index.json must contain at least one pipeline module"
|
| 341 |
+
|
| 342 |
+
for module_name in self.required_config_modules:
|
| 343 |
+
if module_name not in model_index and module_name in self._extra_config_module_map:
|
| 344 |
+
extra_module_value = self._extra_config_module_map[module_name]
|
| 345 |
+
logger.warning(
|
| 346 |
+
"model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
|
| 347 |
+
module_name, module_name, extra_module_value)
|
| 348 |
+
if extra_module_value in model_index:
|
| 349 |
+
logger.info("Using module %s for %s", extra_module_value, module_name)
|
| 350 |
+
model_index[module_name] = model_index[extra_module_value]
|
| 351 |
+
continue
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# all the component models used by the pipeline
|
| 358 |
+
required_modules = self.required_config_modules
|
| 359 |
+
logger.info("Loading required modules: %s", required_modules)
|
| 360 |
+
|
| 361 |
+
modules = {}
|
| 362 |
+
for module_name, module_spec in model_index.items():
|
| 363 |
+
if not isinstance(module_spec, list | tuple):
|
| 364 |
+
logger.info(
|
| 365 |
+
"Skipping non-module config entry %s=%s",
|
| 366 |
+
module_name,
|
| 367 |
+
module_spec,
|
| 368 |
+
)
|
| 369 |
+
continue
|
| 370 |
+
if len(module_spec) < 1:
|
| 371 |
+
logger.warning(
|
| 372 |
+
"Skipping module %s due to invalid empty spec in model_index.json",
|
| 373 |
+
module_name,
|
| 374 |
+
)
|
| 375 |
+
continue
|
| 376 |
+
transformers_or_diffusers = module_spec[0]
|
| 377 |
+
if transformers_or_diffusers is None:
|
| 378 |
+
logger.warning("Module %s in model_index.json has null value, removing from required_config_modules",
|
| 379 |
+
module_name)
|
| 380 |
+
if module_name in self.required_config_modules:
|
| 381 |
+
self.required_config_modules.remove(module_name)
|
| 382 |
+
continue
|
| 383 |
+
if module_name not in required_modules:
|
| 384 |
+
logger.info("Skipping module %s", module_name)
|
| 385 |
+
continue
|
| 386 |
+
if loaded_modules is not None and module_name in loaded_modules:
|
| 387 |
+
logger.info("Using module %s already provided", module_name)
|
| 388 |
+
modules[module_name] = loaded_modules[module_name]
|
| 389 |
+
continue
|
| 390 |
+
|
| 391 |
+
# we load the module from the extra config module map if it exists
|
| 392 |
+
if module_name in self._extra_config_module_map:
|
| 393 |
+
load_module_name = self._extra_config_module_map[module_name]
|
| 394 |
+
else:
|
| 395 |
+
load_module_name = module_name
|
| 396 |
+
|
| 397 |
+
component_model_path = os.path.join(self.model_path, load_module_name)
|
| 398 |
+
module = PipelineComponentLoader.load_module(
|
| 399 |
+
module_name=load_module_name,
|
| 400 |
+
component_model_path=component_model_path,
|
| 401 |
+
transformers_or_diffusers=transformers_or_diffusers,
|
| 402 |
+
fastvideo_args=fastvideo_args,
|
| 403 |
+
)
|
| 404 |
+
logger.info("Loaded module %s from %s", module_name, component_model_path)
|
| 405 |
+
|
| 406 |
+
if module_name in modules:
|
| 407 |
+
logger.warning("Overwriting module %s", module_name)
|
| 408 |
+
modules[module_name] = module
|
| 409 |
+
|
| 410 |
+
# Check if all required modules were loaded
|
| 411 |
+
for module_name in required_modules:
|
| 412 |
+
if module_name not in modules or modules[module_name] is None:
|
| 413 |
+
raise ValueError(
|
| 414 |
+
f"Required module key: {module_name} value: {modules.get(module_name)} was not found in loaded modules {modules.keys()}"
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
return modules
|
| 418 |
+
|
| 419 |
+
def add_stage(self, stage_name: str, stage: PipelineStage):
|
| 420 |
+
assert self.modules is not None, "No modules are registered"
|
| 421 |
+
self._stages.append(stage)
|
| 422 |
+
self._stage_name_mapping[stage_name] = stage
|
| 423 |
+
setattr(self, stage_name, stage)
|
| 424 |
+
|
| 425 |
+
def profile(self, is_start: bool = True):
|
| 426 |
+
if self.profiler is None:
|
| 427 |
+
raise RuntimeError("Profiler is not enabled.")
|
| 428 |
+
if is_start:
|
| 429 |
+
self.profiler.start()
|
| 430 |
+
else:
|
| 431 |
+
self.profiler.stop()
|
| 432 |
+
# only print profiler results on rank 0
|
| 433 |
+
if self.local_rank == 0:
|
| 434 |
+
print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
|
| 435 |
+
|
| 436 |
+
# TODO(will): don't hardcode no_grad
|
| 437 |
+
@torch.no_grad()
|
| 438 |
+
def forward(
|
| 439 |
+
self,
|
| 440 |
+
batch: ForwardBatch,
|
| 441 |
+
fastvideo_args: FastVideoArgs,
|
| 442 |
+
) -> ForwardBatch:
|
| 443 |
+
"""
|
| 444 |
+
Generate a video or image using the pipeline.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
batch: The batch to generate from.
|
| 448 |
+
fastvideo_args: The inference arguments.
|
| 449 |
+
Returns:
|
| 450 |
+
ForwardBatch: The batch with the generated video or image.
|
| 451 |
+
"""
|
| 452 |
+
if not self.post_init_called:
|
| 453 |
+
self.post_init()
|
| 454 |
+
|
| 455 |
+
# Execute each stage
|
| 456 |
+
logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
|
| 457 |
+
# logger.info("Batch: %s", batch)
|
| 458 |
+
for stage in self.stages:
|
| 459 |
+
batch = stage(batch, fastvideo_args)
|
| 460 |
+
|
| 461 |
+
# Return the output
|
| 462 |
+
return batch
|
| 463 |
+
|
| 464 |
+
def train(self) -> None:
|
| 465 |
+
raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
|
| 466 |
+
|
| 467 |
+
def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch:
|
| 468 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset")
|
| 469 |
+
|
| 470 |
+
def streaming_step(self, *args: Any, **kwargs: Any) -> ForwardBatch:
|
| 471 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_step")
|
| 472 |
+
|
| 473 |
+
def streaming_clear(self) -> None:
|
| 474 |
+
raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear")
|
backend_snapshot/manifest.sha256
CHANGED
|
@@ -1,17 +1,33 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
|
| 3 |
211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
|
| 4 |
3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
|
| 5 |
56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
|
|
|
|
| 6 |
2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
|
| 7 |
a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
|
| 8 |
79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
|
|
|
|
|
|
|
|
|
|
| 10 |
6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
|
| 11 |
489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
|
| 12 |
c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
|
| 13 |
2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
|
| 14 |
bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
|
| 15 |
1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
|
|
|
|
|
|
|
|
|
|
| 16 |
5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
|
| 17 |
5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
|
|
|
|
|
|
| 1 |
+
033b7cce6eb0ead450a35b910adc9ae1323b8d2323aea6731b72b940e222fb46 ./README.md
|
| 2 |
+
9c7dec8f1b8160954d0566231b0952a5f6a5d81f546affd71d190b2b3fc79cb6 ./examples/inference/basic/basic.py
|
| 3 |
9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
|
| 4 |
211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
|
| 5 |
3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
|
| 6 |
56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
|
| 7 |
+
58f4ac013e6755336212a7a6c9948b19dab0dafc00f4a3298591598df270cb39 ./fastvideo/api/compat.py
|
| 8 |
2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
|
| 9 |
a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
|
| 10 |
79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
|
| 11 |
+
4bda44746a3626551ea9a9380d890f036087092fb99fce2d302642cce14a97ed ./fastvideo/configs/pipelines/wan.py
|
| 12 |
+
5926e29a594db13b116922f131db50631bf8adbf90fe5cec00a5e2f446bfb4ca ./fastvideo/configs/sample/base.py
|
| 13 |
+
d99adcf607d982b38bbb5a70be60bf87f35d0e9f6f50752f3bceb68b34ce46c2 ./fastvideo/configs/sample/wan.py
|
| 14 |
+
49775ce42fd9643c78d8fad4ab8248c1755c7f1524ad771cbd1863d76c513c38 ./fastvideo/configs/wan_1.3B_t2v_pipeline.json
|
| 15 |
+
ae2d8309472b09927da3e450dea52d9715dcabe5d6722fc2917130ae8d85adb4 ./fastvideo/entrypoints/cli/generate.py
|
| 16 |
+
d0466769626e7fd497376c544904d56ba62847745eb52527896d96b99d76ba03 ./fastvideo/entrypoints/video_generator.py
|
| 17 |
+
73afe6b2ebe0f8cfe0a8ec762a7126161621ad97a64ebad628995f4a164b8b0e ./fastvideo/fastvideo_args.py
|
| 18 |
ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
|
| 19 |
+
e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./fastvideo/pipelines/basic/wan/__init__.py
|
| 20 |
+
deac1e22530a6a41c501629f5e8fce47a7af4e008f321cc8a4d734c5120ef4fe ./fastvideo/pipelines/basic/wan/wan_pipeline.py
|
| 21 |
+
8908223b3ff99cdb3206148a68a730c2a13d554a2fb1316db6f2f9672efac9e8 ./fastvideo/pipelines/composed_pipeline_base.py
|
| 22 |
6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
|
| 23 |
489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
|
| 24 |
c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
|
| 25 |
2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
|
| 26 |
bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
|
| 27 |
1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
|
| 28 |
+
57fd33c78a16c9b4a239734c76726d94df1daf86f15cf22451c6107fcc197834 ./scripts/inference/run_sfp4_ours_p_checkpoint_750.sh
|
| 29 |
+
0162b26dddb2a249e4e2cc56a7a28eba4fb3ea77e938eaf9c6b16a241edcd3ec ./scripts/inference/run_sfp4_single.sh
|
| 30 |
+
159579109f9fb7d7dded977f5c7fc974583c95f3f895e418d6c466463c036304 ./scripts/inference/run_validate_and_gen.sh
|
| 31 |
5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
|
| 32 |
5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
|
| 33 |
+
75455829ca55a80daaa7e3c7faa080b6eec3c7109bdbd1e198b722face62eed0 ./training_attention_settings.json
|
backend_snapshot/scripts/inference/run_sfp4_ours_p_checkpoint_750.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
|
| 5 |
+
REPO_ROOT="${REPO_ROOT:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo}"
|
| 6 |
+
MODEL_PATH="${MODEL_PATH:-Wan-AI/Wan2.1-T2V-1.3B-Diffusers}"
|
| 7 |
+
HF_REPO="${HF_REPO:-yitongl/sparse_quant_exp}"
|
| 8 |
+
CHECKPOINT_DIR="${CHECKPOINT_DIR:-${REPO_ROOT}/checkpoints/hf_download/sparse_quant_exp}"
|
| 9 |
+
WEIGHTS_PATH="${WEIGHTS_PATH:-${CHECKPOINT_DIR}/transformer/diffusion_pytorch_model.safetensors}"
|
| 10 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${REPO_ROOT}/outputs/sfp4_v4_sparse09_checkpoint_750}"
|
| 11 |
+
PROMPT="${PROMPT:-A cinematic shot of a futuristic city street at dusk, reflective pavement, soft volumetric light, detailed motion, stable camera.}"
|
| 12 |
+
NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards}"
|
| 13 |
+
|
| 14 |
+
cd "${REPO_ROOT}"
|
| 15 |
+
|
| 16 |
+
if [[ -f ".venv/bin/activate" ]]; then
|
| 17 |
+
source .venv/bin/activate
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
export PYTHONPATH="${REPO_ROOT}/fastvideo-kernel/python:${REPO_ROOT}/fastvideo-kernel:${PYTHONPATH:-}"
|
| 21 |
+
export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
|
| 22 |
+
export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
|
| 23 |
+
|
| 24 |
+
if [[ ! -f "${WEIGHTS_PATH}" ]]; then
|
| 25 |
+
echo "Missing ${WEIGHTS_PATH}"
|
| 26 |
+
echo "Download the uploaded transformer weights first:"
|
| 27 |
+
echo " hf download ${HF_REPO} --repo-type model --local-dir ${CHECKPOINT_DIR} --include 'transformer/*'"
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
mkdir -p "${OUTPUT_DIR}"
|
| 32 |
+
|
| 33 |
+
fastvideo generate \
|
| 34 |
+
--model-path "${MODEL_PATH}" \
|
| 35 |
+
--init-weights-from-safetensors "${WEIGHTS_PATH}" \
|
| 36 |
+
--sp-size 1 \
|
| 37 |
+
--tp-size 1 \
|
| 38 |
+
--num-gpus 1 \
|
| 39 |
+
--dit-cpu-offload False \
|
| 40 |
+
--vae-cpu-offload False \
|
| 41 |
+
--text-encoder-cpu-offload True \
|
| 42 |
+
--pin-cpu-memory False \
|
| 43 |
+
--height 448 \
|
| 44 |
+
--width 832 \
|
| 45 |
+
--num-frames 77 \
|
| 46 |
+
--num-inference-steps 50 \
|
| 47 |
+
--fps 16 \
|
| 48 |
+
--guidance-scale 5.0 \
|
| 49 |
+
--flow-shift 1.0 \
|
| 50 |
+
--prompt "${PROMPT}" \
|
| 51 |
+
--negative-prompt "${NEGATIVE_PROMPT}" \
|
| 52 |
+
--seed 1000 \
|
| 53 |
+
--VSA-sparsity 0.9 \
|
| 54 |
+
--output-path "${OUTPUT_DIR}/"
|
backend_snapshot/scripts/inference/run_sfp4_single.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=sfp4-s0
|
| 3 |
+
#SBATCH --account=nvr_elm_llm
|
| 4 |
+
#SBATCH --partition=interactive
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gres=gpu:1
|
| 7 |
+
#SBATCH --cpus-per-task=16
|
| 8 |
+
#SBATCH --mem=64G
|
| 9 |
+
#SBATCH --time=00:30:00
|
| 10 |
+
#SBATCH --output=slurm_logs/sfp4_s0_%j.out
|
| 11 |
+
#SBATCH --error=slurm_logs/sfp4_s0_%j.err
|
| 12 |
+
|
| 13 |
+
set -ex
|
| 14 |
+
cd /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo
|
| 15 |
+
source .venv/bin/activate
|
| 16 |
+
export PYTHONPATH=fastvideo-kernel/python:fastvideo-kernel:$PYTHONPATH
|
| 17 |
+
export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_ATTN
|
| 18 |
+
|
| 19 |
+
mkdir -p outputs_sfp4_s0
|
| 20 |
+
|
| 21 |
+
# Same prompt, seed, params as dense FP4 run
|
| 22 |
+
fastvideo generate \
|
| 23 |
+
--model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
|
| 24 |
+
--sp-size 1 --tp-size 1 --num-gpus 1 \
|
| 25 |
+
--dit-cpu-offload False --vae-cpu-offload False \
|
| 26 |
+
--text-encoder-cpu-offload True --pin-cpu-memory False \
|
| 27 |
+
--height 480 --width 832 --num-frames 81 \
|
| 28 |
+
--num-inference-steps 50 --fps 16 \
|
| 29 |
+
--guidance-scale 6.0 --flow-shift 8.0 \
|
| 30 |
+
--prompt "Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting." \
|
| 31 |
+
--seed 1024 \
|
| 32 |
+
--VSA-sparsity 0.0 \
|
| 33 |
+
--output-path outputs_sfp4_s0/
|
| 34 |
+
|
| 35 |
+
echo "=== Done ==="
|
| 36 |
+
ls -lh outputs_sfp4_s0/*.mp4
|
| 37 |
+
echo "--- Dense FP4 reference ---"
|
| 38 |
+
ls -lh outputs_dense_fp4/*.mp4
|
backend_snapshot/scripts/inference/run_validate_and_gen.sh
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=sfp4-val-gen
|
| 3 |
+
#SBATCH --account=nvr_elm_llm
|
| 4 |
+
#SBATCH --partition=interactive
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gres=gpu:8
|
| 7 |
+
#SBATCH --cpus-per-task=128
|
| 8 |
+
#SBATCH --mem=1440G
|
| 9 |
+
#SBATCH --time=02:00:00
|
| 10 |
+
#SBATCH --output=slurm_logs/sfp4_val_gen_%j.out
|
| 11 |
+
#SBATCH --error=slurm_logs/sfp4_val_gen_%j.err
|
| 12 |
+
|
| 13 |
+
set -ex
|
| 14 |
+
|
| 15 |
+
REPO_ROOT="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo"
|
| 16 |
+
KERNEL_ROOT="${REPO_ROOT}/fastvideo-kernel"
|
| 17 |
+
|
| 18 |
+
mkdir -p "${REPO_ROOT}/slurm_logs"
|
| 19 |
+
cd "${REPO_ROOT}"
|
| 20 |
+
source .venv/bin/activate
|
| 21 |
+
export PYTHONPATH="${KERNEL_ROOT}/python:${KERNEL_ROOT}:${PYTHONPATH}"
|
| 22 |
+
|
| 23 |
+
echo "=== Environment ==="
|
| 24 |
+
nvidia-smi -L | head -1
|
| 25 |
+
python -c "import torch; print(f'torch={torch.__version__}, cuda={torch.cuda.is_available()}, gpus={torch.cuda.device_count()}')"
|
| 26 |
+
python -c "import triton; print(f'triton={triton.__version__}')"
|
| 27 |
+
|
| 28 |
+
echo ""
|
| 29 |
+
echo "######################################################################"
|
| 30 |
+
echo "# Generate 8 videos with sparse FP4 attention #"
|
| 31 |
+
echo "######################################################################"
|
| 32 |
+
cd "${REPO_ROOT}"
|
| 33 |
+
|
| 34 |
+
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 35 |
+
PROMPT="Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting."
|
| 36 |
+
NEGATIVE_PROMPT="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 37 |
+
SEED=1024
|
| 38 |
+
|
| 39 |
+
SPARSITY_LIST=(0.0 0.1 0.2 0.4 0.5 0.7 0.8 0.9)
|
| 40 |
+
OUTPUT_BASE="${REPO_ROOT}/outputs_sparse_fp4_sweep"
|
| 41 |
+
mkdir -p "${OUTPUT_BASE}"
|
| 42 |
+
|
| 43 |
+
echo "Sparsity levels: ${SPARSITY_LIST[*]}"
|
| 44 |
+
|
| 45 |
+
PIDS=()
|
| 46 |
+
for i in $(seq 0 7); do
|
| 47 |
+
SPARSITY=${SPARSITY_LIST[$i]}
|
| 48 |
+
OUT_DIR="${OUTPUT_BASE}/sparsity_${SPARSITY}"
|
| 49 |
+
mkdir -p "${OUT_DIR}"
|
| 50 |
+
echo "[GPU ${i}] sparsity=${SPARSITY}"
|
| 51 |
+
|
| 52 |
+
(
|
| 53 |
+
export CUDA_VISIBLE_DEVICES=${i}
|
| 54 |
+
export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_ATTN
|
| 55 |
+
|
| 56 |
+
fastvideo generate \
|
| 57 |
+
--model-path "${MODEL_PATH}" \
|
| 58 |
+
--sp-size 1 --tp-size 1 --num-gpus 1 \
|
| 59 |
+
--dit-cpu-offload False \
|
| 60 |
+
--vae-cpu-offload False \
|
| 61 |
+
--text-encoder-cpu-offload True \
|
| 62 |
+
--pin-cpu-memory False \
|
| 63 |
+
--height 480 --width 832 --num-frames 81 \
|
| 64 |
+
--num-inference-steps 50 --fps 16 \
|
| 65 |
+
--guidance-scale 6.0 --flow-shift 8.0 \
|
| 66 |
+
--prompt "${PROMPT}" \
|
| 67 |
+
--negative-prompt "${NEGATIVE_PROMPT}" \
|
| 68 |
+
--seed ${SEED} \
|
| 69 |
+
--VSA-sparsity ${SPARSITY} \
|
| 70 |
+
--output-path "${OUT_DIR}/" \
|
| 71 |
+
2>&1 | tee "${OUT_DIR}/log.txt"
|
| 72 |
+
|
| 73 |
+
echo "[GPU ${i}] sparsity=${SPARSITY} DONE"
|
| 74 |
+
) &
|
| 75 |
+
PIDS+=($!)
|
| 76 |
+
done
|
| 77 |
+
|
| 78 |
+
echo "=== Waiting for all 8 jobs ==="
|
| 79 |
+
FAIL=0
|
| 80 |
+
for i in $(seq 0 7); do
|
| 81 |
+
wait ${PIDS[$i]} || { echo "[GPU ${i}] FAILED"; FAIL=1; }
|
| 82 |
+
done
|
| 83 |
+
|
| 84 |
+
echo ""
|
| 85 |
+
if [ $FAIL -eq 0 ]; then
|
| 86 |
+
echo "=== All 8 videos generated ==="
|
| 87 |
+
else
|
| 88 |
+
echo "=== Some failed ==="
|
| 89 |
+
fi
|
| 90 |
+
find "${OUTPUT_BASE}" -name "*.mp4" | sort
|
| 91 |
+
echo "=== Done ==="
|
backend_snapshot/training_attention_settings.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"run_name": "sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive",
|
| 3 |
+
"checkpoint": "checkpoint-750",
|
| 4 |
+
"training_method": "legacy_sft_wan_training_pipeline",
|
| 5 |
+
"model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
|
| 6 |
+
"init_weights_from_safetensors": "checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors",
|
| 7 |
+
"environment": {
|
| 8 |
+
"FASTVIDEO_ATTENTION_BACKEND": "SPARSE_FP4_OURS_P_ATTN",
|
| 9 |
+
"FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O": "1",
|
| 10 |
+
"FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK": "1",
|
| 11 |
+
"WANDB_MODE": "online",
|
| 12 |
+
"WANDB_RESUME": "allow"
|
| 13 |
+
},
|
| 14 |
+
"vsa_schedule": {
|
| 15 |
+
"VSA_SPARSITY": 0.9,
|
| 16 |
+
"VSA_INIT_SPARSITY": 0.9,
|
| 17 |
+
"VSA_WARMUP_STEPS": 0,
|
| 18 |
+
"VSA_DECAY_RATE": 0.03,
|
| 19 |
+
"VSA_DECAY_INTERVAL_STEPS": 50,
|
| 20 |
+
"effective_sparsity_from_step_0": 0.9
|
| 21 |
+
},
|
| 22 |
+
"attention_semantics": {
|
| 23 |
+
"selected_backend": "SPARSE_FP4_OURS_P_ATTN",
|
| 24 |
+
"self_attention": {
|
| 25 |
+
"backend_path": "fastvideo/attention/backends/sparse_fp4_ours_p_attn.py",
|
| 26 |
+
"kernel_path": "fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py",
|
| 27 |
+
"tile_size_video": [4, 4, 4],
|
| 28 |
+
"tile_tokens": 64,
|
| 29 |
+
"qkv_quantization": "FP4 fake quantization with STE, no q/k mean subtraction in quantization",
|
| 30 |
+
"block_selection": "top-k blocks from q_c @ k_c tile-mean scores",
|
| 31 |
+
"p_quantization": "group-local exp2(qk - group_max) FP4 fake quantization; compensation multiplies exp2(group_max - running_row_m)",
|
| 32 |
+
"dropped_tile_handling": "tile-level q_mean/k_mean score and mean_v compensation"
|
| 33 |
+
},
|
| 34 |
+
"cross_attention": {
|
| 35 |
+
"backend": "dense_sdpa",
|
| 36 |
+
"reason": "sparse_fp4_ours_p_attn.py treats query_length != key_length as cross attention and returns _dense_sdpa_blhd",
|
| 37 |
+
"quantized": false,
|
| 38 |
+
"sparse": false
|
| 39 |
+
},
|
| 40 |
+
"force_dense": {
|
| 41 |
+
"backend": "dense_sdpa",
|
| 42 |
+
"used_for": "teacher or explicitly forced dense paths, not the normal SFT student self-attention path"
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"validation_and_checkpointing": {
|
| 46 |
+
"save_steps": 50,
|
| 47 |
+
"eval_steps": 50,
|
| 48 |
+
"validation_sampling_steps": 50,
|
| 49 |
+
"validation_guidance_scale": 5.0,
|
| 50 |
+
"checkpoints_total_limit": 5,
|
| 51 |
+
"flow_shift": 1.0
|
| 52 |
+
},
|
| 53 |
+
"training_shape": {
|
| 54 |
+
"num_latent_t": 20,
|
| 55 |
+
"num_frames": 77,
|
| 56 |
+
"height": 448,
|
| 57 |
+
"width": 832,
|
| 58 |
+
"batch_size_per_gpu": 1,
|
| 59 |
+
"sp_size": 1,
|
| 60 |
+
"tp_size": 1
|
| 61 |
+
}
|
| 62 |
+
}
|