yitongl commited on
Commit
697fddf
·
verified ·
1 Parent(s): 272d402

Add inference code and attention settings for sfp4 checkpoint-750

Browse files
backend_snapshot/README.md CHANGED
@@ -13,6 +13,17 @@ Key runtime settings:
13
  - `VSA_WARMUP_STEPS=0`
14
  - tile size: `4 x 4 x 4 = 64` video tokens
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  Important files:
17
 
18
  - `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
@@ -25,6 +36,25 @@ Important files:
25
  - `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
26
  - `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
27
  - `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  Source repo HEAD when staged:
30
 
 
13
  - `VSA_WARMUP_STEPS=0`
14
  - tile size: `4 x 4 x 4 = 64` video tokens
15
 
16
+ Training attention semantics:
17
+
18
+ - Video self-attention uses `SPARSE_FP4_OURS_P_ATTN`.
19
+ - Cross-attention is not quantized/sparse in this backend. It falls back to
20
+ dense SDPA when `query_length != key_length`.
21
+ - `force_dense` paths also use dense SDPA.
22
+ - Q/K/V fake quantization uses FP4 with STE and no q/k mean subtraction.
23
+ - Selected sparse tiles use group-local P quantization in the Triton kernel.
24
+ - Dropped VSA tiles use tile-level q_mean/k_mean score plus mean_v
25
+ compensation.
26
+
27
  Important files:
28
 
29
  - `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
 
36
  - `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
37
  - `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
38
  - `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
39
+ - `training_attention_settings.json`: structured attention/training settings
40
+ for this checkpoint.
41
+ - `scripts/inference/run_sfp4_ours_p_checkpoint_750.sh`: inference example
42
+ for the uploaded transformer checkpoint.
43
+ - `fastvideo/entrypoints/cli/generate.py`, `fastvideo/entrypoints/video_generator.py`,
44
+ `fastvideo/pipelines/basic/wan/wan_pipeline.py`, and
45
+ `fastvideo/pipelines/stages/denoising.py`: `fastvideo generate` inference
46
+ path used by the example script.
47
+
48
+ Example inference flow:
49
+
50
+ ```bash
51
+ hf download yitongl/sparse_quant_exp \
52
+ --repo-type model \
53
+ --local-dir checkpoints/hf_download/sparse_quant_exp \
54
+ --include 'transformer/*'
55
+
56
+ bash backend_snapshot/scripts/inference/run_sfp4_ours_p_checkpoint_750.sh
57
+ ```
58
 
59
  Source repo HEAD when staged:
60
 
backend_snapshot/examples/inference/basic/basic.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastvideo import VideoGenerator
2
+
3
+ # from fastvideo.configs.sample import SamplingParam
4
+
5
+ OUTPUT_PATH = "video_samples"
6
+ def main():
7
+ # FastVideo will automatically use the optimal default arguments for the
8
+ # model.
9
+ # If a local path is provided, FastVideo will make a best effort
10
+ # attempt to identify the optimal arguments.
11
+ generator = VideoGenerator.from_pretrained(
12
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
13
+ # FastVideo will automatically handle distributed setup
14
+ num_gpus=1,
15
+ use_fsdp_inference=False, # set to True if GPU is out of memory
16
+ dit_cpu_offload=False,
17
+ vae_cpu_offload=False,
18
+ text_encoder_cpu_offload=True,
19
+ pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
20
+ # image_encoder_cpu_offload=False,
21
+ )
22
+
23
+ # sampling_param = SamplingParam.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
24
+ # sampling_param.num_frames = 45
25
+ # sampling_param.image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
26
+ # Generate videos with the same simple API, regardless of GPU count
27
+ prompt = (
28
+ "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes "
29
+ "wide with interest. The playful yet serene atmosphere is complemented by soft "
30
+ "natural light filtering through the petals. Mid-shot, warm and cheerful tones."
31
+ )
32
+ video = generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True)
33
+ # video = generator.generate_video(prompt, sampling_param=sampling_param, output_path="wan_t2v_videos/")
34
+
35
+ # Generate another video with a different prompt, without reloading the
36
+ # model!
37
+ prompt2 = (
38
+ "A majestic lion strides across the golden savanna, its powerful frame "
39
+ "glistening under the warm afternoon sun. The tall grass ripples gently in "
40
+ "the breeze, enhancing the lion's commanding presence. The tone is vibrant, "
41
+ "embodying the raw energy of the wild. Low angle, steady tracking shot, "
42
+ "cinematic.")
43
+ video2 = generator.generate_video(prompt2, output_path=OUTPUT_PATH, save_video=True)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
backend_snapshot/fastvideo/api/compat.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from __future__ import annotations
3
+
4
+ from collections.abc import Mapping
5
+ from copy import deepcopy
6
+ from dataclasses import fields, is_dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from fastvideo.api.overrides import apply_overrides, parse_cli_overrides
11
+ from fastvideo.api.parser import config_to_dict, load_raw_config, parse_config
12
+ from fastvideo.api.schema import (
13
+ GenerationRequest,
14
+ GeneratorConfig,
15
+ InputConfig,
16
+ OutputConfig,
17
+ RequestRuntimeConfig,
18
+ SamplingConfig,
19
+ )
20
+ from fastvideo.configs.sample import SamplingParam
21
+ from fastvideo.fastvideo_args import FastVideoArgs
22
+ from fastvideo.utils import shallow_asdict
23
+
24
+ _EXPLICIT_REQUEST_ATTR = "_fastvideo_explicit_request"
25
+ _INPUT_FIELD_NAMES = {field.name for field in fields(InputConfig)}
26
+ _SAMPLING_FIELD_NAMES = {field.name for field in fields(SamplingConfig)}
27
+ _RUNTIME_FIELD_NAMES = {field.name for field in fields(RequestRuntimeConfig)}
28
+ _OUTPUT_FIELD_NAMES = {field.name for field in fields(OutputConfig)}
29
+ _MISSING = object()
30
+ _LEGACY_REQUEST_ALIASES = {
31
+ "neg_prompt": "negative_prompt",
32
+ }
33
+ _REQUEST_PIPELINE_OVERRIDE_FIELDS = frozenset({
34
+ "embedded_cfg_scale",
35
+ })
36
+
37
+
38
+ def normalize_generator_config(config: GeneratorConfig | Mapping[str, Any], ) -> GeneratorConfig:
39
+ if isinstance(config, GeneratorConfig):
40
+ return config
41
+ return parse_config(GeneratorConfig, config)
42
+
43
+
44
+ def load_generator_config_from_file(
45
+ path: str | Path,
46
+ overrides: list[str] | Mapping[str, Any] | None = None,
47
+ ) -> GeneratorConfig:
48
+ raw = load_raw_config(path)
49
+ normalized_overrides = _normalize_overrides(overrides)
50
+
51
+ if _looks_like_run_or_serve_config(raw):
52
+ if normalized_overrides:
53
+ raw = apply_overrides(raw, normalized_overrides)
54
+ return parse_config(GeneratorConfig, raw["generator"])
55
+
56
+ if normalized_overrides:
57
+ adjusted = normalized_overrides
58
+ if all(key.startswith("generator.") for key in adjusted):
59
+ adjusted = {key[len("generator."):]: value for key, value in adjusted.items()}
60
+ raw = apply_overrides(raw, adjusted)
61
+
62
+ return parse_config(GeneratorConfig, raw)
63
+
64
+
65
+ def legacy_from_pretrained_to_config(
66
+ model_path: str,
67
+ kwargs: Mapping[str, Any],
68
+ ) -> GeneratorConfig:
69
+ raw: dict[str, Any] = {"model_path": model_path}
70
+ engine: dict[str, Any] = {}
71
+ parallelism: dict[str, Any] = {}
72
+ offload: dict[str, Any] = {}
73
+ compile_config: dict[str, Any] = {}
74
+ pipeline: dict[str, Any] = {}
75
+ components: dict[str, Any] = {}
76
+ quantization: dict[str, Any] = {}
77
+ experimental: dict[str, Any] = {}
78
+
79
+ for key, value in kwargs.items():
80
+ if key == "revision":
81
+ raw["revision"] = value
82
+ elif key == "trust_remote_code":
83
+ raw["trust_remote_code"] = value
84
+ elif key == "num_gpus":
85
+ engine["num_gpus"] = value
86
+ elif key == "distributed_executor_backend":
87
+ engine["execution_backend"] = value
88
+ elif key in {"tp_size", "sp_size", "hsdp_replicate_dim", "hsdp_shard_dim", "dist_timeout"}:
89
+ parallelism[key] = value
90
+ elif key == "dit_cpu_offload":
91
+ offload["dit"] = value
92
+ elif key == "dit_layerwise_offload":
93
+ offload["dit_layerwise"] = value
94
+ elif key == "text_encoder_cpu_offload":
95
+ offload["text_encoder"] = value
96
+ elif key == "image_encoder_cpu_offload":
97
+ offload["image_encoder"] = value
98
+ elif key == "vae_cpu_offload":
99
+ offload["vae"] = value
100
+ elif key == "pin_cpu_memory":
101
+ offload["pin_cpu_memory"] = value
102
+ elif key == "enable_torch_compile":
103
+ compile_config["enabled"] = value
104
+ elif key == "torch_compile_kwargs":
105
+ compile_config["kwargs"] = deepcopy(value)
106
+ elif key in {"enable_stage_verification", "use_fsdp_inference", "disable_autocast"}:
107
+ engine[key] = value
108
+ elif key == "override_text_encoder_quant":
109
+ quantization["text_encoder_quant"] = value
110
+ elif key == "transformer_quant":
111
+ quantization["transformer_quant"] = value
112
+ elif key == "workload_type":
113
+ pipeline["workload_type"] = value
114
+ elif key == "lora_path":
115
+ components["lora_path"] = value
116
+ elif key == "override_pipeline_cls_name":
117
+ components["override_pipeline_cls_name"] = value
118
+ elif key == "override_transformer_cls_name":
119
+ components["override_transformer_cls_name"] = value
120
+ elif key == "pipeline_config":
121
+ if isinstance(value, str):
122
+ components["pipeline_config_path"] = value
123
+ else:
124
+ experimental[key] = deepcopy(value)
125
+ elif key == "override_text_encoder_safetensors":
126
+ components["text_encoder_weights"] = value
127
+ elif key == "init_weights_from_safetensors":
128
+ components["transformer_weights"] = value
129
+ elif key == "init_weights_from_safetensors_2":
130
+ components["transformer_2_weights"] = value
131
+ else:
132
+ experimental[key] = deepcopy(value)
133
+
134
+ if parallelism:
135
+ engine["parallelism"] = parallelism
136
+ if offload:
137
+ engine["offload"] = offload
138
+ if compile_config:
139
+ engine["compile"] = compile_config
140
+ if quantization:
141
+ engine["quantization"] = quantization
142
+ if engine:
143
+ raw["engine"] = engine
144
+
145
+ if components:
146
+ pipeline["components"] = components
147
+ if experimental:
148
+ pipeline["experimental"] = experimental
149
+ if pipeline:
150
+ raw["pipeline"] = pipeline
151
+
152
+ return parse_config(GeneratorConfig, raw)
153
+
154
+
155
+ def generator_config_to_fastvideo_args(config: GeneratorConfig | Mapping[str, Any], ) -> FastVideoArgs:
156
+ normalized = normalize_generator_config(config)
157
+ unsupported = []
158
+ if normalized.pipeline.profile is not None:
159
+ unsupported.append("pipeline.profile")
160
+ if normalized.pipeline.profile_version is not None:
161
+ unsupported.append("pipeline.profile_version")
162
+ if normalized.pipeline.components.config_root is not None:
163
+ unsupported.append("pipeline.components.config_root")
164
+ if normalized.pipeline.components.vae_weights is not None:
165
+ unsupported.append("pipeline.components.vae_weights")
166
+ if normalized.pipeline.components.upsampler_weights is not None:
167
+ unsupported.append("pipeline.components.upsampler_weights")
168
+ if unsupported:
169
+ joined = ", ".join(unsupported)
170
+ raise NotImplementedError(f"VideoGenerator compatibility adapter does not support {joined} yet")
171
+
172
+ engine = normalized.engine
173
+ kwargs: dict[str, Any] = {
174
+ "model_path": normalized.model_path,
175
+ "revision": normalized.revision,
176
+ "trust_remote_code": normalized.trust_remote_code,
177
+ "num_gpus": engine.num_gpus,
178
+ "distributed_executor_backend": engine.execution_backend,
179
+ "tp_size": engine.parallelism.tp_size,
180
+ "sp_size": engine.parallelism.sp_size,
181
+ "hsdp_replicate_dim": engine.parallelism.hsdp_replicate_dim,
182
+ "hsdp_shard_dim": engine.parallelism.hsdp_shard_dim,
183
+ "dist_timeout": engine.parallelism.dist_timeout,
184
+ "dit_cpu_offload": engine.offload.dit,
185
+ "dit_layerwise_offload": engine.offload.dit_layerwise,
186
+ "text_encoder_cpu_offload": engine.offload.text_encoder,
187
+ "image_encoder_cpu_offload": engine.offload.image_encoder,
188
+ "vae_cpu_offload": engine.offload.vae,
189
+ "pin_cpu_memory": engine.offload.pin_cpu_memory,
190
+ "enable_torch_compile": engine.compile.enabled,
191
+ "torch_compile_kwargs": deepcopy(engine.compile.kwargs),
192
+ "enable_stage_verification": engine.enable_stage_verification,
193
+ "use_fsdp_inference": engine.use_fsdp_inference,
194
+ "disable_autocast": engine.disable_autocast,
195
+ }
196
+ if normalized.pipeline.workload_type is not None:
197
+ kwargs["workload_type"] = normalized.pipeline.workload_type
198
+
199
+ quantization = engine.quantization
200
+ if quantization is not None and quantization.text_encoder_quant is not None:
201
+ kwargs["override_text_encoder_quant"] = quantization.text_encoder_quant
202
+ if quantization is not None and quantization.transformer_quant is not None:
203
+ kwargs["transformer_quant"] = quantization.transformer_quant
204
+
205
+ components = normalized.pipeline.components
206
+ if components.pipeline_config_path is not None:
207
+ kwargs["pipeline_config"] = components.pipeline_config_path
208
+ if components.lora_path is not None:
209
+ kwargs["lora_path"] = components.lora_path
210
+ if components.override_pipeline_cls_name is not None:
211
+ kwargs["override_pipeline_cls_name"] = components.override_pipeline_cls_name
212
+ if components.override_transformer_cls_name is not None:
213
+ kwargs["override_transformer_cls_name"] = components.override_transformer_cls_name
214
+ if components.text_encoder_weights is not None:
215
+ kwargs["override_text_encoder_safetensors"] = components.text_encoder_weights
216
+ if components.transformer_weights is not None:
217
+ kwargs["init_weights_from_safetensors"] = components.transformer_weights
218
+ if components.transformer_2_weights is not None:
219
+ kwargs["init_weights_from_safetensors_2"] = components.transformer_2_weights
220
+
221
+ kwargs.update(deepcopy(normalized.pipeline.profile_overrides))
222
+ kwargs.update(deepcopy(normalized.pipeline.experimental))
223
+ return FastVideoArgs.from_kwargs(**kwargs)
224
+
225
+
226
+ def normalize_generation_request(request: GenerationRequest | Mapping[str, Any], ) -> GenerationRequest:
227
+ normalized = (request if isinstance(request, GenerationRequest) else parse_config(GenerationRequest, request))
228
+
229
+ if not hasattr(normalized, _EXPLICIT_REQUEST_ATTR):
230
+ setattr(normalized, _EXPLICIT_REQUEST_ATTR, _serialize_generation_request(normalized))
231
+ return normalized
232
+
233
+
234
+ def legacy_generate_call_to_request(
235
+ prompt: str | None,
236
+ sampling_param: SamplingParam | None,
237
+ *,
238
+ mouse_cond: Any | None = None,
239
+ keyboard_cond: Any | None = None,
240
+ grid_sizes: Any | None = None,
241
+ legacy_kwargs: Mapping[str, Any] | None = None,
242
+ ) -> GenerationRequest:
243
+ raw = _sampling_param_to_request_raw(sampling_param)
244
+ if prompt is not None:
245
+ raw["prompt"] = prompt
246
+
247
+ for key, value in (legacy_kwargs or {}).items():
248
+ _apply_request_field(raw, key, value)
249
+
250
+ if mouse_cond is not None:
251
+ raw.setdefault("inputs", {})["mouse_cond"] = mouse_cond
252
+ if keyboard_cond is not None:
253
+ raw.setdefault("inputs", {})["keyboard_cond"] = keyboard_cond
254
+ if grid_sizes is not None:
255
+ raw.setdefault("inputs", {})["grid_sizes"] = grid_sizes
256
+
257
+ normalized = parse_config(GenerationRequest, raw)
258
+ setattr(normalized, _EXPLICIT_REQUEST_ATTR, deepcopy(raw))
259
+ return normalized
260
+
261
+
262
+ def request_to_sampling_param(
263
+ request: GenerationRequest,
264
+ *,
265
+ model_path: str,
266
+ ) -> SamplingParam:
267
+ if request.plan is not None:
268
+ raise NotImplementedError("GenerationRequest.plan is not wired into VideoGenerator yet")
269
+ if request.state is not None:
270
+ raise NotImplementedError("GenerationRequest.state is not wired into VideoGenerator yet")
271
+
272
+ sampling_param = SamplingParam.from_pretrained(model_path)
273
+ updates = _explicit_request_updates(request)
274
+
275
+ for key, value in updates.items():
276
+ if hasattr(sampling_param, key):
277
+ setattr(sampling_param, key, deepcopy(value))
278
+ elif key in _REQUEST_PIPELINE_OVERRIDE_FIELDS or _is_supported_as_default_only(key, value):
279
+ continue
280
+ else:
281
+ raise ValueError(f"Request field {key!r} is not supported by sampling params for {model_path}")
282
+
283
+ sampling_param.__post_init__()
284
+ sampling_param.check_sampling_param()
285
+ return sampling_param
286
+
287
+
288
+ def expand_request_prompt_batch(request: GenerationRequest, ) -> list[GenerationRequest]:
289
+ if not isinstance(request.prompt, list):
290
+ return [request]
291
+
292
+ requests: list[GenerationRequest] = []
293
+ for index, prompt in enumerate(request.prompt):
294
+ single_request = deepcopy(request)
295
+ single_request.prompt = prompt
296
+ _fan_out_batched_input_value(request, single_request, "image_path", index)
297
+ _fan_out_batched_input_value(request, single_request, "video_path", index)
298
+ _fan_out_explicit_request_metadata(request, single_request, index, prompt)
299
+ requests.append(single_request)
300
+ return requests
301
+
302
+
303
+ def _looks_like_run_or_serve_config(raw: Mapping[str, Any]) -> bool:
304
+ return isinstance(raw.get("generator"), Mapping)
305
+
306
+
307
+ def _normalize_overrides(overrides: list[str] | Mapping[str, Any] | None, ) -> dict[str, Any] | None:
308
+ if not overrides:
309
+ return None
310
+ if isinstance(overrides, list):
311
+ return parse_cli_overrides(overrides)
312
+ return dict(overrides)
313
+
314
+
315
+ def _sampling_param_to_request_raw(sampling_param: SamplingParam | None, ) -> dict[str, Any]:
316
+ if sampling_param is None:
317
+ return {}
318
+
319
+ raw: dict[str, Any] = {}
320
+ for key, value in shallow_asdict(sampling_param).items():
321
+ if key == "prompt":
322
+ continue
323
+ _apply_request_field(raw, key, deepcopy(value))
324
+ return raw
325
+
326
+
327
+ def _apply_request_field(
328
+ raw: dict[str, Any],
329
+ key: str,
330
+ value: Any,
331
+ ) -> None:
332
+ key = _LEGACY_REQUEST_ALIASES.get(key, key)
333
+ if key == "negative_prompt":
334
+ raw["negative_prompt"] = value
335
+ return
336
+ if key in _INPUT_FIELD_NAMES:
337
+ raw.setdefault("inputs", {})[key] = value
338
+ return
339
+ if key in _SAMPLING_FIELD_NAMES:
340
+ raw.setdefault("sampling", {})[key] = value
341
+ return
342
+ if key in _RUNTIME_FIELD_NAMES:
343
+ raw.setdefault("runtime", {})[key] = value
344
+ return
345
+ if key in _OUTPUT_FIELD_NAMES:
346
+ raw.setdefault("output", {})[key] = value
347
+ return
348
+ raw.setdefault("extensions", {})[key] = value
349
+
350
+
351
+ def request_to_pipeline_overrides(request: GenerationRequest) -> dict[str, Any]:
352
+ overrides: dict[str, Any] = {}
353
+ for key, value in _explicit_request_updates(request).items():
354
+ if key in _REQUEST_PIPELINE_OVERRIDE_FIELDS:
355
+ overrides[key] = deepcopy(value)
356
+ return overrides
357
+
358
+
359
+ def _explicit_request_updates(request: GenerationRequest) -> dict[str, Any]:
360
+ raw = getattr(request, _EXPLICIT_REQUEST_ATTR, None)
361
+ if raw is None:
362
+ raw = _serialize_generation_request(request)
363
+
364
+ return _extract_request_updates(raw)
365
+
366
+
367
+ def _extract_request_updates(raw: Mapping[str, Any]) -> dict[str, Any]:
368
+ updates: dict[str, Any] = {}
369
+ if "negative_prompt" in raw:
370
+ updates["negative_prompt"] = deepcopy(raw["negative_prompt"])
371
+
372
+ for section_name in ("inputs", "sampling", "runtime", "output"):
373
+ section = raw.get(section_name)
374
+ if not isinstance(section, Mapping):
375
+ continue
376
+ for key, value in section.items():
377
+ updates[key] = deepcopy(value)
378
+
379
+ stage_overrides = raw.get("stage_overrides")
380
+ if stage_overrides:
381
+ updates.update(_flatten_stage_overrides(stage_overrides))
382
+
383
+ extensions = raw.get("extensions")
384
+ if isinstance(extensions, Mapping):
385
+ for key, value in extensions.items():
386
+ updates[key] = deepcopy(value)
387
+
388
+ return updates
389
+
390
+
391
+ def _flatten_stage_overrides(stage_overrides: Any) -> dict[str, Any]:
392
+ if not isinstance(stage_overrides, Mapping):
393
+ raise ValueError("GenerationRequest.stage_overrides must be a mapping")
394
+
395
+ flattened: dict[str, Any] = {}
396
+ for stage_name, overrides in stage_overrides.items():
397
+ if not isinstance(overrides, Mapping):
398
+ raise ValueError(f"GenerationRequest.stage_overrides.{stage_name} must be a mapping")
399
+ for key, value in overrides.items():
400
+ if key in flattened and flattened[key] != value:
401
+ raise ValueError(f"Conflicting stage override for {key!r} across stages")
402
+ flattened[key] = deepcopy(value)
403
+ return flattened
404
+
405
+
406
+ def _serialize_generation_request(request: GenerationRequest) -> dict[str, Any]:
407
+ return deepcopy(config_to_dict(request))
408
+
409
+
410
+ def _fan_out_batched_input_value(
411
+ source_request: GenerationRequest,
412
+ target_request: GenerationRequest,
413
+ field_name: str,
414
+ index: int,
415
+ ) -> None:
416
+ value = getattr(source_request.inputs, field_name)
417
+ if not isinstance(value, list):
418
+ return
419
+ _validate_batched_input_length(source_request.prompt, value, field_name)
420
+ setattr(target_request.inputs, field_name, deepcopy(value[index]))
421
+
422
+
423
+ def _fan_out_explicit_request_metadata(
424
+ source_request: GenerationRequest,
425
+ target_request: GenerationRequest,
426
+ index: int,
427
+ prompt: str,
428
+ ) -> None:
429
+ raw = getattr(source_request, _EXPLICIT_REQUEST_ATTR, None)
430
+ if raw is None:
431
+ return
432
+
433
+ raw = deepcopy(raw)
434
+ raw["prompt"] = prompt
435
+ inputs = raw.get("inputs")
436
+ if isinstance(inputs, dict):
437
+ for field_name in ("image_path", "video_path"):
438
+ value = inputs.get(field_name)
439
+ if isinstance(value, list):
440
+ _validate_batched_input_length(source_request.prompt, value, field_name)
441
+ inputs[field_name] = deepcopy(value[index])
442
+
443
+ setattr(target_request, _EXPLICIT_REQUEST_ATTR, raw)
444
+
445
+
446
+ def _validate_batched_input_length(
447
+ prompts: str | list[str] | None,
448
+ values: list[Any],
449
+ field_name: str,
450
+ ) -> None:
451
+ if not isinstance(prompts, list):
452
+ return
453
+ if len(values) != len(prompts):
454
+ raise ValueError(f"GenerationRequest.inputs.{field_name} must have the same length as request.prompt")
455
+
456
+
457
+ def _is_supported_as_default_only(key: str, value: Any) -> bool:
458
+ default_value = _DEFAULT_REQUEST_UPDATES.get(key, _MISSING)
459
+ return default_value is not _MISSING and _values_equal(value, default_value)
460
+
461
+
462
+ def _collect_non_default_fields(
463
+ value: Any,
464
+ default: Any,
465
+ ) -> dict[str, Any]:
466
+ if not (is_dataclass(value) and is_dataclass(default)):
467
+ return {}
468
+
469
+ result: dict[str, Any] = {}
470
+ for field in fields(value):
471
+ current = getattr(value, field.name)
472
+ default_value = getattr(default, field.name)
473
+ if is_dataclass(current) and is_dataclass(default_value):
474
+ nested = _collect_non_default_fields(current, default_value)
475
+ if nested:
476
+ result[field.name] = nested
477
+ continue
478
+ if not _values_equal(current, default_value):
479
+ result[field.name] = deepcopy(current)
480
+ return result
481
+
482
+
483
+ def _values_equal(left: Any, right: Any) -> bool:
484
+ if left is right:
485
+ return True
486
+ try:
487
+ return bool(left == right)
488
+ except Exception:
489
+ return False
490
+
491
+
492
+ _DEFAULT_REQUEST_UPDATES = _extract_request_updates(config_to_dict(GenerationRequest()))
493
+
494
+ __all__ = [
495
+ "generator_config_to_fastvideo_args",
496
+ "legacy_from_pretrained_to_config",
497
+ "legacy_generate_call_to_request",
498
+ "load_generator_config_from_file",
499
+ "normalize_generation_request",
500
+ "normalize_generator_config",
501
+ "request_to_pipeline_overrides",
502
+ "request_to_sampling_param",
503
+ ]
backend_snapshot/fastvideo/configs/pipelines/wan.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from collections.abc import Callable
3
+ from dataclasses import dataclass, field
4
+
5
+ import torch
6
+
7
+ from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
8
+ from fastvideo.configs.models.dits import WanVideoConfig
9
+ from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
10
+ from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config,
11
+ WAN2_1ControlCLIPVisionConfig)
12
+ from fastvideo.configs.models.vaes import WanVAEConfig
13
+ from fastvideo.configs.pipelines.base import PipelineConfig
14
+
15
+
16
+ def t5_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
17
+ mask: torch.Tensor = outputs.attention_mask
18
+ hidden_state: torch.Tensor = outputs.last_hidden_state
19
+ seq_lens = mask.gt(0).sum(dim=1).long()
20
+ assert torch.isnan(hidden_state).sum() == 0
21
+ prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
22
+ prompt_embeds_tensor: torch.Tensor = torch.stack(
23
+ [torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0)
24
+ return prompt_embeds_tensor
25
+
26
+
27
+ @dataclass
28
+ class WanT2V480PConfig(PipelineConfig):
29
+ """Base configuration for Wan T2V 1.3B pipeline architecture."""
30
+
31
+ # WanConfig-specific parameters with defaults
32
+ # DiT
33
+ dit_config: DiTConfig = field(default_factory=WanVideoConfig)
34
+ # VAE
35
+ vae_config: VAEConfig = field(default_factory=WanVAEConfig)
36
+ vae_tiling: bool = False
37
+ vae_sp: bool = False
38
+
39
+ # Denoising stage
40
+ flow_shift: float | None = 3.0
41
+
42
+ # Text encoding stage
43
+ text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (T5Config(), ))
44
+ postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor],
45
+ ...] = field(default_factory=lambda: (t5_postprocess_text, ))
46
+
47
+ # Precision for each component
48
+ precision: str = "bf16"
49
+ vae_precision: str = "fp32"
50
+ text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", ))
51
+
52
+ # self-forcing params
53
+ warp_denoising_step: bool = True
54
+
55
+ # WanConfig-specific added parameters
56
+
57
+ def __post_init__(self):
58
+ self.vae_config.load_encoder = False
59
+ self.vae_config.load_decoder = True
60
+
61
+
62
+ @dataclass
63
+ class WanT2V720PConfig(WanT2V480PConfig):
64
+ """Base configuration for Wan T2V 14B 720P pipeline architecture."""
65
+
66
+ # WanConfig-specific parameters with defaults
67
+
68
+ # Denoising stage
69
+ flow_shift: float | None = 5.0
70
+
71
+
72
+ @dataclass
73
+ class WanI2V480PConfig(WanT2V480PConfig):
74
+ """Base configuration for Wan I2V 14B 480P pipeline architecture."""
75
+
76
+ # WanConfig-specific parameters with defaults
77
+
78
+ # Precision for each component
79
+ image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig)
80
+ image_encoder_precision: str = "fp32"
81
+
82
+ def __post_init__(self) -> None:
83
+ self.vae_config.load_encoder = True
84
+ self.vae_config.load_decoder = True
85
+
86
+
87
+ @dataclass
88
+ class WanI2V720PConfig(WanI2V480PConfig):
89
+ """Base configuration for Wan I2V 14B 720P pipeline architecture."""
90
+
91
+ # WanConfig-specific parameters with defaults
92
+
93
+ # Denoising stage
94
+ flow_shift: float | None = 5.0
95
+
96
+
97
+ @dataclass
98
+ class WANV2VConfig(WanI2V480PConfig):
99
+ """Configuration for WAN2.1 1.3B Control pipeline."""
100
+
101
+ image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
102
+ # CLIP encoder precision
103
+ image_encoder_precision: str = 'bf16'
104
+
105
+
106
+ @dataclass
107
+ class FastWan2_1_T2V_480P_Config(WanT2V480PConfig):
108
+ """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD"""
109
+
110
+ # WanConfig-specific parameters with defaults
111
+
112
+ # Denoising stage
113
+ flow_shift: float | None = 8.0
114
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
115
+
116
+
117
+ @dataclass
118
+ class Wan2_2_TI2V_5B_Config(WanT2V480PConfig):
119
+ flow_shift: float | None = 5.0
120
+ ti2v_task: bool = True
121
+ expand_timesteps: bool = True
122
+
123
+ def __post_init__(self) -> None:
124
+ self.vae_config.load_encoder = True
125
+ self.vae_config.load_decoder = True
126
+ self.dit_config.expand_timesteps = self.expand_timesteps
127
+
128
+
129
+ @dataclass
130
+ class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config):
131
+ flow_shift: float | None = 5.0
132
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
133
+
134
+
135
+ @dataclass
136
+ class Wan2_2_T2V_A14B_Config(WanT2V480PConfig):
137
+ flow_shift: float | None = 12.0
138
+ boundary_ratio: float | None = 0.875
139
+
140
+ # self-forcing params
141
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
142
+ warp_denoising_step: bool = True
143
+
144
+ def __post_init__(self) -> None:
145
+ self.dit_config.boundary_ratio = self.boundary_ratio
146
+
147
+
148
+ @dataclass
149
+ class Wan2_2_I2V_A14B_Config(WanI2V480PConfig):
150
+ flow_shift: float | None = 5.0
151
+ boundary_ratio: float | None = 0.900
152
+
153
+ def __post_init__(self) -> None:
154
+ super().__post_init__()
155
+ self.dit_config.boundary_ratio = self.boundary_ratio
156
+
157
+
158
+ # =============================================
159
+ # ============= Causal Self-Forcing =============
160
+ # =============================================
161
+ @dataclass
162
+ class SelfForcingWanT2V480PConfig(WanT2V480PConfig):
163
+ is_causal: bool = True
164
+ flow_shift: float | None = 5.0
165
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 750, 500, 250])
166
+ warp_denoising_step: bool = True
167
+
168
+
169
+ @dataclass
170
+ class SelfForcingWan2_2_T2V480PConfig(Wan2_2_T2V_A14B_Config):
171
+ is_causal: bool = True
172
+ flow_shift: float | None = 12.0
173
+ boundary_ratio: float | None = 0.875
174
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 850, 700, 550, 350, 275, 200, 125])
175
+ warp_denoising_step: bool = True
176
+
177
+ def __post_init__(self) -> None:
178
+ self.vae_config.load_encoder = True
179
+ self.vae_config.load_decoder = True
180
+
181
+
182
+ # =============================================
183
+ # ============= Matrix Game ===================
184
+ # =============================================
185
+ @dataclass
186
+ class MatrixGameBaseI2V480PConfig(WanI2V480PConfig):
187
+ dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
188
+ flow_shift: float | None = 5.0
189
+
190
+
191
+ @dataclass
192
+ class MatrixGameI2V480PConfig(WanI2V480PConfig):
193
+ dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
194
+
195
+ image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
196
+
197
+ is_causal: bool = True
198
+ flow_shift: float | None = 5.0
199
+ dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
200
+ warp_denoising_step: bool = True
201
+ context_noise: int = 0
202
+ num_frames_per_block: int = 3
203
+ # sliding_window_num_frames: int = 15
backend_snapshot/fastvideo/configs/sample/base.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from fastvideo.logger import init_logger
6
+ from fastvideo.utils import StoreBoolean
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ @dataclass
12
+ class SamplingParam:
13
+ """
14
+ Sampling parameters for video generation.
15
+ """
16
+ # All fields below are copied from ForwardBatch
17
+ data_type: str = "video"
18
+
19
+ # Image inputs
20
+ image_path: str | None = None
21
+ pil_image: Any | None = None
22
+
23
+ # Video inputs
24
+ video_path: str | None = None
25
+
26
+ # Action control inputs (Matrix-Game)
27
+ mouse_cond: Any | None = None # Shape: (B, T, 2)
28
+ keyboard_cond: Any | None = None # Shape: (B, T, K)
29
+ grid_sizes: Any | None = None # Shape: (3,) [F,H,W]
30
+
31
+ # Camera control inputs (HYWorld)
32
+ pose: str | None = None # Camera trajectory: pose string (e.g., 'w-31') or JSON file path
33
+
34
+ # Camera control inputs (LingBotWorld)
35
+ c2ws_plucker_emb: Any | None = None # Plucker embedding: [B, C, F_lat, H_lat, W_lat]
36
+
37
+ # Refine inputs (LongCat 480p->720p upscaling)
38
+ # Path-based refine (load stage1 video from disk, e.g. MP4)
39
+ refine_from: str | None = None # Path to stage1 video (480p output from distill)
40
+ t_thresh: float = 0.5 # Threshold for timestep scheduling in refinement
41
+ spatial_refine_only: bool = False # If True, only spatial (no temporal doubling)
42
+ num_cond_frames: int = 0 # Number of conditioning frames
43
+ # In-memory refine input (for two-stage pipeline where stage1 frames are already in memory)
44
+ # This mirrors LongCat's demo where a list of frames (e.g. np.ndarray or PIL.Image)
45
+ # is passed directly to the refinement pipeline instead of reloading from disk.
46
+ stage1_video: Any | None = None
47
+
48
+ # Text inputs
49
+ prompt: str | list[str] | None = None
50
+ negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
51
+ prompt_path: str | None = None
52
+ output_path: str = "outputs/"
53
+ output_video_name: str | None = None
54
+
55
+ # Batch info
56
+ num_videos_per_prompt: int = 1
57
+ seed: int = 1024
58
+
59
+ # Original dimensions (before VAE scaling)
60
+ num_frames: int = 125
61
+ height: int = 720
62
+ width: int = 1280
63
+ height_sr: int = 1072
64
+ width_sr: int = 1920
65
+ fps: int = 24
66
+
67
+ # Denoising parameters
68
+ num_inference_steps: int = 50
69
+ num_inference_steps_sr: int = 50
70
+ guidance_scale: float = 1.0
71
+ guidance_scale_2: float | None = None
72
+ guidance_rescale: float = 0.0
73
+ boundary_ratio: float | None = None
74
+ sigmas: list[float] | None = None
75
+
76
+ # TeaCache parameters
77
+ enable_teacache: bool = False
78
+
79
+ # GEN3C camera control
80
+ trajectory_type: str | None = None
81
+ movement_distance: float | None = None
82
+ camera_rotation: str | None = None
83
+
84
+ # Misc
85
+ save_video: bool = True
86
+ return_frames: bool = True
87
+ return_trajectory_latents: bool = False # returns all latents for each timestep
88
+ return_trajectory_decoded: bool = False # returns decoded latents for each timestep
89
+
90
+ def __post_init__(self) -> None:
91
+ self.data_type = "video" if self.num_frames > 1 else "image"
92
+
93
+ def __getattr__(self, name: str) -> Any:
94
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
95
+
96
+ def check_sampling_param(self) -> None:
97
+ if self.prompt_path and not self.prompt_path.endswith(".txt"):
98
+ raise ValueError("prompt_path must be a txt file")
99
+
100
+ def update(self, source_dict: dict[str, Any]) -> None:
101
+ for key, value in source_dict.items():
102
+ if hasattr(self, key):
103
+ setattr(self, key, value)
104
+ else:
105
+ logger.exception("%s has no attribute %s", type(self).__name__, key)
106
+
107
+ self.__post_init__()
108
+
109
+ @classmethod
110
+ def from_pretrained(cls, model_path: str) -> "SamplingParam":
111
+ from fastvideo.registry import get_sampling_param_cls_for_name
112
+ sampling_cls = get_sampling_param_cls_for_name(model_path)
113
+ if sampling_cls is not None:
114
+ sampling_param: SamplingParam = sampling_cls()
115
+ else:
116
+ logger.warning("Couldn't find an optimal sampling param for %s. Using the default sampling param.",
117
+ model_path)
118
+ sampling_param = cls()
119
+
120
+ return sampling_param
121
+
122
+ @staticmethod
123
+ def add_cli_args(parser: Any) -> Any:
124
+ """Add CLI arguments for SamplingParam fields"""
125
+ parser.add_argument(
126
+ "--prompt",
127
+ type=str,
128
+ default=SamplingParam.prompt,
129
+ help="Text prompt for video generation",
130
+ )
131
+ parser.add_argument(
132
+ "--negative-prompt",
133
+ type=str,
134
+ default=SamplingParam.negative_prompt,
135
+ help="Negative text prompt for video generation",
136
+ )
137
+ parser.add_argument(
138
+ "--prompt-path",
139
+ type=str,
140
+ default=SamplingParam.prompt_path,
141
+ help="Path to a text file containing the prompt",
142
+ )
143
+ parser.add_argument(
144
+ "--output-path",
145
+ type=str,
146
+ default=SamplingParam.output_path,
147
+ help="Path to save the generated video",
148
+ )
149
+ parser.add_argument(
150
+ "--output-video-name",
151
+ type=str,
152
+ default=SamplingParam.output_video_name,
153
+ help="Name of the output video",
154
+ )
155
+ parser.add_argument(
156
+ "--num-videos-per-prompt",
157
+ type=int,
158
+ default=SamplingParam.num_videos_per_prompt,
159
+ help="Number of videos to generate per prompt",
160
+ )
161
+ parser.add_argument(
162
+ "--seed",
163
+ type=int,
164
+ default=SamplingParam.seed,
165
+ help="Random seed for generation",
166
+ )
167
+ parser.add_argument(
168
+ "--num-frames",
169
+ type=int,
170
+ default=SamplingParam.num_frames,
171
+ help="Number of frames to generate",
172
+ )
173
+ parser.add_argument(
174
+ "--height",
175
+ type=int,
176
+ default=SamplingParam.height,
177
+ help="Height of generated video",
178
+ )
179
+ parser.add_argument(
180
+ "--width",
181
+ type=int,
182
+ default=SamplingParam.width,
183
+ help="Width of generated video",
184
+ )
185
+ parser.add_argument(
186
+ "--fps",
187
+ type=int,
188
+ default=SamplingParam.fps,
189
+ help="Frames per second for saved video",
190
+ )
191
+ parser.add_argument(
192
+ "--num-inference-steps",
193
+ type=int,
194
+ default=SamplingParam.num_inference_steps,
195
+ help="Number of denoising steps",
196
+ )
197
+ parser.add_argument(
198
+ "--guidance-scale",
199
+ type=float,
200
+ default=SamplingParam.guidance_scale,
201
+ help="Classifier-free guidance scale",
202
+ )
203
+ parser.add_argument(
204
+ "--guidance-rescale",
205
+ type=float,
206
+ default=SamplingParam.guidance_rescale,
207
+ help="Guidance rescale factor",
208
+ )
209
+ parser.add_argument(
210
+ "--boundary-ratio",
211
+ type=float,
212
+ default=SamplingParam.boundary_ratio,
213
+ help="Boundary timestep ratio",
214
+ )
215
+ parser.add_argument(
216
+ "--save-video",
217
+ action="store_true",
218
+ default=SamplingParam.save_video,
219
+ help="Whether to save the video to disk",
220
+ )
221
+ parser.add_argument(
222
+ "--no-save-video",
223
+ action="store_false",
224
+ dest="save_video",
225
+ help="Don't save the video to disk",
226
+ )
227
+ parser.add_argument(
228
+ "--return-frames",
229
+ action="store_true",
230
+ default=False,
231
+ help="Whether to return the raw frames",
232
+ )
233
+ parser.add_argument(
234
+ "--image-path",
235
+ type=str,
236
+ default=SamplingParam.image_path,
237
+ help="Path to input image for image-to-video generation",
238
+ )
239
+ parser.add_argument(
240
+ "--video-path",
241
+ type=str,
242
+ default=SamplingParam.video_path,
243
+ help="Path to input video for video-to-video generation",
244
+ )
245
+ parser.add_argument(
246
+ "--refine-from",
247
+ type=str,
248
+ default=SamplingParam.refine_from,
249
+ help="Path to stage1 video for refinement (LongCat 480p->720p)",
250
+ )
251
+ parser.add_argument(
252
+ "--t-thresh",
253
+ type=float,
254
+ default=SamplingParam.t_thresh,
255
+ help="Threshold for timestep scheduling in refinement (default: 0.5)",
256
+ )
257
+ parser.add_argument(
258
+ "--spatial-refine-only",
259
+ action=StoreBoolean,
260
+ default=SamplingParam.spatial_refine_only,
261
+ help="Only perform spatial super-resolution (no temporal doubling)",
262
+ )
263
+ parser.add_argument(
264
+ "--num-cond-frames",
265
+ type=int,
266
+ default=SamplingParam.num_cond_frames,
267
+ help="Number of conditioning frames for refinement",
268
+ )
269
+ parser.add_argument(
270
+ "--moba-config-path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to a JSON file containing V-MoBA specific configurations.",
274
+ )
275
+ parser.add_argument(
276
+ "--return-trajectory-latents",
277
+ action="store_true",
278
+ default=SamplingParam.return_trajectory_latents,
279
+ help="Whether to return the trajectory",
280
+ )
281
+ parser.add_argument(
282
+ "--return-trajectory-decoded",
283
+ action="store_true",
284
+ default=SamplingParam.return_trajectory_decoded,
285
+ help="Whether to return the decoded trajectory",
286
+ )
287
+ return parser
288
+
289
+
290
+ @dataclass
291
+ class CacheParams:
292
+ cache_type: str = "none"
backend_snapshot/fastvideo/configs/sample/wan.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass
3
+
4
+ from fastvideo.configs.sample.base import SamplingParam
5
+
6
+
7
+ @dataclass
8
+ class WanT2V_1_3B_SamplingParam(SamplingParam):
9
+ # Video parameters
10
+ height: int = 480
11
+ width: int = 832
12
+ num_frames: int = 81
13
+ fps: int = 16
14
+
15
+ # Denoising stage
16
+ guidance_scale: float = 3.0
17
+ negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
18
+ num_inference_steps: int = 50
19
+
20
+
21
+ @dataclass
22
+ class WanT2V_14B_SamplingParam(SamplingParam):
23
+ # Video parameters
24
+ height: int = 720
25
+ width: int = 1280
26
+ num_frames: int = 81
27
+ fps: int = 16
28
+
29
+ # Denoising stage
30
+ guidance_scale: float = 5.0
31
+ negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
32
+ num_inference_steps: int = 50
33
+
34
+
35
+ @dataclass
36
+ class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParam):
37
+ # Denoising stage
38
+ guidance_scale: float = 5.0
39
+ num_inference_steps: int = 40
40
+
41
+
42
+ @dataclass
43
+ class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParam):
44
+ # Denoising stage
45
+ guidance_scale: float = 5.0
46
+ num_inference_steps: int = 40
47
+
48
+
49
+ @dataclass
50
+ class FastWanT2V480P_SamplingParam(WanT2V_1_3B_SamplingParam):
51
+ # DMD parameters
52
+ # dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522])
53
+ num_inference_steps: int = 3
54
+ num_frames: int = 61
55
+ height: int = 448
56
+ width: int = 832
57
+ fps: int = 16
58
+
59
+
60
+ # =============================================
61
+ # ============= Wan2.1 Fun Models =============
62
+ # =============================================
63
+ @dataclass
64
+ class Wan2_1_Fun_1_3B_InP_SamplingParam(SamplingParam):
65
+ """Sampling parameters for Wan2.1 Fun 1.3B InP model."""
66
+ height: int = 480
67
+ width: int = 832
68
+ num_frames: int = 81
69
+ fps: int = 16
70
+ negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
71
+ guidance_scale: float = 6.0
72
+ num_inference_steps: int = 50
73
+
74
+
75
+ @dataclass
76
+ class Wan2_1_Fun_1_3B_Control_SamplingParam(SamplingParam):
77
+ fps: int = 16
78
+ num_frames: int = 49
79
+ height: int = 832
80
+ width: int = 480
81
+ guidance_scale: float = 6.0
82
+
83
+
84
+ # =============================================
85
+ # ============= Wan2.2 TI2V Models =============
86
+ # =============================================
87
+ @dataclass
88
+ class Wan2_2_Base_SamplingParam(SamplingParam):
89
+ """Sampling parameters for Wan2.2 TI2V 5B model."""
90
+ negative_prompt: str | None = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
91
+
92
+
93
+ @dataclass
94
+ class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParam):
95
+ """Sampling parameters for Wan2.2 TI2V 5B model."""
96
+ height: int = 704
97
+ width: int = 1280
98
+ num_frames: int = 121
99
+ fps: int = 24
100
+ guidance_scale: float = 5.0
101
+ num_inference_steps: int = 50
102
+
103
+
104
+ @dataclass
105
+ class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
106
+ guidance_scale: float = 4.0 # high_noise
107
+ guidance_scale_2: float = 3.0 # low_noise
108
+ num_inference_steps: int = 40
109
+ fps: int = 16
110
+ # NOTE(will): default boundary timestep is tracked by PipelineConfig, but
111
+ # can be overridden during sampling
112
+
113
+
114
+ @dataclass
115
+ class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParam):
116
+ guidance_scale: float = 3.5 # high_noise
117
+ guidance_scale_2: float = 3.5 # low_noise
118
+ num_inference_steps: int = 40
119
+ fps: int = 16
120
+ # NOTE(will): default boundary timestep is tracked by PipelineConfig, but
121
+ # can be overridden during sampling
122
+
123
+
124
+ @dataclass
125
+ class Wan2_2_Fun_A14B_Control_SamplingParam(Wan2_1_Fun_1_3B_Control_SamplingParam):
126
+ num_frames: int = 81
127
+
128
+
129
+ # =============================================
130
+ # ============= Causal Self-Forcing =============
131
+ # =============================================
132
+ @dataclass
133
+ class SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam(Wan2_1_Fun_1_3B_InP_SamplingParam):
134
+ pass
135
+
136
+
137
+ @dataclass
138
+ class SelfForcingWan2_2_T2V_A14B_480P_SamplingParam(Wan2_2_T2V_A14B_SamplingParam):
139
+ num_inference_steps: int = 8
140
+ num_frames: int = 81
141
+ height: int = 448
142
+ width: int = 832
143
+ fps: int = 16
144
+
145
+
146
+ @dataclass
147
+ class MatrixGame2_SamplingParam(SamplingParam):
148
+ height: int = 352
149
+ width: int = 640
150
+ num_frames: int = 57
151
+ fps: int = 25
152
+ guidance_scale: float = 1.0
153
+ num_inference_steps: int = 3
154
+ negative_prompt: str | None = None
backend_snapshot/fastvideo/configs/wan_1.3B_t2v_pipeline.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "embedded_cfg_scale": 6.0,
3
+ "flow_shift": 3,
4
+ "dit_cpu_offload": true,
5
+ "disable_autocast": false,
6
+ "precision": "bf16",
7
+ "vae_precision": "fp32",
8
+ "vae_tiling": false,
9
+ "vae_sp": false,
10
+ "vae_config": {
11
+ "load_encoder": false,
12
+ "load_decoder": true,
13
+ "tile_sample_min_height": 256,
14
+ "tile_sample_min_width": 256,
15
+ "tile_sample_min_num_frames": 16,
16
+ "tile_sample_stride_height": 192,
17
+ "tile_sample_stride_width": 192,
18
+ "tile_sample_stride_num_frames": 12,
19
+ "blend_num_frames": 8,
20
+ "use_tiling": false,
21
+ "use_temporal_tiling": false,
22
+ "use_parallel_tiling": false,
23
+ "use_feature_cache": true
24
+ },
25
+ "dit_config": {
26
+ "prefix": "Wan",
27
+ "quant_config": null
28
+ },
29
+ "text_encoder_precisions": [
30
+ "fp32"
31
+ ],
32
+ "text_encoder_configs": [
33
+ {
34
+ "prefix": "t5",
35
+ "quant_config": null,
36
+ "lora_config": null
37
+ }
38
+ ],
39
+ "enable_torch_compile": false
40
+ }
backend_snapshot/fastvideo/entrypoints/cli/generate.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py
3
+
4
+ import argparse
5
+ import dataclasses
6
+ import os
7
+ from typing import cast
8
+
9
+ from fastvideo import VideoGenerator
10
+ from fastvideo.configs.sample.base import SamplingParam
11
+ from fastvideo.entrypoints.cli.cli_types import CLISubcommand
12
+ from fastvideo.entrypoints.cli.utils import RaiseNotImplementedAction
13
+ from fastvideo.fastvideo_args import FastVideoArgs
14
+ from fastvideo.logger import init_logger
15
+ from fastvideo.utils import FlexibleArgumentParser
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class GenerateSubcommand(CLISubcommand):
21
+ """The `generate` subcommand for the FastVideo CLI"""
22
+
23
+ def __init__(self) -> None:
24
+ self.name = "generate"
25
+ super().__init__()
26
+ self.init_arg_names = self._get_init_arg_names()
27
+ self.generation_arg_names = self._get_generation_arg_names()
28
+
29
+ def _get_init_arg_names(self) -> list[str]:
30
+ """Get names of arguments for VideoGenerator initialization"""
31
+ return ["num_gpus", "tp_size", "sp_size", "model_path"]
32
+
33
+ def _get_generation_arg_names(self) -> list[str]:
34
+ """Get names of arguments for generate_video method"""
35
+ return [field.name for field in dataclasses.fields(SamplingParam)]
36
+
37
+ def cmd(self, args: argparse.Namespace) -> None:
38
+ excluded_args = ['subparser', 'config', 'dispatch_function']
39
+
40
+ provided_args = {}
41
+ for k, v in vars(args).items():
42
+ if (k not in excluded_args and v is not None and hasattr(args, '_provided') and k in args._provided):
43
+ provided_args[k] = v
44
+
45
+ if 'model_path' in vars(args) and args.model_path is not None:
46
+ provided_args['model_path'] = args.model_path
47
+
48
+ if 'prompt' in vars(args) and args.prompt is not None:
49
+ provided_args['prompt'] = args.prompt
50
+
51
+ merged_args = {**provided_args}
52
+
53
+ logger.info('CLI Args: %s', merged_args)
54
+
55
+ if 'model_path' not in merged_args or not merged_args['model_path']:
56
+ raise ValueError("model_path must be provided either in config file or via --model-path")
57
+
58
+ # Check if either prompt or prompt_txt is provided
59
+ has_prompt = 'prompt' in merged_args and merged_args['prompt']
60
+ has_prompt_txt = 'prompt_txt' in merged_args and merged_args['prompt_txt']
61
+
62
+ if not (has_prompt or has_prompt_txt):
63
+ raise ValueError("Either prompt or prompt_txt must be provided")
64
+
65
+ if has_prompt and has_prompt_txt:
66
+ raise ValueError("Cannot provide both 'prompt' and 'prompt_txt'. Use only one of them.")
67
+
68
+ init_args = {k: v for k, v in merged_args.items() if k not in self.generation_arg_names}
69
+ generation_args = {k: v for k, v in merged_args.items() if k in self.generation_arg_names}
70
+ generation_args.setdefault("return_frames", False)
71
+
72
+ model_path = init_args.pop('model_path')
73
+ prompt = generation_args.pop('prompt', None)
74
+
75
+ generator = VideoGenerator.from_pretrained(model_path=model_path, **init_args)
76
+
77
+ # Call generate_video - it handles both single and batch modes
78
+ generator.generate_video(prompt=prompt, **generation_args)
79
+
80
+ def validate(self, args: argparse.Namespace) -> None:
81
+ """Validate the arguments for this command"""
82
+ if args.num_gpus is not None and args.num_gpus <= 0:
83
+ raise ValueError("Number of gpus must be positive")
84
+
85
+ if args.config and not os.path.exists(args.config):
86
+ raise ValueError(f"Config file not found: {args.config}")
87
+
88
+ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
89
+ generate_parser = subparsers.add_parser(
90
+ "generate",
91
+ help="Run inference on a model",
92
+ usage="fastvideo generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]")
93
+
94
+ generate_parser.add_argument(
95
+ "--config",
96
+ type=str,
97
+ default='',
98
+ required=False,
99
+ help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional."
100
+ )
101
+
102
+ generate_parser = FastVideoArgs.add_cli_args(generate_parser)
103
+ generate_parser = SamplingParam.add_cli_args(generate_parser)
104
+
105
+ generate_parser.add_argument(
106
+ "--text-encoder-configs",
107
+ action=RaiseNotImplementedAction,
108
+ help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)",
109
+ )
110
+
111
+ return cast(FlexibleArgumentParser, generate_parser)
112
+
113
+
114
+ def cmd_init() -> list[CLISubcommand]:
115
+ return [GenerateSubcommand()]
backend_snapshot/fastvideo/entrypoints/video_generator.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ VideoGenerator module for FastVideo.
4
+
5
+ This module provides a consolidated interface for generating videos using
6
+ diffusion models.
7
+ """
8
+
9
+ import os
10
+ import re
11
+ import shutil
12
+ import threading
13
+ import time
14
+ import tempfile
15
+ import warnings
16
+ from collections.abc import Mapping
17
+ from copy import deepcopy
18
+ from typing import Any
19
+
20
+ import imageio
21
+ import numpy as np
22
+ import torch
23
+ import torchvision
24
+ from einops import rearrange
25
+
26
+ from fastvideo.api.compat import (
27
+ expand_request_prompt_batch,
28
+ generator_config_to_fastvideo_args,
29
+ legacy_from_pretrained_to_config,
30
+ load_generator_config_from_file,
31
+ normalize_generation_request,
32
+ normalize_generator_config,
33
+ request_to_pipeline_overrides,
34
+ request_to_sampling_param,
35
+ )
36
+ from fastvideo.api.results import GenerationResult
37
+ from fastvideo.api.schema import GenerationRequest, GeneratorConfig
38
+ from fastvideo.configs.sample import SamplingParam
39
+ from fastvideo.fastvideo_args import FastVideoArgs
40
+ from fastvideo.logger import init_logger
41
+ from fastvideo.pipelines import ForwardBatch
42
+ from fastvideo.utils import align_to, shallow_asdict
43
+ from fastvideo.worker.executor import Executor
44
+
45
+ logger = init_logger(__name__)
46
+
47
+ _FROM_PRETRAINED_CONVENIENCE_KWARGS = frozenset({
48
+ "num_gpus",
49
+ "revision",
50
+ "trust_remote_code",
51
+ "distributed_executor_backend",
52
+ "tp_size",
53
+ "sp_size",
54
+ "hsdp_replicate_dim",
55
+ "hsdp_shard_dim",
56
+ "dist_timeout",
57
+ "use_fsdp_inference",
58
+ "disable_autocast",
59
+ "enable_stage_verification",
60
+ "dit_cpu_offload",
61
+ "dit_layerwise_offload",
62
+ "text_encoder_cpu_offload",
63
+ "image_encoder_cpu_offload",
64
+ "vae_cpu_offload",
65
+ "pin_cpu_memory",
66
+ "enable_torch_compile",
67
+ "torch_compile_kwargs",
68
+ "transformer_quant",
69
+ })
70
+
71
+
72
+ def _infer_latent_batch_size(batch: ForwardBatch) -> int:
73
+ if isinstance(batch.prompt, list):
74
+ latent_batch_size = len(batch.prompt)
75
+ elif batch.prompt is not None:
76
+ latent_batch_size = 1
77
+ elif batch.prompt_embeds is not None and len(batch.prompt_embeds) > 0:
78
+ latent_batch_size = batch.prompt_embeds[0].shape[0]
79
+ else:
80
+ raise ValueError("Cannot infer batch size from batch; no prompt or prompt_embeds found")
81
+ latent_batch_size *= batch.num_videos_per_prompt
82
+ return latent_batch_size
83
+
84
+
85
+ class VideoGenerator:
86
+ """
87
+ A unified class for generating videos using diffusion models.
88
+
89
+ This class provides a simple interface for video generation with rich
90
+ customization options, similar to popular frameworks like HF Diffusers.
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ fastvideo_args: FastVideoArgs,
96
+ executor_class: type[Executor],
97
+ log_stats: bool,
98
+ *,
99
+ log_queue=None,
100
+ ):
101
+ """
102
+ Initialize the video generator.
103
+
104
+ Args:
105
+ fastvideo_args: The inference arguments
106
+ executor_class: The executor class to use for inference
107
+ log_stats: Whether to log statistics
108
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
109
+ """
110
+ self.config: GeneratorConfig | None = None
111
+ self.fastvideo_args = fastvideo_args
112
+ self.executor = executor_class(fastvideo_args, log_queue=log_queue)
113
+
114
+ @classmethod
115
+ def from_pretrained(
116
+ cls,
117
+ model_path: str | GeneratorConfig | Mapping[str, Any] | None = None,
118
+ **kwargs,
119
+ ) -> "VideoGenerator":
120
+ """
121
+ Create a video generator from a pretrained model.
122
+
123
+ Args:
124
+ model_path: Path or identifier for the pretrained model
125
+ pipeline_config: Pipeline config to use for inference
126
+ **kwargs: Additional arguments to customize model loading, set any FastVideoArgs or PipelineConfig attributes here.
127
+
128
+ Returns:
129
+ The created video generator
130
+
131
+ Priority level: Default pipeline config < User's pipeline config < User's kwargs
132
+
133
+ Stable convenience kwargs remain supported here for common engine and
134
+ offload settings. Advanced model- or pipeline-specific options should
135
+ move to VideoGenerator.from_config(...).
136
+ """
137
+ log_queue = kwargs.pop("log_queue", None)
138
+ typed_config = kwargs.pop("config", None)
139
+ if typed_config is not None:
140
+ if model_path is not None:
141
+ raise TypeError("Pass either model_path or config to from_pretrained, not both")
142
+ if kwargs:
143
+ unexpected = ", ".join(sorted(kwargs))
144
+ raise TypeError(f"Unexpected keyword arguments with config: {unexpected}")
145
+ return cls.from_config(typed_config, log_queue=log_queue)
146
+
147
+ if isinstance(model_path, GeneratorConfig | Mapping):
148
+ if kwargs:
149
+ unexpected = ", ".join(sorted(kwargs))
150
+ raise TypeError(f"Unexpected keyword arguments with typed config: {unexpected}")
151
+ return cls.from_config(model_path, log_queue=log_queue)
152
+
153
+ if model_path is None:
154
+ raise TypeError("model_path or config is required")
155
+
156
+ legacy_only_kwargs = sorted(set(kwargs) - _FROM_PRETRAINED_CONVENIENCE_KWARGS)
157
+ if legacy_only_kwargs:
158
+ warnings.warn(
159
+ "VideoGenerator.from_pretrained(...) received legacy-only kwargs "
160
+ f"({', '.join(legacy_only_kwargs)}); prefer VideoGenerator.from_config(...) "
161
+ "for advanced configuration.",
162
+ DeprecationWarning,
163
+ stacklevel=2,
164
+ )
165
+ return cls.from_config(
166
+ legacy_from_pretrained_to_config(model_path, kwargs),
167
+ log_queue=log_queue,
168
+ )
169
+
170
+ @classmethod
171
+ def from_config(
172
+ cls,
173
+ config: GeneratorConfig | Mapping[str, Any],
174
+ *,
175
+ log_queue=None,
176
+ ) -> "VideoGenerator":
177
+ normalized = normalize_generator_config(config)
178
+ fastvideo_args = generator_config_to_fastvideo_args(normalized)
179
+ generator = cls.from_fastvideo_args(fastvideo_args, log_queue=log_queue)
180
+ generator.config = normalized
181
+ return generator
182
+
183
+ @classmethod
184
+ def from_file(
185
+ cls,
186
+ path: str,
187
+ overrides: list[str] | Mapping[str, Any] | None = None,
188
+ *,
189
+ log_queue=None,
190
+ ) -> "VideoGenerator":
191
+ return cls.from_config(
192
+ load_generator_config_from_file(path, overrides=overrides),
193
+ log_queue=log_queue,
194
+ )
195
+
196
+ @classmethod
197
+ def from_fastvideo_args(
198
+ cls,
199
+ fastvideo_args: FastVideoArgs,
200
+ *,
201
+ log_queue=None,
202
+ ) -> "VideoGenerator":
203
+ """
204
+ Create a video generator with the specified arguments.
205
+
206
+ Args:
207
+ fastvideo_args: The inference arguments
208
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
209
+
210
+ Returns:
211
+ The created video generator
212
+ """
213
+ # Initialize distributed environment if needed
214
+ # initialize_distributed_and_parallelism(fastvideo_args)
215
+
216
+ executor_class = Executor.get_class(fastvideo_args)
217
+ return cls(
218
+ fastvideo_args=fastvideo_args,
219
+ executor_class=executor_class,
220
+ log_stats=False, # TODO: implement
221
+ log_queue=log_queue,
222
+ )
223
+
224
+ def generate(
225
+ self,
226
+ request: GenerationRequest | Mapping[str, Any],
227
+ *,
228
+ log_queue=None,
229
+ ) -> GenerationResult | list[GenerationResult]:
230
+ """
231
+ Generate video or image outputs from a typed inference request.
232
+
233
+ Args:
234
+ request: A `GenerationRequest` instance or a mapping that can be
235
+ parsed into one. This is the primary public inference
236
+ entrypoint for the typed API.
237
+ log_queue: Optional multiprocessing.Queue to forward worker logs to
238
+ during this request.
239
+
240
+ Returns:
241
+ A `GenerationResult` for single-request generation, or a list of
242
+ `GenerationResult` objects when the request expands into multiple
243
+ prompts.
244
+ """
245
+ normalized_request = normalize_generation_request(request)
246
+ if log_queue:
247
+ self.executor.set_log_queue(log_queue)
248
+
249
+ try:
250
+ return self._generate_request_impl(normalized_request)
251
+ finally:
252
+ if log_queue:
253
+ self.executor.clear_log_queue()
254
+
255
+ def generate_video(
256
+ self,
257
+ prompt: str | None = None,
258
+ sampling_param: SamplingParam | None = None,
259
+ # Action control inputs (Matrix-Game)
260
+ mouse_cond: torch.Tensor | None = None,
261
+ keyboard_cond: torch.Tensor | None = None,
262
+ grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
263
+ | None = None,
264
+ **kwargs,
265
+ ) -> dict[str, Any] | list[dict[str, Any]]:
266
+ """
267
+ Generate a video based on the given prompt.
268
+
269
+ Args:
270
+ prompt: The prompt to use for generation (optional if prompt_txt is provided)
271
+ negative_prompt: The negative prompt to use (overrides the one in fastvideo_args)
272
+ output_path: Path to save the video (overrides the one in fastvideo_args)
273
+ prompt_path: Path to prompt file
274
+ save_video: Whether to save the video to disk
275
+ return_frames: Whether to include raw frames in the result dict
276
+ num_inference_steps: Number of denoising steps (overrides fastvideo_args)
277
+ guidance_scale: Classifier-free guidance scale (overrides fastvideo_args)
278
+ num_frames: Number of frames to generate (overrides fastvideo_args)
279
+ height: Height of generated video (overrides fastvideo_args)
280
+ width: Width of generated video (overrides fastvideo_args)
281
+ fps: Frames per second for saved video (overrides fastvideo_args)
282
+ seed: Random seed for generation (overrides fastvideo_args)
283
+ callback: Callback function called after each step
284
+ callback_steps: Number of steps between each callback
285
+
286
+ Returns:
287
+ A metadata dictionary for single-prompt generation, or a list of
288
+ metadata dictionaries for prompt-file batch generation.
289
+ """
290
+ log_queue = kwargs.pop("log_queue", None)
291
+ warnings.warn(
292
+ "VideoGenerator.generate_video(...) is deprecated; use "
293
+ "VideoGenerator.generate(request=...) instead.",
294
+ DeprecationWarning,
295
+ stacklevel=2,
296
+ )
297
+ if log_queue:
298
+ self.executor.set_log_queue(log_queue)
299
+
300
+ try:
301
+ return self._generate_video_impl(
302
+ prompt=prompt,
303
+ sampling_param=sampling_param,
304
+ mouse_cond=mouse_cond,
305
+ keyboard_cond=keyboard_cond,
306
+ grid_sizes=grid_sizes,
307
+ **kwargs,
308
+ )
309
+ finally:
310
+ if log_queue:
311
+ self.executor.clear_log_queue()
312
+
313
+ def _generate_request_impl(
314
+ self,
315
+ request: GenerationRequest,
316
+ ) -> GenerationResult | list[GenerationResult]:
317
+ if isinstance(request.prompt, list):
318
+ if request.inputs.prompt_path is not None:
319
+ raise ValueError("request.prompt list cannot be combined with request.inputs.prompt_path")
320
+ results: list[GenerationResult] = []
321
+ for index, single_request in enumerate(expand_request_prompt_batch(request)):
322
+ prompt = single_request.prompt
323
+ wrapped = self._generate_single_request(single_request)
324
+ if isinstance(wrapped, list):
325
+ results.extend(wrapped)
326
+ continue
327
+ wrapped.prompt_index = index
328
+ if wrapped.prompt is None and isinstance(prompt, str):
329
+ wrapped.prompt = prompt
330
+ results.append(wrapped)
331
+ return results
332
+
333
+ return self._generate_single_request(request)
334
+
335
+ def _generate_single_request(
336
+ self,
337
+ request: GenerationRequest,
338
+ ) -> GenerationResult | list[GenerationResult]:
339
+ fastvideo_args = self.fastvideo_args
340
+ pipeline_overrides = request_to_pipeline_overrides(request)
341
+ if pipeline_overrides:
342
+ fastvideo_args = deepcopy(self.fastvideo_args)
343
+ for key, value in pipeline_overrides.items():
344
+ if not hasattr(fastvideo_args.pipeline_config, key):
345
+ raise ValueError(f"Request field {key!r} is not supported by pipeline config overrides")
346
+ setattr(fastvideo_args.pipeline_config, key, deepcopy(value))
347
+
348
+ sampling_param = request_to_sampling_param(
349
+ request,
350
+ model_path=self.fastvideo_args.model_path,
351
+ )
352
+ result = self._generate_video_impl(
353
+ prompt=request.prompt,
354
+ sampling_param=sampling_param,
355
+ fastvideo_args=fastvideo_args,
356
+ )
357
+ return self._wrap_legacy_result(result)
358
+
359
+ def _generate_video_impl(
360
+ self,
361
+ prompt: str | list[str] | None = None,
362
+ sampling_param: SamplingParam | None = None,
363
+ mouse_cond: torch.Tensor | None = None,
364
+ keyboard_cond: torch.Tensor | None = None,
365
+ grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
366
+ | None = None,
367
+ fastvideo_args: FastVideoArgs | None = None,
368
+ **kwargs,
369
+ ) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]]:
370
+ """Internal implementation of generate_video."""
371
+ if fastvideo_args is None:
372
+ fastvideo_args = self.fastvideo_args
373
+
374
+ # Handle batch processing from text file
375
+ if sampling_param is None:
376
+ sampling_param = SamplingParam.from_pretrained(fastvideo_args.model_path)
377
+
378
+ # Add action control inputs to kwargs if provided
379
+ if mouse_cond is not None:
380
+ kwargs['mouse_cond'] = mouse_cond
381
+ if keyboard_cond is not None:
382
+ kwargs['keyboard_cond'] = keyboard_cond
383
+ if grid_sizes is not None:
384
+ kwargs['grid_sizes'] = grid_sizes
385
+
386
+ sampling_param.update(kwargs)
387
+
388
+ if fastvideo_args.prompt_txt is not None or sampling_param.prompt_path is not None:
389
+ prompt_txt_path = sampling_param.prompt_path or fastvideo_args.prompt_txt
390
+ if not prompt_txt_path or not os.path.exists(prompt_txt_path):
391
+ raise FileNotFoundError(f"Prompt text file not found: {prompt_txt_path}")
392
+
393
+ # Read prompts from file
394
+ with open(prompt_txt_path, encoding='utf-8') as f:
395
+ prompts = [line.strip() for line in f if line.strip()]
396
+
397
+ if not prompts:
398
+ raise ValueError(f"No prompts found in file: {prompt_txt_path}")
399
+
400
+ logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path)
401
+
402
+ results = []
403
+ for i, batch_prompt in enumerate(prompts):
404
+ logger.info("Processing prompt %d/%d: %s...", i + 1, len(prompts), batch_prompt[:100])
405
+ try:
406
+ # Generate video for this prompt using the same logic below
407
+ output_path = self._prepare_output_path(sampling_param.output_path, batch_prompt)
408
+ kwargs["output_path"] = output_path
409
+ result = self._generate_single_video(
410
+ prompt=batch_prompt,
411
+ sampling_param=sampling_param,
412
+ fastvideo_args=fastvideo_args,
413
+ **kwargs,
414
+ )
415
+
416
+ # Add prompt info to result
417
+ result["prompt_index"] = i
418
+ result["prompt"] = batch_prompt
419
+
420
+ results.append(result)
421
+ logger.info("Successfully generated video for prompt %d", i + 1)
422
+
423
+ except Exception as e:
424
+ logger.error("Failed to generate video for prompt %d: %s", i + 1, e)
425
+ continue
426
+
427
+ logger.info("Completed batch processing. Generated %d videos successfully.", len(results))
428
+ return results
429
+
430
+ # Single prompt generation (original behavior)
431
+ if prompt is None:
432
+ raise ValueError("Either prompt or prompt_txt must be provided")
433
+ if not isinstance(prompt, str):
434
+ raise ValueError("Single-prompt generation expects a string prompt")
435
+ output_path = self._prepare_output_path(sampling_param.output_path, prompt)
436
+ kwargs["output_path"] = output_path
437
+ return self._generate_single_video(
438
+ prompt=prompt,
439
+ sampling_param=sampling_param,
440
+ fastvideo_args=fastvideo_args,
441
+ **kwargs,
442
+ )
443
+
444
+ def _is_image_workload(self) -> bool:
445
+ """Return True when the workload produces a single image (t2i, i2i …)."""
446
+ args = getattr(self, "fastvideo_args", None)
447
+ if args is None:
448
+ return False
449
+ return args.workload_type.value.endswith("2i")
450
+
451
+ def _prepare_output_path(
452
+ self,
453
+ output_path: str,
454
+ prompt: str,
455
+ ) -> str:
456
+ """Build a unique, sanitized output file path.
457
+
458
+ The file extension is chosen automatically based on the workload type:
459
+ ``.png`` for image workloads (``t2i``, ``i2i``, …) and ``.mp4`` for
460
+ video workloads.
461
+
462
+ - If ``output_path`` already carries the correct extension, treat it
463
+ as a file path.
464
+ - Otherwise, treat ``output_path`` as a directory and derive the
465
+ filename from the prompt.
466
+ - Invalid filename characters are removed; if the name changes, a
467
+ warning is logged.
468
+ - If the target path already exists, a numeric suffix is appended.
469
+ """
470
+ target_ext = ".png" if self._is_image_workload() else ".mp4"
471
+
472
+ def _sanitize_filename_component(name: str) -> str:
473
+ # Remove characters invalid on common filesystems, strip spaces/dots
474
+ sanitized = re.sub(r'[\\/:*?"<>|]', '', name)
475
+ sanitized = sanitized.strip().strip('.')
476
+ sanitized = re.sub(r'\s+', ' ', sanitized)
477
+ return sanitized or "output"
478
+
479
+ base_path, extension = os.path.splitext(output_path)
480
+ extension_lower = extension.lower()
481
+
482
+ if extension_lower == target_ext:
483
+ output_dir = os.path.dirname(output_path)
484
+ base_name = os.path.basename(base_path) # filename without extension
485
+ sanitized_base = _sanitize_filename_component(base_name)
486
+ if sanitized_base != base_name:
487
+ logger.warning(
488
+ "The output name '%s' contained invalid characters. "
489
+ "It has been renamed to '%s%s'",
490
+ os.path.basename(output_path),
491
+ sanitized_base,
492
+ target_ext,
493
+ )
494
+ out_name = f"{sanitized_base}{target_ext}"
495
+ else:
496
+ # Treat as directory; inform if an unexpected extension was
497
+ # provided.
498
+ if extension:
499
+ logger.info(
500
+ "Output path '%s' has extension '%s' which does not "
501
+ "match the target '%s'; treating it as a directory",
502
+ output_path,
503
+ extension,
504
+ target_ext,
505
+ )
506
+ output_dir = output_path
507
+ prompt_component = _sanitize_filename_component(prompt[:100])
508
+ out_name = f"{prompt_component}{target_ext}"
509
+
510
+ if output_dir:
511
+ os.makedirs(output_dir, exist_ok=True)
512
+
513
+ new_output_path = os.path.join(output_dir, out_name)
514
+ counter = 1
515
+ while os.path.exists(new_output_path):
516
+ name_part, ext_part = os.path.splitext(out_name)
517
+ new_name = f"{name_part}_{counter}{ext_part}"
518
+ new_output_path = os.path.join(output_dir, new_name)
519
+ counter += 1
520
+ return new_output_path
521
+
522
+ def _generate_single_video(
523
+ self,
524
+ prompt: str,
525
+ sampling_param: SamplingParam | None = None,
526
+ fastvideo_args: FastVideoArgs | None = None,
527
+ **kwargs,
528
+ ) -> dict[str, Any]:
529
+ """Internal method for single video generation"""
530
+ if fastvideo_args is None:
531
+ fastvideo_args = self.fastvideo_args
532
+
533
+ # Validate inputs
534
+ if not isinstance(prompt, str):
535
+ raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
536
+ prompt = prompt.strip()
537
+ sampling_param = deepcopy(sampling_param)
538
+ output_path = kwargs["output_path"]
539
+ sampling_param.prompt = prompt
540
+ # Process negative prompt
541
+ if sampling_param.negative_prompt is not None:
542
+ sampling_param.negative_prompt = sampling_param.negative_prompt.strip()
543
+
544
+ # Validate dimensions
545
+ if (sampling_param.height <= 0 or sampling_param.width <= 0 or sampling_param.num_frames <= 0):
546
+ raise ValueError(f"Height, width, and num_frames must be positive integers, got "
547
+ f"height={sampling_param.height}, width={sampling_param.width}, "
548
+ f"num_frames={sampling_param.num_frames}")
549
+
550
+ # Calculate sizes
551
+ target_height = align_to(sampling_param.height, 16)
552
+ target_width = align_to(sampling_param.width, 16)
553
+
554
+ # Calculate latent sizes
555
+ latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
556
+ n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
557
+
558
+ # Log parameters
559
+ debug_str = f"""
560
+ height: {target_height}
561
+ width: {target_width}
562
+ video_length: {sampling_param.num_frames}
563
+ prompt: {sampling_param.prompt}
564
+ image_path: {sampling_param.image_path}
565
+ neg_prompt: {sampling_param.negative_prompt}
566
+ seed: {sampling_param.seed}
567
+ infer_steps: {sampling_param.num_inference_steps}
568
+ num_videos_per_prompt: {sampling_param.num_videos_per_prompt}
569
+ guidance_scale: {sampling_param.guidance_scale}
570
+ n_tokens: {n_tokens}
571
+ flow_shift: {fastvideo_args.pipeline_config.flow_shift}
572
+ embedded_guidance_scale: {fastvideo_args.pipeline_config.embedded_cfg_scale}
573
+ save_video: {sampling_param.save_video}
574
+ output_path: {output_path}
575
+ """ # type: ignore[attr-defined]
576
+ logger.info(debug_str)
577
+
578
+ # Prepare batch
579
+ batch = ForwardBatch(
580
+ **shallow_asdict(sampling_param),
581
+ eta=0.0,
582
+ n_tokens=n_tokens,
583
+ VSA_sparsity=fastvideo_args.VSA_sparsity,
584
+ )
585
+
586
+ # Run inference
587
+ start_time = time.perf_counter()
588
+
589
+ # Execute forward pass in a new thread for non-blocking tensor
590
+ # allocation. Capture thread exceptions so we can surface the true
591
+ # failure in the main thread instead of later hitting None outputs.
592
+ result_container = {"output_batch": ForwardBatch(data_type=batch.data_type)}
593
+ thread_error: dict[str, BaseException | None] = {"error": None}
594
+ thread_error_traceback: dict[str, str] = {"traceback": ""}
595
+
596
+ def execute_forward_thread():
597
+ import traceback
598
+ try:
599
+ result_container["output_batch"] = self.executor.execute_forward(batch, fastvideo_args)
600
+ except BaseException as error: # noqa: BLE001
601
+ thread_error["error"] = error
602
+ thread_error_traceback["traceback"] = traceback.format_exc()
603
+
604
+ thread = threading.Thread(target=execute_forward_thread)
605
+ thread.start()
606
+ latent_batch_size = _infer_latent_batch_size(batch)
607
+ samples = torch.empty(
608
+ (latent_batch_size, 3, sampling_param.num_frames, sampling_param.height, sampling_param.width),
609
+ device='cpu',
610
+ pin_memory=fastvideo_args.pin_cpu_memory)
611
+ thread.join()
612
+
613
+ if thread_error["error"] is not None:
614
+ raise RuntimeError("Forward execution thread failed.\n"
615
+ f"{thread_error_traceback['traceback']}") from thread_error["error"]
616
+
617
+ output_batch = result_container["output_batch"]
618
+ if output_batch.output is None:
619
+ raise RuntimeError("Forward execution returned no output tensor. "
620
+ "This usually means the executor/pipeline failed earlier.")
621
+
622
+ if output_batch.output.shape == samples.shape:
623
+ samples.copy_(output_batch.output)
624
+ else:
625
+ logger.warning("Output shape %s does not match expected shape %s; use slow path", output_batch.output.shape,
626
+ samples.shape)
627
+ samples = output_batch.output.cpu()
628
+ logging_info = output_batch.logging_info
629
+
630
+ gen_time = time.perf_counter() - start_time
631
+ logger.info("Generated successfully in %.2f seconds", gen_time)
632
+
633
+ # Process outputs
634
+ videos = rearrange(samples, "b c t h w -> t b c h w")
635
+ frames = []
636
+ for x in videos:
637
+ x = torchvision.utils.make_grid(x, nrow=6)
638
+ x = x.permute(1, 2, 0).squeeze(-1)
639
+ x = (x * 255).to(torch.uint8)
640
+ frames.append(x.cpu().numpy())
641
+
642
+ # Save output if requested
643
+ if batch.save_video:
644
+ if self._is_image_workload():
645
+ # Image workloads (t2i, i2i, …): save the first frame as PNG.
646
+ imageio.imwrite(output_path, frames[0])
647
+ logger.info("Saved image to %s", output_path)
648
+ else:
649
+ imageio.mimsave(output_path, frames, fps=batch.fps, format="mp4")
650
+ logger.info("Saved video to %s", output_path)
651
+ audio = output_batch.extra.get("audio")
652
+ audio_sample_rate = output_batch.extra.get("audio_sample_rate")
653
+ if (audio is not None and audio_sample_rate is not None
654
+ and not self._mux_audio(output_path, audio, audio_sample_rate)):
655
+ logger.warning("Audio mux failed; saved video without audio.")
656
+
657
+ result: dict[str, Any] = {
658
+ "prompts": prompt,
659
+ "samples": samples if batch.return_frames else None,
660
+ "frames": frames if batch.return_frames else None,
661
+ "audio": output_batch.extra.get("audio") if batch.return_frames else None,
662
+ "size": (target_height, target_width, batch.num_frames),
663
+ "generation_time": gen_time,
664
+ "logging_info": logging_info,
665
+ "trajectory": output_batch.trajectory_latents,
666
+ "trajectory_timesteps": output_batch.trajectory_timesteps,
667
+ "trajectory_decoded": output_batch.trajectory_decoded,
668
+ "video_path": output_path if batch.save_video else None,
669
+ "peak_memory_mb": output_batch.extra.get("peak_memory_mb"),
670
+ }
671
+
672
+ return result
673
+
674
+ @staticmethod
675
+ def _wrap_legacy_result(
676
+ result: dict[str, Any] | list[dict[str, Any]], ) -> GenerationResult | list[GenerationResult]:
677
+ if isinstance(result, list):
678
+ return [GenerationResult.from_legacy_result(item) for item in result]
679
+ return GenerationResult.from_legacy_result(result)
680
+
681
+ @staticmethod
682
+ def _unwrap_typed_result(
683
+ result: GenerationResult | list[GenerationResult], ) -> dict[str, Any] | list[dict[str, Any]]:
684
+ if isinstance(result, list):
685
+ return [item.to_legacy_dict() for item in result]
686
+ return result.to_legacy_dict()
687
+
688
+ @staticmethod
689
+ def _mux_audio(
690
+ video_path: str,
691
+ audio: torch.Tensor | np.ndarray,
692
+ sample_rate: int,
693
+ ) -> bool:
694
+ """Mux audio into video using PyAV."""
695
+ try:
696
+ import av
697
+ except ImportError:
698
+ logger.warning("PyAV not installed; cannot mux audio. "
699
+ "Install with: pip install av")
700
+ return False
701
+
702
+ if torch.is_tensor(audio):
703
+ audio_np = audio.detach().cpu().float().numpy()
704
+ else:
705
+ audio_np = np.asarray(audio, dtype=np.float32)
706
+
707
+ if audio_np.ndim == 1:
708
+ audio_np = audio_np[:, None]
709
+ elif audio_np.ndim == 2:
710
+ if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
711
+ audio_np = audio_np.T
712
+ else:
713
+ logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
714
+ return False
715
+
716
+ audio_np = np.clip(audio_np, -1.0, 1.0)
717
+ audio_int16 = (audio_np * 32767.0).astype(np.int16)
718
+ num_channels = audio_int16.shape[1]
719
+ layout = "stereo" if num_channels == 2 else "mono"
720
+
721
+ try:
722
+ import wave
723
+ with tempfile.TemporaryDirectory() as tmpdir:
724
+ out_path = os.path.join(tmpdir, "muxed.mp4")
725
+ wav_path = os.path.join(tmpdir, "audio.wav")
726
+
727
+ # Write audio to WAV file
728
+ with wave.open(wav_path, "wb") as wav_file:
729
+ wav_file.setnchannels(num_channels)
730
+ wav_file.setsampwidth(2)
731
+ wav_file.setframerate(sample_rate)
732
+ wav_file.writeframes(audio_int16.tobytes())
733
+
734
+ # Open input video and audio
735
+ input_video = av.open(video_path)
736
+ input_audio = av.open(wav_path)
737
+
738
+ # Create output with both streams
739
+ output = av.open(out_path, mode="w")
740
+
741
+ # Add video stream (copy codec from input)
742
+ in_video_stream = input_video.streams.video[0]
743
+ out_video_stream = output.add_stream(
744
+ codec_name=in_video_stream.codec_context.name,
745
+ rate=in_video_stream.average_rate,
746
+ )
747
+ out_video_stream.width = in_video_stream.width
748
+ out_video_stream.height = in_video_stream.height
749
+ out_video_stream.pix_fmt = in_video_stream.pix_fmt
750
+
751
+ # Add audio stream (AAC)
752
+ out_audio_stream = output.add_stream("aac", rate=sample_rate)
753
+ out_audio_stream.layout = layout
754
+
755
+ # Remux video (decode and re-encode to be safe)
756
+ for frame in input_video.decode(video=0):
757
+ for packet in out_video_stream.encode(frame):
758
+ output.mux(packet)
759
+ for packet in out_video_stream.encode():
760
+ output.mux(packet)
761
+
762
+ # Encode audio
763
+ for frame in input_audio.decode(audio=0):
764
+ frame.pts = None # Let encoder assign PTS
765
+ for packet in out_audio_stream.encode(frame):
766
+ output.mux(packet)
767
+ for packet in out_audio_stream.encode():
768
+ output.mux(packet)
769
+
770
+ input_video.close()
771
+ input_audio.close()
772
+ output.close()
773
+ shutil.move(out_path, video_path)
774
+ return True
775
+ except Exception as e:
776
+ logger.warning("Audio mux failed: %s", e)
777
+ return False
778
+
779
+ def set_lora_adapter(self, lora_nickname: str, lora_path: str | None = None) -> None:
780
+ self.executor.set_lora_adapter(lora_nickname, lora_path)
781
+
782
+ def unmerge_lora_weights(self) -> None:
783
+ """
784
+ Use unmerged weights for inference to produce videos that align with
785
+ validation videos generated during training.
786
+ """
787
+ self.executor.unmerge_lora_weights()
788
+
789
+ def merge_lora_weights(self) -> None:
790
+ self.executor.merge_lora_weights()
791
+
792
+ def shutdown(self) -> None:
793
+ """
794
+ Shutdown the video generator.
795
+ """
796
+ self.executor.shutdown()
797
+ del self.executor
backend_snapshot/fastvideo/fastvideo_args.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
3
+ """The arguments of FastVideo Inference."""
4
+ import argparse
5
+ import dataclasses
6
+ import json
7
+ from contextlib import contextmanager
8
+ from dataclasses import field
9
+ from enum import Enum
10
+ from typing import Any, TYPE_CHECKING
11
+
12
+ from fastvideo.configs.configs import PreprocessConfig
13
+ from fastvideo.configs.pipelines.base import PipelineConfig
14
+ from fastvideo.configs.utils import clean_cli_args
15
+ from fastvideo.layers.quantization import QUANTIZATION_METHODS, QuantizationMethods
16
+ from fastvideo.logger import init_logger
17
+ from fastvideo.utils import FlexibleArgumentParser, StoreBoolean
18
+
19
+ if TYPE_CHECKING:
20
+ from ray.runtime_env import RuntimeEnv
21
+ from ray.util.placement_group import PlacementGroup
22
+ else:
23
+ RuntimeEnv = Any
24
+ PlacementGroup = Any
25
+
26
+ logger = init_logger(__name__)
27
+
28
+
29
+ class ExecutionMode(str, Enum):
30
+ """
31
+ Enumeration for different pipeline modes.
32
+
33
+ Inherits from str to allow string comparison for backward compatibility.
34
+ """
35
+ INFERENCE = "inference"
36
+ PREPROCESS = "preprocess"
37
+ FINETUNING = "finetuning"
38
+ DISTILLATION = "distillation"
39
+
40
+ @classmethod
41
+ def from_string(cls, value: str) -> "ExecutionMode":
42
+ """Convert string to ExecutionMode enum."""
43
+ try:
44
+ return cls(value.lower())
45
+ except ValueError:
46
+ raise ValueError(f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
47
+
48
+ @classmethod
49
+ def choices(cls) -> list[str]:
50
+ """Get all available choices as strings for argparse."""
51
+ return [mode.value for mode in cls]
52
+
53
+
54
+ class WorkloadType(str, Enum):
55
+ """
56
+ Enumeration for different workload types.
57
+
58
+ Inherits from str to allow string comparison for backward compatibility.
59
+ """
60
+ I2V = "i2v" # Image to Video
61
+ T2V = "t2v" # Text to Video
62
+ T2I = "t2i" # Text to Image
63
+ I2I = "i2i" # Image to Image
64
+
65
+ @classmethod
66
+ def from_string(cls, value: str) -> "WorkloadType":
67
+ """Convert string to WorkloadType enum."""
68
+ try:
69
+ return cls(value.lower())
70
+ except ValueError:
71
+ raise ValueError(
72
+ f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}") from None
73
+
74
+ @classmethod
75
+ def choices(cls) -> list[str]:
76
+ """Get all available choices as strings for argparse."""
77
+ return [workload.value for workload in cls]
78
+
79
+
80
+ # args for fastvideo framework
81
+ @dataclasses.dataclass
82
+ class FastVideoArgs:
83
+ # Model and path configuration (for convenience)
84
+ model_path: str
85
+
86
+ # Running mode
87
+ mode: ExecutionMode = ExecutionMode.INFERENCE
88
+
89
+ # Workload type
90
+ workload_type: WorkloadType = WorkloadType.T2V
91
+
92
+ # Distributed executor backend
93
+ distributed_executor_backend: str = "mp"
94
+
95
+ # a few attributes for ray related
96
+ ray_placement_group: PlacementGroup | None = None
97
+ ray_runtime_env: RuntimeEnv | None = None
98
+
99
+ inference_mode: bool = True # if False == training mode
100
+
101
+ # HuggingFace specific parameters
102
+ trust_remote_code: bool = False
103
+ revision: str | None = None
104
+
105
+ # Parallelism
106
+ num_gpus: int = 1
107
+ tp_size: int = -1
108
+ sp_size: int = -1
109
+ hsdp_replicate_dim: int = 1
110
+ hsdp_shard_dim: int = -1
111
+ dist_timeout: int | None = None # timeout for torch.distributed
112
+
113
+ pipeline_config: PipelineConfig = field(default_factory=PipelineConfig)
114
+ preprocess_config: PreprocessConfig | None = None
115
+
116
+ # LoRA parameters
117
+ # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated.
118
+ lora_path: str | None = None
119
+ lora_nickname: str = "default" # for swapping adapters in the pipeline
120
+ # can restrict layers to adapt, e.g. ["q_proj"]
121
+ # Will adapt only q, k, v, o by default.
122
+ lora_target_modules: list[str] | None = None
123
+
124
+ output_type: str = "pil"
125
+
126
+ # CPU offload parameters
127
+ dit_cpu_offload: bool = True
128
+ use_fsdp_inference: bool = False
129
+ dit_layerwise_offload: bool = True
130
+ text_encoder_cpu_offload: bool = True
131
+ image_encoder_cpu_offload: bool = True
132
+ vae_cpu_offload: bool = True
133
+ pin_cpu_memory: bool = True
134
+
135
+ # Compilation
136
+ enable_torch_compile: bool = False
137
+ torch_compile_kwargs: dict[str, Any] = field(default_factory=dict)
138
+
139
+ disable_autocast: bool = False
140
+
141
+ # VSA parameters
142
+ VSA_sparsity: float = 0.0 # inference/validation sparsity
143
+
144
+ # V-MoBA parameters
145
+ moba_config_path: str | None = None
146
+ moba_config: dict[str, Any] = field(default_factory=dict)
147
+
148
+ # Master port for distributed training/inference
149
+ master_port: int | None = None
150
+
151
+ # Stage verification
152
+ enable_stage_verification: bool = True
153
+
154
+ # Prompt text file for batch processing
155
+ prompt_txt: str | None = None
156
+
157
+ # LTX-2 VAE tiling overrides
158
+ ltx2_vae_tiling: bool | None = None
159
+ ltx2_vae_spatial_tile_size_in_pixels: int | None = None
160
+ ltx2_vae_spatial_tile_overlap_in_pixels: int | None = None
161
+ ltx2_vae_temporal_tile_size_in_frames: int | None = None
162
+ ltx2_vae_temporal_tile_overlap_in_frames: int | None = None
163
+ ltx2_initial_latent_path: str | None = None
164
+
165
+ # model paths for correct deallocation
166
+ model_paths: dict[str, str] = field(default_factory=dict)
167
+ model_loaded: dict[str, bool] = field(default_factory=lambda: {
168
+ "transformer": True,
169
+ "vae": True,
170
+ "upsampler": True,
171
+ })
172
+
173
+ override_text_encoder_safetensors: str | None = None # path to safetensors file for text encoder override
174
+ override_text_encoder_quant: QuantizationMethods = None
175
+ transformer_quant: QuantizationMethods = None
176
+
177
+ override_transformer_cls_name: str | None = None
178
+ init_weights_from_safetensors: str = "" # path to safetensors file for initial weight loading
179
+ init_weights_from_safetensors_2: str = "" # path to safetensors file for initial weight loading for transformer_2
180
+
181
+ override_pipeline_cls_name: str | None = None
182
+
183
+ # # DMD parameters
184
+ # dmd_denoising_steps: List[int] | None = field(default=None)
185
+
186
+ # MoE parameters used by Wan2.2
187
+ boundary_ratio: float = 0.875
188
+
189
+ @property
190
+ def training_mode(self) -> bool:
191
+ return not self.inference_mode
192
+
193
+ def __post_init__(self):
194
+ if self.moba_config_path:
195
+ try:
196
+ with open(self.moba_config_path) as f:
197
+ self.moba_config = json.load(f)
198
+ logger.info("Loaded V-MoBA config from %s", self.moba_config_path)
199
+ except (FileNotFoundError, json.JSONDecodeError) as e:
200
+ logger.error("Failed to load V-MoBA config from %s: %s", self.moba_config_path, e)
201
+ raise
202
+ self._apply_ltx2_vae_overrides()
203
+ self.check_fastvideo_args()
204
+
205
+ def __getattr__(self, name: str) -> Any:
206
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
207
+
208
+ def _apply_ltx2_vae_overrides(self) -> None:
209
+ if self.pipeline_config is None:
210
+ return
211
+ vae_config = self.pipeline_config.vae_config
212
+ has_any = any(value is not None for value in (
213
+ self.ltx2_vae_spatial_tile_size_in_pixels,
214
+ self.ltx2_vae_spatial_tile_overlap_in_pixels,
215
+ self.ltx2_vae_temporal_tile_size_in_frames,
216
+ self.ltx2_vae_temporal_tile_overlap_in_frames,
217
+ ))
218
+ if self.ltx2_vae_tiling is not None and hasattr(self.pipeline_config, "vae_tiling"):
219
+ self.pipeline_config.vae_tiling = self.ltx2_vae_tiling
220
+ elif has_any and hasattr(self.pipeline_config, "vae_tiling"):
221
+ self.pipeline_config.vae_tiling = True
222
+
223
+ if hasattr(vae_config,
224
+ "ltx2_spatial_tile_size_in_pixels") and self.ltx2_vae_spatial_tile_size_in_pixels is not None:
225
+ vae_config.ltx2_spatial_tile_size_in_pixels = (self.ltx2_vae_spatial_tile_size_in_pixels)
226
+ if hasattr(vae_config,
227
+ "ltx2_spatial_tile_overlap_in_pixels") and self.ltx2_vae_spatial_tile_overlap_in_pixels is not None:
228
+ vae_config.ltx2_spatial_tile_overlap_in_pixels = (self.ltx2_vae_spatial_tile_overlap_in_pixels)
229
+ if hasattr(vae_config,
230
+ "ltx2_temporal_tile_size_in_frames") and self.ltx2_vae_temporal_tile_size_in_frames is not None:
231
+ vae_config.ltx2_temporal_tile_size_in_frames = (self.ltx2_vae_temporal_tile_size_in_frames)
232
+ if hasattr(
233
+ vae_config,
234
+ "ltx2_temporal_tile_overlap_in_frames") and self.ltx2_vae_temporal_tile_overlap_in_frames is not None:
235
+ vae_config.ltx2_temporal_tile_overlap_in_frames = (self.ltx2_vae_temporal_tile_overlap_in_frames)
236
+
237
+ @staticmethod
238
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
239
+ # Model and path configuration
240
+ parser.add_argument(
241
+ "--model-path",
242
+ type=str,
243
+ help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
244
+ )
245
+
246
+ # Running mode
247
+ parser.add_argument(
248
+ "--mode",
249
+ type=str,
250
+ choices=ExecutionMode.choices(),
251
+ default=FastVideoArgs.mode.value,
252
+ help="The mode to run FastVideo",
253
+ )
254
+
255
+ # Workload type
256
+ parser.add_argument(
257
+ "--workload-type",
258
+ type=str,
259
+ choices=WorkloadType.choices(),
260
+ default=FastVideoArgs.workload_type.value,
261
+ help="The workload type",
262
+ )
263
+
264
+ # distributed_executor_backend
265
+ parser.add_argument(
266
+ "--distributed-executor-backend",
267
+ type=str,
268
+ choices=["mp"],
269
+ default=FastVideoArgs.distributed_executor_backend,
270
+ help="The distributed executor backend to use",
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--inference-mode",
275
+ action=StoreBoolean,
276
+ default=FastVideoArgs.inference_mode,
277
+ help="Whether to use inference mode",
278
+ )
279
+
280
+ # HuggingFace specific parameters
281
+ parser.add_argument(
282
+ "--trust-remote-code",
283
+ action=StoreBoolean,
284
+ default=FastVideoArgs.trust_remote_code,
285
+ help="Trust remote code when loading HuggingFace models",
286
+ )
287
+ parser.add_argument(
288
+ "--revision",
289
+ type=str,
290
+ default=FastVideoArgs.revision,
291
+ help="The specific model version to use (can be a branch name, tag name, or commit id)",
292
+ )
293
+
294
+ # Parallelism
295
+ parser.add_argument(
296
+ "--num-gpus",
297
+ type=int,
298
+ default=FastVideoArgs.num_gpus,
299
+ help="The number of GPUs to use.",
300
+ )
301
+ parser.add_argument(
302
+ "--tp-size",
303
+ type=int,
304
+ default=FastVideoArgs.tp_size,
305
+ help="The tensor parallelism size.",
306
+ )
307
+ parser.add_argument(
308
+ "--sp-size",
309
+ type=int,
310
+ default=FastVideoArgs.sp_size,
311
+ help="The sequence parallelism size.",
312
+ )
313
+ parser.add_argument(
314
+ "--hsdp-replicate-dim",
315
+ type=int,
316
+ default=FastVideoArgs.hsdp_replicate_dim,
317
+ help="The data parallelism size.",
318
+ )
319
+ parser.add_argument(
320
+ "--hsdp-shard-dim",
321
+ type=int,
322
+ default=FastVideoArgs.hsdp_shard_dim,
323
+ help="The data parallelism shards.",
324
+ )
325
+ parser.add_argument(
326
+ "--dist-timeout",
327
+ type=int,
328
+ default=FastVideoArgs.dist_timeout,
329
+ help="Set timeout for torch.distributed initialization.",
330
+ )
331
+
332
+ # Output type
333
+ parser.add_argument(
334
+ "--output-type",
335
+ type=str,
336
+ default=FastVideoArgs.output_type,
337
+ choices=["pil"],
338
+ help="Output type for the generated video",
339
+ )
340
+
341
+ # Prompt text file for batch processing
342
+ parser.add_argument(
343
+ "--prompt-txt",
344
+ type=str,
345
+ default=FastVideoArgs.prompt_txt,
346
+ help="Path to a text file containing prompts (one per line) for batch processing",
347
+ )
348
+
349
+ # LTX-2 VAE tiling overrides
350
+ parser.add_argument(
351
+ "--ltx2-vae-tiling",
352
+ action=StoreBoolean,
353
+ default=FastVideoArgs.ltx2_vae_tiling,
354
+ help="Enable LTX-2 VAE tiling overrides.",
355
+ )
356
+ parser.add_argument(
357
+ "--ltx2-vae-spatial-tile-size-in-pixels",
358
+ type=int,
359
+ default=FastVideoArgs.ltx2_vae_spatial_tile_size_in_pixels,
360
+ help="LTX-2 VAE spatial tile size in pixels.",
361
+ )
362
+ parser.add_argument(
363
+ "--ltx2-vae-spatial-tile-overlap-in-pixels",
364
+ type=int,
365
+ default=FastVideoArgs.ltx2_vae_spatial_tile_overlap_in_pixels,
366
+ help="LTX-2 VAE spatial tile overlap in pixels.",
367
+ )
368
+ parser.add_argument(
369
+ "--ltx2-vae-temporal-tile-size-in-frames",
370
+ type=int,
371
+ default=FastVideoArgs.ltx2_vae_temporal_tile_size_in_frames,
372
+ help="LTX-2 VAE temporal tile size in frames.",
373
+ )
374
+ parser.add_argument(
375
+ "--ltx2-vae-temporal-tile-overlap-in-frames",
376
+ type=int,
377
+ default=FastVideoArgs.ltx2_vae_temporal_tile_overlap_in_frames,
378
+ help="LTX-2 VAE temporal tile overlap in frames.",
379
+ )
380
+ parser.add_argument(
381
+ "--ltx2-initial-latent-path",
382
+ type=str,
383
+ default=FastVideoArgs.ltx2_initial_latent_path,
384
+ help="Path to load/save a precomputed LTX-2 initial latent.",
385
+ )
386
+
387
+ # LoRA parameters (inference-time adapter loading)
388
+ parser.add_argument(
389
+ "--lora-path",
390
+ type=str,
391
+ default=FastVideoArgs.lora_path,
392
+ help="Path to a LoRA adapter (directory or HF repo id). If set, LoRA will be applied at inference.",
393
+ )
394
+ parser.add_argument(
395
+ "--lora-nickname",
396
+ type=str,
397
+ default=FastVideoArgs.lora_nickname,
398
+ help="Nickname to refer to the loaded LoRA adapter (useful for swapping).",
399
+ )
400
+ parser.add_argument(
401
+ "--lora-target-modules",
402
+ nargs="+",
403
+ type=str,
404
+ default=FastVideoArgs.lora_target_modules,
405
+ help="Optional list of module name substrings to restrict LoRA injection (e.g. q_proj k_proj v_proj).",
406
+ )
407
+
408
+ # BSA runtime control (LongCat)
409
+ parser.add_argument(
410
+ "--enable-bsa",
411
+ action=StoreBoolean,
412
+ help="Enable Block Sparse Attention (BSA) at runtime (overrides config).",
413
+ )
414
+ parser.add_argument(
415
+ "--bsa-sparsity",
416
+ type=float,
417
+ help="BSA sparsity (e.g., 0.9375).",
418
+ )
419
+ parser.add_argument(
420
+ "--bsa-cdf-threshold",
421
+ type=float,
422
+ help="BSA CDF threshold (optional).",
423
+ )
424
+ parser.add_argument(
425
+ "--bsa-chunk-q",
426
+ nargs=3,
427
+ type=int,
428
+ metavar=("T", "H", "W"),
429
+ help="BSA chunk_3d_shape_q as three ints, e.g., 4 4 4.",
430
+ )
431
+ parser.add_argument(
432
+ "--bsa-chunk-k",
433
+ nargs=3,
434
+ type=int,
435
+ metavar=("T", "H", "W"),
436
+ help="BSA chunk_3d_shape_k as three ints, e.g., 4 4 4.",
437
+ )
438
+
439
+ parser.add_argument(
440
+ "--enable-torch-compile",
441
+ action=StoreBoolean,
442
+ default=FastVideoArgs.enable_torch_compile,
443
+ help="Use torch.compile to speed up DiT inference." +
444
+ "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)",
445
+ )
446
+ parser.add_argument(
447
+ "--torch-compile-kwargs",
448
+ type=str,
449
+ default=None,
450
+ help=
451
+ "JSON string of kwargs to pass to torch.compile. Example: '{\"backend\":\"inductor\",\"mode\":\"reduce-overhead\"}'",
452
+ )
453
+
454
+ parser.add_argument(
455
+ "--dit-cpu-offload",
456
+ action=StoreBoolean,
457
+ help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.",
458
+ )
459
+ parser.add_argument(
460
+ "--dit-layerwise-offload",
461
+ action=StoreBoolean,
462
+ help="Enable layerwise CPU offload with async H2D prefetch overlap.",
463
+ )
464
+ parser.add_argument(
465
+ "--use-fsdp-inference",
466
+ action=StoreBoolean,
467
+ help=
468
+ "Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.",
469
+ )
470
+ parser.add_argument(
471
+ "--text-encoder-cpu-offload",
472
+ action=StoreBoolean,
473
+ help="Use CPU offload for text encoder. Enable if run out of memory.",
474
+ )
475
+ parser.add_argument(
476
+ "--image-encoder-cpu-offload",
477
+ action=StoreBoolean,
478
+ help="Use CPU offload for image encoder. Enable if run out of memory.",
479
+ )
480
+ parser.add_argument(
481
+ "--vae-cpu-offload",
482
+ action=StoreBoolean,
483
+ help="Use CPU offload for VAE. Enable if run out of memory.",
484
+ )
485
+ parser.add_argument(
486
+ "--pin-cpu-memory",
487
+ action=StoreBoolean,
488
+ help=
489
+ "Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". "
490
+ "Should be enabled in almost all cases",
491
+ )
492
+ parser.add_argument(
493
+ "--disable-autocast",
494
+ action=StoreBoolean,
495
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
496
+ )
497
+
498
+ # VSA parameters
499
+ parser.add_argument(
500
+ "--VSA-sparsity",
501
+ type=float,
502
+ default=FastVideoArgs.VSA_sparsity,
503
+ help="Validation sparsity for VSA",
504
+ )
505
+
506
+ # Master port for distributed training/inference
507
+ parser.add_argument(
508
+ "--master-port",
509
+ type=int,
510
+ default=FastVideoArgs.master_port,
511
+ help="Master port for distributed training/inference",
512
+ )
513
+
514
+ # Stage verification
515
+ parser.add_argument(
516
+ "--enable-stage-verification",
517
+ action=StoreBoolean,
518
+ default=FastVideoArgs.enable_stage_verification,
519
+ help="Enable input/output verification for pipeline stages",
520
+ )
521
+ parser.add_argument(
522
+ "--override-text-encoder-safetensors",
523
+ type=str,
524
+ default=FastVideoArgs.override_text_encoder_safetensors,
525
+ help="Path to safetensors file for text encoder override",
526
+ )
527
+ parser.add_argument(
528
+ "--override-text-encoder-quant",
529
+ type=str,
530
+ choices=QUANTIZATION_METHODS,
531
+ default=FastVideoArgs.override_text_encoder_quant,
532
+ help="Quantization method for text encoder override",
533
+ )
534
+ parser.add_argument(
535
+ "--transformer-quant",
536
+ type=str,
537
+ choices=QUANTIZATION_METHODS,
538
+ default=FastVideoArgs.transformer_quant,
539
+ help="Quantization method for transformer loading",
540
+ )
541
+ parser.add_argument(
542
+ "--override-transformer-cls-name",
543
+ type=str,
544
+ default=FastVideoArgs.override_transformer_cls_name,
545
+ help="Override transformer cls name",
546
+ )
547
+ parser.add_argument(
548
+ "--override-pipeline-cls-name",
549
+ type=str,
550
+ default=FastVideoArgs.override_pipeline_cls_name,
551
+ help="Override pipeline cls name",
552
+ )
553
+ parser.add_argument("--init-weights-from-safetensors",
554
+ type=str,
555
+ help="Path to safetensors file for initial weight loading")
556
+ parser.add_argument("--init-weights-from-safetensors-2",
557
+ type=str,
558
+ help="Path to safetensors file for initial weight loading")
559
+
560
+ # Add pipeline configuration arguments
561
+ PipelineConfig.add_cli_args(parser)
562
+
563
+ # Add preprocessing configuration arguments
564
+ PreprocessConfig.add_cli_args(parser)
565
+
566
+ return parser
567
+
568
+ @classmethod
569
+ def from_cli_args(cls, args: argparse.Namespace) -> "FastVideoArgs":
570
+ provided_args = clean_cli_args(args)
571
+ # Get all fields from the dataclass
572
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
573
+
574
+ # Create a dictionary of attribute values, with defaults for missing attributes
575
+ kwargs: dict[str, Any] = {}
576
+ for attr in attrs:
577
+ if attr == 'pipeline_config':
578
+ pipeline_config = PipelineConfig.from_kwargs(provided_args)
579
+ kwargs['pipeline_config'] = pipeline_config
580
+ elif attr == 'preprocess_config':
581
+ preprocess_config = PreprocessConfig.from_kwargs(provided_args)
582
+ kwargs['preprocess_config'] = preprocess_config
583
+ elif attr == 'mode':
584
+ # Convert string to ExecutionMode enum
585
+ mode_value = getattr(args, attr, FastVideoArgs.mode.value)
586
+ kwargs['mode'] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
587
+ elif attr == 'torch_compile_kwargs':
588
+ # Parse JSON string for torch.compile kwargs
589
+ torch_compile_kwargs_str = getattr(args, 'torch_compile_kwargs', None)
590
+ if torch_compile_kwargs_str:
591
+ try:
592
+ import json
593
+ kwargs['torch_compile_kwargs'] = json.loads(torch_compile_kwargs_str)
594
+ except json.JSONDecodeError as e:
595
+ raise ValueError(f"Invalid JSON for torch_compile_kwargs: {e}") from e
596
+ else:
597
+ kwargs['torch_compile_kwargs'] = {}
598
+ elif attr == 'workload_type':
599
+ # Convert string to WorkloadType enum
600
+ workload_type_value = getattr(args, 'workload_type', FastVideoArgs.workload_type.value)
601
+ kwargs['workload_type'] = WorkloadType.from_string(workload_type_value) if isinstance(
602
+ workload_type_value, str) else workload_type_value
603
+ # Use getattr with default value from the dataclass for potentially missing attributes
604
+ else:
605
+ # Get the field to check if it has a default_factory
606
+ field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
607
+ if f.name == attr)]
608
+ if field.default_factory is not dataclasses.MISSING:
609
+ # Use the default_factory to create the default value
610
+ default_value = field.default_factory()
611
+ else:
612
+ default_value = getattr(cls, attr, None)
613
+ value = getattr(args, attr, default_value)
614
+ kwargs[attr] = value # type: ignore
615
+
616
+ return cls(**kwargs) # type: ignore
617
+
618
+ @classmethod
619
+ def from_kwargs(cls, **kwargs: Any) -> "FastVideoArgs":
620
+ # Convert mode string to enum if necessary
621
+ if 'mode' in kwargs and isinstance(kwargs['mode'], str):
622
+ kwargs['mode'] = ExecutionMode.from_string(kwargs['mode'])
623
+
624
+ # Convert workload_type string to enum if necessary
625
+ if 'workload_type' in kwargs and isinstance(kwargs['workload_type'], str):
626
+ kwargs['workload_type'] = WorkloadType.from_string(kwargs['workload_type'])
627
+
628
+ kwargs['pipeline_config'] = PipelineConfig.from_kwargs(kwargs)
629
+ kwargs['preprocess_config'] = PreprocessConfig.from_kwargs(kwargs)
630
+ # Filter to only FastVideoArgs dataclass fields — pipeline-specific CLI
631
+ # args (e.g. enable_bsa, bsa_sparsity) live in PipelineConfig and must
632
+ # not be forwarded to the FastVideoArgs constructor.
633
+ valid_fields = {f.name for f in dataclasses.fields(cls)}
634
+ return cls(**{k: v for k, v in kwargs.items() if k in valid_fields})
635
+
636
+ def check_fastvideo_args(self) -> None:
637
+ """Validate inference arguments for consistency"""
638
+ from fastvideo.platforms import current_platform
639
+
640
+ if current_platform.is_mps():
641
+ self.use_fsdp_inference = False
642
+ self.dit_layerwise_offload = False
643
+
644
+ if self.dit_layerwise_offload:
645
+ if self.use_fsdp_inference:
646
+ logger.warning("dit_layerwise_offload is enabled, automatically disabling use_fsdp_inference.")
647
+ self.use_fsdp_inference = False
648
+ if self.dit_cpu_offload:
649
+ logger.warning("dit_layerwise_offload is enabled, automatically disabling dit_cpu_offload.")
650
+ self.dit_cpu_offload = False
651
+
652
+ # Validate mode and inference_mode consistency
653
+ assert isinstance(self.mode, ExecutionMode), f"Mode must be an ExecutionMode enum, got {type(self.mode)}"
654
+ assert self.mode in ExecutionMode.choices(), f"Invalid execution mode: {self.mode}"
655
+
656
+ # Validate workload type
657
+ assert isinstance(self.workload_type,
658
+ WorkloadType), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}"
659
+ assert self.workload_type in WorkloadType.choices(), f"Invalid workload type: {self.workload_type}"
660
+
661
+ if self.mode in [ExecutionMode.DISTILLATION, ExecutionMode.FINETUNING] and self.inference_mode:
662
+ logger.warning("Mode is 'training' but inference_mode is True. Setting inference_mode to False.")
663
+ self.inference_mode = False
664
+ elif self.mode in [ExecutionMode.INFERENCE, ExecutionMode.PREPROCESS] and not self.inference_mode:
665
+ logger.warning("Mode is '%s' but inference_mode is False. Setting inference_mode to True.", self.mode)
666
+ self.inference_mode = True
667
+
668
+ if not self.inference_mode:
669
+ assert self.hsdp_replicate_dim != -1, "hsdp_replicate_dim must be set for training"
670
+ assert self.hsdp_shard_dim != -1, "hsdp_shard_dim must be set for training"
671
+ assert self.sp_size != -1, "sp_size must be set for training"
672
+
673
+ if self.tp_size == -1:
674
+ self.tp_size = 1
675
+ if self.sp_size == -1:
676
+ self.sp_size = self.num_gpus
677
+ if self.hsdp_shard_dim == -1:
678
+ self.hsdp_shard_dim = self.num_gpus
679
+
680
+ assert self.sp_size <= self.num_gpus and self.num_gpus % self.sp_size == 0, "num_gpus must >= and be divisible by sp_size"
681
+ assert self.hsdp_replicate_dim <= self.num_gpus and self.num_gpus % self.hsdp_replicate_dim == 0, "num_gpus must >= and be divisible by hsdp_replicate_dim"
682
+ assert self.hsdp_shard_dim <= self.num_gpus and self.num_gpus % self.hsdp_shard_dim == 0, "num_gpus must >= and be divisible by hsdp_shard_dim"
683
+
684
+ if self.num_gpus < max(self.tp_size, self.sp_size):
685
+ self.num_gpus = max(self.tp_size, self.sp_size)
686
+
687
+ if self.pipeline_config is None:
688
+ raise ValueError("pipeline_config is not set in FastVideoArgs")
689
+
690
+ self.pipeline_config.check_pipeline_config()
691
+
692
+ # Add preprocessing config validation if needed
693
+ if self.mode == ExecutionMode.PREPROCESS:
694
+ if self.preprocess_config is None:
695
+ raise ValueError("preprocess_config is not set in FastVideoArgs when mode is PREPROCESS")
696
+ if self.preprocess_config.model_path == "":
697
+ self.preprocess_config.model_path = self.model_path
698
+ if not self.pipeline_config.vae_config.load_encoder:
699
+ self.pipeline_config.vae_config.load_encoder = True
700
+ self.preprocess_config.check_preprocess_config()
701
+
702
+
703
+ _current_fastvideo_args = None
704
+
705
+
706
+ def prepare_fastvideo_args(argv: list[str]) -> FastVideoArgs:
707
+ """
708
+ Prepare the inference arguments from the command line arguments.
709
+
710
+ Args:
711
+ argv: The command line arguments. Typically, it should be `sys.argv[1:]`
712
+ to ensure compatibility with `parse_args` when no arguments are passed.
713
+
714
+ Returns:
715
+ The inference arguments.
716
+ """
717
+ parser = FlexibleArgumentParser()
718
+ FastVideoArgs.add_cli_args(parser)
719
+ raw_args = parser.parse_args(argv)
720
+ fastvideo_args = FastVideoArgs.from_cli_args(raw_args)
721
+ global _current_fastvideo_args
722
+ _current_fastvideo_args = fastvideo_args
723
+ return fastvideo_args
724
+
725
+
726
+ @contextmanager
727
+ def set_current_fastvideo_args(fastvideo_args: FastVideoArgs):
728
+ """
729
+ Temporarily set the current fastvideo config.
730
+ Used during model initialization.
731
+ We save the current fastvideo config in a global variable,
732
+ so that all modules can access it, e.g. custom ops
733
+ can access the fastvideo config to determine how to dispatch.
734
+ """
735
+ global _current_fastvideo_args
736
+ old_fastvideo_args = _current_fastvideo_args
737
+ try:
738
+ _current_fastvideo_args = fastvideo_args
739
+ yield
740
+ finally:
741
+ _current_fastvideo_args = old_fastvideo_args
742
+
743
+
744
+ def get_current_fastvideo_args() -> FastVideoArgs:
745
+ if _current_fastvideo_args is None:
746
+ # in ci, usually when we test custom ops/modules directly,
747
+ # we don't set the fastvideo config. In that case, we set a default
748
+ # config.
749
+ # TODO(will): may need to handle this for CI.
750
+ raise ValueError("Current fastvideo args is not set.")
751
+ return _current_fastvideo_args
752
+
753
+
754
+ @dataclasses.dataclass
755
+ class TrainingArgs(FastVideoArgs):
756
+ """
757
+ Training arguments. Inherits from FastVideoArgs and adds training-specific
758
+ arguments. If there are any conflicts, the training arguments will take
759
+ precedence.
760
+ """
761
+ data_path: str = ""
762
+ dataloader_num_workers: int = 0
763
+ num_height: int = 0
764
+ num_width: int = 0
765
+ num_frames: int = 0
766
+
767
+ train_batch_size: int = 0
768
+ num_latent_t: int = 0
769
+ group_frame: bool = False
770
+ group_resolution: bool = False
771
+
772
+ # text encoder & vae & diffusion model
773
+ pretrained_model_name_or_path: str = ""
774
+
775
+ # DMD model paths - separate paths for each network
776
+ real_score_model_path: str = "" # path for real score (teacher) model
777
+ fake_score_model_path: str = "" # path for fake score (critic) model
778
+
779
+ # diffusion setting
780
+ ema_decay: float = 0.0
781
+ ema_start_step: int = 0
782
+ training_cfg_rate: float = 0.0
783
+ precondition_outputs: bool = False
784
+
785
+ # validation & logs
786
+ validation_dataset_file: str = ""
787
+ validation_preprocessed_path: str = ""
788
+ validation_sampling_steps: str = ""
789
+ validation_guidance_scale: str = ""
790
+ validation_steps: float = 0.0
791
+ log_validation: bool = False
792
+ trackers: list[str] = dataclasses.field(default_factory=list)
793
+ tracker_project_name: str = ""
794
+ wandb_run_name: str = ""
795
+ seed: int = 0
796
+ _loading_teacher_critic_model: bool = False
797
+
798
+ # output
799
+ output_dir: str = ""
800
+ checkpoints_total_limit: int = 0
801
+ resume_from_checkpoint: str = "" # specify the checkpoint folder to resume from
802
+
803
+ # optimizer & scheduler
804
+ num_train_epochs: int = 0
805
+ max_train_steps: int = 0
806
+ gradient_accumulation_steps: int = 0
807
+ learning_rate: float = 0.0
808
+ scale_lr: bool = False
809
+ lr_scheduler: str = "constant"
810
+ lr_warmup_steps: int = 0
811
+ max_grad_norm: float = 0.0
812
+ enable_gradient_checkpointing_type: str | None = None
813
+ selective_checkpointing: float = 0.0
814
+ mixed_precision: str = ""
815
+ train_sp_batch_size: int = 0
816
+ fsdp_sharding_startegy: str = ""
817
+
818
+ weighting_scheme: str = ""
819
+ logit_mean: float = 0.0
820
+ logit_std: float = 1.0
821
+ mode_scale: float = 0.0
822
+
823
+ num_euler_timesteps: int = 0
824
+ lr_num_cycles: int = 0
825
+ lr_power: float = 0.0
826
+ min_lr_ratio: float = 0.5 # minimum learning rate ratio for cosine_with_min_lr scheduler
827
+ not_apply_cfg_solver: bool = False
828
+ distill_cfg: float = 0.0
829
+ scheduler_type: str = ""
830
+ linear_quadratic_threshold: float = 0.0
831
+ linear_range: float = 0.0
832
+ weight_decay: float = 0.0
833
+ betas: str = "0.9,0.999" # betas for optimizer, format: "beta1,beta2"
834
+ use_ema: bool = False
835
+ multi_phased_distill_schedule: str = ""
836
+ pred_decay_weight: float = 0.0
837
+ pred_decay_type: str = ""
838
+ hunyuan_teacher_disable_cfg: bool = False
839
+
840
+ # master_weight_type
841
+ master_weight_type: str = ""
842
+
843
+ # VSA training decay parameters
844
+ VSA_decay_rate: float = 0.01 # decay rate -> 0.02
845
+ VSA_decay_interval_steps: int = 1 # decay interval steps -> 50
846
+ VSA_init_sparsity: float = 0.0 # initial sparsity (default 0, ramp from 0)
847
+ VSA_warmup_steps: int = 0 # keep init_sparsity for this many steps before ramping
848
+
849
+ # LoRA training parameters
850
+ lora_rank: int | None = None
851
+ lora_alpha: int | None = None
852
+ lora_training: bool = False
853
+ ltx2_first_frame_conditioning_p: float = 0.1
854
+
855
+ # distillation args
856
+ generator_update_interval: int = 5
857
+ dfake_gen_update_ratio: int = 5 # self-forcing: how often to train generator vs critic
858
+ min_timestep_ratio: float = 0.2
859
+ max_timestep_ratio: float = 0.98
860
+ real_score_guidance_scale: float = 3.5
861
+ fake_score_learning_rate: float = 0.0 # separate learning rate for fake_score_transformer, if 0.0, use learning_rate
862
+ fake_score_lr_scheduler: str = "constant" # separate lr scheduler for fake_score_transformer, if not set, use lr_scheduler
863
+ fake_score_betas: str = "0.9,0.999" # betas for fake score optimizer, format: "beta1,beta2"
864
+ training_state_checkpointing_steps: int = 0 # for resuming training
865
+ weight_only_checkpointing_steps: int = 0 # for inference
866
+ log_visualization: bool = False
867
+ visualization_steps: int = 0
868
+ # simulate generator forward to match inference
869
+ simulate_generator_forward: bool = False
870
+ warp_denoising_step: bool = False
871
+ generator_4bit_attn: bool = False
872
+ generator_4bit_linear: bool = False
873
+
874
+ # Self-forcing specific arguments
875
+ num_frame_per_block: int = 3
876
+ independent_first_frame: bool = False
877
+ enable_gradient_masking: bool = True
878
+ gradient_mask_last_n_frames: int = 21
879
+ same_step_across_blocks: bool = False # Use same exit timestep for all blocks
880
+ last_step_only: bool = False # Only use the last timestep for training
881
+ context_noise: int = 0 # Context noise level for cache updates
882
+
883
+ @classmethod
884
+ def from_cli_args(cls, args: argparse.Namespace) -> "TrainingArgs":
885
+ provided_args = clean_cli_args(args)
886
+ # Get all fields from the dataclass
887
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
888
+ logger.info(provided_args)
889
+ # Create a dictionary of attribute values, with defaults for missing attributes
890
+ kwargs: dict[str, Any] = {}
891
+ for attr in attrs:
892
+ if attr == 'pipeline_config':
893
+ pipeline_config = PipelineConfig.from_kwargs(provided_args)
894
+ kwargs[attr] = pipeline_config
895
+ elif attr == 'mode':
896
+ # Convert string to ExecutionMode enum
897
+ mode_value = getattr(args, attr, ExecutionMode.FINETUNING.value)
898
+ kwargs[attr] = ExecutionMode.from_string(mode_value) if isinstance(mode_value, str) else mode_value
899
+ elif attr == 'workload_type':
900
+ # Convert string to WorkloadType enum
901
+ workload_type_value = getattr(args, 'workload_type', WorkloadType.T2V.value)
902
+ kwargs[attr] = WorkloadType.from_string(workload_type_value) if isinstance(workload_type_value,
903
+ str) else workload_type_value
904
+ # Use getattr with default value from the dataclass for potentially missing attributes
905
+ else:
906
+ # Get the field to check its default value
907
+ field = dataclasses.fields(cls)[next(i for i, f in enumerate(dataclasses.fields(cls))
908
+ if f.name == attr)]
909
+
910
+ # Check if the attribute is provided in args
911
+ if hasattr(args, attr):
912
+ value = getattr(args, attr)
913
+ else:
914
+ # Use the field's default value
915
+ if field.default_factory is not dataclasses.MISSING:
916
+ value = field.default_factory()
917
+ elif field.default is not dataclasses.MISSING:
918
+ value = field.default
919
+ else:
920
+ # No default value, use None
921
+ value = None
922
+
923
+ kwargs[attr] = value
924
+
925
+ return cls(**kwargs) # type: ignore
926
+
927
+ @staticmethod
928
+ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
929
+ parser.add_argument("--data-path", type=str, required=True, help="Path to parquet files")
930
+ parser.add_argument("--dataloader-num-workers",
931
+ type=int,
932
+ required=True,
933
+ help="Number of workers for dataloader")
934
+ parser.add_argument("--num-height", type=int, required=True, help="Number of heights")
935
+ parser.add_argument("--num-width", type=int, required=True, help="Number of widths")
936
+ parser.add_argument("--num-frames", type=int, required=True, help="Number of frames")
937
+
938
+ # Training batch and model configuration
939
+ parser.add_argument("--train-batch-size", type=int, required=True, help="Training batch size")
940
+ parser.add_argument("--num-latent-t", type=int, required=True, help="Number of latent time steps")
941
+ parser.add_argument("--group-frame", action=StoreBoolean, help="Whether to group frames during training")
942
+ parser.add_argument("--group-resolution",
943
+ action=StoreBoolean,
944
+ help="Whether to group resolutions during training")
945
+
946
+ # Model paths
947
+ parser.add_argument("--pretrained-model-name-or-path",
948
+ type=str,
949
+ required=True,
950
+ help="Path to pretrained model or model name")
951
+ parser.add_argument("--dit-model-name-or-path",
952
+ type=str,
953
+ required=False,
954
+ help="Path to DiT model or model name")
955
+ parser.add_argument("--cache-dir", type=str, help="Directory to cache models")
956
+
957
+ # DMD model paths - separate paths for each network
958
+ parser.add_argument("--generator-model-path",
959
+ type=str,
960
+ help="Path to generator (student) model for DMD distillation")
961
+ parser.add_argument("--real-score-model-path",
962
+ type=str,
963
+ help="Path to real score (teacher) model for DMD distillation")
964
+ parser.add_argument("--fake-score-model-path",
965
+ type=str,
966
+ help="Path to fake score (critic) model for DMD distillation")
967
+
968
+ # Diffusion settings
969
+ parser.add_argument("--ema-decay", type=float, default=0.999, help="EMA decay rate")
970
+ parser.add_argument("--ema-start-step", type=int, default=0, help="Step to start EMA")
971
+ parser.add_argument("--training-cfg-rate", type=float, help="Classifier-free guidance scale")
972
+ parser.add_argument("--precondition-outputs",
973
+ action=StoreBoolean,
974
+ help="Whether to precondition the outputs of the model")
975
+
976
+ # Validation and logging
977
+ parser.add_argument("--validation-dataset-file", type=str, help="Path to unprocessed validation dataset")
978
+ parser.add_argument("--validation-preprocessed-path", type=str, help="Path to processed validation dataset")
979
+ parser.add_argument("--validation-sampling-steps", type=str, help="Validation sampling steps")
980
+ parser.add_argument("--validation-guidance-scale", type=str, help="Validation guidance scale")
981
+ parser.add_argument("--validation-steps", type=float, help="Number of validation steps")
982
+ parser.add_argument("--log-validation", action=StoreBoolean, help="Whether to log validation results")
983
+ parser.add_argument("--visualization-steps", type=int, help="Number of visualization steps")
984
+ parser.add_argument("--tracker-project-name", type=str, help="Project name for tracking")
985
+ parser.add_argument("--wandb-run-name", type=str, help="Run name for wandb")
986
+ parser.add_argument("--seed", type=int, default=42, help="Seed for deterministic training")
987
+
988
+ # Output configuration
989
+ parser.add_argument("--output-dir", type=str, required=True, help="Output directory for checkpoints and logs")
990
+ parser.add_argument("--checkpoints-total-limit", type=int, help="Maximum number of checkpoints to keep")
991
+ parser.add_argument("--training-state-checkpointing-steps",
992
+ type=int,
993
+ help="Steps between training state checkpoints (for resuming training)")
994
+ parser.add_argument("--weight-only-checkpointing-steps",
995
+ type=int,
996
+ help="Steps between weight-only checkpoints (for inference)")
997
+ parser.add_argument("--resume-from-checkpoint", type=str, help="Path to checkpoint to resume from")
998
+ parser.add_argument("--logging-dir", type=str, help="Directory for logging")
999
+
1000
+ # Training configuration
1001
+ parser.add_argument("--num-train-epochs", type=int, help="Number of training epochs")
1002
+ parser.add_argument("--max-train-steps", type=int, help="Maximum number of training steps")
1003
+ parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of steps to accumulate gradients")
1004
+ parser.add_argument("--learning-rate", type=float, required=True, help="Learning rate")
1005
+ parser.add_argument("--scale-lr", action=StoreBoolean, help="Whether to scale learning rate")
1006
+ parser.add_argument("--lr-scheduler", type=str, default="constant", help="Learning rate scheduler type")
1007
+ parser.add_argument("--lr-warmup-steps", type=int, default=10, help="Number of warmup steps for learning rate")
1008
+ parser.add_argument("--max-grad-norm", type=float, help="Maximum gradient norm")
1009
+ parser.add_argument("--enable-gradient-checkpointing-type",
1010
+ type=str,
1011
+ choices=["full", "ops", "block_skip"],
1012
+ default=None,
1013
+ help="Gradient checkpointing type")
1014
+ parser.add_argument("--selective-checkpointing", type=float, help="Selective checkpointing threshold")
1015
+ parser.add_argument("--mixed-precision", type=str, help="Mixed precision training type")
1016
+ parser.add_argument("--train-sp-batch-size", type=int, help="Training spatial parallelism batch size")
1017
+
1018
+ parser.add_argument("--fsdp-sharding-strategy", type=str, help="FSDP sharding strategy")
1019
+
1020
+ parser.add_argument(
1021
+ "--weighting_scheme",
1022
+ type=str,
1023
+ default="uniform",
1024
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "uniform"],
1025
+ )
1026
+ parser.add_argument(
1027
+ "--logit_mean",
1028
+ type=float,
1029
+ default=0.0,
1030
+ help="mean to use when using the `'logit_normal'` weighting scheme.",
1031
+ )
1032
+ parser.add_argument(
1033
+ "--logit_std",
1034
+ type=float,
1035
+ default=1.0,
1036
+ help="std to use when using the `'logit_normal'` weighting scheme.",
1037
+ )
1038
+ parser.add_argument(
1039
+ "--mode_scale",
1040
+ type=float,
1041
+ default=1.29,
1042
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
1043
+ )
1044
+
1045
+ # Additional training parameters
1046
+ parser.add_argument("--num-euler-timesteps", type=int, help="Number of Euler timesteps")
1047
+ parser.add_argument("--lr-num-cycles", type=int, help="Number of learning rate cycles")
1048
+ parser.add_argument("--lr-power", type=float, help="Learning rate power")
1049
+ parser.add_argument("--min-lr-ratio",
1050
+ type=float,
1051
+ default=TrainingArgs.min_lr_ratio,
1052
+ help="Minimum learning rate ratio for cosine_with_min_lr scheduler")
1053
+ parser.add_argument("--not-apply-cfg-solver", action=StoreBoolean, help="Whether to not apply CFG solver")
1054
+ parser.add_argument("--distill-cfg", type=float, help="Distillation CFG scale")
1055
+ parser.add_argument("--scheduler-type", type=str, help="Scheduler type")
1056
+ parser.add_argument("--linear-quadratic-threshold", type=float, help="Linear quadratic threshold")
1057
+ parser.add_argument("--linear-range", type=float, help="Linear range")
1058
+ parser.add_argument("--weight-decay", type=float, help="Weight decay")
1059
+ parser.add_argument("--betas",
1060
+ type=str,
1061
+ default=TrainingArgs.betas,
1062
+ help="Betas for optimizer (format: 'beta1,beta2')")
1063
+ parser.add_argument("--use-ema", action=StoreBoolean, help="Whether to use EMA")
1064
+ parser.add_argument("--multi-phased-distill-schedule", type=str, help="Multi-phased distillation schedule")
1065
+ parser.add_argument("--pred-decay-weight", type=float, help="Prediction decay weight")
1066
+ parser.add_argument("--pred-decay-type", type=str, help="Prediction decay type")
1067
+ parser.add_argument("--hunyuan-teacher-disable-cfg",
1068
+ action=StoreBoolean,
1069
+ help="Whether to disable CFG for Hunyuan teacher")
1070
+ parser.add_argument("--master-weight-type", type=str, help="Master weight type")
1071
+
1072
+ # VSA parameters for training with dense to sparse adaption
1073
+ parser.add_argument(
1074
+ "--VSA-decay-rate", # decay rate, how much sparsity you want to decay each step
1075
+ type=float,
1076
+ default=TrainingArgs.VSA_decay_rate,
1077
+ help="VSA decay rate")
1078
+ parser.add_argument(
1079
+ "--VSA-decay-interval-steps", # how many steps for training with current sparsity
1080
+ type=int,
1081
+ default=TrainingArgs.VSA_decay_interval_steps,
1082
+ help="VSA decay interval steps")
1083
+ parser.add_argument(
1084
+ "--VSA-init-sparsity",
1085
+ type=float,
1086
+ default=TrainingArgs.VSA_init_sparsity,
1087
+ help="Initial sparsity to start from (default 0)")
1088
+ parser.add_argument(
1089
+ "--VSA-warmup-steps",
1090
+ type=int,
1091
+ default=TrainingArgs.VSA_warmup_steps,
1092
+ help="Keep init sparsity for N steps before ramping (default 0)")
1093
+ parser.add_argument("--lora-training", action=StoreBoolean, help="Whether to use LoRA training")
1094
+ parser.add_argument("--lora-rank", type=int, help="LoRA rank")
1095
+ parser.add_argument("--lora-alpha", type=int, help="LoRA alpha")
1096
+ parser.add_argument(
1097
+ "--ltx2-first-frame-conditioning-p",
1098
+ type=float,
1099
+ default=TrainingArgs.ltx2_first_frame_conditioning_p,
1100
+ help="Probability of conditioning on the first frame during LTX-2 training",
1101
+ )
1102
+
1103
+ # V-MoBA parameters
1104
+ parser.add_argument(
1105
+ "--moba-config-path",
1106
+ type=str,
1107
+ default=None,
1108
+ help="Path to a JSON file containing V-MoBA specific configurations.",
1109
+ )
1110
+
1111
+ # Distillation arguments
1112
+ parser.add_argument("--generator-update-interval",
1113
+ type=int,
1114
+ default=TrainingArgs.generator_update_interval,
1115
+ help="Ratio of student updates to critic updates.")
1116
+ parser.add_argument(
1117
+ "--dfake-gen-update-ratio",
1118
+ type=int,
1119
+ default=TrainingArgs.dfake_gen_update_ratio,
1120
+ help="Self-forcing: How often to train generator vs critic (train generator every N steps).")
1121
+ parser.add_argument("--min-timestep-ratio",
1122
+ type=float,
1123
+ default=TrainingArgs.min_timestep_ratio,
1124
+ help="Minimum step ratio")
1125
+ parser.add_argument("--max-timestep-ratio",
1126
+ type=float,
1127
+ default=TrainingArgs.max_timestep_ratio,
1128
+ help="Maximum step ratio")
1129
+ parser.add_argument("--real-score-guidance-scale",
1130
+ type=float,
1131
+ default=TrainingArgs.real_score_guidance_scale,
1132
+ help="Teacher guidance scale")
1133
+ parser.add_argument("--fake-score-learning-rate",
1134
+ type=float,
1135
+ default=TrainingArgs.fake_score_learning_rate,
1136
+ help="Learning rate for fake score transformer")
1137
+ parser.add_argument("--fake-score-betas",
1138
+ type=str,
1139
+ default=TrainingArgs.fake_score_betas,
1140
+ help="Betas for fake score optimizer (format: 'beta1,beta2')")
1141
+ parser.add_argument("--fake-score-lr-scheduler",
1142
+ type=str,
1143
+ default=TrainingArgs.fake_score_lr_scheduler,
1144
+ help="Learning rate scheduler for fake score transformer")
1145
+ parser.add_argument("--log-visualization", action=StoreBoolean, help="Whether to log visualization")
1146
+ parser.add_argument("--simulate-generator-forward",
1147
+ action=StoreBoolean,
1148
+ help="Whether to simulate generator forward to match inference")
1149
+ parser.add_argument("--warp-denoising-step",
1150
+ action=StoreBoolean,
1151
+ help="Whether to warp denoising step according to the scheduler time shift")
1152
+
1153
+ # Self-forcing specific arguments
1154
+ parser.add_argument("--num-frame-per-block",
1155
+ type=int,
1156
+ default=TrainingArgs.num_frame_per_block,
1157
+ help="Number of frames per block for causal generation")
1158
+ parser.add_argument("--independent-first-frame",
1159
+ action=StoreBoolean,
1160
+ help="Whether the first frame is independent in causal generation")
1161
+ parser.add_argument("--enable-gradient-masking",
1162
+ action=StoreBoolean,
1163
+ help="Whether to enable frame-level gradient masking")
1164
+ parser.add_argument("--gradient-mask-last-n-frames",
1165
+ type=int,
1166
+ default=TrainingArgs.gradient_mask_last_n_frames,
1167
+ help="Number of last frames to enable gradients for")
1168
+ parser.add_argument("--validate-cache-structure",
1169
+ action=StoreBoolean,
1170
+ help="Whether to validate KV cache structure (debug flag)")
1171
+ parser.add_argument("--same-step-across-blocks",
1172
+ action=StoreBoolean,
1173
+ help="Whether to use the same exit timestep for all blocks")
1174
+ parser.add_argument("--last-step-only",
1175
+ action=StoreBoolean,
1176
+ help="Whether to only use the last timestep for training")
1177
+ parser.add_argument("--context-noise",
1178
+ type=int,
1179
+ default=TrainingArgs.context_noise,
1180
+ help="Context noise level for cache updates")
1181
+
1182
+ return parser
1183
+
1184
+
1185
+ def parse_int_list(value: str) -> list[int]:
1186
+ if not value:
1187
+ return []
1188
+ return [int(x.strip()) for x in value.split(",")]
backend_snapshot/fastvideo/pipelines/basic/wan/__init__.py ADDED
File without changes
backend_snapshot/fastvideo/pipelines/basic/wan/wan_pipeline.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Wan video diffusion pipeline implementation.
4
+
5
+ This module contains an implementation of the Wan video diffusion pipeline
6
+ using the modular pipeline architecture.
7
+ """
8
+
9
+ from fastvideo.fastvideo_args import FastVideoArgs
10
+ from fastvideo.logger import init_logger
11
+ from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
12
+ from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline
13
+ from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, DenoisingStage, InputValidationStage,
14
+ LatentPreparationStage, TextEncodingStage, TimestepPreparationStage)
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ class WanPipeline(LoRAPipeline, ComposedPipelineBase):
20
+ """
21
+ Wan video diffusion pipeline with LoRA support.
22
+ """
23
+
24
+ _required_config_modules = ["text_encoder", "tokenizer", "vae", "transformer", "scheduler"]
25
+
26
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
27
+ # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers.
28
+ self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
29
+
30
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None:
31
+ """Set up pipeline stages with proper dependency injection."""
32
+
33
+ self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage())
34
+
35
+ self.add_stage(stage_name="prompt_encoding_stage",
36
+ stage=TextEncodingStage(
37
+ text_encoders=[self.get_module("text_encoder")],
38
+ tokenizers=[self.get_module("tokenizer")],
39
+ ))
40
+
41
+ self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage())
42
+
43
+ self.add_stage(stage_name="timestep_preparation_stage",
44
+ stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")))
45
+
46
+ self.add_stage(stage_name="latent_preparation_stage",
47
+ stage=LatentPreparationStage(scheduler=self.get_module("scheduler"),
48
+ transformer=self.get_module("transformer", None)))
49
+
50
+ self.add_stage(stage_name="denoising_stage",
51
+ stage=DenoisingStage(transformer=self.get_module("transformer"),
52
+ transformer_2=self.get_module("transformer_2", None),
53
+ scheduler=self.get_module("scheduler"),
54
+ vae=self.get_module("vae"),
55
+ pipeline=self))
56
+
57
+ self.add_stage(stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae"), pipeline=self))
58
+
59
+
60
+ EntryClass = WanPipeline
backend_snapshot/fastvideo/pipelines/composed_pipeline_base.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Base class for composed pipelines.
4
+
5
+ This module defines the base class for pipelines that are composed of multiple stages.
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ from abc import ABC, abstractmethod
11
+ from typing import Any, cast
12
+
13
+ import torch
14
+
15
+ from fastvideo.configs.pipelines import PipelineConfig
16
+ from fastvideo.distributed import (maybe_init_distributed_environment_and_model_parallel, get_world_group)
17
+ from fastvideo.distributed.communication_op import (warmup_sequence_parallel_communication)
18
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
19
+ from fastvideo.logger import init_logger
20
+ from fastvideo.profiler import get_or_create_profiler
21
+ from fastvideo.models.loader.component_loader import PipelineComponentLoader
22
+ from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
23
+ from fastvideo.pipelines.stages import PipelineStage
24
+ import fastvideo.envs as envs
25
+ from fastvideo.utils import (maybe_download_model, verify_model_config_and_directory)
26
+
27
+ logger = init_logger(__name__)
28
+
29
+
30
+ class ComposedPipelineBase(ABC):
31
+ """
32
+ Base class for pipelines composed of multiple stages.
33
+
34
+ This class provides the framework for creating pipelines by composing multiple
35
+ stages together. Each stage is responsible for a specific part of the diffusion
36
+ process, and the pipeline orchestrates the execution of these stages.
37
+ """
38
+
39
+ is_video_pipeline: bool = False # To be overridden by video pipelines
40
+ _required_config_modules: list[str] = []
41
+ _extra_config_module_map: dict[str, str] = {}
42
+ training_args: Any = None
43
+ fastvideo_args: Any = None
44
+ modules: dict[str, Any] = {}
45
+ # do not need to include moe related transformers
46
+ trainable_transformer_names: list[str] = ["transformer"]
47
+ trainable_transformer_modules: dict[str, torch.nn.Module] = {}
48
+ post_init_called: bool = False
49
+
50
+ # TODO(will): args should support both inference args and training args
51
+ def __init__(self,
52
+ model_path: str,
53
+ fastvideo_args: FastVideoArgs | TrainingArgs,
54
+ required_config_modules: list[str] | None = None,
55
+ loaded_modules: dict[str, torch.nn.Module] | None = None):
56
+ """
57
+ Initialize the pipeline. After __init__, the pipeline should be ready to
58
+ use. The pipeline should be stateless and not hold any batch state.
59
+ """
60
+ self.fastvideo_args = fastvideo_args
61
+
62
+ self.model_path: str = model_path
63
+ self._stages: list[PipelineStage] = []
64
+ self._stage_name_mapping: dict[str, PipelineStage] = {}
65
+
66
+ if required_config_modules is not None:
67
+ self._required_config_modules = required_config_modules
68
+
69
+ if self._required_config_modules is None:
70
+ raise NotImplementedError("Subclass must set _required_config_modules")
71
+
72
+ maybe_init_distributed_environment_and_model_parallel(fastvideo_args.tp_size, fastvideo_args.sp_size)
73
+
74
+ # Torch profiler. Enabled and configured through env vars:
75
+ # FASTVIDEO_TORCH_PROFILER_DIR=/path/to/save/trace
76
+ trace_dir = envs.FASTVIDEO_TORCH_PROFILER_DIR
77
+ self.profiler_controller = get_or_create_profiler(trace_dir)
78
+ self.profiler = self.profiler_controller.profiler
79
+
80
+ self.local_rank = get_world_group().local_rank
81
+
82
+ # Load modules directly in initialization
83
+ logger.info("Loading pipeline modules...")
84
+ with self.profiler_controller.region("profiler_region_model_loading"):
85
+ self.modules = self.load_modules(fastvideo_args, loaded_modules)
86
+
87
+ def set_trainable(self) -> None:
88
+ # Only train DiT
89
+ if getattr(self.fastvideo_args, "training_mode", False):
90
+ for name, module in self.trainable_transformer_modules.items():
91
+ logger.info("Setting %s to requires_grad=True", name)
92
+ if not isinstance(module, torch.nn.Module):
93
+ logger.info("Skipping %s because it is not a torch.nn.Module", name)
94
+ continue
95
+ module.requires_grad_(True)
96
+ module.train()
97
+
98
+ @staticmethod
99
+ def _compile_with_conditions(
100
+ module: torch.nn.Module,
101
+ compile_kwargs: dict[str, Any],
102
+ ) -> int:
103
+ """Compile submodules that match module._compile_conditions."""
104
+ compile_conditions = getattr(module, "_compile_conditions", None)
105
+ if not compile_conditions:
106
+ return 0
107
+
108
+ compiled_count = 0
109
+ for name, submodule in module.named_modules():
110
+ if not name:
111
+ continue
112
+ if any(cond(name, submodule) for cond in compile_conditions):
113
+ submodule.forward = torch.compile(submodule.forward, **compile_kwargs)
114
+ compiled_count += 1
115
+ return compiled_count
116
+
117
+ def _maybe_compile_pipeline_module(
118
+ self,
119
+ module_name: str,
120
+ fsdp_module_cls: type | None,
121
+ compile_kwargs: dict[str, Any],
122
+ ) -> None:
123
+ if module_name not in self.modules:
124
+ return
125
+
126
+ module = self.modules[module_name]
127
+ if fsdp_module_cls is not None and isinstance(module, fsdp_module_cls):
128
+ logger.info(
129
+ "%s is already FSDP-wrapped; skipping torch.compile in pipeline",
130
+ module_name.capitalize(),
131
+ )
132
+ return
133
+
134
+ compiled_count = self._compile_with_conditions(module, compile_kwargs)
135
+ if compiled_count > 0:
136
+ logger.info(
137
+ "Enabled torch.compile for %d submodules in %s via _compile_conditions with kwargs=%s",
138
+ compiled_count,
139
+ module_name,
140
+ compile_kwargs,
141
+ )
142
+ return
143
+
144
+ # Backward-compatible fallback: compile full module if no condition matched.
145
+ logger.info("Enabling torch.compile for %s with kwargs=%s", module_name, compile_kwargs)
146
+ self.modules[module_name] = torch.compile(module, **compile_kwargs)
147
+
148
+ def post_init(self) -> None:
149
+ assert self.fastvideo_args is not None, "fastvideo_args must be set"
150
+ if self.post_init_called:
151
+ return
152
+ self.post_init_called = True
153
+ if self.fastvideo_args.training_mode:
154
+ assert isinstance(self.fastvideo_args, TrainingArgs)
155
+ self.training_args = self.fastvideo_args
156
+ assert self.training_args is not None
157
+ self.initialize_training_pipeline(self.training_args)
158
+ if self.training_args.log_validation:
159
+ self.initialize_validation_pipeline(self.training_args)
160
+
161
+ self.initialize_pipeline(self.fastvideo_args)
162
+ if self.fastvideo_args.enable_torch_compile:
163
+ if self.fastvideo_args.training_mode:
164
+ logger.info("Torch Compile enabled via FSDP loader for training; skipping additional pipeline compile")
165
+ else:
166
+ fsdp_module_cls = None
167
+ try:
168
+ from torch.distributed.fsdp import FSDPModule # type: ignore
169
+ fsdp_module_cls = FSDPModule
170
+ except Exception: # pragma: no cover - FSDP not always available
171
+ fsdp_module_cls = None
172
+
173
+ compile_kwargs = self.fastvideo_args.torch_compile_kwargs or {}
174
+ self._maybe_compile_pipeline_module(
175
+ module_name="transformer",
176
+ fsdp_module_cls=fsdp_module_cls,
177
+ compile_kwargs=compile_kwargs,
178
+ )
179
+ self._maybe_compile_pipeline_module(
180
+ module_name="transformer_2",
181
+ fsdp_module_cls=fsdp_module_cls,
182
+ compile_kwargs=compile_kwargs,
183
+ )
184
+ logger.info("Torch Compile enabled for DiT")
185
+
186
+ if not self.fastvideo_args.training_mode:
187
+ logger.info("Creating pipeline stages...")
188
+ self.create_pipeline_stages(self.fastvideo_args)
189
+
190
+ # Warmup NCCL communicators for sequence parallelism to avoid
191
+ # slow first forward pass due to lazy initialization
192
+ warmup_sequence_parallel_communication()
193
+
194
+ def initialize_training_pipeline(self, training_args: TrainingArgs):
195
+ raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
196
+
197
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
198
+ raise NotImplementedError("if log_validation is True, the pipeline must implement this method")
199
+
200
+ @classmethod
201
+ def from_pretrained(cls,
202
+ model_path: str,
203
+ device: str | None = None,
204
+ torch_dtype: torch.dtype | None = None,
205
+ pipeline_config: str | PipelineConfig | None = None,
206
+ args: argparse.Namespace | FastVideoArgs | TrainingArgs | None = None,
207
+ required_config_modules: list[str] | None = None,
208
+ loaded_modules: dict[str, torch.nn.Module]
209
+ | None = None,
210
+ **kwargs) -> "ComposedPipelineBase":
211
+ """
212
+ Load a pipeline from a pretrained model.
213
+ loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
214
+ If provided, loaded_modules will be used instead of loading from config/pretrained weights.
215
+ """
216
+ if args is None or (isinstance(args, FastVideoArgs) and args.inference_mode):
217
+
218
+ kwargs['model_path'] = model_path
219
+ fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)
220
+ else:
221
+ if isinstance(args, TrainingArgs):
222
+ fastvideo_args = args
223
+ else:
224
+ assert isinstance(args, argparse.Namespace), "training mode expects argparse.Namespace args"
225
+ fastvideo_args = TrainingArgs.from_cli_args(args)
226
+ # TODO(will): fix this so that its not so ugly
227
+ fastvideo_args.model_path = model_path
228
+ for key, value in kwargs.items():
229
+ setattr(fastvideo_args, key, value)
230
+
231
+ fastvideo_args.dit_cpu_offload = False
232
+ # we hijack the precision to be the master weight type so that the
233
+ # model is loaded with the correct precision. Subsequently we will
234
+ # use FSDP2's MixedPrecisionPolicy to set the precision for the
235
+ # fwd, bwd, and other operations' precision.
236
+ assert fastvideo_args.pipeline_config.dit_precision == 'fp32', 'only fp32 is supported for training'
237
+
238
+ logger.info("fastvideo_args in from_pretrained: %s", fastvideo_args)
239
+
240
+ pipe = cls(model_path,
241
+ fastvideo_args,
242
+ required_config_modules=required_config_modules,
243
+ loaded_modules=loaded_modules)
244
+ pipe.post_init()
245
+ return pipe
246
+
247
+ def get_module(self, module_name: str, default_value: Any = None) -> Any:
248
+ if module_name not in self.modules:
249
+ return default_value
250
+ return self.modules[module_name]
251
+
252
+ def add_module(self, module_name: str, module: Any):
253
+ self.modules[module_name] = module
254
+
255
+ def __getattr__(self, name: str) -> Any:
256
+ if "_stage_name_mapping" in self.__dict__ and name in self._stage_name_mapping:
257
+ return self._stage_name_mapping[name]
258
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
259
+
260
+ def _load_config(self, model_path: str) -> dict[str, Any]:
261
+ model_path = maybe_download_model(self.model_path)
262
+ self.model_path = model_path
263
+ # fastvideo_args.downloaded_model_path = model_path
264
+ logger.info("Model path: %s", model_path)
265
+ config = verify_model_config_and_directory(model_path)
266
+ return cast(dict[str, Any], config)
267
+
268
+ @property
269
+ def required_config_modules(self) -> list[str]:
270
+ """
271
+ List of modules that are required by the pipeline. The names should match
272
+ the diffusers directory and model_index.json file. These modules will be
273
+ loaded using the PipelineComponentLoader and made available in the
274
+ modules dictionary. Access these modules using the get_module method.
275
+
276
+ class ConcretePipeline(ComposedPipelineBase):
277
+ _required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"]
278
+
279
+
280
+ @property
281
+ def required_config_modules(self):
282
+ return self._required_config_modules
283
+ """
284
+ return self._required_config_modules
285
+
286
+ @property
287
+ def stages(self) -> list[PipelineStage]:
288
+ """
289
+ List of stages in the pipeline.
290
+ """
291
+ return self._stages
292
+
293
+ @abstractmethod
294
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
295
+ """
296
+ Create the inference pipeline stages.
297
+ """
298
+ raise NotImplementedError
299
+
300
+ def create_training_stages(self, training_args: TrainingArgs):
301
+ """
302
+ Create the training pipeline stages.
303
+ """
304
+ raise NotImplementedError
305
+
306
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
307
+ """
308
+ Initialize the pipeline.
309
+ """
310
+ return
311
+
312
+ def load_modules(self,
313
+ fastvideo_args: FastVideoArgs,
314
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
315
+ """
316
+ Load the modules from the config.
317
+ loaded_modules: Optional[Dict[str, torch.nn.Module]] = None,
318
+ If provided, loaded_modules will be used instead of loading from config/pretrained weights.
319
+ """
320
+
321
+ model_index = self._load_config(self.model_path)
322
+ logger.info("Loading pipeline modules from config: %s", model_index)
323
+
324
+ # remove keys that are not pipeline modules
325
+ model_index.pop("_class_name")
326
+ model_index.pop("_diffusers_version")
327
+ model_index.pop("_name_or_path", None)
328
+ model_index.pop("workload_type", None)
329
+ if "boundary_ratio" in model_index and model_index["boundary_ratio"] is not None:
330
+ logger.info("MoE pipeline detected. Adding transformer_2 to self.required_config_modules...")
331
+ self.required_config_modules.append("transformer_2")
332
+ logger.info("MoE pipeline detected. Setting boundary ratio to %s", model_index["boundary_ratio"])
333
+ fastvideo_args.pipeline_config.dit_config.boundary_ratio = model_index["boundary_ratio"]
334
+
335
+ model_index.pop("boundary_ratio", None)
336
+ # used by Wan2.2 ti2v
337
+ model_index.pop("expand_timesteps", None)
338
+
339
+ # some sanity checks
340
+ assert len(model_index) > 1, "model_index.json must contain at least one pipeline module"
341
+
342
+ for module_name in self.required_config_modules:
343
+ if module_name not in model_index and module_name in self._extra_config_module_map:
344
+ extra_module_value = self._extra_config_module_map[module_name]
345
+ logger.warning(
346
+ "model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.",
347
+ module_name, module_name, extra_module_value)
348
+ if extra_module_value in model_index:
349
+ logger.info("Using module %s for %s", extra_module_value, module_name)
350
+ model_index[module_name] = model_index[extra_module_value]
351
+ continue
352
+ else:
353
+ raise ValueError(
354
+ f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}"
355
+ )
356
+
357
+ # all the component models used by the pipeline
358
+ required_modules = self.required_config_modules
359
+ logger.info("Loading required modules: %s", required_modules)
360
+
361
+ modules = {}
362
+ for module_name, module_spec in model_index.items():
363
+ if not isinstance(module_spec, list | tuple):
364
+ logger.info(
365
+ "Skipping non-module config entry %s=%s",
366
+ module_name,
367
+ module_spec,
368
+ )
369
+ continue
370
+ if len(module_spec) < 1:
371
+ logger.warning(
372
+ "Skipping module %s due to invalid empty spec in model_index.json",
373
+ module_name,
374
+ )
375
+ continue
376
+ transformers_or_diffusers = module_spec[0]
377
+ if transformers_or_diffusers is None:
378
+ logger.warning("Module %s in model_index.json has null value, removing from required_config_modules",
379
+ module_name)
380
+ if module_name in self.required_config_modules:
381
+ self.required_config_modules.remove(module_name)
382
+ continue
383
+ if module_name not in required_modules:
384
+ logger.info("Skipping module %s", module_name)
385
+ continue
386
+ if loaded_modules is not None and module_name in loaded_modules:
387
+ logger.info("Using module %s already provided", module_name)
388
+ modules[module_name] = loaded_modules[module_name]
389
+ continue
390
+
391
+ # we load the module from the extra config module map if it exists
392
+ if module_name in self._extra_config_module_map:
393
+ load_module_name = self._extra_config_module_map[module_name]
394
+ else:
395
+ load_module_name = module_name
396
+
397
+ component_model_path = os.path.join(self.model_path, load_module_name)
398
+ module = PipelineComponentLoader.load_module(
399
+ module_name=load_module_name,
400
+ component_model_path=component_model_path,
401
+ transformers_or_diffusers=transformers_or_diffusers,
402
+ fastvideo_args=fastvideo_args,
403
+ )
404
+ logger.info("Loaded module %s from %s", module_name, component_model_path)
405
+
406
+ if module_name in modules:
407
+ logger.warning("Overwriting module %s", module_name)
408
+ modules[module_name] = module
409
+
410
+ # Check if all required modules were loaded
411
+ for module_name in required_modules:
412
+ if module_name not in modules or modules[module_name] is None:
413
+ raise ValueError(
414
+ f"Required module key: {module_name} value: {modules.get(module_name)} was not found in loaded modules {modules.keys()}"
415
+ )
416
+
417
+ return modules
418
+
419
+ def add_stage(self, stage_name: str, stage: PipelineStage):
420
+ assert self.modules is not None, "No modules are registered"
421
+ self._stages.append(stage)
422
+ self._stage_name_mapping[stage_name] = stage
423
+ setattr(self, stage_name, stage)
424
+
425
+ def profile(self, is_start: bool = True):
426
+ if self.profiler is None:
427
+ raise RuntimeError("Profiler is not enabled.")
428
+ if is_start:
429
+ self.profiler.start()
430
+ else:
431
+ self.profiler.stop()
432
+ # only print profiler results on rank 0
433
+ if self.local_rank == 0:
434
+ print(self.profiler.key_averages().table(sort_by="self_cuda_time_total"))
435
+
436
+ # TODO(will): don't hardcode no_grad
437
+ @torch.no_grad()
438
+ def forward(
439
+ self,
440
+ batch: ForwardBatch,
441
+ fastvideo_args: FastVideoArgs,
442
+ ) -> ForwardBatch:
443
+ """
444
+ Generate a video or image using the pipeline.
445
+
446
+ Args:
447
+ batch: The batch to generate from.
448
+ fastvideo_args: The inference arguments.
449
+ Returns:
450
+ ForwardBatch: The batch with the generated video or image.
451
+ """
452
+ if not self.post_init_called:
453
+ self.post_init()
454
+
455
+ # Execute each stage
456
+ logger.info("Running pipeline stages: %s", self._stage_name_mapping.keys())
457
+ # logger.info("Batch: %s", batch)
458
+ for stage in self.stages:
459
+ batch = stage(batch, fastvideo_args)
460
+
461
+ # Return the output
462
+ return batch
463
+
464
+ def train(self) -> None:
465
+ raise NotImplementedError("if training_mode is True, the pipeline must implement this method")
466
+
467
+ def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch:
468
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset")
469
+
470
+ def streaming_step(self, *args: Any, **kwargs: Any) -> ForwardBatch:
471
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_step")
472
+
473
+ def streaming_clear(self) -> None:
474
+ raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear")
backend_snapshot/manifest.sha256 CHANGED
@@ -1,17 +1,33 @@
1
- 45ff4b677a84fad92bd2ff596bf432cb1b9386c5923b6c0824f896074e7cfbc6 ./README.md
 
2
  9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
3
  211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
4
  3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
5
  56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
 
6
  2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
7
  a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
8
  79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
 
 
 
 
 
 
 
9
  ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
 
 
 
10
  6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
11
  489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
12
  c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
13
  2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
14
  bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
15
  1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
 
 
 
16
  5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
17
  5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
 
 
1
+ 033b7cce6eb0ead450a35b910adc9ae1323b8d2323aea6731b72b940e222fb46 ./README.md
2
+ 9c7dec8f1b8160954d0566231b0952a5f6a5d81f546affd71d190b2b3fc79cb6 ./examples/inference/basic/basic.py
3
  9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
4
  211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
5
  3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
6
  56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
7
+ 58f4ac013e6755336212a7a6c9948b19dab0dafc00f4a3298591598df270cb39 ./fastvideo/api/compat.py
8
  2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
9
  a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
10
  79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
11
+ 4bda44746a3626551ea9a9380d890f036087092fb99fce2d302642cce14a97ed ./fastvideo/configs/pipelines/wan.py
12
+ 5926e29a594db13b116922f131db50631bf8adbf90fe5cec00a5e2f446bfb4ca ./fastvideo/configs/sample/base.py
13
+ d99adcf607d982b38bbb5a70be60bf87f35d0e9f6f50752f3bceb68b34ce46c2 ./fastvideo/configs/sample/wan.py
14
+ 49775ce42fd9643c78d8fad4ab8248c1755c7f1524ad771cbd1863d76c513c38 ./fastvideo/configs/wan_1.3B_t2v_pipeline.json
15
+ ae2d8309472b09927da3e450dea52d9715dcabe5d6722fc2917130ae8d85adb4 ./fastvideo/entrypoints/cli/generate.py
16
+ d0466769626e7fd497376c544904d56ba62847745eb52527896d96b99d76ba03 ./fastvideo/entrypoints/video_generator.py
17
+ 73afe6b2ebe0f8cfe0a8ec762a7126161621ad97a64ebad628995f4a164b8b0e ./fastvideo/fastvideo_args.py
18
  ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
19
+ e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 ./fastvideo/pipelines/basic/wan/__init__.py
20
+ deac1e22530a6a41c501629f5e8fce47a7af4e008f321cc8a4d734c5120ef4fe ./fastvideo/pipelines/basic/wan/wan_pipeline.py
21
+ 8908223b3ff99cdb3206148a68a730c2a13d554a2fb1316db6f2f9672efac9e8 ./fastvideo/pipelines/composed_pipeline_base.py
22
  6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
23
  489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
24
  c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
25
  2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
26
  bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
27
  1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
28
+ 57fd33c78a16c9b4a239734c76726d94df1daf86f15cf22451c6107fcc197834 ./scripts/inference/run_sfp4_ours_p_checkpoint_750.sh
29
+ 0162b26dddb2a249e4e2cc56a7a28eba4fb3ea77e938eaf9c6b16a241edcd3ec ./scripts/inference/run_sfp4_single.sh
30
+ 159579109f9fb7d7dded977f5c7fc974583c95f3f895e418d6c466463c036304 ./scripts/inference/run_validate_and_gen.sh
31
  5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
32
  5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
33
+ 75455829ca55a80daaa7e3c7faa080b6eec3c7109bdbd1e198b722face62eed0 ./training_attention_settings.json
backend_snapshot/scripts/inference/run_sfp4_ours_p_checkpoint_750.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -euo pipefail
4
+
5
+ REPO_ROOT="${REPO_ROOT:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo}"
6
+ MODEL_PATH="${MODEL_PATH:-Wan-AI/Wan2.1-T2V-1.3B-Diffusers}"
7
+ HF_REPO="${HF_REPO:-yitongl/sparse_quant_exp}"
8
+ CHECKPOINT_DIR="${CHECKPOINT_DIR:-${REPO_ROOT}/checkpoints/hf_download/sparse_quant_exp}"
9
+ WEIGHTS_PATH="${WEIGHTS_PATH:-${CHECKPOINT_DIR}/transformer/diffusion_pytorch_model.safetensors}"
10
+ OUTPUT_DIR="${OUTPUT_DIR:-${REPO_ROOT}/outputs/sfp4_v4_sparse09_checkpoint_750}"
11
+ PROMPT="${PROMPT:-A cinematic shot of a futuristic city street at dusk, reflective pavement, soft volumetric light, detailed motion, stable camera.}"
12
+ NEGATIVE_PROMPT="${NEGATIVE_PROMPT:-Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards}"
13
+
14
+ cd "${REPO_ROOT}"
15
+
16
+ if [[ -f ".venv/bin/activate" ]]; then
17
+ source .venv/bin/activate
18
+ fi
19
+
20
+ export PYTHONPATH="${REPO_ROOT}/fastvideo-kernel/python:${REPO_ROOT}/fastvideo-kernel:${PYTHONPATH:-}"
21
+ export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
22
+ export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
23
+
24
+ if [[ ! -f "${WEIGHTS_PATH}" ]]; then
25
+ echo "Missing ${WEIGHTS_PATH}"
26
+ echo "Download the uploaded transformer weights first:"
27
+ echo " hf download ${HF_REPO} --repo-type model --local-dir ${CHECKPOINT_DIR} --include 'transformer/*'"
28
+ exit 1
29
+ fi
30
+
31
+ mkdir -p "${OUTPUT_DIR}"
32
+
33
+ fastvideo generate \
34
+ --model-path "${MODEL_PATH}" \
35
+ --init-weights-from-safetensors "${WEIGHTS_PATH}" \
36
+ --sp-size 1 \
37
+ --tp-size 1 \
38
+ --num-gpus 1 \
39
+ --dit-cpu-offload False \
40
+ --vae-cpu-offload False \
41
+ --text-encoder-cpu-offload True \
42
+ --pin-cpu-memory False \
43
+ --height 448 \
44
+ --width 832 \
45
+ --num-frames 77 \
46
+ --num-inference-steps 50 \
47
+ --fps 16 \
48
+ --guidance-scale 5.0 \
49
+ --flow-shift 1.0 \
50
+ --prompt "${PROMPT}" \
51
+ --negative-prompt "${NEGATIVE_PROMPT}" \
52
+ --seed 1000 \
53
+ --VSA-sparsity 0.9 \
54
+ --output-path "${OUTPUT_DIR}/"
backend_snapshot/scripts/inference/run_sfp4_single.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=sfp4-s0
3
+ #SBATCH --account=nvr_elm_llm
4
+ #SBATCH --partition=interactive
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gres=gpu:1
7
+ #SBATCH --cpus-per-task=16
8
+ #SBATCH --mem=64G
9
+ #SBATCH --time=00:30:00
10
+ #SBATCH --output=slurm_logs/sfp4_s0_%j.out
11
+ #SBATCH --error=slurm_logs/sfp4_s0_%j.err
12
+
13
+ set -ex
14
+ cd /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo
15
+ source .venv/bin/activate
16
+ export PYTHONPATH=fastvideo-kernel/python:fastvideo-kernel:$PYTHONPATH
17
+ export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_ATTN
18
+
19
+ mkdir -p outputs_sfp4_s0
20
+
21
+ # Same prompt, seed, params as dense FP4 run
22
+ fastvideo generate \
23
+ --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \
24
+ --sp-size 1 --tp-size 1 --num-gpus 1 \
25
+ --dit-cpu-offload False --vae-cpu-offload False \
26
+ --text-encoder-cpu-offload True --pin-cpu-memory False \
27
+ --height 480 --width 832 --num-frames 81 \
28
+ --num-inference-steps 50 --fps 16 \
29
+ --guidance-scale 6.0 --flow-shift 8.0 \
30
+ --prompt "Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting." \
31
+ --seed 1024 \
32
+ --VSA-sparsity 0.0 \
33
+ --output-path outputs_sfp4_s0/
34
+
35
+ echo "=== Done ==="
36
+ ls -lh outputs_sfp4_s0/*.mp4
37
+ echo "--- Dense FP4 reference ---"
38
+ ls -lh outputs_dense_fp4/*.mp4
backend_snapshot/scripts/inference/run_validate_and_gen.sh ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=sfp4-val-gen
3
+ #SBATCH --account=nvr_elm_llm
4
+ #SBATCH --partition=interactive
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gres=gpu:8
7
+ #SBATCH --cpus-per-task=128
8
+ #SBATCH --mem=1440G
9
+ #SBATCH --time=02:00:00
10
+ #SBATCH --output=slurm_logs/sfp4_val_gen_%j.out
11
+ #SBATCH --error=slurm_logs/sfp4_val_gen_%j.err
12
+
13
+ set -ex
14
+
15
+ REPO_ROOT="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo"
16
+ KERNEL_ROOT="${REPO_ROOT}/fastvideo-kernel"
17
+
18
+ mkdir -p "${REPO_ROOT}/slurm_logs"
19
+ cd "${REPO_ROOT}"
20
+ source .venv/bin/activate
21
+ export PYTHONPATH="${KERNEL_ROOT}/python:${KERNEL_ROOT}:${PYTHONPATH}"
22
+
23
+ echo "=== Environment ==="
24
+ nvidia-smi -L | head -1
25
+ python -c "import torch; print(f'torch={torch.__version__}, cuda={torch.cuda.is_available()}, gpus={torch.cuda.device_count()}')"
26
+ python -c "import triton; print(f'triton={triton.__version__}')"
27
+
28
+ echo ""
29
+ echo "######################################################################"
30
+ echo "# Generate 8 videos with sparse FP4 attention #"
31
+ echo "######################################################################"
32
+ cd "${REPO_ROOT}"
33
+
34
+ MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
35
+ PROMPT="Will Smith casually eats noodles, his relaxed demeanor contrasting with the energetic background of a bustling street food market. The scene captures a mix of humor and authenticity. Mid-shot framing, vibrant lighting."
36
+ NEGATIVE_PROMPT="Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
37
+ SEED=1024
38
+
39
+ SPARSITY_LIST=(0.0 0.1 0.2 0.4 0.5 0.7 0.8 0.9)
40
+ OUTPUT_BASE="${REPO_ROOT}/outputs_sparse_fp4_sweep"
41
+ mkdir -p "${OUTPUT_BASE}"
42
+
43
+ echo "Sparsity levels: ${SPARSITY_LIST[*]}"
44
+
45
+ PIDS=()
46
+ for i in $(seq 0 7); do
47
+ SPARSITY=${SPARSITY_LIST[$i]}
48
+ OUT_DIR="${OUTPUT_BASE}/sparsity_${SPARSITY}"
49
+ mkdir -p "${OUT_DIR}"
50
+ echo "[GPU ${i}] sparsity=${SPARSITY}"
51
+
52
+ (
53
+ export CUDA_VISIBLE_DEVICES=${i}
54
+ export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_ATTN
55
+
56
+ fastvideo generate \
57
+ --model-path "${MODEL_PATH}" \
58
+ --sp-size 1 --tp-size 1 --num-gpus 1 \
59
+ --dit-cpu-offload False \
60
+ --vae-cpu-offload False \
61
+ --text-encoder-cpu-offload True \
62
+ --pin-cpu-memory False \
63
+ --height 480 --width 832 --num-frames 81 \
64
+ --num-inference-steps 50 --fps 16 \
65
+ --guidance-scale 6.0 --flow-shift 8.0 \
66
+ --prompt "${PROMPT}" \
67
+ --negative-prompt "${NEGATIVE_PROMPT}" \
68
+ --seed ${SEED} \
69
+ --VSA-sparsity ${SPARSITY} \
70
+ --output-path "${OUT_DIR}/" \
71
+ 2>&1 | tee "${OUT_DIR}/log.txt"
72
+
73
+ echo "[GPU ${i}] sparsity=${SPARSITY} DONE"
74
+ ) &
75
+ PIDS+=($!)
76
+ done
77
+
78
+ echo "=== Waiting for all 8 jobs ==="
79
+ FAIL=0
80
+ for i in $(seq 0 7); do
81
+ wait ${PIDS[$i]} || { echo "[GPU ${i}] FAILED"; FAIL=1; }
82
+ done
83
+
84
+ echo ""
85
+ if [ $FAIL -eq 0 ]; then
86
+ echo "=== All 8 videos generated ==="
87
+ else
88
+ echo "=== Some failed ==="
89
+ fi
90
+ find "${OUTPUT_BASE}" -name "*.mp4" | sort
91
+ echo "=== Done ==="
backend_snapshot/training_attention_settings.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "run_name": "sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive",
3
+ "checkpoint": "checkpoint-750",
4
+ "training_method": "legacy_sft_wan_training_pipeline",
5
+ "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
6
+ "init_weights_from_safetensors": "checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors",
7
+ "environment": {
8
+ "FASTVIDEO_ATTENTION_BACKEND": "SPARSE_FP4_OURS_P_ATTN",
9
+ "FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O": "1",
10
+ "FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK": "1",
11
+ "WANDB_MODE": "online",
12
+ "WANDB_RESUME": "allow"
13
+ },
14
+ "vsa_schedule": {
15
+ "VSA_SPARSITY": 0.9,
16
+ "VSA_INIT_SPARSITY": 0.9,
17
+ "VSA_WARMUP_STEPS": 0,
18
+ "VSA_DECAY_RATE": 0.03,
19
+ "VSA_DECAY_INTERVAL_STEPS": 50,
20
+ "effective_sparsity_from_step_0": 0.9
21
+ },
22
+ "attention_semantics": {
23
+ "selected_backend": "SPARSE_FP4_OURS_P_ATTN",
24
+ "self_attention": {
25
+ "backend_path": "fastvideo/attention/backends/sparse_fp4_ours_p_attn.py",
26
+ "kernel_path": "fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py",
27
+ "tile_size_video": [4, 4, 4],
28
+ "tile_tokens": 64,
29
+ "qkv_quantization": "FP4 fake quantization with STE, no q/k mean subtraction in quantization",
30
+ "block_selection": "top-k blocks from q_c @ k_c tile-mean scores",
31
+ "p_quantization": "group-local exp2(qk - group_max) FP4 fake quantization; compensation multiplies exp2(group_max - running_row_m)",
32
+ "dropped_tile_handling": "tile-level q_mean/k_mean score and mean_v compensation"
33
+ },
34
+ "cross_attention": {
35
+ "backend": "dense_sdpa",
36
+ "reason": "sparse_fp4_ours_p_attn.py treats query_length != key_length as cross attention and returns _dense_sdpa_blhd",
37
+ "quantized": false,
38
+ "sparse": false
39
+ },
40
+ "force_dense": {
41
+ "backend": "dense_sdpa",
42
+ "used_for": "teacher or explicitly forced dense paths, not the normal SFT student self-attention path"
43
+ }
44
+ },
45
+ "validation_and_checkpointing": {
46
+ "save_steps": 50,
47
+ "eval_steps": 50,
48
+ "validation_sampling_steps": 50,
49
+ "validation_guidance_scale": 5.0,
50
+ "checkpoints_total_limit": 5,
51
+ "flow_shift": 1.0
52
+ },
53
+ "training_shape": {
54
+ "num_latent_t": 20,
55
+ "num_frames": 77,
56
+ "height": 448,
57
+ "width": 832,
58
+ "batch_size_per_gpu": 1,
59
+ "sp_size": 1,
60
+ "tp_size": 1
61
+ }
62
+ }