yitongl commited on
Commit
4db877c
·
verified ·
1 Parent(s): 9d27bcd

Upload backend snapshot for sfp4 checkpoint-750

Browse files
backend_snapshot/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend snapshot for checkpoint-750
2
+
3
+ This directory is the code snapshot for the training backend used by:
4
+
5
+ `sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive/checkpoint-750`
6
+
7
+ Key runtime settings:
8
+
9
+ - `FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN`
10
+ - `FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1`
11
+ - `VSA_SPARSITY=0.9`
12
+ - `VSA_INIT_SPARSITY=0.9`
13
+ - `VSA_WARMUP_STEPS=0`
14
+ - tile size: `4 x 4 x 4 = 64` video tokens
15
+
16
+ Important files:
17
+
18
+ - `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
19
+ - `fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py`: PyTorch custom op and autograd wrapper.
20
+ - `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py`: Triton forward/backward kernel.
21
+ - `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py`: FP4 quant/dequant utilities used by the kernel.
22
+ - `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py`: Q/K/V fake quant kernels.
23
+ - `fastvideo/attention/backends/video_sparse_attn.py`: VSA metadata and tile-size helper.
24
+ - `fastvideo/platforms/interface.py` and `fastvideo/platforms/cuda.py`: backend enum and CUDA backend selection wiring.
25
+ - `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
26
+ - `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
27
+ - `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
28
+
29
+ Source repo HEAD when staged:
30
+
31
+ `3f818d0fc532ec6494b465967d5f485150917d0c`
32
+
33
+ Note: several backend files were uncommitted or locally modified when this
34
+ snapshot was staged, so the files here are the authoritative copy for this
35
+ checkpoint rather than the clean git commit alone.
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import torch
6
+
7
+
8
+ def _use_high_prec_output_for_backward() -> bool:
9
+ value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
10
+ return value.lower() not in ("0", "false", "no", "off")
11
+
12
+
13
+ def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
14
+ if block_map.dim() == 3:
15
+ block_map = block_map.unsqueeze(0)
16
+ if block_map.dim() != 4:
17
+ raise ValueError(
18
+ f"block_map must be [B,H,Q,KV] or [H,Q,KV], got {tuple(block_map.shape)}"
19
+ )
20
+ if block_map.dtype != torch.bool:
21
+ block_map = block_map.to(torch.bool)
22
+ if not block_map.is_cuda:
23
+ raise RuntimeError("block_map must be a CUDA tensor.")
24
+
25
+ try:
26
+ from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index
27
+ except Exception as e:
28
+ raise ImportError("Triton map_to_index is required for ours-P Sparse FP4.") from e
29
+ return triton_map_to_index(block_map)
30
+
31
+
32
+ @torch.library.custom_op(
33
+ "fastvideo_kernel::block_sparse_attn_ours_p_triton",
34
+ mutates_args=(),
35
+ device_types="cuda",
36
+ )
37
+ def block_sparse_attn_ours_p_triton(
38
+ q: torch.Tensor,
39
+ k: torch.Tensor,
40
+ v: torch.Tensor,
41
+ block_map: torch.Tensor,
42
+ variable_block_sizes: torch.Tensor,
43
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
44
+ q = q.contiguous()
45
+ k = k.contiguous()
46
+ v = v.contiguous()
47
+ block_map = block_map.to(torch.bool)
48
+ q2k_idx, q2k_num = _map_to_index(block_map)
49
+
50
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
51
+ triton_block_sparse_attn_forward,
52
+ )
53
+
54
+ return triton_block_sparse_attn_forward(
55
+ q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True
56
+ )
57
+
58
+
59
+ @torch.library.register_fake("fastvideo_kernel::block_sparse_attn_ours_p_triton")
60
+ def _block_sparse_attn_ours_p_triton_fake(
61
+ q: torch.Tensor,
62
+ k: torch.Tensor,
63
+ v: torch.Tensor,
64
+ block_map: torch.Tensor,
65
+ variable_block_sizes: torch.Tensor,
66
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ o = torch.empty_like(q)
68
+ high_prec_o = torch.empty_like(q)
69
+ M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
70
+ return o, M, high_prec_o
71
+
72
+
73
+ @torch.library.custom_op(
74
+ "fastvideo_kernel::block_sparse_attn_ours_p_backward_triton",
75
+ mutates_args=(),
76
+ device_types="cuda",
77
+ )
78
+ def block_sparse_attn_ours_p_backward_triton(
79
+ grad_output: torch.Tensor,
80
+ q: torch.Tensor,
81
+ k: torch.Tensor,
82
+ v: torch.Tensor,
83
+ o: torch.Tensor,
84
+ M: torch.Tensor,
85
+ block_map: torch.Tensor,
86
+ variable_block_sizes: torch.Tensor,
87
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
88
+ grad_output = grad_output.contiguous()
89
+ block_map = block_map.to(torch.bool)
90
+ q2k_idx, q2k_num = _map_to_index(block_map)
91
+ k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
92
+
93
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
94
+ triton_block_sparse_attn_backward,
95
+ )
96
+
97
+ return triton_block_sparse_attn_backward(
98
+ grad_output,
99
+ q,
100
+ k,
101
+ v,
102
+ o,
103
+ M,
104
+ q2k_idx,
105
+ q2k_num,
106
+ k2q_idx,
107
+ k2q_num,
108
+ variable_block_sizes,
109
+ is_qat=True,
110
+ )
111
+
112
+
113
+ @torch.library.register_fake(
114
+ "fastvideo_kernel::block_sparse_attn_ours_p_backward_triton"
115
+ )
116
+ def _block_sparse_attn_ours_p_backward_triton_fake(
117
+ grad_output: torch.Tensor,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ v: torch.Tensor,
121
+ o: torch.Tensor,
122
+ M: torch.Tensor,
123
+ block_map: torch.Tensor,
124
+ variable_block_sizes: torch.Tensor,
125
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
126
+ return torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
127
+
128
+
129
+ def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o):
130
+ q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors
131
+ dq, dk, dv = block_sparse_attn_ours_p_backward_triton(
132
+ grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes
133
+ )
134
+ return dq, dk, dv, None, None
135
+
136
+
137
+ def _setup_context_triton(ctx, inputs, output):
138
+ q, k, v, block_map, variable_block_sizes = inputs
139
+ o, M, high_prec_o = output
140
+ o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
141
+ ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, variable_block_sizes)
142
+
143
+
144
+ block_sparse_attn_ours_p_triton.register_autograd(
145
+ _backward_triton, setup_context=_setup_context_triton
146
+ )
147
+
148
+
149
+ class _BlockSparseAttnOursPTileComp(torch.autograd.Function):
150
+
151
+ @staticmethod
152
+ def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes):
153
+ q = q.contiguous()
154
+ k = k.contiguous()
155
+ v = v.contiguous()
156
+ q_mean = q_mean.contiguous()
157
+ k_mean = k_mean.contiguous()
158
+ v_mean = v_mean.contiguous()
159
+ block_map = block_map.to(torch.bool)
160
+ dropped_block_map = torch.logical_not(block_map)
161
+
162
+ q2k_idx, q2k_num = _map_to_index(block_map)
163
+ dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
164
+
165
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
166
+ triton_block_sparse_attn_forward,
167
+ )
168
+
169
+ o, M, high_prec_o = triton_block_sparse_attn_forward(
170
+ q,
171
+ k,
172
+ v,
173
+ q2k_idx,
174
+ q2k_num,
175
+ variable_block_sizes,
176
+ is_qat=True,
177
+ q_mean=q_mean,
178
+ k_mean=k_mean,
179
+ v_mean=v_mean,
180
+ dropped_q2k_index=dropped_q2k_idx,
181
+ dropped_q2k_num=dropped_q2k_num,
182
+ )
183
+ o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
184
+ ctx.save_for_backward(
185
+ q,
186
+ k,
187
+ v,
188
+ q_mean,
189
+ k_mean,
190
+ v_mean,
191
+ o_for_bwd,
192
+ M,
193
+ block_map,
194
+ dropped_block_map,
195
+ variable_block_sizes,
196
+ )
197
+ return o, M
198
+
199
+ @staticmethod
200
+ def backward(ctx, grad_o, grad_M):
201
+ (
202
+ q,
203
+ k,
204
+ v,
205
+ q_mean,
206
+ k_mean,
207
+ v_mean,
208
+ o_for_bwd,
209
+ M,
210
+ block_map,
211
+ dropped_block_map,
212
+ variable_block_sizes,
213
+ ) = ctx.saved_tensors
214
+
215
+ q2k_idx, q2k_num = _map_to_index(block_map)
216
+ k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
217
+ dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
218
+ dropped_k2q_idx, dropped_k2q_num = _map_to_index(
219
+ dropped_block_map.transpose(-1, -2).contiguous()
220
+ )
221
+
222
+ from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
223
+ triton_block_sparse_attn_backward,
224
+ )
225
+
226
+ dq, dk, dv = triton_block_sparse_attn_backward(
227
+ grad_o.contiguous(),
228
+ q,
229
+ k,
230
+ v,
231
+ o_for_bwd,
232
+ M,
233
+ q2k_idx,
234
+ q2k_num,
235
+ k2q_idx,
236
+ k2q_num,
237
+ variable_block_sizes,
238
+ is_qat=True,
239
+ q_mean=q_mean,
240
+ k_mean=k_mean,
241
+ v_mean=v_mean,
242
+ dropped_q2k_index=dropped_q2k_idx,
243
+ dropped_q2k_num=dropped_q2k_num,
244
+ dropped_k2q_index=dropped_k2q_idx,
245
+ dropped_k2q_num=dropped_k2q_num,
246
+ )
247
+ return dq, dk, dv, None, None, None, None, None
248
+
249
+
250
+ def block_sparse_attn_ours_p(
251
+ q: torch.Tensor,
252
+ k: torch.Tensor,
253
+ v: torch.Tensor,
254
+ block_map: torch.Tensor,
255
+ variable_block_sizes: torch.Tensor,
256
+ q_mean: torch.Tensor | None = None,
257
+ k_mean: torch.Tensor | None = None,
258
+ v_mean: torch.Tensor | None = None,
259
+ ) -> tuple[torch.Tensor, torch.Tensor]:
260
+ if (q_mean is not None) or (k_mean is not None) or (v_mean is not None):
261
+ if q_mean is None or k_mean is None or v_mean is None:
262
+ raise ValueError("q_mean, k_mean, and v_mean must be provided together")
263
+ return _BlockSparseAttnOursPTileComp.apply(
264
+ q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes
265
+ )
266
+
267
+ o, M, _ = block_sparse_attn_ours_p_triton(
268
+ q, k, v, block_map, variable_block_sizes
269
+ )
270
+ return o, M
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fused Attention
3
+ ===============
4
+
5
+ This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
6
+ (https://tridao.me/publications/flash2/flash2.pdf)
7
+
8
+ Credits: OpenAI kernel team
9
+ """
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from .quant_utils import fake_quantize
15
+
16
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
17
+ import math # small utility needed by the sparse wrapper
18
+ # ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
19
+
20
+ # We don't run auto-tuning every time to keep the tutorial fast. Keeping
21
+ # the code below and commenting out the equivalent parameters is convenient for
22
+ # re-tuning.
23
+ configs = [
24
+ triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
25
+ for BM in [64]\
26
+ for BN in [64]\
27
+ for s in [3, 4, 7]\
28
+ for w in [4, 8]\
29
+ ]
30
+
31
+
32
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
33
+ @triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
34
+ @triton.jit
35
+ def _attn_fwd_sparse(
36
+ Q,
37
+ K,
38
+ V,
39
+ QMean,
40
+ KMean,
41
+ VMean,
42
+ sm_scale, #
43
+ q2k_index,
44
+ q2k_num,
45
+ max_kv_blks, #
46
+ dropped_q2k_index,
47
+ dropped_q2k_num,
48
+ max_dropped_kv_blks, #
49
+ variable_block_sizes,
50
+ M,
51
+ Out, #
52
+ HighPrecOut, #
53
+ stride_qz,
54
+ stride_qh,
55
+ stride_qm,
56
+ stride_qk,
57
+ stride_kz,
58
+ stride_kh,
59
+ stride_kn,
60
+ stride_kk,
61
+ stride_vz,
62
+ stride_vh,
63
+ stride_vk,
64
+ stride_vn,
65
+ stride_oz,
66
+ stride_oh,
67
+ stride_om,
68
+ stride_on,
69
+ Z,
70
+ H,
71
+ N_CTX_Q, #
72
+ N_CTX_KV, #
73
+ HEAD_DIM: tl.constexpr, #
74
+ BLOCK_M: tl.constexpr,
75
+ BLOCK_N: tl.constexpr,
76
+ STAGE: tl.constexpr,
77
+ IS_QAT: tl.constexpr = False,
78
+ USE_TILE_COMP: tl.constexpr = False):
79
+ """
80
+ 64x64 block-sparse forward kernel for the independent "ours P quant" path.
81
+
82
+ P quantization is group-local: each selected KV tile quantizes
83
+ exp2(logit - tile_row_max), then applies exp2(tile_row_max - online_max)
84
+ after the FP4 PV GEMM. This intentionally differs from the QAT-style
85
+ backend, which quantizes exp2(logit - online_max) directly.
86
+ """
87
+
88
+ # ----- program-id mapping -----
89
+ q_blk = tl.program_id(0) # Q-tile index
90
+ off_hz = tl.program_id(1) # fused (batch, head)
91
+ b = off_hz // H
92
+ h = off_hz % H
93
+ q_tiles = N_CTX_Q // BLOCK_M
94
+ meta_base = ((b * H + h) * q_tiles + q_blk)
95
+
96
+ kv_blocks = tl.load(q2k_num + meta_base) # int32
97
+ kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
98
+ dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
99
+ dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
100
+
101
+ # ----- base pointers -----
102
+ q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
103
+ k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
104
+ v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
105
+ o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
106
+
107
+ Q_ptr = tl.make_block_ptr(base=Q + q_off,
108
+ shape=(N_CTX_Q, HEAD_DIM),
109
+ strides=(stride_qm, stride_qk),
110
+ offsets=(q_blk * BLOCK_M, 0),
111
+ block_shape=(BLOCK_M, HEAD_DIM),
112
+ order=(1, 0))
113
+
114
+ K_base = tl.make_block_ptr(base=K + k_off,
115
+ shape=(HEAD_DIM, N_CTX_KV),
116
+ strides=(stride_kk, stride_kn),
117
+ offsets=(0, 0),
118
+ block_shape=(HEAD_DIM, BLOCK_N),
119
+ order=(0, 1))
120
+
121
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
122
+ 0)
123
+ V_base = tl.make_block_ptr(base=V + v_off,
124
+ shape=(N_CTX_KV, HEAD_DIM),
125
+ strides=(stride_vk, stride_vn),
126
+ offsets=(0, 0),
127
+ block_shape=(BLOCK_N, HEAD_DIM),
128
+ order=v_order)
129
+
130
+ O_ptr = tl.make_block_ptr(base=Out + o_off,
131
+ shape=(N_CTX_Q, HEAD_DIM),
132
+ strides=(stride_om, stride_on),
133
+ offsets=(q_blk * BLOCK_M, 0),
134
+ block_shape=(BLOCK_M, HEAD_DIM),
135
+ order=(1, 0))
136
+ HPO_ptr = tl.make_block_ptr(base=HighPrecOut + o_off,
137
+ shape=(N_CTX_Q, HEAD_DIM),
138
+ strides=(stride_om, stride_on),
139
+ offsets=(q_blk * BLOCK_M, 0),
140
+ block_shape=(BLOCK_M, HEAD_DIM),
141
+ order=(1, 0))
142
+
143
+ # ----- accumulators -----
144
+ offs_m = q_blk * BLOCK_M + tl.arange(0, BLOCK_M)
145
+ m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
146
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
147
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
148
+ high_prec_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
149
+ qk_scale = sm_scale * 1.44269504 # 1/ln2
150
+ q = tl.load(Q_ptr)
151
+ offs_d = tl.arange(0, HEAD_DIM)
152
+
153
+ # ----- sparse loop over valid K/V tiles -----
154
+ for i in range(0, kv_blocks):
155
+ kv_idx = tl.load(kv_ptr + i).to(tl.int32)
156
+ block_size = tl.load(variable_block_sizes + kv_idx)
157
+ K_ptr = tl.advance(K_base, (0, kv_idx * BLOCK_N))
158
+ V_ptr = tl.advance(V_base, (kv_idx * BLOCK_N, 0))
159
+
160
+ k = tl.load(K_ptr)
161
+ mask = tl.arange(0, BLOCK_N) < block_size
162
+ qk = tl.dot(q, k) * qk_scale
163
+ # mask out invalid columns
164
+ qk = tl.where(mask[None, :], qk, -float("inf"))
165
+ group_m = tl.max(qk, 1)
166
+ m_ij = tl.maximum(m_i, group_m)
167
+
168
+ p_local = tl.math.exp2(qk - group_m[:, None])
169
+ p_local = tl.where(mask[None, :], p_local, 0.0)
170
+ p_comp = tl.math.exp2(group_m - m_ij)
171
+ p_valid = mask[None, :] & (
172
+ tl.full(shape=p_local.shape, value=1.0,
173
+ dtype=p_local.dtype) == 1.0
174
+ )
175
+ p_quant, high_prec_p = fake_quantize(
176
+ src_tensor=p_local, valid_src_mask=p_valid,
177
+ BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=BLOCK_N,
178
+ dst_dtype=tl.bfloat16, use_global_sf=False,
179
+ )
180
+ l_ij = tl.sum(high_prec_p, 1) * p_comp
181
+
182
+ alpha = tl.math.exp2(m_i - m_ij)
183
+ l_i = l_i * alpha + l_ij
184
+ acc = acc * alpha[:, None]
185
+ high_prec_acc = high_prec_acc * alpha[:, None]
186
+
187
+ v = tl.load(V_ptr)
188
+ acc = acc + tl.dot(
189
+ p_quant.to(tl.bfloat16),
190
+ v.to(tl.bfloat16),
191
+ ) * p_comp[:, None]
192
+ high_prec_acc = high_prec_acc + tl.dot(
193
+ high_prec_p.to(tl.bfloat16),
194
+ v.to(tl.bfloat16),
195
+ ) * p_comp[:, None]
196
+ m_i = m_ij
197
+
198
+ if USE_TILE_COMP:
199
+ q_mean_base = (off_hz * q_tiles + q_blk).to(tl.int64) * HEAD_DIM
200
+ q_mean = tl.load(QMean + q_mean_base + offs_d).to(tl.float32)
201
+ kv_tiles = N_CTX_KV // BLOCK_N
202
+
203
+ for i in range(0, dropped_kv_blocks):
204
+ kv_idx = tl.load(dropped_kv_ptr + i).to(tl.int32)
205
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
206
+ kv_mean_base = (off_hz * kv_tiles + kv_idx).to(tl.int64) * HEAD_DIM
207
+ k_mean = tl.load(KMean + kv_mean_base + offs_d).to(tl.float32)
208
+ v_mean = tl.load(VMean + kv_mean_base + offs_d).to(tl.float32)
209
+
210
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
211
+ m_ij = tl.maximum(m_i, score)
212
+ alpha = tl.math.exp2(m_i - m_ij)
213
+ beta = tl.math.exp2(score - m_ij)
214
+
215
+ l_i = l_i * alpha + block_size * beta
216
+ comp = (block_size * beta)[:, None] * v_mean[None, :]
217
+ acc = acc * alpha[:, None] + comp
218
+ high_prec_acc = high_prec_acc * alpha[:, None] + comp
219
+ m_i = m_ij
220
+
221
+ # ----- epilogue -----
222
+ m_i += tl.math.log2(l_i)
223
+ acc = acc / l_i[:, None]
224
+ high_prec_acc = high_prec_acc / l_i[:, None]
225
+ tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
226
+ tl.store(O_ptr, acc.to(Out.type.element_ty))
227
+ tl.store(HPO_ptr, high_prec_acc.to(HighPrecOut.type.element_ty))
228
+
229
+
230
+ # ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
231
+
232
+
233
+ @triton.jit
234
+ def _attn_bwd_preprocess(
235
+ O,
236
+ DO, #
237
+ Delta, #
238
+ Z,
239
+ H,
240
+ N_CTX, #
241
+ BLOCK_M: tl.constexpr,
242
+ HEAD_DIM: tl.constexpr #
243
+ ):
244
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
245
+ off_hz = tl.program_id(1)
246
+ off_n = tl.arange(0, HEAD_DIM)
247
+ # load
248
+ o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
249
+ off_n[None, :])
250
+ do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
251
+ off_n[None, :]).to(tl.float32)
252
+ delta = tl.sum(o * do, axis=1)
253
+ # write-back
254
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
255
+
256
+
257
+ # The main inner-loop logic for computing dK and dV.
258
+ @triton.jit
259
+ def _attn_bwd_dkdv(
260
+ dk,
261
+ dv, #
262
+ Q,
263
+ k,
264
+ v,
265
+ QMean,
266
+ KMean,
267
+ VMean,
268
+ sm_scale, #
269
+ DO, #
270
+ M,
271
+ D, #
272
+ k2q_index,
273
+ k2q_num,
274
+ max_q_blks,
275
+ dropped_k2q_index,
276
+ dropped_k2q_num,
277
+ max_dropped_q_blks,
278
+ variable_block_sizes,
279
+ # shared by Q/K/V/DO.
280
+ stride_tok,
281
+ stride_d, #
282
+ H,
283
+ N_CTX_KV,
284
+ BLOCK_M1: tl.constexpr, #
285
+ BLOCK_N1: tl.constexpr, #
286
+ HEAD_DIM: tl.constexpr, #
287
+ # Filled in by the wrapper.
288
+ start_n,
289
+ start_m,
290
+ num_steps,
291
+ IS_QAT: tl.constexpr = False,
292
+ USE_TILE_COMP: tl.constexpr = False):
293
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
294
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
295
+ offs_k = tl.arange(0, HEAD_DIM)
296
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
297
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
298
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
299
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
300
+ step_m = BLOCK_M1
301
+ kv_blk = tl.program_id(0) # Q-tile index
302
+ off_hz = tl.program_id(2) # fused (batch, head)
303
+ b = off_hz // H
304
+ h = off_hz % H
305
+ kv_tiles = N_CTX_KV // BLOCK_N1
306
+ meta_base = ((b * H + h) * kv_tiles + kv_blk)
307
+
308
+ q_blocks = tl.load(k2q_num + meta_base) # int32
309
+ q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
310
+ dropped_q_blocks = tl.load(dropped_k2q_num + meta_base)
311
+ dropped_q_ptr = dropped_k2q_index + meta_base * max_dropped_q_blks
312
+ block_size = tl.load(variable_block_sizes + kv_blk)
313
+ block_size_f = block_size.to(tl.float32)
314
+
315
+ for blk_idx in range(q_blocks * 2):
316
+ block_sparse_offset = (tl.load(q_ptr + blk_idx // 2).to(tl.int32) * 2 +
317
+ blk_idx % 2) * step_m
318
+ qT = tl.load(qT_ptrs + block_sparse_offset * stride_tok)
319
+ # Load m before computing qk to reduce pipeline stall.
320
+ offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
321
+ m = tl.load(M + offs_m)
322
+ qkT = tl.dot(k.to(tl.bfloat16), qT)
323
+ qkT = qkT * sm_scale * 1.44269504
324
+ mask = tl.arange(0, BLOCK_N1) < block_size
325
+ qkT = tl.where(mask[:, None], qkT, -float("inf"))
326
+ group_m = tl.max(qkT, 0)
327
+ pT = tl.math.exp2(qkT - m[None, :])
328
+ pT = tl.where(mask[:, None], pT, 0.0)
329
+
330
+ do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
331
+ # Compute dV with group-local P quantization:
332
+ # quantize exp2(logit - tile_col_max), then multiply dO by
333
+ # exp2(tile_col_max - final_lse) to recover the final softmax scale.
334
+ p_local_T = tl.math.exp2(qkT - group_m[None, :])
335
+ p_local_T = tl.where(mask[:, None], p_local_T, 0.0)
336
+ p_comp = tl.math.exp2(group_m - m)
337
+ p_for_quant = tl.trans(p_local_T)
338
+ p_valid = mask[None, :] & (
339
+ tl.full(
340
+ shape=p_for_quant.shape,
341
+ value=1.0,
342
+ dtype=p_for_quant.dtype,
343
+ ) == 1.0
344
+ )
345
+ p_quant, _ = fake_quantize(
346
+ src_tensor=p_for_quant, valid_src_mask=p_valid,
347
+ BLOCK_SIZE_OUT_DIM=BLOCK_M1, BLOCK_SIZE_QUANT_DIM=BLOCK_N1,
348
+ dst_dtype=p_for_quant.dtype, use_global_sf=False,
349
+ )
350
+ dv += tl.dot(
351
+ tl.trans(p_quant.to(tl.bfloat16)),
352
+ (do * p_comp[:, None]).to(tl.bfloat16),
353
+ )
354
+ # D (= delta) is pre-divided by ds_scale.
355
+ Di = tl.load(D + offs_m)
356
+ # Compute dP and dS.
357
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
358
+ dsT = pT * (dpT - Di[None, :])
359
+ dsT = dsT.to(tl.bfloat16)
360
+ dk += tl.dot(dsT, tl.trans(qT))
361
+ # Increment pointers.
362
+
363
+ if USE_TILE_COMP:
364
+ k_mean = tl.load(KMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
365
+ v_mean = tl.load(VMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
366
+ qk_scale = sm_scale * 1.44269504
367
+
368
+ for blk_idx in range(dropped_q_blocks * 2):
369
+ q_blk_idx = tl.load(dropped_q_ptr + blk_idx // 2).to(tl.int32)
370
+ half = (blk_idx % 2).to(tl.int32)
371
+ block_sparse_offset = (q_blk_idx * 2 + half) * step_m
372
+ offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
373
+ q_mean = tl.load(QMean + q_blk_idx * HEAD_DIM +
374
+ offs_k).to(tl.float32)
375
+ m = tl.load(M + offs_m)
376
+ do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
377
+ Di = tl.load(D + offs_m)
378
+ q_block_size = tl.load(variable_block_sizes +
379
+ q_blk_idx).to(tl.float32)
380
+
381
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
382
+ p = tl.math.exp2(score - m)
383
+ dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
384
+ ds = block_size_f * p * (dp - Di)
385
+
386
+ dk_mean = tl.sum(ds[:, None] * q_mean[None, :],
387
+ axis=0) / block_size_f
388
+ dv_mean = tl.sum(p[:, None] * do.to(tl.float32), axis=0)
389
+ dk += dk_mean[None, :]
390
+ dv += dv_mean[None, :]
391
+ return dk, dv
392
+
393
+
394
+ # the main inner-loop logic for computing dQ
395
+ @triton.jit
396
+ def _attn_bwd_dq(
397
+ dq,
398
+ q,
399
+ K,
400
+ V, #
401
+ QMean,
402
+ KMean,
403
+ VMean,
404
+ do,
405
+ m,
406
+ m_vec,
407
+ D,
408
+ # shared by Q/K/V/DO.
409
+ q2k_index,
410
+ q2k_num,
411
+ max_kv_blks,
412
+ dropped_q2k_index,
413
+ dropped_q2k_num,
414
+ max_dropped_kv_blks,
415
+ variable_block_sizes,
416
+ stride_tok,
417
+ stride_d, #
418
+ H,
419
+ N_CTX, #
420
+ BLOCK_M2: tl.constexpr, #
421
+ BLOCK_N2: tl.constexpr, #
422
+ HEAD_DIM: tl.constexpr,
423
+ # Filled in by the wrapper.
424
+ start_m,
425
+ start_n,
426
+ num_steps,
427
+ sm_scale=1.0,
428
+ IS_QAT: tl.constexpr = False,
429
+ USE_TILE_COMP: tl.constexpr = False):
430
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
431
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
432
+ offs_k = tl.arange(0, HEAD_DIM)
433
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
434
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
435
+ # D (= delta) is pre-divided by ds_scale.
436
+ Di = tl.load(D + offs_m)
437
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
438
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
439
+ step_n = BLOCK_N2
440
+
441
+ q_blk = tl.program_id(0) # Q-tile index
442
+ off_hz = tl.program_id(2) # fused (batch, head)
443
+ b = off_hz // H
444
+ h = off_hz % H
445
+ q_tiles = N_CTX // BLOCK_M2
446
+ meta_base = ((b * H + h) * q_tiles + q_blk)
447
+
448
+ kv_blocks = tl.load(q2k_num + meta_base) # int32
449
+ kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
450
+ dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
451
+ dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
452
+
453
+ for blk_idx in range(kv_blocks * 2):
454
+ kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
455
+ # variable_block_sizes is defined per KV block (tile). Mask must therefore
456
+ # use kv_idx (not q_blk). Also, because we split each 64-token block into
457
+ # two 32-token halves, the mask must account for the half-block offset.
458
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
459
+ half = (blk_idx % 2).to(tl.int32)
460
+ block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
461
+ kT = tl.load(kT_ptrs + block_sparse_offset)
462
+ vT = tl.load(vT_ptrs + block_sparse_offset)
463
+ qk = tl.dot(q, kT)
464
+ qk = qk * sm_scale * 1.44269504
465
+ p = tl.math.exp2(qk - m)
466
+ offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
467
+ mask = offs_in_block < block_size
468
+ p = tl.where(mask[None, :], p, 0.0)
469
+ # Compute dP and dS.
470
+ dp = tl.dot(do, vT).to(tl.float32)
471
+ ds = p * (dp - Di[:, None])
472
+ ds = ds.to(tl.bfloat16)
473
+ # Compute dQ.
474
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
475
+ dq += tl.dot(ds, tl.trans(kT))
476
+ # Increment pointers.
477
+
478
+ if USE_TILE_COMP:
479
+ q_mean = tl.load(QMean + q_blk * HEAD_DIM + offs_k).to(tl.float32)
480
+ q_block_size = tl.load(variable_block_sizes + q_blk).to(tl.float32)
481
+ qk_scale = sm_scale * 1.44269504
482
+ dq_mean = tl.zeros([HEAD_DIM], dtype=tl.float32)
483
+
484
+ for blk_idx in range(dropped_kv_blocks):
485
+ kv_idx = tl.load(dropped_kv_ptr + blk_idx).to(tl.int32)
486
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
487
+ k_mean = tl.load(KMean + kv_idx * HEAD_DIM +
488
+ offs_k).to(tl.float32)
489
+ v_mean = tl.load(VMean + kv_idx * HEAD_DIM +
490
+ offs_k).to(tl.float32)
491
+
492
+ score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
493
+ p = tl.math.exp2(score - m_vec)
494
+ dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
495
+ ds = block_size * p * (dp - Di)
496
+ dq_mean = dq_mean + tl.sum(ds, axis=0) * k_mean
497
+
498
+ dq += dq_mean[None, :] / q_block_size
499
+ return dq
500
+
501
+
502
+ @triton.jit
503
+ def _attn_bwd(
504
+ Q,
505
+ K,
506
+ V,
507
+ sm_scale, #
508
+ DO, #
509
+ DQ,
510
+ DK,
511
+ DV, #
512
+ M,
513
+ D,
514
+ q2k_index,
515
+ q2k_num,
516
+ max_kv_blks,
517
+ k2q_index,
518
+ k2q_num,
519
+ max_q_blks,
520
+ variable_block_sizes,
521
+ # shared by Q/K/V/DO.
522
+ stride_z,
523
+ stride_h,
524
+ stride_tok,
525
+ stride_d, #
526
+ H,
527
+ N_CTX, #
528
+ BLOCK_M1: tl.constexpr, #
529
+ BLOCK_N1: tl.constexpr, #
530
+ BLOCK_M2: tl.constexpr, #
531
+ BLOCK_N2: tl.constexpr, #
532
+ HEAD_DIM: tl.constexpr,
533
+ IS_QAT: tl.constexpr = False):
534
+ LN2 = 0.6931471824645996 # = ln(2)
535
+
536
+ bhid = tl.program_id(2)
537
+ off_chz = (bhid * N_CTX).to(tl.int64)
538
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
539
+ pid = tl.program_id(0)
540
+
541
+ # offset pointers for batch/head
542
+ Q += adj
543
+ K += adj
544
+ V += adj
545
+ DO += adj
546
+ DQ += adj
547
+ DK += adj
548
+ DV += adj
549
+ M += off_chz
550
+ D += off_chz
551
+
552
+ # load scales
553
+ offs_k = tl.arange(0, HEAD_DIM)
554
+
555
+ start_n = pid * BLOCK_N1
556
+ start_m = 0
557
+
558
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
559
+
560
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
561
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
562
+
563
+ # load K and V: they stay in SRAM throughout the inner loop.
564
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
565
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
566
+
567
+ num_steps = N_CTX // BLOCK_M1
568
+
569
+ dk, dv = _attn_bwd_dkdv( #
570
+ dk,
571
+ dv, #
572
+ Q,
573
+ k,
574
+ v,
575
+ Q,
576
+ K,
577
+ V,
578
+ sm_scale, #
579
+ DO, #
580
+ M,
581
+ D, #
582
+ k2q_index,
583
+ k2q_num,
584
+ max_q_blks,
585
+ k2q_index,
586
+ k2q_num,
587
+ max_q_blks,
588
+ variable_block_sizes,
589
+ stride_tok,
590
+ stride_d, #
591
+ H,
592
+ N_CTX, #
593
+ BLOCK_M1,
594
+ BLOCK_N1,
595
+ HEAD_DIM, #
596
+ start_n,
597
+ start_m,
598
+ num_steps, #
599
+ IS_QAT=IS_QAT,
600
+ USE_TILE_COMP=False,
601
+ )
602
+
603
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
604
+ tl.store(dv_ptrs, dv)
605
+
606
+ # Write back dK.
607
+ dk *= sm_scale
608
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
609
+ tl.store(dk_ptrs, dk)
610
+
611
+ # THIS BLOCK DOES DQ:
612
+ start_m = pid * BLOCK_M2
613
+ end_n = 0
614
+
615
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
616
+
617
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
618
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
619
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
620
+
621
+ m_vec = tl.load(M + offs_m)
622
+ m = m_vec[:, None]
623
+
624
+ num_steps = N_CTX // BLOCK_N2
625
+ dq = _attn_bwd_dq(
626
+ dq,
627
+ q,
628
+ K,
629
+ V, #
630
+ Q,
631
+ K,
632
+ V,
633
+ do,
634
+ m,
635
+ m_vec,
636
+ D, #
637
+ q2k_index,
638
+ q2k_num,
639
+ max_kv_blks,
640
+ q2k_index,
641
+ q2k_num,
642
+ max_kv_blks,
643
+ variable_block_sizes,
644
+ stride_tok,
645
+ stride_d, #
646
+ H,
647
+ N_CTX, #
648
+ BLOCK_M2,
649
+ BLOCK_N2,
650
+ HEAD_DIM, #
651
+ start_m,
652
+ end_n,
653
+ num_steps, #
654
+ sm_scale=sm_scale,
655
+ IS_QAT=IS_QAT,
656
+ USE_TILE_COMP=False,
657
+ )
658
+ # Write back dQ.
659
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
660
+ dq *= sm_scale
661
+ tl.store(dq_ptrs, dq)
662
+
663
+
664
+ @triton.jit
665
+ def _attn_bwd_dkdv_kernel(
666
+ Q,
667
+ K,
668
+ V,
669
+ QMean,
670
+ KMean,
671
+ VMean,
672
+ sm_scale, #
673
+ DO, #
674
+ DK,
675
+ DV, #
676
+ M,
677
+ D,
678
+ k2q_index,
679
+ k2q_num,
680
+ max_q_blks,
681
+ dropped_k2q_index,
682
+ dropped_k2q_num,
683
+ max_dropped_q_blks,
684
+ variable_block_sizes,
685
+ # shared token/dim strides (assumed contiguous along token and dim)
686
+ stride_tok,
687
+ stride_d, #
688
+ # batch/head strides (may differ between Q and KV)
689
+ stride_qz,
690
+ stride_qh,
691
+ stride_kz,
692
+ stride_kh,
693
+ stride_vz,
694
+ stride_vh,
695
+ stride_doz,
696
+ stride_doh,
697
+ stride_dkz,
698
+ stride_dkh,
699
+ stride_dvz,
700
+ stride_dvh,
701
+ H,
702
+ N_CTX_Q,
703
+ N_CTX_KV,
704
+ BLOCK_M1: tl.constexpr, #
705
+ BLOCK_N1: tl.constexpr, #
706
+ HEAD_DIM: tl.constexpr,
707
+ IS_QAT: tl.constexpr = False,
708
+ USE_TILE_COMP: tl.constexpr = False):
709
+ """
710
+ Backward kernel that computes dK and dV for each KV block (64 tokens).
711
+ Grid:
712
+ pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
713
+ pid2: fused (batch, head) in [0, B*H)
714
+ """
715
+ bhid = tl.program_id(2)
716
+ b = bhid // H
717
+ h = bhid % H
718
+ kv_blk = tl.program_id(0)
719
+
720
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
721
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
722
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
723
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
724
+ dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
725
+ dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
726
+
727
+ Q = Q + q_adj
728
+ K = K + kv_adj_k
729
+ V = V + kv_adj_v
730
+ DO = DO + do_adj
731
+ DK = DK + dk_adj
732
+ DV = DV + dv_adj
733
+
734
+ q_tiles = N_CTX_Q // BLOCK_M1 // 2
735
+ kv_tiles = N_CTX_KV // BLOCK_N1
736
+ mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
737
+ mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
738
+ QMean = QMean + mean_q_adj
739
+ KMean = KMean + mean_kv_adj
740
+ VMean = VMean + mean_kv_adj
741
+
742
+ # M and D (delta) are always sized by Q length.
743
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
744
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
745
+
746
+ offs_k = tl.arange(0, HEAD_DIM)
747
+ start_n = kv_blk * BLOCK_N1
748
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
749
+
750
+ # load K and V: they stay in SRAM throughout the inner loop.
751
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
752
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
753
+
754
+ dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
755
+ dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
756
+
757
+ num_steps = N_CTX_Q // BLOCK_M1
758
+ dk_acc, dv_acc = _attn_bwd_dkdv(
759
+ dk_acc,
760
+ dv_acc,
761
+ Q,
762
+ k,
763
+ v,
764
+ QMean,
765
+ KMean,
766
+ VMean,
767
+ sm_scale,
768
+ DO,
769
+ M,
770
+ D,
771
+ k2q_index,
772
+ k2q_num,
773
+ max_q_blks,
774
+ dropped_k2q_index,
775
+ dropped_k2q_num,
776
+ max_dropped_q_blks,
777
+ variable_block_sizes,
778
+ stride_tok,
779
+ stride_d,
780
+ H,
781
+ N_CTX_KV,
782
+ BLOCK_M1=BLOCK_M1,
783
+ BLOCK_N1=BLOCK_N1,
784
+ HEAD_DIM=HEAD_DIM,
785
+ start_n=start_n,
786
+ start_m=0,
787
+ num_steps=num_steps,
788
+ IS_QAT=IS_QAT,
789
+ USE_TILE_COMP=USE_TILE_COMP,
790
+ )
791
+
792
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
793
+ tl.store(dv_ptrs, dv_acc)
794
+
795
+ dk_acc *= sm_scale
796
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
797
+ tl.store(dk_ptrs, dk_acc)
798
+
799
+
800
+ @triton.jit
801
+ def _attn_bwd_dq_kernel(
802
+ Q,
803
+ K,
804
+ V,
805
+ QMean,
806
+ KMean,
807
+ VMean,
808
+ DO, #
809
+ DQ,
810
+ M,
811
+ D,
812
+ q2k_index,
813
+ q2k_num,
814
+ max_kv_blks,
815
+ dropped_q2k_index,
816
+ dropped_q2k_num,
817
+ max_dropped_kv_blks,
818
+ variable_block_sizes,
819
+ # shared token/dim strides (assumed contiguous along token and dim)
820
+ stride_tok,
821
+ stride_d, #
822
+ # batch/head strides (may differ between Q and KV)
823
+ stride_qz,
824
+ stride_qh,
825
+ stride_kz,
826
+ stride_kh,
827
+ stride_vz,
828
+ stride_vh,
829
+ stride_doz,
830
+ stride_doh,
831
+ stride_dqz,
832
+ stride_dqh,
833
+ H,
834
+ N_CTX_Q,
835
+ sm_scale,
836
+ BLOCK_M2: tl.constexpr, #
837
+ BLOCK_N2: tl.constexpr, #
838
+ HEAD_DIM: tl.constexpr,
839
+ IS_QAT: tl.constexpr = False,
840
+ USE_TILE_COMP: tl.constexpr = False):
841
+ """
842
+ Backward kernel that computes dQ for each Q block (64 tokens).
843
+ Grid:
844
+ pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
845
+ pid2: fused (batch, head) in [0, B*H)
846
+ """
847
+ LN2 = 0.6931471824645996 # = ln(2)
848
+ bhid = tl.program_id(2)
849
+ b = bhid // H
850
+ h = bhid % H
851
+ q_blk = tl.program_id(0)
852
+
853
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
854
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
855
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
856
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
857
+ dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
858
+
859
+ Q = Q + q_adj
860
+ K = K + kv_adj_k
861
+ V = V + kv_adj_v
862
+ DO = DO + do_adj
863
+ DQ = DQ + dq_adj
864
+
865
+ q_tiles = N_CTX_Q // BLOCK_M2
866
+ kv_tiles = N_CTX_Q // 64
867
+ mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
868
+ mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
869
+ QMean = QMean + mean_q_adj
870
+ KMean = KMean + mean_kv_adj
871
+ VMean = VMean + mean_kv_adj
872
+
873
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
874
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
875
+
876
+ offs_k = tl.arange(0, HEAD_DIM)
877
+ start_m = q_blk * BLOCK_M2
878
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
879
+
880
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
881
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
882
+ m_vec = tl.load(M + offs_m)
883
+ m = m_vec[:, None]
884
+
885
+ dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
886
+ num_steps = 0 # unused in _attn_bwd_dq
887
+ dq_acc = _attn_bwd_dq(
888
+ dq_acc,
889
+ q,
890
+ K,
891
+ V,
892
+ QMean,
893
+ KMean,
894
+ VMean,
895
+ do,
896
+ m,
897
+ m_vec,
898
+ D,
899
+ q2k_index,
900
+ q2k_num,
901
+ max_kv_blks,
902
+ dropped_q2k_index,
903
+ dropped_q2k_num,
904
+ max_dropped_kv_blks,
905
+ variable_block_sizes,
906
+ stride_tok,
907
+ stride_d,
908
+ H,
909
+ N_CTX_Q,
910
+ BLOCK_M2=BLOCK_M2,
911
+ BLOCK_N2=BLOCK_N2,
912
+ HEAD_DIM=HEAD_DIM,
913
+ start_m=start_m,
914
+ start_n=0,
915
+ num_steps=num_steps,
916
+ sm_scale=sm_scale,
917
+ IS_QAT=IS_QAT,
918
+ USE_TILE_COMP=USE_TILE_COMP,
919
+ )
920
+
921
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
922
+ dq_acc *= sm_scale
923
+ tl.store(dq_ptrs, dq_acc)
924
+
925
+
926
+ # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
927
+ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
928
+ variable_block_sizes, is_qat=False,
929
+ q_mean=None, k_mean=None, v_mean=None,
930
+ dropped_q2k_index=None,
931
+ dropped_q2k_num=None):
932
+ B, H, Tq, D = q.shape
933
+ Tkv = k.shape[2]
934
+ sm_scale = 1.0 / math.sqrt(D)
935
+ max_kv_blks = q2k_index.shape[-1]
936
+ use_tile_comp = q_mean is not None
937
+ if use_tile_comp:
938
+ assert k_mean is not None and v_mean is not None
939
+ assert dropped_q2k_index is not None and dropped_q2k_num is not None
940
+ q_mean = q_mean.contiguous()
941
+ k_mean = k_mean.contiguous()
942
+ v_mean = v_mean.contiguous()
943
+ max_dropped_kv_blks = dropped_q2k_index.shape[-1]
944
+ else:
945
+ q_mean = q
946
+ k_mean = k
947
+ v_mean = v
948
+ dropped_q2k_index = q2k_index
949
+ dropped_q2k_num = q2k_num
950
+ max_dropped_kv_blks = max_kv_blks
951
+ assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
952
+ assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
953
+ assert q2k_num.shape[
954
+ -1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
955
+ assert variable_block_sizes.numel() == Tkv // 64, (
956
+ f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
957
+ f"got {variable_block_sizes.numel()}"
958
+ )
959
+ o = torch.empty_like(q)
960
+ high_prec_o = torch.empty_like(q)
961
+ M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
962
+
963
+ grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
964
+ _attn_fwd_sparse[grid](q,
965
+ k,
966
+ v,
967
+ q_mean,
968
+ k_mean,
969
+ v_mean,
970
+ sm_scale,
971
+ q2k_index,
972
+ q2k_num,
973
+ max_kv_blks,
974
+ dropped_q2k_index,
975
+ dropped_q2k_num,
976
+ max_dropped_kv_blks,
977
+ variable_block_sizes,
978
+ M,
979
+ o,
980
+ high_prec_o,
981
+ q.stride(0),
982
+ q.stride(1),
983
+ q.stride(2),
984
+ q.stride(3),
985
+ k.stride(0),
986
+ k.stride(1),
987
+ k.stride(2),
988
+ k.stride(3),
989
+ v.stride(0),
990
+ v.stride(1),
991
+ v.stride(2),
992
+ v.stride(3),
993
+ o.stride(0),
994
+ o.stride(1),
995
+ o.stride(2),
996
+ o.stride(3),
997
+ B,
998
+ H,
999
+ Tq,
1000
+ Tkv,
1001
+ HEAD_DIM=D,
1002
+ STAGE=3,
1003
+ IS_QAT=is_qat,
1004
+ USE_TILE_COMP=use_tile_comp)
1005
+
1006
+ return o, M, high_prec_o
1007
+
1008
+
1009
+ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
1010
+ k2q_index, k2q_num, variable_block_sizes,
1011
+ is_qat=False, q_mean=None, k_mean=None,
1012
+ v_mean=None, dropped_q2k_index=None,
1013
+ dropped_q2k_num=None,
1014
+ dropped_k2q_index=None,
1015
+ dropped_k2q_num=None):
1016
+ assert do.is_contiguous()
1017
+
1018
+ B, H, Tq, D = q.shape
1019
+ Tkv = k.shape[2]
1020
+ sm_scale = 1.0 / math.sqrt(D)
1021
+ dq = torch.empty_like(q)
1022
+ dk = torch.empty_like(k)
1023
+ dv = torch.empty_like(v)
1024
+ BATCH, N_HEAD = q.shape[:2]
1025
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
1026
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
1027
+ # Ours-P mode keeps K unscaled and applies sm_scale inside the bwd kernels.
1028
+ arg_k = k
1029
+ PRE_BLOCK = 64
1030
+ assert Tq % PRE_BLOCK == 0
1031
+ pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
1032
+ delta = torch.empty_like(M)
1033
+ _attn_bwd_preprocess[pre_grid](
1034
+ o,
1035
+ do, #
1036
+ delta, #
1037
+ BATCH,
1038
+ N_HEAD,
1039
+ Tq, #
1040
+ BLOCK_M=PRE_BLOCK,
1041
+ HEAD_DIM=D #
1042
+ )
1043
+
1044
+ max_q_blks = k2q_index.shape[-1]
1045
+ max_kv_blks = q2k_index.shape[-1]
1046
+ use_tile_comp = q_mean is not None
1047
+ if use_tile_comp:
1048
+ assert k_mean is not None and v_mean is not None
1049
+ assert dropped_q2k_index is not None and dropped_q2k_num is not None
1050
+ assert dropped_k2q_index is not None and dropped_k2q_num is not None
1051
+ q_mean = q_mean.contiguous()
1052
+ k_mean = k_mean.contiguous()
1053
+ v_mean = v_mean.contiguous()
1054
+ max_dropped_kv_blks = dropped_q2k_index.shape[-1]
1055
+ max_dropped_q_blks = dropped_k2q_index.shape[-1]
1056
+ else:
1057
+ q_mean = q
1058
+ k_mean = k
1059
+ v_mean = v
1060
+ dropped_q2k_index = q2k_index
1061
+ dropped_q2k_num = q2k_num
1062
+ dropped_k2q_index = k2q_index
1063
+ dropped_k2q_num = k2q_num
1064
+ max_dropped_kv_blks = max_kv_blks
1065
+ max_dropped_q_blks = max_q_blks
1066
+
1067
+ # dK/dV kernel: grid over KV blocks
1068
+ grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
1069
+ _attn_bwd_dkdv_kernel[grid_kv](
1070
+ q,
1071
+ arg_k,
1072
+ v,
1073
+ q_mean,
1074
+ k_mean,
1075
+ v_mean,
1076
+ sm_scale,
1077
+ do,
1078
+ dk,
1079
+ dv,
1080
+ M,
1081
+ delta,
1082
+ k2q_index,
1083
+ k2q_num,
1084
+ max_q_blks,
1085
+ dropped_k2q_index,
1086
+ dropped_k2q_num,
1087
+ max_dropped_q_blks,
1088
+ variable_block_sizes,
1089
+ q.stride(2),
1090
+ q.stride(3),
1091
+ q.stride(0),
1092
+ q.stride(1),
1093
+ arg_k.stride(0),
1094
+ arg_k.stride(1),
1095
+ v.stride(0),
1096
+ v.stride(1),
1097
+ do.stride(0),
1098
+ do.stride(1),
1099
+ dk.stride(0),
1100
+ dk.stride(1),
1101
+ dv.stride(0),
1102
+ dv.stride(1),
1103
+ N_HEAD,
1104
+ Tq,
1105
+ Tkv,
1106
+ BLOCK_M1=BLOCK_M1,
1107
+ BLOCK_N1=BLOCK_N1,
1108
+ HEAD_DIM=D,
1109
+ IS_QAT=is_qat,
1110
+ USE_TILE_COMP=use_tile_comp,
1111
+ )
1112
+
1113
+ # dQ kernel: grid over Q blocks
1114
+ grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
1115
+ _attn_bwd_dq_kernel[grid_q](
1116
+ q,
1117
+ arg_k,
1118
+ v,
1119
+ q_mean,
1120
+ k_mean,
1121
+ v_mean,
1122
+ do,
1123
+ dq,
1124
+ M,
1125
+ delta,
1126
+ q2k_index,
1127
+ q2k_num,
1128
+ max_kv_blks,
1129
+ dropped_q2k_index,
1130
+ dropped_q2k_num,
1131
+ max_dropped_kv_blks,
1132
+ variable_block_sizes,
1133
+ q.stride(2),
1134
+ q.stride(3),
1135
+ q.stride(0),
1136
+ q.stride(1),
1137
+ arg_k.stride(0),
1138
+ arg_k.stride(1),
1139
+ v.stride(0),
1140
+ v.stride(1),
1141
+ do.stride(0),
1142
+ do.stride(1),
1143
+ dq.stride(0),
1144
+ dq.stride(1),
1145
+ N_HEAD,
1146
+ Tq,
1147
+ sm_scale,
1148
+ BLOCK_M2=BLOCK_M2,
1149
+ BLOCK_N2=BLOCK_N2,
1150
+ HEAD_DIM=D,
1151
+ IS_QAT=is_qat,
1152
+ USE_TILE_COMP=use_tile_comp,
1153
+ )
1154
+
1155
+ return dq, dk, dv
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
3
+ # and https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
4
+
5
+ import triton
6
+ import triton.language as tl
7
+ try:
8
+ from triton.language.target_info import cuda_capability_geq
9
+ _HAS_CAPABILITY_CHECK = True
10
+ except ImportError:
11
+ cuda_capability_geq = None
12
+ _HAS_CAPABILITY_CHECK = False
13
+
14
+ MXFP_BLOCK_SIZE = tl.constexpr(16)
15
+
16
+ @triton.jit
17
+ def _compute_quant_and_scale(
18
+ src_tensor,
19
+ valid_src_mask,
20
+ mx_tensor_dtype: tl.constexpr = tl.uint8,
21
+ use_global_sf=True,
22
+ two_level_quant_P=False,
23
+ IS_BLACKWELL: tl.constexpr = False,
24
+ ):
25
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
26
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
27
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
28
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
29
+
30
+ is_fp8e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
31
+ is_fp8e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
32
+ tl.static_assert(
33
+ is_fp4 or (is_fp8e4 or is_fp8e5),
34
+ "mx_tensor_dtype must be uint8, float8e4nv, or float8e5",
35
+ )
36
+
37
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
38
+ f32_tensor = src_tensor.to(tl.float32)
39
+ abs_tensor = tl.abs(f32_tensor)
40
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
41
+
42
+ if two_level_quant_P:
43
+ # row max from SageAttn3 paper
44
+ global_max_val = tl.max(f32_tensor, axis=1, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, 1)
45
+ global_max_val = tl.maximum(global_max_val, 1e-8)
46
+ s_enc = ((6 * 448) / global_max_val).reshape([BLOCK_SIZE_OUT_DIM, 1, 1])
47
+ s_dec = (1 / s_enc)
48
+
49
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
50
+
51
+ if use_global_sf and not two_level_quant_P:
52
+ global_max_val = tl.max(abs_tensor)
53
+ # Avoid division by zero: if all values are padding (max is 0), use a default scale
54
+ global_max_val = tl.maximum(global_max_val, 1e-8)
55
+ s_enc = (6 * 448) / global_max_val
56
+ s_dec = (1 / s_enc)
57
+ elif not two_level_quant_P and not use_global_sf:
58
+ s_dec = 1.0
59
+ s_enc = 1.0
60
+
61
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) # per block maxima
62
+ s_dec_b = max_val / 6 # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
63
+ s_dec_b_e4m3 = (s_dec_b * s_enc).to(tl.float8e4nv) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
64
+ s_enc_b = 1 / (s_dec_b_e4m3.to(tl.float32) * s_dec) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
65
+
66
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
67
+ quant_tensor = f32_tensor * s_enc_b
68
+
69
+ # Reshape the tensors after scaling
70
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
71
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
72
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0.0)
73
+ dequant_scale = s_dec_b_e4m3.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
74
+
75
+ if is_fp4 and IS_BLACKWELL:
76
+ # Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands.
77
+ pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
78
+ lo_f, hi_f = tl.split(pairs)
79
+ lo_f32 = lo_f.to(tl.float32)
80
+ hi_f32 = hi_f.to(tl.float32)
81
+
82
+ # Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2.
83
+ out_tensor = tl.inline_asm_elementwise(
84
+ """
85
+ {
86
+ .reg .b8 r;
87
+ cvt.rn.satfinite.e2m1x2.f32 r, $1, $2;
88
+ mov.b32 $0, {r, r, r, r};
89
+ }
90
+ """,
91
+ constraints="=r,f,f",
92
+ args=[hi_f32, lo_f32],
93
+ dtype=tl.uint8,
94
+ is_pure=True,
95
+ pack=1,
96
+ )
97
+ elif is_fp4:
98
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
99
+ signs = quant_tensor & 0x80000000
100
+ exponents = (quant_tensor >> 23) & 0xFF
101
+ mantissas_orig = (quant_tensor & 0x7FFFFF)
102
+
103
+ # For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0
104
+ E8_BIAS = 127
105
+ E2_BIAS = 1
106
+ # Move implicit bit 1 at the beginning to mantissa for denormals
107
+ is_subnormal = exponents < E8_BIAS
108
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
109
+ mantissas_pre = (0x400000 | (mantissas_orig >> 1))
110
+ mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig)
111
+
112
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
113
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
114
+
115
+ # Combine sign, exponent, and mantissa, while saturating
116
+ # Round to nearest, ties to even (RTNE): use guard/sticky and LSB to decide increment
117
+ m2bits = mantissas >> 21
118
+ lsb_keep = (m2bits >> 1) & 0x1
119
+ guard = m2bits & 0x1
120
+ IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32
121
+ if IS_SRC_FP32:
122
+ bit0_dropped = (mantissas_orig & 0x1) != 0
123
+ mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1
124
+ dropped_post = (mantissas_pre & mask) != 0
125
+ sticky = is_subnormal & (bit0_dropped | dropped_post)
126
+ sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
127
+ else:
128
+ sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
129
+ round_inc = guard & (sticky | lsb_keep)
130
+ e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7)
131
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
132
+
133
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
134
+ evens, odds = tl.split(e2m1_value)
135
+ out_tensor = evens | (odds << 4)
136
+ else:
137
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
138
+
139
+ return out_tensor, dequant_scale, s_dec
140
+
141
+ @triton.jit
142
+ def _compute_dequant(
143
+ mx_tensor,
144
+ scale,
145
+ s_dec,
146
+ BLOCK_SIZE_OUT_DIM: tl.constexpr,
147
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr,
148
+ dst_dtype: tl.constexpr,
149
+ IS_BLACKWELL: tl.constexpr = False,
150
+ ):
151
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"Block size along quantization block must be a multiple of {MXFP_BLOCK_SIZE=}")
152
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
153
+ mx_tensor_dtype: tl.constexpr = mx_tensor.dtype
154
+ _is_f16: tl.constexpr = dst_dtype == tl.float16
155
+ _is_bf16: tl.constexpr = dst_dtype == tl.bfloat16
156
+ _is_f32: tl.constexpr = dst_dtype == tl.float32
157
+ tl.static_assert(_is_f16 or (_is_bf16 or _is_f32))
158
+ _is_u8: tl.constexpr = mx_tensor_dtype == tl.uint8
159
+ _is_e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
160
+ _is_e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
161
+ _is_dst: tl.constexpr = mx_tensor_dtype == dst_dtype
162
+ tl.static_assert(
163
+ _is_u8 or ((_is_e4 or _is_e5) or _is_dst),
164
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
165
+ tl.static_assert(scale.dtype == tl.float8e4nv, "scale must be float8e4nv")
166
+
167
+ # Determine if we are dealing with fp8 types.
168
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
169
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
170
+
171
+ # Upcast the scale to the destination type.
172
+ if dst_dtype == tl.bfloat16:
173
+ dst_scale = scale.to(tl.bfloat16)
174
+ else:
175
+ dst_scale = scale.to(tl.float32)
176
+ if dst_dtype == tl.float16:
177
+ dst_scale = dst_scale.to(tl.float16)
178
+
179
+ # Now upcast the tensor.
180
+ intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
181
+ if IS_BLACKWELL:
182
+ assert is_fp4
183
+ packed_u32 = tl.inline_asm_elementwise(
184
+ asm="""
185
+ {
186
+ .reg .b8 in_8;
187
+ .reg .f16x2 out;
188
+ cvt.u8.u32 in_8, $1;
189
+ cvt.rn.f16x2.e2m1x2 out, in_8;
190
+ mov.b32 $0, out;
191
+ }
192
+ """,
193
+ constraints="=r,r",
194
+ args=[mx_tensor], # tl.uint8 passed in as a 32-bit reg with value in low 8 bits
195
+ dtype=tl.uint32,
196
+ is_pure=True,
197
+ pack=1,
198
+ )
199
+ lo_u16 = (packed_u32 & 0xFFFF).to(tl.uint16)
200
+ hi_u16 = (packed_u32 >> 16).to(tl.uint16)
201
+ lo_f16 = lo_u16.to(tl.float16, bitcast=True)
202
+ hi_f16 = hi_u16.to(tl.float16, bitcast=True)
203
+
204
+ if intermediate_dtype == tl.float16:
205
+ x0, x1 = lo_f16, hi_f16
206
+ else:
207
+ x0 = lo_f16.to(intermediate_dtype)
208
+ x1 = hi_f16.to(intermediate_dtype)
209
+
210
+ dst_tensor = tl.interleave(x0, x1)
211
+
212
+ else:
213
+ assert is_fp4
214
+ dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 # exponent bias
215
+ dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
216
+ dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 # mantissa bits
217
+ # e2m1
218
+ em0 = mx_tensor & 0x07
219
+ em1 = mx_tensor & 0x70
220
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((mx_tensor & 0x08).to(tl.uint16) << 12)
221
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((mx_tensor & 0x80).to(tl.uint16) << 8)
222
+ # Three cases:
223
+ # 1) x is normal and non-zero: Correct bias
224
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
225
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
226
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
227
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
228
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
229
+ # 3) x is zero, do nothing
230
+ dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
231
+
232
+ dst_tensor = dst_tensor.to(dst_dtype)
233
+
234
+ # Reshape for proper broadcasting: the scale was stored with a 16‐sized “inner” grouping.
235
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
236
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
237
+ scale = scale.reshape(dst_scale.shape)
238
+
239
+ out_tensor = dst_tensor * dst_scale * s_dec # NVFP4 has the additional global scale factor
240
+ if dst_dtype == tl.float32:
241
+ max_fin = 3.4028234663852886e+38
242
+ elif dst_dtype == tl.bfloat16:
243
+ max_fin = 3.3895313892515355e+38
244
+ else:
245
+ tl.static_assert(dst_dtype == tl.float16)
246
+ max_fin = 65504
247
+ out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
248
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
249
+ out_tensor = out_tensor.to(dst_dtype)
250
+ return out_tensor
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ from .nvfp4_utils import _compute_quant_and_scale, _compute_dequant
5
+
6
+ @triton.jit
7
+ def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr,
8
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr,
9
+ dst_dtype: tl.constexpr,
10
+ mx_tensor_dtype: tl.constexpr = tl.uint8,
11
+ use_global_sf: tl.constexpr = True,
12
+ two_level_quant_P: tl.constexpr = False):
13
+ high_prec_src_tensor = src_tensor
14
+ src_tensor, src_scale, src_s_dec = _compute_quant_and_scale(src_tensor=src_tensor,
15
+ valid_src_mask=valid_src_mask,
16
+ mx_tensor_dtype=mx_tensor_dtype,
17
+ use_global_sf=use_global_sf,
18
+ two_level_quant_P=two_level_quant_P)
19
+ src_tensor = _compute_dequant(mx_tensor=src_tensor,
20
+ scale=src_scale,
21
+ s_dec=src_s_dec,
22
+ BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
23
+ BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
24
+ dst_dtype=dst_dtype)
25
+ return src_tensor, high_prec_src_tensor.to(src_tensor.dtype)
26
+
27
+ @triton.jit
28
+ def fake_quantize_q(Q, fake_Q, stride_z_q, stride_h_q,
29
+ stride_tok_q, stride_d_q,
30
+ fake_stride_z_q, fake_stride_h_q,
31
+ fake_stride_tok_q, fake_stride_d_q,
32
+ H, N_CTX_Q,
33
+ BLOCK_M: tl.constexpr,
34
+ HEAD_DIM: tl.constexpr,
35
+ use_global_sf: tl.constexpr = True):
36
+ bhid = tl.program_id(1)
37
+ adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H))
38
+ fake_adj_q = (fake_stride_h_q * (bhid % H) + fake_stride_z_q * (bhid // H))
39
+ Q += adj_q
40
+ fake_Q += fake_adj_q
41
+
42
+ pid = tl.program_id(0)
43
+ start_m = pid * BLOCK_M
44
+ offs_m = start_m + tl.arange(0, BLOCK_M)
45
+ offs_k = tl.arange(0, HEAD_DIM)
46
+
47
+ q_valid = offs_m < N_CTX_Q
48
+ q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None], other=0.0)
49
+ q, _ = fake_quantize(src_tensor=q, valid_src_mask=q_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=q.dtype, use_global_sf=use_global_sf)
50
+ tl.store(fake_Q + offs_m[:, None] * fake_stride_tok_q + offs_k[None, :] * fake_stride_d_q, q, mask=q_valid[:, None])
51
+
52
+ @triton.jit
53
+ def fake_quantize_kv(K, V, fake_K, fake_V, stride_z_kv, stride_h_kv,
54
+ stride_tok_kv, stride_d_kv,
55
+ fake_stride_z_kv, fake_stride_h_kv,
56
+ fake_stride_tok_kv, fake_stride_d_kv,
57
+ H, N_CTX_KV,
58
+ BLOCK_N: tl.constexpr,
59
+ HEAD_DIM: tl.constexpr,
60
+ use_global_sf: tl.constexpr = True):
61
+ bhid = tl.program_id(1)
62
+ adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H))
63
+ fake_adj_kv = (fake_stride_h_kv * (bhid % H) + fake_stride_z_kv * (bhid // H))
64
+ K += adj_kv
65
+ V += adj_kv
66
+ fake_K += fake_adj_kv
67
+ fake_V += fake_adj_kv
68
+
69
+ pid = tl.program_id(0)
70
+ start_n = pid * BLOCK_N
71
+ offs_n = start_n + tl.arange(0, BLOCK_N)
72
+ offs_k = tl.arange(0, HEAD_DIM)
73
+
74
+ kv_valid = offs_n < N_CTX_KV
75
+ k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
76
+ v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
77
+ k, _ = fake_quantize(src_tensor=k_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=k_block.dtype, use_global_sf=use_global_sf)
78
+ v, _ = fake_quantize(src_tensor=v_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=v_block.dtype, use_global_sf=use_global_sf)
79
+ tl.store(fake_K + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, k, mask=kv_valid[:, None])
80
+ tl.store(fake_V + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, v, mask=kv_valid[:, None])
backend_snapshot/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Sparse FP4 Attention backend with the independent ours-P quant kernel."""
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import triton
9
+
10
+ from fastvideo_kernel.triton_kernels.quant_utils import (
11
+ fake_quantize_q,
12
+ fake_quantize_kv,
13
+ )
14
+ from fastvideo_kernel.block_sparse_attn_ours_p import block_sparse_attn_ours_p
15
+ from fastvideo.forward_context import get_forward_context
16
+
17
+ from fastvideo.attention.backends.abstract import (
18
+ AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
19
+ )
20
+ from fastvideo.attention.backends.video_sparse_attn import (
21
+ VideoSparseAttentionMetadata,
22
+ VideoSparseAttentionMetadataBuilder,
23
+ VSA_TILE_SIZE,
24
+ )
25
+ from fastvideo.distributed import get_sp_group
26
+ from fastvideo.logger import init_logger
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ def _dense_sdpa_blhd(query, key, value):
32
+ q = query.transpose(1, 2)
33
+ k = key.transpose(1, 2)
34
+ v = value.transpose(1, 2)
35
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
36
+ return out.transpose(1, 2)
37
+
38
+
39
+ def _quantize_qkv_bhld(q, k, v):
40
+ """FP4 fake quantize Q/K/V in BHLD layout, same as attn_qat_train."""
41
+ H = q.shape[1]
42
+ N_Q = q.shape[2]
43
+ N_KV = k.shape[2]
44
+ D = q.shape[3]
45
+ BLOCK = 32
46
+
47
+ fake_q = torch.empty_like(q)
48
+ fake_k = torch.empty_like(k)
49
+ fake_v = torch.empty_like(v)
50
+
51
+ grid_q = (triton.cdiv(N_Q, BLOCK), q.shape[0] * H, 1)
52
+ grid_kv = (triton.cdiv(N_KV, BLOCK), q.shape[0] * H, 1)
53
+
54
+ fake_quantize_q[grid_q](
55
+ q, fake_q,
56
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
57
+ fake_q.stride(0), fake_q.stride(1), fake_q.stride(2), fake_q.stride(3),
58
+ H, N_Q, BLOCK_M=BLOCK, HEAD_DIM=D, use_global_sf=False,
59
+ )
60
+ fake_quantize_kv[grid_kv](
61
+ k, v, fake_k, fake_v,
62
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
63
+ fake_k.stride(0), fake_k.stride(1), fake_k.stride(2), fake_k.stride(3),
64
+ H, N_KV, BLOCK_N=BLOCK, HEAD_DIM=D, use_global_sf=False,
65
+ )
66
+ return fake_q, fake_k, fake_v
67
+
68
+
69
+ class SparseFP4OursPAttentionBackend(AttentionBackend):
70
+ accept_output_buffer: bool = True
71
+
72
+ @staticmethod
73
+ def get_supported_head_sizes() -> list[int]:
74
+ return [64, 96, 128, 160, 192, 224, 256]
75
+
76
+ @staticmethod
77
+ def get_name() -> str:
78
+ return "SPARSE_FP4_OURS_P_ATTN"
79
+
80
+ @staticmethod
81
+ def get_impl_cls() -> type["SparseFP4OursPAttentionImpl"]:
82
+ return SparseFP4OursPAttentionImpl
83
+
84
+ @staticmethod
85
+ def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
86
+ return VideoSparseAttentionMetadata
87
+
88
+ @staticmethod
89
+ def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
90
+ return VideoSparseAttentionMetadataBuilder
91
+
92
+
93
+ class SparseFP4OursPAttentionImpl(AttentionImpl):
94
+
95
+ def __init__(self, num_heads, head_size, causal, softmax_scale,
96
+ num_kv_heads=None, prefix="", **extra):
97
+ self.prefix = prefix
98
+ self.sp_size = get_sp_group().world_size
99
+
100
+ def tile(self, x, num_tiles, tile_partition_indices, non_pad_index):
101
+ t_p = num_tiles[0] * VSA_TILE_SIZE[0]
102
+ h_p = num_tiles[1] * VSA_TILE_SIZE[1]
103
+ w_p = num_tiles[2] * VSA_TILE_SIZE[2]
104
+ out = torch.zeros(
105
+ (x.shape[0], t_p * h_p * w_p, x.shape[-2], x.shape[-1]),
106
+ device=x.device, dtype=x.dtype,
107
+ )
108
+ out[:, non_pad_index] = x[:, tile_partition_indices]
109
+ return out
110
+
111
+ def untile(self, x, reverse_tile_partition_indices, non_pad_index):
112
+ return x[:, non_pad_index][:, reverse_tile_partition_indices]
113
+
114
+ def _is_force_dense(self) -> bool:
115
+ ctx = get_forward_context()
116
+ return ctx.force_dense
117
+
118
+ def preprocess_qkv(self, qkv, attn_metadata):
119
+ if attn_metadata is None or self._is_force_dense():
120
+ return qkv
121
+ return self.tile(qkv, attn_metadata.num_tiles,
122
+ attn_metadata.tile_partition_indices,
123
+ attn_metadata.non_pad_index)
124
+
125
+ def postprocess_output(self, output, attn_metadata):
126
+ if attn_metadata is None or self._is_force_dense():
127
+ return output
128
+ return self.untile(output,
129
+ attn_metadata.reverse_tile_partition_indices,
130
+ attn_metadata.non_pad_index)
131
+
132
+ def forward(self, query, key, value,
133
+ gate_compress_or_metadata=None, attn_metadata=None):
134
+ # Handle both call conventions
135
+ if attn_metadata is None and isinstance(
136
+ gate_compress_or_metadata, (VideoSparseAttentionMetadata, type(None))):
137
+ attn_metadata = gate_compress_or_metadata
138
+
139
+ # ── force_dense: true dense BF16 SDPA (for teacher in distillation) ──
140
+ ctx = get_forward_context()
141
+ if ctx.force_dense:
142
+ return _dense_sdpa_blhd(query, key, value)
143
+
144
+ is_cross = query.shape[1] != key.shape[1]
145
+
146
+ # ── Cross-attention/no metadata: keep dense. The sparse VSA metadata only
147
+ # applies to tiled video self-attention.
148
+ if attn_metadata is None or is_cross:
149
+ return _dense_sdpa_blhd(query, key, value)
150
+
151
+ # ── Self-attention: FP4 quant Q/K/V + block-sparse attention ──
152
+ # BLHD → BHLD
153
+ q = query.transpose(1, 2).contiguous()
154
+ k = key.transpose(1, 2).contiguous()
155
+ v = value.transpose(1, 2).contiguous()
156
+
157
+ # Step 1: FP4 fake quantize Q/K/V with STE (straight-through estimator)
158
+ with torch.no_grad():
159
+ fq, fk, fv = _quantize_qkv_bhld(q, k, v)
160
+ # STE: forward uses quantized values, backward passes gradient through as-is
161
+ fq = q + (fq - q).detach()
162
+ fk = k + (fk - k).detach()
163
+ fv = v + (fv - v).detach()
164
+
165
+ # Step 2: Build sparse block map
166
+ B, H, S, D = fq.shape
167
+ block_elements = math.prod(VSA_TILE_SIZE)
168
+ num_blocks = S // block_elements
169
+
170
+ VSA_sparsity = attn_metadata.VSA_sparsity
171
+ cur_topk = max(1, math.ceil((1 - VSA_sparsity) * num_blocks))
172
+ logger.info(f"[SFP4] S={S} num_blocks={num_blocks} sparsity={VSA_sparsity} topk={cur_topk}/{num_blocks}")
173
+
174
+ block_sizes = attn_metadata.variable_block_sizes.to(
175
+ device=fq.device, dtype=torch.float32).clamp_min(1)
176
+ block_sizes = block_sizes.view(1, 1, num_blocks, 1)
177
+ q_c = (fq.view(B, H, num_blocks, block_elements, D).float().sum(3) /
178
+ block_sizes).to(fq.dtype)
179
+ k_c = (fk.view(B, H, num_blocks, block_elements, D).float().sum(3) /
180
+ block_sizes).to(fk.dtype)
181
+ v_c = (fv.view(B, H, num_blocks, block_elements, D).float().sum(3) /
182
+ block_sizes).to(fv.dtype)
183
+ scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5)
184
+ topk_idx = torch.topk(scores, cur_topk, dim=-1).indices
185
+ block_map = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
186
+
187
+ # Step 3: Block-sparse attention with independent group-local P quant.
188
+ out, _ = block_sparse_attn_ours_p(fq, fk, fv, block_map,
189
+ attn_metadata.variable_block_sizes,
190
+ q_c, k_c, v_c)
191
+
192
+ return out.transpose(1, 2) # BHLD → BLHD
backend_snapshot/fastvideo/attention/backends/video_sparse_attn.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import functools
3
+ import math
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+
8
+ try:
9
+ from fastvideo_kernel import video_sparse_attn
10
+ except ImportError:
11
+ video_sparse_attn = None
12
+
13
+ from typing import Any
14
+
15
+ from fastvideo.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata,
16
+ AttentionMetadataBuilder)
17
+ from fastvideo.distributed import get_sp_group
18
+ from fastvideo.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+ VSA_TILE_SIZE = (4, 4, 4)
22
+
23
+
24
+ @functools.lru_cache(maxsize=10)
25
+ def get_tile_partition_indices(
26
+ dit_seq_shape: tuple[int, int, int],
27
+ tile_size: tuple[int, int, int],
28
+ device: torch.device,
29
+ ) -> torch.LongTensor:
30
+ T, H, W = dit_seq_shape
31
+ ts, hs, ws = tile_size
32
+ indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
33
+ ls = []
34
+ for t in range(math.ceil(T / ts)):
35
+ for h in range(math.ceil(H / hs)):
36
+ for w in range(math.ceil(W / ws)):
37
+ ls.append(indices[t * ts:min(t * ts + ts, T), h * hs:min(h * hs + hs, H),
38
+ w * ws:min(w * ws + ws, W)].flatten())
39
+ index = torch.cat(ls, dim=0)
40
+ return index
41
+
42
+
43
+ @functools.lru_cache(maxsize=10)
44
+ def get_reverse_tile_partition_indices(
45
+ dit_seq_shape: tuple[int, int, int],
46
+ tile_size: tuple[int, int, int],
47
+ device: torch.device,
48
+ ) -> torch.LongTensor:
49
+ return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
50
+
51
+
52
+ @functools.lru_cache(maxsize=10)
53
+ def construct_variable_block_sizes(
54
+ dit_seq_shape: tuple[int, int, int],
55
+ num_tiles: tuple[int, int, int],
56
+ device: torch.device,
57
+ ) -> torch.LongTensor:
58
+ """
59
+ Compute the number of valid (non‑padded) tokens inside every
60
+ (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
61
+ (t‑tile, h‑tile, w‑tile) that `rearrange` uses.
62
+
63
+ Returns
64
+ -------
65
+ torch.LongTensor # shape: [∏ full_window_size]
66
+ """
67
+ # unpack
68
+ t, h, w = dit_seq_shape
69
+ ts_t, ts_h, ts_w = VSA_TILE_SIZE
70
+ n_t, n_h, n_w = num_tiles
71
+
72
+ def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
73
+ """Vector with the size of each tile along one dimension."""
74
+ sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
75
+ # size of last (possibly partial) tile
76
+ remainder = dim_len - (n_tiles - 1) * tile
77
+ sizes[-1] = remainder if remainder > 0 else tile
78
+ return sizes
79
+
80
+ t_sizes = _sizes(t, ts_t, n_t) # [n_t]
81
+ h_sizes = _sizes(h, ts_h, n_h) # [n_h]
82
+ w_sizes = _sizes(w, ts_w, n_w) # [n_w]
83
+
84
+ # broadcast‑multiply to get voxels per tile, then flatten
85
+ block_sizes = (
86
+ t_sizes[:, None, None] # [n_t, 1, 1]
87
+ * h_sizes[None, :, None] # [1, n_h, 1]
88
+ * w_sizes[None, None, :] # [1, 1, n_w]
89
+ ).reshape(-1) # [n_t * n_h * n_w]
90
+
91
+ return block_sizes
92
+
93
+
94
+ @functools.lru_cache(maxsize=10)
95
+ def get_non_pad_index(
96
+ variable_block_sizes: torch.LongTensor,
97
+ max_block_size: int,
98
+ ):
99
+ n_win = variable_block_sizes.shape[0]
100
+ device = variable_block_sizes.device
101
+ starts_pad = torch.arange(n_win, device=device) * max_block_size
102
+ index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
103
+ index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
104
+ return index_pad[index_mask]
105
+
106
+
107
+ class VideoSparseAttentionBackend(AttentionBackend):
108
+
109
+ accept_output_buffer: bool = True
110
+
111
+ @staticmethod
112
+ def get_supported_head_sizes() -> list[int]:
113
+ return [64, 128]
114
+
115
+ @staticmethod
116
+ def get_name() -> str:
117
+ return "VIDEO_SPARSE_ATTN"
118
+
119
+ @staticmethod
120
+ def get_impl_cls() -> type["VideoSparseAttentionImpl"]:
121
+ return VideoSparseAttentionImpl
122
+
123
+ @staticmethod
124
+ def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
125
+ return VideoSparseAttentionMetadata
126
+
127
+ @staticmethod
128
+ def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
129
+ return VideoSparseAttentionMetadataBuilder
130
+
131
+
132
+ @dataclass
133
+ class VideoSparseAttentionMetadata(AttentionMetadata):
134
+ current_timestep: int
135
+ dit_seq_shape: list[int]
136
+ num_tiles: list[int]
137
+ total_seq_length: int
138
+ tile_partition_indices: torch.LongTensor
139
+ reverse_tile_partition_indices: torch.LongTensor
140
+ variable_block_sizes: torch.LongTensor
141
+ non_pad_index: torch.LongTensor
142
+
143
+
144
+ class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
145
+
146
+ def __init__(self) -> None:
147
+ pass
148
+
149
+ def prepare(self) -> None:
150
+ pass
151
+
152
+ def build( # type: ignore
153
+ self,
154
+ current_timestep: int,
155
+ raw_latent_shape: tuple[int, int, int],
156
+ patch_size: tuple[int, int, int],
157
+ VSA_sparsity: float,
158
+ device: torch.device,
159
+ **kwargs: dict[str, Any],
160
+ ) -> VideoSparseAttentionMetadata:
161
+ patch_size = patch_size
162
+ dit_seq_shape = (raw_latent_shape[0] // patch_size[0], raw_latent_shape[1] // patch_size[1],
163
+ raw_latent_shape[2] // patch_size[2])
164
+
165
+ num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
166
+ math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
167
+ total_seq_length = math.prod(dit_seq_shape)
168
+
169
+ tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
170
+ reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
171
+ variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
172
+ non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
173
+
174
+ return VideoSparseAttentionMetadata(
175
+ current_timestep=current_timestep,
176
+ dit_seq_shape=dit_seq_shape, # type: ignore
177
+ VSA_sparsity=VSA_sparsity, # type: ignore
178
+ num_tiles=num_tiles, # type: ignore
179
+ total_seq_length=total_seq_length, # type: ignore
180
+ tile_partition_indices=tile_partition_indices, # type: ignore
181
+ reverse_tile_partition_indices=reverse_tile_partition_indices,
182
+ variable_block_sizes=variable_block_sizes,
183
+ non_pad_index=non_pad_index)
184
+
185
+
186
+ class VideoSparseAttentionImpl(AttentionImpl):
187
+
188
+ def __init__(
189
+ self,
190
+ num_heads: int,
191
+ head_size: int,
192
+ causal: bool,
193
+ softmax_scale: float,
194
+ num_kv_heads: int | None = None,
195
+ prefix: str = "",
196
+ **extra_impl_args,
197
+ ) -> None:
198
+ self.prefix = prefix
199
+ sp_group = get_sp_group()
200
+ self.sp_size = sp_group.world_size
201
+
202
+ def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor,
203
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
204
+ t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
205
+ h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
206
+ w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
207
+
208
+ x_padded = torch.zeros((x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
209
+ device=x.device,
210
+ dtype=x.dtype)
211
+ x_padded[:, non_pad_index] = x[:, tile_partition_indices]
212
+ return x_padded
213
+
214
+ def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor,
215
+ non_pad_index: torch.LongTensor) -> torch.Tensor:
216
+ x = x[:, non_pad_index][:, reverse_tile_partition_indices]
217
+ return x
218
+
219
+ def preprocess_qkv(
220
+ self,
221
+ qkv: torch.Tensor,
222
+ attn_metadata: VideoSparseAttentionMetadata,
223
+ ) -> torch.Tensor:
224
+ return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices,
225
+ attn_metadata.non_pad_index)
226
+
227
+ def postprocess_output(
228
+ self,
229
+ output: torch.Tensor,
230
+ attn_metadata: VideoSparseAttentionMetadata,
231
+ ) -> torch.Tensor:
232
+ return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index)
233
+
234
+ def forward( # type: ignore[override]
235
+ self,
236
+ query: torch.Tensor,
237
+ key: torch.Tensor,
238
+ value: torch.Tensor,
239
+ gate_compress: torch.Tensor,
240
+ attn_metadata: VideoSparseAttentionMetadata,
241
+ ) -> torch.Tensor:
242
+ query = query.transpose(1, 2).contiguous()
243
+ key = key.transpose(1, 2).contiguous()
244
+ value = value.transpose(1, 2).contiguous()
245
+ gate_compress = gate_compress.transpose(1, 2).contiguous()
246
+
247
+ VSA_sparsity = attn_metadata.VSA_sparsity
248
+
249
+ cur_topk = math.ceil((1 - VSA_sparsity) * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
250
+
251
+ if video_sparse_attn is None:
252
+ raise NotImplementedError("video_sparse_attn is not installed")
253
+ hidden_states = video_sparse_attn(query,
254
+ key,
255
+ value,
256
+ attn_metadata.variable_block_sizes,
257
+ attn_metadata.variable_block_sizes,
258
+ cur_topk,
259
+ block_size=VSA_TILE_SIZE,
260
+ compress_attn_weight=gate_compress).transpose(1, 2)
261
+
262
+ return hidden_states
backend_snapshot/fastvideo/configs/models/dits/base.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import dataclass, field
3
+ from typing import Any
4
+
5
+ from fastvideo.configs.models.base import ArchConfig, ModelConfig
6
+ from fastvideo.layers.quantization import QuantizationConfig
7
+ from fastvideo.platforms import AttentionBackendEnum
8
+
9
+
10
+ @dataclass
11
+ class DiTArchConfig(ArchConfig):
12
+ _fsdp_shard_conditions: list = field(default_factory=list)
13
+ _compile_conditions: list = field(default_factory=list)
14
+ param_names_mapping: dict = field(default_factory=dict)
15
+ reverse_param_names_mapping: dict = field(default_factory=dict)
16
+ lora_param_names_mapping: dict = field(default_factory=dict)
17
+ _supported_attention_backends: tuple[AttentionBackendEnum,
18
+ ...] = (AttentionBackendEnum.SAGE_ATTN, AttentionBackendEnum.FLASH_ATTN,
19
+ AttentionBackendEnum.TORCH_SDPA,
20
+ AttentionBackendEnum.VIDEO_SPARSE_ATTN,
21
+ AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.SAGE_ATTN_THREE,
22
+ AttentionBackendEnum.ATTN_QAT_INFER,
23
+ AttentionBackendEnum.ATTN_QAT_TRAIN, AttentionBackendEnum.SLA_ATTN,
24
+ AttentionBackendEnum.SAGE_SLA_ATTN,
25
+ AttentionBackendEnum.SPARSE_FP4_ATTN,
26
+ AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN)
27
+
28
+ hidden_size: int = 0
29
+ num_attention_heads: int = 0
30
+ num_channels_latents: int = 0
31
+ in_channels: int | None = 0
32
+ out_channels: int | None = 0
33
+ patch_size: int | tuple[int, int, int] | None = None
34
+ expand_timesteps: bool = False
35
+ num_layers: int = 0
36
+ ffn_dim: int = 0
37
+ exclude_lora_layers: list[str] = field(default_factory=list)
38
+ boundary_ratio: float | None = None
39
+
40
+ def __post_init__(self) -> None:
41
+ if not self._compile_conditions:
42
+ self._compile_conditions = self._fsdp_shard_conditions.copy()
43
+
44
+
45
+ @dataclass
46
+ class DiTConfig(ModelConfig):
47
+ arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
48
+
49
+ # FastVideoDiT-specific parameters
50
+ prefix: str = ""
51
+ quant_config: QuantizationConfig | None = None
52
+ expand_timesteps: bool = False
53
+ boundary_ratio: float | None = None
54
+
55
+ def __post_init__(self) -> None:
56
+ super().__post_init__()
57
+ self.arch_config.expand_timesteps = self.expand_timesteps
58
+ self.arch_config.boundary_ratio = self.boundary_ratio
59
+
60
+ @staticmethod
61
+ def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
62
+ """Add CLI arguments for DiTConfig fields"""
63
+ parser.add_argument(
64
+ f"--{prefix}.prefix",
65
+ type=str,
66
+ dest=f"{prefix.replace('-', '_')}.prefix",
67
+ default=DiTConfig.prefix,
68
+ help="Prefix for the DiT model",
69
+ )
70
+
71
+ parser.add_argument(
72
+ f"--{prefix}.quant-config",
73
+ type=str,
74
+ dest=f"{prefix.replace('-', '_')}.quant_config",
75
+ default=None,
76
+ help="Quantization configuration for the DiT model",
77
+ )
78
+
79
+ return parser
backend_snapshot/fastvideo/forward_context.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
3
+
4
+ import time
5
+ from collections import defaultdict
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Optional
9
+
10
+ import torch
11
+
12
+ from fastvideo.logger import init_logger
13
+
14
+ if TYPE_CHECKING:
15
+ from fastvideo.attention import AttentionMetadata
16
+ from fastvideo.pipelines import ForwardBatch
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ # TODO(will): check if this is needed
21
+ # track_batchsize: bool = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL >= 0
22
+ track_batchsize: bool = False
23
+ last_logging_time: float = 0
24
+ forward_start_time: float = 0
25
+ # batchsize_logging_interval: float = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL
26
+ batchsize_logging_interval: float = 1000
27
+ batchsize_forward_time: defaultdict = defaultdict(list)
28
+
29
+
30
+ #
31
+ @dataclass
32
+ class ForwardContext:
33
+ current_timestep: int
34
+ # TODO(will): check this arg
35
+ # copy from vllm_config.compilation_config.static_forward_context
36
+ # attn_layers: Dict[str, Any]
37
+ # TODO: extend to support per-layer dynamic forward context
38
+ attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
39
+ forward_batch: Optional["ForwardBatch"] = None
40
+ force_dense: bool = False
41
+
42
+
43
+ _forward_context: Optional["ForwardContext"] = None
44
+
45
+
46
+ def get_forward_context() -> "ForwardContext":
47
+ """Get the current forward context."""
48
+ assert _forward_context is not None, ("Forward context is not set. "
49
+ "Please use `set_forward_context` to set the forward context.")
50
+ return _forward_context
51
+
52
+
53
+ # TODO(will): finalize the interface
54
+ @contextmanager
55
+ def set_forward_context(current_timestep, attn_metadata, forward_batch: Optional["ForwardBatch"] = None, force_dense: bool = False):
56
+ """A context manager that stores the current forward context,
57
+ can be attention metadata, etc.
58
+ Here we can inject common logic for every model forward pass.
59
+ """
60
+ global forward_start_time
61
+ need_to_track_batchsize = track_batchsize and attn_metadata is not None
62
+ if need_to_track_batchsize:
63
+ forward_start_time = time.perf_counter()
64
+ global _forward_context
65
+ prev_context = _forward_context
66
+ _forward_context = ForwardContext(current_timestep=current_timestep,
67
+ attn_metadata=attn_metadata,
68
+ forward_batch=forward_batch,
69
+ force_dense=force_dense)
70
+
71
+ try:
72
+ yield
73
+ finally:
74
+ global last_logging_time, batchsize_logging_interval
75
+ if need_to_track_batchsize:
76
+ if hasattr(attn_metadata, "num_prefill_tokens"):
77
+ # for v0 attention backends
78
+ batchsize = attn_metadata.num_prefill_tokens + \
79
+ attn_metadata.num_decode_tokens
80
+ else:
81
+ # for v1 attention backends
82
+ batchsize = attn_metadata.num_input_tokens
83
+ now = time.perf_counter()
84
+ # time measurement is in milliseconds
85
+ batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
86
+ if now - last_logging_time > batchsize_logging_interval:
87
+ last_logging_time = now
88
+ forward_stats = []
89
+ for bs, times in batchsize_forward_time.items():
90
+ if len(times) <= 1:
91
+ # can be cudagraph / profiling run
92
+ continue
93
+ medium = torch.quantile(torch.tensor(times), q=0.5).item()
94
+ medium = round(medium, 2)
95
+ forward_stats.append((bs, len(times), medium))
96
+ forward_stats.sort(key=lambda x: x[1], reverse=True)
97
+ if forward_stats:
98
+ logger.info(("Batchsize forward time stats "
99
+ "(batchsize, count, median_time(ms)): %s"), forward_stats)
100
+ _forward_context = prev_context
backend_snapshot/fastvideo/pipelines/stages/denoising.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ Denoising stage for diffusion pipelines.
4
+ """
5
+
6
+ import inspect
7
+ import weakref
8
+ from collections.abc import Iterable
9
+ from typing import Any
10
+
11
+ import torch
12
+ from tqdm.auto import tqdm
13
+
14
+ from fastvideo.attention import get_attn_backend
15
+ from fastvideo.distributed import (get_local_torch_device, get_world_group)
16
+ from fastvideo.fastvideo_args import FastVideoArgs
17
+ from fastvideo.forward_context import set_forward_context
18
+ from fastvideo.logger import init_logger
19
+ from fastvideo.models.loader.component_loader import TransformerLoader
20
+ from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (FlowMatchEulerDiscreteScheduler)
21
+ from fastvideo.models.utils import pred_noise_to_pred_video
22
+ from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
23
+ from fastvideo.pipelines.stages.base import PipelineStage
24
+ from fastvideo.pipelines.stages.validators import StageValidators as V
25
+ from fastvideo.pipelines.stages.validators import VerificationResult
26
+ from fastvideo.platforms import AttentionBackendEnum
27
+ from fastvideo.utils import dict_to_3d_list, masks_like
28
+
29
+ try:
30
+ from fastvideo.attention.backends.vmoba import VMOBAAttentionBackend
31
+ from fastvideo.utils import is_vmoba_available
32
+ vmoba_attn_available = is_vmoba_available()
33
+ except ImportError:
34
+ vmoba_attn_available = False
35
+
36
+ try:
37
+ from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionBackend)
38
+ vsa_available = True
39
+ except ImportError:
40
+ vsa_available = False
41
+
42
+ try:
43
+ from fastvideo.attention.backends.sparse_fp4_attn import (SparseFP4AttentionBackend)
44
+ except ImportError:
45
+ SparseFP4AttentionBackend = None # type: ignore[assignment]
46
+
47
+ try:
48
+ from fastvideo.attention.backends.sparse_fp4_ours_p_attn import (SparseFP4OursPAttentionBackend)
49
+ except ImportError:
50
+ SparseFP4OursPAttentionBackend = None # type: ignore[assignment]
51
+
52
+ sparse_fp4_backends = tuple(
53
+ backend for backend in (
54
+ SparseFP4AttentionBackend,
55
+ SparseFP4OursPAttentionBackend,
56
+ ) if backend is not None)
57
+ sparse_fp4_available = bool(sparse_fp4_backends)
58
+
59
+ logger = init_logger(__name__)
60
+
61
+
62
+ class DenoisingStage(PipelineStage):
63
+ """
64
+ Stage for running the denoising loop in diffusion pipelines.
65
+
66
+ This stage handles the iterative denoising process that transforms
67
+ the initial noise into the final output.
68
+ """
69
+
70
+ def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.transformer_2 = transformer_2
74
+ self.scheduler = scheduler
75
+ self.vae = vae
76
+ self.pipeline = weakref.ref(pipeline) if pipeline else None
77
+ attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
78
+ self.attn_backend = get_attn_backend(
79
+ head_size=attn_head_size,
80
+ dtype=torch.float16, # TODO(will): hack
81
+ supported_attention_backends=(
82
+ AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, AttentionBackendEnum.VMOBA_ATTN,
83
+ AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE,
84
+ AttentionBackendEnum.ATTN_QAT_INFER, AttentionBackendEnum.ATTN_QAT_TRAIN,
85
+ AttentionBackendEnum.SPARSE_FP4_ATTN, AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN) # hack
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ batch: ForwardBatch,
91
+ fastvideo_args: FastVideoArgs,
92
+ ) -> ForwardBatch:
93
+ """
94
+ Run the denoising loop.
95
+
96
+ Args:
97
+ batch: The current batch information.
98
+ fastvideo_args: The inference arguments.
99
+
100
+ Returns:
101
+ The batch with denoised latents.
102
+ """
103
+ pipeline = self.pipeline() if self.pipeline else None
104
+ if not fastvideo_args.model_loaded["transformer"]:
105
+ loader = TransformerLoader()
106
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
107
+ if pipeline:
108
+ pipeline.add_module("transformer", self.transformer)
109
+ fastvideo_args.model_loaded["transformer"] = True
110
+
111
+ # Prepare extra step kwargs for scheduler
112
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
113
+ self.scheduler.step,
114
+ {
115
+ "generator": batch.generator,
116
+ "eta": batch.eta
117
+ },
118
+ )
119
+
120
+ # Setup precision and autocast settings
121
+ # TODO(will): make the precision configurable for inference
122
+ # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
123
+ target_dtype = torch.bfloat16
124
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
125
+
126
+ # Get timesteps and calculate warmup steps
127
+ timesteps = batch.timesteps
128
+ # TODO(will): remove this once we add input/output validation for stages
129
+ if timesteps is None:
130
+ raise ValueError("Timesteps must be provided")
131
+ num_inference_steps = batch.num_inference_steps
132
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
133
+
134
+ # Prepare image latents and embeddings for I2V generation
135
+ image_embeds = batch.image_embeds
136
+ if len(image_embeds) > 0:
137
+ assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
138
+ image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
139
+
140
+ image_kwargs = self.prepare_extra_func_kwargs(
141
+ self.transformer.forward,
142
+ {
143
+ "encoder_hidden_states_image": image_embeds,
144
+ "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
145
+ },
146
+ )
147
+
148
+ pos_cond_kwargs = self.prepare_extra_func_kwargs(
149
+ self.transformer.forward,
150
+ {
151
+ "encoder_hidden_states_2": batch.clip_embedding_pos,
152
+ "encoder_attention_mask": batch.prompt_attention_mask,
153
+ },
154
+ )
155
+
156
+ neg_cond_kwargs = self.prepare_extra_func_kwargs(
157
+ self.transformer.forward,
158
+ {
159
+ "encoder_hidden_states_2": batch.clip_embedding_neg,
160
+ "encoder_attention_mask": batch.negative_attention_mask,
161
+ },
162
+ )
163
+
164
+ action_kwargs = self.prepare_extra_func_kwargs(
165
+ self.transformer.forward,
166
+ {
167
+ "mouse_cond": batch.mouse_cond,
168
+ "keyboard_cond": batch.keyboard_cond,
169
+ "c2ws_plucker_emb": batch.c2ws_plucker_emb,
170
+ },
171
+ )
172
+
173
+ camera_kwargs = self.prepare_extra_func_kwargs(
174
+ self.transformer.forward,
175
+ {
176
+ "camera_states": batch.camera_states,
177
+ },
178
+ )
179
+
180
+ # Get latents and embeddings
181
+ latents = batch.latents
182
+ prompt_embeds = batch.prompt_embeds
183
+ assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
184
+ if batch.do_classifier_free_guidance:
185
+ neg_prompt_embeds = batch.negative_prompt_embeds
186
+ assert neg_prompt_embeds is not None
187
+ assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"
188
+
189
+ # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
190
+ boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
191
+ if batch.boundary_ratio is not None:
192
+ logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
193
+ boundary_ratio = batch.boundary_ratio
194
+
195
+ boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
196
+ latent_model_input = latents.to(target_dtype)
197
+ assert latent_model_input.shape[0] == 1, "only support batch size 1"
198
+
199
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
200
+ # TI2V directly replaces the first frame of the latent with
201
+ # the image latent instead of appending along the channel dim
202
+ assert batch.image_latent is None, "TI2V task should not have image latents"
203
+ assert self.vae is not None, "VAE is not provided for TI2V task"
204
+ z = self.vae.encode(batch.pil_image).mean.float()
205
+ if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
206
+ if isinstance(self.vae.shift_factor, torch.Tensor):
207
+ z -= self.vae.shift_factor.to(z.device, z.dtype)
208
+ else:
209
+ z -= self.vae.shift_factor
210
+
211
+ if isinstance(self.vae.scaling_factor, torch.Tensor):
212
+ z = z * self.vae.scaling_factor.to(z.device, z.dtype)
213
+ else:
214
+ z = z * self.vae.scaling_factor
215
+
216
+ latent_model_input = latent_model_input.squeeze(0)
217
+ _, mask2 = masks_like([latent_model_input], zero=True)
218
+
219
+ latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
220
+ # latent_model_input = latent_model_input.unsqueeze(0)
221
+ latent_model_input = latent_model_input.to(get_local_torch_device())
222
+ latents = latent_model_input
223
+ F = batch.num_frames
224
+ temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
225
+ spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
226
+ patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
227
+ if not isinstance(patch_size, tuple):
228
+ raise ValueError(f"Expected 3D patch_size tuple for denoising, got {patch_size!r}")
229
+ seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
230
+ batch.width // spatial_scale) // (patch_size[1] * patch_size[2])
231
+
232
+ # Initialize lists for ODE trajectory
233
+ trajectory_timesteps: list[torch.Tensor] = []
234
+ trajectory_latents: list[torch.Tensor] = []
235
+
236
+ # Run denoising loop
237
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
238
+ for i, t in enumerate(timesteps):
239
+ # Skip if interrupted
240
+ if hasattr(self, 'interrupt') and self.interrupt:
241
+ continue
242
+
243
+ if boundary_timestep is None or t >= boundary_timestep:
244
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
245
+ and self.transformer_2 is not None
246
+ and next(self.transformer_2.parameters()).device.type == 'cuda'):
247
+ self.transformer_2.to('cpu')
248
+ current_model = self.transformer
249
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
250
+ and not fastvideo_args.use_fsdp_inference and current_model is not None):
251
+ transformer_device = next(current_model.parameters()).device.type
252
+ if transformer_device == 'cpu':
253
+ current_model.to(get_local_torch_device())
254
+ current_guidance_scale = batch.guidance_scale
255
+ else:
256
+ # low-noise stage in wan2.2
257
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
258
+ and next(self.transformer.parameters()).device.type == 'cuda'):
259
+ self.transformer.to('cpu')
260
+ current_model = self.transformer_2
261
+ if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
262
+ and not fastvideo_args.use_fsdp_inference and current_model is not None):
263
+ transformer_2_device = next(current_model.parameters()).device.type
264
+ if transformer_2_device == 'cpu':
265
+ current_model.to(get_local_torch_device())
266
+ current_guidance_scale = batch.guidance_scale_2
267
+ assert current_model is not None, "current_model is None"
268
+
269
+ # Expand latents for V2V/I2V
270
+ latent_model_input = latents.to(target_dtype)
271
+ if batch.video_latent is not None:
272
+ latent_model_input = torch.cat([latent_model_input, batch.video_latent,
273
+ torch.zeros_like(latents)],
274
+ dim=1).to(target_dtype)
275
+ elif batch.image_latent is not None:
276
+ assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
277
+ latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)
278
+
279
+ assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
280
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
281
+ timestep = torch.stack([t]).to(get_local_torch_device())
282
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
283
+ temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
284
+ timestep = temp_ts.unsqueeze(0)
285
+ t_expand = timestep.repeat(latent_model_input.shape[0], 1)
286
+ else:
287
+ t_expand = t.repeat(latent_model_input.shape[0])
288
+ t_expand = t_expand.to(get_local_torch_device())
289
+
290
+ use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
291
+ if use_meanflow:
292
+ if i == len(timesteps) - 1:
293
+ timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
294
+ else:
295
+ timesteps_r = timesteps[i + 1]
296
+ timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
297
+ else:
298
+ timesteps_r = None
299
+
300
+ timesteps_r_kwarg = self.prepare_extra_func_kwargs(
301
+ self.transformer.forward,
302
+ {
303
+ "timestep_r": timesteps_r,
304
+ },
305
+ )
306
+
307
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
308
+
309
+ # Prepare inputs for transformer
310
+ guidance_expand = (torch.tensor(
311
+ [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
312
+ dtype=torch.float32,
313
+ device=get_local_torch_device(),
314
+ ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
315
+
316
+ # Predict noise residual
317
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
318
+ if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
319
+ (sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
320
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
321
+
322
+ if self.attn_metadata_builder_cls is not None:
323
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
324
+ # TODO(will): clean this up
325
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
326
+ current_timestep=i, # type: ignore
327
+ raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
328
+ patch_size=fastvideo_args.pipeline_config. # type: ignore
329
+ dit_config.patch_size, # type: ignore
330
+ VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
331
+ device=get_local_torch_device(),
332
+ )
333
+ assert attn_metadata is not None, "attn_metadata cannot be None"
334
+ else:
335
+ attn_metadata = None
336
+ elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
337
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
338
+ if self.attn_metadata_builder_cls is not None:
339
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
340
+ # Prepare V-MoBA parameters from config
341
+ moba_params = fastvideo_args.moba_config.copy()
342
+ assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA"
343
+ moba_params.update({
344
+ "current_timestep": i,
345
+ "raw_latent_shape": batch.raw_latent_shape[2:5],
346
+ "patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
347
+ "device": get_local_torch_device(),
348
+ })
349
+ attn_metadata = self.attn_metadata_builder.build(**moba_params)
350
+ assert attn_metadata is not None, "attn_metadata cannot be None"
351
+ else:
352
+ attn_metadata = None
353
+ else:
354
+ attn_metadata = None
355
+ # TODO(will): finalize the interface. vLLM uses this to
356
+ # support torch dynamo compilation. They pass in
357
+ # attn_metadata, vllm_config, and num_tokens. We can pass in
358
+ # fastvideo_args or training_args, and attn_metadata.
359
+ batch.is_cfg_negative = False
360
+ with set_forward_context(
361
+ current_timestep=i,
362
+ attn_metadata=attn_metadata,
363
+ forward_batch=batch,
364
+ # fastvideo_args=fastvideo_args
365
+ ):
366
+ # Run transformer
367
+ noise_pred = current_model(
368
+ latent_model_input,
369
+ prompt_embeds,
370
+ t_expand,
371
+ guidance=guidance_expand,
372
+ **image_kwargs,
373
+ **pos_cond_kwargs,
374
+ **action_kwargs,
375
+ **camera_kwargs,
376
+ **timesteps_r_kwarg,
377
+ )
378
+
379
+ if batch.do_classifier_free_guidance:
380
+ batch.is_cfg_negative = True
381
+ with set_forward_context(
382
+ current_timestep=i,
383
+ attn_metadata=attn_metadata,
384
+ forward_batch=batch,
385
+ ):
386
+ noise_pred_uncond = current_model(
387
+ latent_model_input,
388
+ neg_prompt_embeds,
389
+ t_expand,
390
+ guidance=guidance_expand,
391
+ **image_kwargs,
392
+ **neg_cond_kwargs,
393
+ **action_kwargs,
394
+ **camera_kwargs,
395
+ **timesteps_r_kwarg,
396
+ )
397
+
398
+ noise_pred_text = noise_pred
399
+ noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - noise_pred_uncond)
400
+
401
+ # Apply guidance rescale if needed
402
+ if batch.guidance_rescale > 0.0:
403
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
404
+ noise_pred = self.rescale_noise_cfg(
405
+ noise_pred,
406
+ noise_pred_text,
407
+ guidance_rescale=batch.guidance_rescale,
408
+ )
409
+ # Compute the previous noisy sample
410
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
411
+ if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
412
+ latents = latents.squeeze(0)
413
+ latents = (1. - mask2[0]) * z + mask2[0] * latents
414
+ # latents = latents.unsqueeze(0)
415
+
416
+ # save trajectory latents if needed
417
+ if batch.return_trajectory_latents:
418
+ trajectory_timesteps.append(t)
419
+ trajectory_latents.append(latents)
420
+
421
+ # Update progress bar
422
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
423
+ (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
424
+ progress_bar.update()
425
+
426
+ trajectory_tensor: torch.Tensor | None = None
427
+ if trajectory_latents:
428
+ trajectory_tensor = torch.stack(trajectory_latents, dim=1)
429
+ trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
430
+ else:
431
+ trajectory_tensor = None
432
+ trajectory_timesteps_tensor = None
433
+
434
+ if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
435
+ batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
436
+ batch.trajectory_latents = trajectory_tensor.cpu()
437
+
438
+ # Update batch with final latents
439
+ batch.latents = latents
440
+
441
+ if fastvideo_args.dit_layerwise_offload:
442
+ mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
443
+ if mgr is not None and getattr(mgr, "enabled", False):
444
+ mgr.release_all()
445
+ if self.transformer_2 is not None:
446
+ mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
447
+ if mgr2 is not None and getattr(mgr2, "enabled", False):
448
+ mgr2.release_all()
449
+
450
+ # deallocate transformer if on mps
451
+ if torch.backends.mps.is_available():
452
+ logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
453
+ del self.transformer
454
+ if pipeline is not None and "transformer" in pipeline.modules:
455
+ del pipeline.modules["transformer"]
456
+ fastvideo_args.model_loaded["transformer"] = False
457
+ logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())
458
+
459
+ return batch
460
+
461
+ def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
462
+ """
463
+ Prepare extra kwargs for the scheduler step / denoise step.
464
+
465
+ Args:
466
+ func: The function to prepare kwargs for.
467
+ kwargs: The kwargs to prepare.
468
+
469
+ Returns:
470
+ The prepared kwargs.
471
+ """
472
+ extra_step_kwargs = {}
473
+ for k, v in kwargs.items():
474
+ accepts = k in set(inspect.signature(func).parameters.keys())
475
+ if accepts:
476
+ extra_step_kwargs[k] = v
477
+ return extra_step_kwargs
478
+
479
+ def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
480
+ """
481
+ Create a progress bar for the denoising process.
482
+
483
+ Args:
484
+ iterable: The iterable to iterate over.
485
+ total: The total number of items.
486
+
487
+ Returns:
488
+ A tqdm progress bar.
489
+ """
490
+ local_rank = get_world_group().local_rank
491
+ if local_rank == 0:
492
+ return tqdm(iterable=iterable, total=total)
493
+ else:
494
+ return tqdm(iterable=iterable, total=total, disable=True)
495
+
496
+ def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
497
+ """
498
+ Rescale noise prediction according to guidance_rescale.
499
+
500
+ Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
501
+ (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.
502
+
503
+ Args:
504
+ noise_cfg: The noise prediction with guidance.
505
+ noise_pred_text: The text-conditioned noise prediction.
506
+ guidance_rescale: The guidance rescale factor.
507
+
508
+ Returns:
509
+ The rescaled noise prediction.
510
+ """
511
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
512
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
513
+ # Rescale the results from guidance (fixes overexposure)
514
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
515
+ # Mix with the original results from guidance by factor guidance_rescale
516
+ noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
517
+ return noise_cfg
518
+
519
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
520
+ """Verify denoising stage inputs."""
521
+ result = VerificationResult()
522
+ result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
523
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
524
+ result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
525
+ result.add_check("image_embeds", batch.image_embeds, V.is_list)
526
+ result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
527
+ result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
528
+ result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
529
+ result.add_check("eta", batch.eta, V.non_negative_float)
530
+ result.add_check("generator", batch.generator, V.generator_or_list_generators)
531
+ result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
532
+ result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
533
+ lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
534
+ return result
535
+
536
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
537
+ """Verify denoising stage outputs."""
538
+ result = VerificationResult()
539
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
540
+ return result
541
+
542
+
543
+ class CosmosDenoisingStage(DenoisingStage):
544
+ """
545
+ Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
546
+ """
547
+
548
+ def __init__(self, transformer, scheduler, pipeline=None) -> None:
549
+ super().__init__(transformer, scheduler, pipeline)
550
+
551
+ def forward(
552
+ self,
553
+ batch: ForwardBatch,
554
+ fastvideo_args: FastVideoArgs,
555
+ ) -> ForwardBatch:
556
+ pipeline = self.pipeline() if self.pipeline else None
557
+ if not fastvideo_args.model_loaded["transformer"]:
558
+ loader = TransformerLoader()
559
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
560
+ if pipeline:
561
+ pipeline.add_module("transformer", self.transformer)
562
+ fastvideo_args.model_loaded["transformer"] = True
563
+
564
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
565
+ self.scheduler.step,
566
+ {
567
+ "generator": batch.generator,
568
+ "eta": batch.eta
569
+ },
570
+ )
571
+
572
+ if hasattr(self.transformer, 'module'):
573
+ transformer_dtype = next(self.transformer.module.parameters()).dtype
574
+ else:
575
+ transformer_dtype = next(self.transformer.parameters()).dtype
576
+ target_dtype = transformer_dtype
577
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
578
+
579
+ latents = batch.latents
580
+ num_inference_steps = batch.num_inference_steps
581
+ guidance_scale = batch.guidance_scale
582
+
583
+ sigma_max = 80.0
584
+ sigma_min = 0.002
585
+ sigma_data = 1.0
586
+ final_sigmas_type = "sigma_min"
587
+
588
+ if self.scheduler is not None:
589
+ self.scheduler.register_to_config(
590
+ sigma_max=sigma_max,
591
+ sigma_min=sigma_min,
592
+ sigma_data=sigma_data,
593
+ final_sigmas_type=final_sigmas_type,
594
+ )
595
+
596
+ self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
597
+ timesteps = self.scheduler.timesteps
598
+
599
+ if (hasattr(self.scheduler.config, 'final_sigmas_type')
600
+ and self.scheduler.config.final_sigmas_type == "sigma_min" and len(self.scheduler.sigmas) > 1):
601
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
602
+
603
+ conditioning_latents = getattr(batch, 'conditioning_latents', None)
604
+ unconditioning_latents = conditioning_latents
605
+
606
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
607
+ for i, t in enumerate(timesteps):
608
+ if hasattr(self, 'interrupt') and self.interrupt:
609
+ continue
610
+
611
+ current_sigma = self.scheduler.sigmas[i]
612
+ current_t = current_sigma / (current_sigma + 1)
613
+ c_in = 1 - current_t
614
+ c_skip = 1 - current_t
615
+ c_out = -current_t
616
+
617
+ timestep = current_t.view(1, 1, 1, 1, 1).expand(latents.size(0), -1, latents.size(2), -1,
618
+ -1) # [B, 1, T, 1, 1]
619
+
620
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
621
+
622
+ cond_latent = latents * c_in
623
+
624
+ if hasattr(
625
+ batch,
626
+ 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
627
+ cond_latent = batch.cond_indicator * conditioning_latents + (1 -
628
+ batch.cond_indicator) * cond_latent
629
+ else:
630
+ logger.warning(
631
+ "Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", i,
632
+ hasattr(batch, 'cond_indicator'), conditioning_latents is not None)
633
+
634
+ cond_latent = cond_latent.to(target_dtype)
635
+
636
+ cond_timestep = timestep
637
+ if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
638
+ sigma_conditioning = 0.0001
639
+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
640
+ cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
641
+ cond_timestep = cond_timestep.to(target_dtype)
642
+
643
+ with set_forward_context(
644
+ current_timestep=i,
645
+ attn_metadata=None,
646
+ forward_batch=batch,
647
+ ):
648
+ # Use conditioning masks from CosmosLatentPreparationStage
649
+ condition_mask = batch.cond_mask.to(target_dtype) if hasattr(batch, 'cond_mask') else None
650
+ padding_mask = torch.zeros(1,
651
+ 1,
652
+ batch.height,
653
+ batch.width,
654
+ device=cond_latent.device,
655
+ dtype=target_dtype)
656
+
657
+ # Fallback if masks not available
658
+ if condition_mask is None:
659
+ batch_size, num_channels, num_frames, height, width = cond_latent.shape
660
+ condition_mask = torch.zeros(batch_size,
661
+ 1,
662
+ num_frames,
663
+ height,
664
+ width,
665
+ device=cond_latent.device,
666
+ dtype=target_dtype)
667
+
668
+ noise_pred = self.transformer(
669
+ hidden_states=cond_latent,
670
+ timestep=cond_timestep.to(target_dtype),
671
+ encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
672
+ fps=24, # TODO: get fps from batch or config
673
+ condition_mask=condition_mask,
674
+ padding_mask=padding_mask,
675
+ return_dict=False,
676
+ )[0]
677
+
678
+ cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
679
+
680
+ if hasattr(
681
+ batch,
682
+ 'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
683
+ cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
684
+
685
+ if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
686
+ uncond_latent = latents * c_in
687
+
688
+ if hasattr(batch, 'uncond_indicator'
689
+ ) and batch.uncond_indicator is not None and unconditioning_latents is not None:
690
+ uncond_latent = batch.uncond_indicator * unconditioning_latents + (
691
+ 1 - batch.uncond_indicator) * uncond_latent
692
+
693
+ with set_forward_context(
694
+ current_timestep=i,
695
+ attn_metadata=None,
696
+ forward_batch=batch,
697
+ ):
698
+ uncond_condition_mask = batch.uncond_mask.to(target_dtype) if hasattr(
699
+ batch, 'uncond_mask') and batch.uncond_mask is not None else condition_mask
700
+
701
+ uncond_timestep = timestep
702
+ if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
703
+ sigma_conditioning = 0.0001
704
+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
705
+ uncond_timestep = batch.uncond_indicator * t_conditioning + (
706
+ 1 - batch.uncond_indicator) * timestep
707
+ uncond_timestep = uncond_timestep.to(target_dtype)
708
+
709
+ noise_pred_uncond = self.transformer(
710
+ hidden_states=uncond_latent.to(target_dtype),
711
+ timestep=uncond_timestep.to(target_dtype),
712
+ encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
713
+ fps=24, # TODO: get fps from batch or config
714
+ condition_mask=uncond_condition_mask,
715
+ padding_mask=padding_mask,
716
+ return_dict=False,
717
+ )[0]
718
+
719
+ uncond_pred = (c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype)
720
+
721
+ if hasattr(batch, 'uncond_indicator'
722
+ ) and batch.uncond_indicator is not None and unconditioning_latents is not None:
723
+ uncond_pred = batch.uncond_indicator * unconditioning_latents + (
724
+ 1 - batch.uncond_indicator) * uncond_pred
725
+
726
+ guidance_diff = cond_pred - uncond_pred
727
+ final_pred = cond_pred + guidance_scale * guidance_diff
728
+ else:
729
+ final_pred = cond_pred
730
+
731
+ # Convert to noise for scheduler step
732
+ if current_sigma > 1e-8:
733
+ noise_for_scheduler = (latents - final_pred) / current_sigma
734
+ else:
735
+ logger.warning("Step %s: current_sigma too small (%s), using final_pred directly", i, current_sigma)
736
+ noise_for_scheduler = final_pred
737
+
738
+ if torch.isnan(noise_for_scheduler).sum() > 0:
739
+ logger.error("Step %s: NaN detected in noise_for_scheduler, sum: %s", i,
740
+ noise_for_scheduler.float().sum().item())
741
+ logger.error("Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", i,
742
+ latents.float().sum().item(),
743
+ final_pred.float().sum().item(), current_sigma)
744
+
745
+ latents = self.scheduler.step(noise_for_scheduler, t, latents, **extra_step_kwargs,
746
+ return_dict=False)[0]
747
+
748
+ progress_bar.update()
749
+
750
+ batch.latents = latents
751
+
752
+ return batch
753
+
754
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
755
+ """Verify Cosmos denoising stage inputs."""
756
+ result = VerificationResult()
757
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
758
+ result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
759
+ result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
760
+ result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
761
+ result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
762
+ result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
763
+ lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
764
+ return result
765
+
766
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
767
+ """Verify Cosmos denoising stage outputs."""
768
+ result = VerificationResult()
769
+ result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
770
+ return result
771
+
772
+
773
+ class Cosmos25DenoisingStage(CosmosDenoisingStage):
774
+ """Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D)."""
775
+
776
+ def forward(
777
+ self,
778
+ batch: ForwardBatch,
779
+ fastvideo_args: FastVideoArgs,
780
+ ) -> ForwardBatch:
781
+ pipeline = self.pipeline() if self.pipeline else None
782
+ if not fastvideo_args.model_loaded["transformer"]:
783
+ loader = TransformerLoader()
784
+ self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
785
+ if pipeline:
786
+ pipeline.add_module("transformer", self.transformer)
787
+ fastvideo_args.model_loaded["transformer"] = True
788
+
789
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
790
+ self.scheduler.step,
791
+ {
792
+ "generator": batch.generator,
793
+ "eta": batch.eta
794
+ },
795
+ )
796
+
797
+ if hasattr(self.transformer, 'module'):
798
+ transformer_dtype = next(self.transformer.module.parameters()).dtype
799
+ else:
800
+ transformer_dtype = next(self.transformer.parameters()).dtype
801
+ target_dtype = transformer_dtype
802
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
803
+
804
+ latents = batch.latents
805
+ if latents is None:
806
+ raise ValueError("latents must be provided for Cosmos25DenoisingStage")
807
+ guidance_scale = batch.guidance_scale
808
+
809
+ if batch.timesteps is None:
810
+ self.scheduler.set_timesteps(batch.num_inference_steps, device=latents.device)
811
+ timesteps = self.scheduler.timesteps
812
+ else:
813
+ timesteps = batch.timesteps.to(latents.device)
814
+
815
+ cfg = fastvideo_args.pipeline_config
816
+
817
+ if batch.fps is None:
818
+ gen = batch.generator
819
+ if isinstance(gen, list) and len(gen) > 0:
820
+ gen = gen[0]
821
+ fps_tensor = torch.randint(
822
+ 16,
823
+ 32,
824
+ (1, ),
825
+ generator=gen if isinstance(gen, torch.Generator) else None,
826
+ device=latents.device,
827
+ ).float().to(dtype=target_dtype)
828
+ else:
829
+ fps_val = batch.fps
830
+ fps_tensor = torch.tensor(
831
+ [fps_val],
832
+ device=latents.device,
833
+ dtype=target_dtype,
834
+ )
835
+
836
+ latents_4d = latents[0]
837
+
838
+ # Masks are optional for T2W.
839
+ cond_mask = getattr(batch, "cond_mask", None)
840
+ condition_mask = cond_mask.to(target_dtype) if isinstance(cond_mask, torch.Tensor) else None
841
+ pad_mask = getattr(batch, "padding_mask", None)
842
+ padding_mask = pad_mask.to(target_dtype) if isinstance(pad_mask, torch.Tensor) else None
843
+
844
+ # Conditioning fields are attached by latent preparation stage.
845
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
846
+ cond_indicator = getattr(batch, "cond_indicator", None)
847
+ # Infer whether this is a conditioned run (V2W/I2W) purely from the presence
848
+ # of conditioning latents. Avoid carrying explicit mode flags on the batch.
849
+ is_conditioned = (conditioning_latents is not None)
850
+
851
+ init_noise_4d = latents_4d.clone()
852
+ if condition_mask is None:
853
+ _, t, h, w = latents_4d.shape
854
+ condition_mask = torch.zeros(1, 1, t, h, w, device=latents.device, dtype=target_dtype)
855
+ if padding_mask is None:
856
+ _, _, h, w = latents_4d.shape
857
+ padding_default = 0.0 if is_conditioned else 1.0
858
+ padding_mask = torch.full(
859
+ (1, 1, h, w),
860
+ float(padding_default),
861
+ device=latents.device,
862
+ dtype=target_dtype,
863
+ )
864
+
865
+ timestep_scale = 0.001
866
+
867
+ state_dtype = torch.float32
868
+
869
+ conditional_frame_timestep = 0.1
870
+ latents_4d = latents_4d.to(state_dtype)
871
+ init_noise_4d = init_noise_4d.to(state_dtype)
872
+
873
+ clamp_every_step = bool(getattr(cfg, "cosmos25_clamp_every_step", True)) if is_conditioned else False
874
+
875
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
876
+ for i, t in enumerate(timesteps):
877
+ t_val = float(t)
878
+ if is_conditioned:
879
+ t_frames = int(latents_4d.shape[1])
880
+ timestep = torch.full(
881
+ (1, t_frames),
882
+ float(t_val * timestep_scale),
883
+ device=latents.device,
884
+ dtype=torch.float32,
885
+ )
886
+ if cond_indicator is not None and t_frames > 0:
887
+ cond_t = cond_indicator[0, 0, :t_frames, 0, 0]
888
+ cond_mask_t = (cond_t > 0.5)
889
+ if bool(cond_mask_t.any().item()):
890
+ timestep[0, cond_mask_t] = float(conditional_frame_timestep)
891
+ else:
892
+ timestep_val = t_val * timestep_scale
893
+ timestep = torch.tensor(
894
+ [[float(timestep_val)]],
895
+ device=latents.device,
896
+ dtype=target_dtype,
897
+ )
898
+
899
+ # Conditioned runs: replace x_t with GT x0 on the conditioned frames.
900
+ if (is_conditioned and cond_indicator is not None and conditioning_latents is not None
901
+ and (clamp_every_step or i == 0)):
902
+ cond_ind_4d = cond_indicator[0].to(state_dtype)
903
+ gt_x0 = conditioning_latents[0].to(state_dtype)
904
+ latents_4d = gt_x0 * cond_ind_4d + latents_4d * (1 - cond_ind_4d)
905
+
906
+ model_hidden_states = latents_4d.unsqueeze(0)
907
+
908
+ with (
909
+ set_forward_context(current_timestep=int(t_val), attn_metadata=None, forward_batch=batch),
910
+ torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled),
911
+ ):
912
+ cond_v = self.transformer(
913
+ hidden_states=model_hidden_states.to(target_dtype),
914
+ encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
915
+ timestep=timestep,
916
+ fps=fps_tensor,
917
+ condition_mask=condition_mask,
918
+ padding_mask=padding_mask,
919
+ return_dict=False,
920
+ )[0]
921
+
922
+ if batch.do_classifier_free_guidance and batch.negative_prompt_embeds:
923
+ uncond_v = self.transformer(
924
+ hidden_states=model_hidden_states.to(target_dtype),
925
+ encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
926
+ timestep=timestep,
927
+ fps=fps_tensor,
928
+ condition_mask=condition_mask,
929
+ padding_mask=padding_mask,
930
+ return_dict=False,
931
+ )[0]
932
+ if is_conditioned:
933
+ v = cond_v + guidance_scale * (cond_v - uncond_v)
934
+ else:
935
+ v = uncond_v + guidance_scale * (cond_v - uncond_v)
936
+ else:
937
+ v = cond_v
938
+
939
+ # Conditioned runs: replace velocity on conditioned frames with GT velocity.
940
+ if (is_conditioned and cond_indicator is not None and conditioning_latents is not None):
941
+ cond_ind_4d = cond_indicator[0].to(state_dtype)
942
+ gt_x0 = conditioning_latents[0].to(state_dtype)
943
+ gt_v = init_noise_4d.to(state_dtype) - gt_x0
944
+ v = cond_ind_4d * gt_v + (1 - cond_ind_4d) * v.to(state_dtype)
945
+
946
+ prev = self.scheduler.step(v.unsqueeze(0),
947
+ t,
948
+ latents_4d.unsqueeze(0),
949
+ **extra_step_kwargs,
950
+ return_dict=False)[0]
951
+ latents_4d = prev.squeeze(0)
952
+
953
+ progress_bar.update()
954
+
955
+ batch.latents = latents_4d.to(target_dtype).unsqueeze(0)
956
+ return batch
957
+
958
+
959
+ class Cosmos25T2WDenoisingStage(Cosmos25DenoisingStage):
960
+ """Cosmos 2.5 Text2World denoising stage."""
961
+
962
+ _CONDITIONING_FIELDS = (
963
+ "conditioning_latents",
964
+ "cond_indicator",
965
+ "uncond_indicator",
966
+ )
967
+
968
+ def forward(
969
+ self,
970
+ batch: ForwardBatch,
971
+ fastvideo_args: FastVideoArgs,
972
+ ) -> ForwardBatch:
973
+ for name in self._CONDITIONING_FIELDS:
974
+ if hasattr(batch, name):
975
+ setattr(batch, name, None)
976
+ return super().forward(batch, fastvideo_args)
977
+
978
+
979
+ class Cosmos25V2WDenoisingStage(Cosmos25DenoisingStage):
980
+ """Cosmos 2.5 Video2World denoising stage."""
981
+
982
+ def forward(
983
+ self,
984
+ batch: ForwardBatch,
985
+ fastvideo_args: FastVideoArgs,
986
+ ) -> ForwardBatch:
987
+ return super().forward(batch, fastvideo_args)
988
+
989
+
990
+ class Cosmos25AutoDenoisingStage(PipelineStage):
991
+ """Route Cosmos 2.5 denoising to T2W vs V2W/I2W."""
992
+
993
+ def __init__(self, transformer, scheduler) -> None:
994
+ super().__init__()
995
+ self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
996
+ self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)
997
+
998
+ def pipeline(self):
999
+ return self._v2w.pipeline() if self._v2w.pipeline else None
1000
+
1001
+ def forward(
1002
+ self,
1003
+ batch: ForwardBatch,
1004
+ fastvideo_args: FastVideoArgs,
1005
+ ) -> ForwardBatch:
1006
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1007
+ if conditioning_latents is not None:
1008
+ return self._v2w.forward(batch, fastvideo_args)
1009
+ return self._t2w.forward(batch, fastvideo_args)
1010
+
1011
+ def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
1012
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1013
+ if conditioning_latents is not None:
1014
+ return self._v2w.verify_input(batch, fastvideo_args)
1015
+ return self._t2w.verify_input(batch, fastvideo_args)
1016
+
1017
+ def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
1018
+ conditioning_latents = getattr(batch, "conditioning_latents", None)
1019
+ if conditioning_latents is not None:
1020
+ return self._v2w.verify_output(batch, fastvideo_args)
1021
+ return self._t2w.verify_output(batch, fastvideo_args)
1022
+
1023
+
1024
+ class DmdDenoisingStage(DenoisingStage):
1025
+ """
1026
+ Denoising stage for DMD.
1027
+ """
1028
+
1029
+ def __init__(self, transformer, scheduler) -> None:
1030
+ super().__init__(transformer, scheduler)
1031
+ self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
1032
+
1033
+ def forward(
1034
+ self,
1035
+ batch: ForwardBatch,
1036
+ fastvideo_args: FastVideoArgs,
1037
+ ) -> ForwardBatch:
1038
+ """
1039
+ Run the denoising loop.
1040
+
1041
+ Args:
1042
+ batch: The current batch information.
1043
+ fastvideo_args: The inference arguments.
1044
+
1045
+ Returns:
1046
+ The batch with denoised latents.
1047
+ """
1048
+ # Setup precision and autocast settings
1049
+ # TODO(will): make the precision configurable for inference
1050
+ # target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
1051
+ target_dtype = torch.bfloat16
1052
+ autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
1053
+
1054
+ # Get timesteps and calculate warmup steps
1055
+ timesteps = batch.timesteps
1056
+
1057
+ # TODO(will): remove this once we add input/output validation for stages
1058
+ if timesteps is None:
1059
+ raise ValueError("Timesteps must be provided")
1060
+ num_inference_steps = batch.num_inference_steps
1061
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1062
+
1063
+ # Prepare image latents and embeddings for I2V generation
1064
+ image_embeds = batch.image_embeds
1065
+ if len(image_embeds) > 0:
1066
+ assert torch.isnan(image_embeds[0]).sum() == 0
1067
+ image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
1068
+
1069
+ image_kwargs = self.prepare_extra_func_kwargs(
1070
+ self.transformer.forward,
1071
+ {
1072
+ "encoder_hidden_states_image": image_embeds,
1073
+ "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
1074
+ },
1075
+ )
1076
+
1077
+ pos_cond_kwargs = self.prepare_extra_func_kwargs(
1078
+ self.transformer.forward,
1079
+ {
1080
+ "encoder_hidden_states_2": batch.clip_embedding_pos,
1081
+ "encoder_attention_mask": batch.prompt_attention_mask,
1082
+ },
1083
+ )
1084
+
1085
+ # Get latents and embeddings
1086
+ assert batch.latents is not None, "latents must be provided"
1087
+ latents = batch.latents
1088
+
1089
+ video_raw_latent_shape = latents.shape
1090
+ prompt_embeds = batch.prompt_embeds
1091
+ assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
1092
+ timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
1093
+ dtype=torch.long,
1094
+ device=get_local_torch_device())
1095
+
1096
+ # Run denoising loop
1097
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1098
+ for i, t in enumerate(timesteps):
1099
+ # Skip if interrupted
1100
+ if hasattr(self, 'interrupt') and self.interrupt:
1101
+ continue
1102
+ # Expand latents for I2V
1103
+ noise_latents = latents.clone()
1104
+ latent_model_input = latents.to(target_dtype)
1105
+
1106
+ if batch.image_latent is not None:
1107
+ latent_model_input = torch.cat(
1108
+ [latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
1109
+ assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
1110
+
1111
+ # Prepare inputs for transformer
1112
+ t_expand = t.repeat(latent_model_input.shape[0])
1113
+ guidance_expand = (torch.tensor(
1114
+ [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
1115
+ dtype=torch.float32,
1116
+ device=get_local_torch_device(),
1117
+ ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
1118
+
1119
+ # Predict noise residual
1120
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1121
+ if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
1122
+ (sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
1123
+ self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
1124
+
1125
+ if self.attn_metadata_builder_cls is not None:
1126
+ self.attn_metadata_builder = self.attn_metadata_builder_cls()
1127
+ # TODO(will): clean this up
1128
+ attn_metadata = self.attn_metadata_builder.build( # type: ignore
1129
+ current_timestep=i, # type: ignore
1130
+ raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
1131
+ patch_size=fastvideo_args.pipeline_config. # type: ignore
1132
+ dit_config.patch_size, # type: ignore
1133
+ VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
1134
+ device=get_local_torch_device(), # type: ignore
1135
+ ) # type: ignore
1136
+ assert attn_metadata is not None, "attn_metadata cannot be None"
1137
+ else:
1138
+ attn_metadata = None
1139
+ else:
1140
+ attn_metadata = None
1141
+
1142
+ batch.is_cfg_negative = False
1143
+ with set_forward_context(
1144
+ current_timestep=i,
1145
+ attn_metadata=attn_metadata,
1146
+ forward_batch=batch,
1147
+ # fastvideo_args=fastvideo_args
1148
+ ):
1149
+ # Run transformer
1150
+ pred_noise = self.transformer(
1151
+ latent_model_input.permute(0, 2, 1, 3, 4),
1152
+ prompt_embeds,
1153
+ t_expand,
1154
+ guidance=guidance_expand,
1155
+ **image_kwargs,
1156
+ **pos_cond_kwargs,
1157
+ ).permute(0, 2, 1, 3, 4)
1158
+
1159
+ pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
1160
+ noise_input_latent=noise_latents.flatten(0, 1),
1161
+ timestep=t_expand,
1162
+ scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])
1163
+
1164
+ if i < len(timesteps) - 1:
1165
+ next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
1166
+ noise_generator = batch.generator[0] if isinstance(batch.generator, list) else batch.generator
1167
+ noise = torch.randn(video_raw_latent_shape, dtype=pred_video.dtype,
1168
+ generator=noise_generator).to(self.device)
1169
+ latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
1170
+ next_timestep).unflatten(0, pred_video.shape[:2])
1171
+ else:
1172
+ latents = pred_video
1173
+
1174
+ # Update progress bar
1175
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
1176
+ (i + 1) % self.scheduler.order == 0 and progress_bar is not None):
1177
+ progress_bar.update()
1178
+
1179
+ # Gather results if using sequence parallelism
1180
+ latents = latents.permute(0, 2, 1, 3, 4)
1181
+ # Update batch with final latents
1182
+ batch.latents = latents
1183
+
1184
+ return batch
backend_snapshot/fastvideo/platforms/cuda.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py
3
+ """Code inside this file can safely assume cuda platform, e.g. importing
4
+ pynvml. However, it should not initialize cuda context.
5
+ """
6
+
7
+ import os
8
+ from collections.abc import Callable
9
+ from functools import lru_cache, wraps
10
+ from typing import TypeVar
11
+
12
+ import torch
13
+ from typing_extensions import ParamSpec
14
+
15
+ import fastvideo.envs as envs
16
+ from fastvideo.logger import init_logger
17
+ from fastvideo.platforms.interface import (AttentionBackendEnum, DeviceCapability, Platform, PlatformEnum)
18
+ from fastvideo.utils import import_pynvml
19
+
20
+ logger = init_logger(__name__)
21
+
22
+ _P = ParamSpec("_P")
23
+ _R = TypeVar("_R")
24
+
25
+ pynvml = import_pynvml() # type: ignore[no-untyped-call]
26
+
27
+ # pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
28
+ # see https://github.com/huggingface/diffusers/issues/9704 for details
29
+ torch.backends.cuda.enable_cudnn_sdp(False)
30
+
31
+
32
+ def device_id_to_physical_device_id(device_id: int) -> int:
33
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
34
+ device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
35
+ if device_ids == [""]:
36
+ msg = ("CUDA_VISIBLE_DEVICES is set to empty string, which means"
37
+ " GPU support is disabled. If you are using ray, please unset"
38
+ " the environment variable `CUDA_VISIBLE_DEVICES` inside the"
39
+ " worker/actor. "
40
+ "Check https://github.com/vllm-project/vllm/issues/8402 for"
41
+ " more information.")
42
+ raise RuntimeError(msg)
43
+ physical_device_id = device_ids[device_id]
44
+ return int(physical_device_id)
45
+ else:
46
+ return device_id
47
+
48
+
49
+ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
50
+
51
+ @wraps(fn)
52
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
53
+ pynvml.nvmlInit()
54
+ try:
55
+ return fn(*args, **kwargs)
56
+ finally:
57
+ pynvml.nvmlShutdown()
58
+
59
+ return wrapper
60
+
61
+
62
+ class CudaPlatformBase(Platform):
63
+ _enum = PlatformEnum.CUDA
64
+ device_name: str = "cuda"
65
+ device_type: str = "cuda"
66
+ dispatch_key: str = "CUDA"
67
+ ray_device_key: str = "GPU"
68
+ device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
69
+
70
+ @classmethod
71
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
72
+ raise NotImplementedError
73
+
74
+ @classmethod
75
+ def get_device_name(cls, device_id: int = 0) -> str:
76
+ raise NotImplementedError
77
+
78
+ @classmethod
79
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
80
+ raise NotImplementedError
81
+
82
+ @classmethod
83
+ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
84
+ if enforce_eager:
85
+ logger.warning("To see benefits of async output processing, enable CUDA "
86
+ "graph. Since, enforce-eager is enabled, async output "
87
+ "processor cannot be used")
88
+ return False
89
+ return True
90
+
91
+ @classmethod
92
+ def is_full_nvlink(cls, device_ids: list[int]) -> bool:
93
+ raise NotImplementedError
94
+
95
+ @classmethod
96
+ def log_warnings(cls) -> None:
97
+ pass
98
+
99
+ @classmethod
100
+ def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
101
+ torch.cuda.reset_peak_memory_stats(device)
102
+ return float(torch.cuda.max_memory_allocated(device))
103
+
104
+ @classmethod
105
+ def get_torch_device(cls) -> object:
106
+ """
107
+ Return torch.cuda
108
+ """
109
+ return torch.cuda
110
+
111
+ @classmethod
112
+ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
113
+ dtype: torch.dtype) -> str:
114
+ # TODO(will): maybe come up with a more general interface for local attention
115
+ # if distributed is False, we always try to use Flash attn
116
+
117
+ logger.info("Trying FASTVIDEO_ATTENTION_BACKEND=%s", envs.FASTVIDEO_ATTENTION_BACKEND)
118
+ logger.info("Selected backend: %s", selected_backend)
119
+ if selected_backend == AttentionBackendEnum.SAGE_ATTN:
120
+ try:
121
+ from sageattention import sageattn # noqa: F401
122
+
123
+ from fastvideo.attention.backends.sage_attn import ( # noqa: F401
124
+ SageAttentionBackend)
125
+ logger.info("Using Sage Attention backend.")
126
+
127
+ return "fastvideo.attention.backends.sage_attn.SageAttentionBackend"
128
+ except ImportError as e:
129
+ logger.info(e)
130
+ logger.info("Sage Attention backend is not installed. Fall back to Flash Attention.")
131
+ elif selected_backend == AttentionBackendEnum.SAGE_ATTN_THREE:
132
+ try:
133
+ from sageattn3 import sageattn3_blackwell # noqa: F401
134
+
135
+ from fastvideo.attention.backends.sage_attn3 import ( # noqa: F401
136
+ SageAttention3Backend)
137
+ logger.info("Using Sage Attention 3 backend.")
138
+
139
+ return "fastvideo.attention.backends.sage_attn3.SageAttention3Backend"
140
+ except ImportError as e:
141
+ logger.info(e)
142
+ logger.info("Sage Attention 3 backend is not installed. Fall back to Flash Attention.")
143
+ elif selected_backend == AttentionBackendEnum.ATTN_QAT_INFER:
144
+ try:
145
+ from fastvideo.attention.backends.attn_qat_infer import ( # noqa: F401
146
+ AttnQatInferBackend, is_attn_qat_infer_available,
147
+ )
148
+ if not is_attn_qat_infer_available():
149
+ raise ImportError("attn_qat_infer could not be imported.")
150
+ logger.info("Using attn_qat_infer backend.")
151
+
152
+ return "fastvideo.attention.backends.attn_qat_infer.AttnQatInferBackend"
153
+ except ImportError as e:
154
+ logger.info(e)
155
+ logger.info("attn_qat_infer backend is not installed. Fall back to Flash Attention.")
156
+ elif selected_backend == AttentionBackendEnum.ATTN_QAT_TRAIN:
157
+ try:
158
+ from fastvideo_kernel.triton_kernels.attn_qat_train import attention # noqa: F401
159
+
160
+ from fastvideo.attention.backends.attn_qat_train import ( # noqa: F401
161
+ AttnQatTrainBackend)
162
+ logger.info("Using attn_qat_train backend.")
163
+
164
+ return "fastvideo.attention.backends.attn_qat_train.AttnQatTrainBackend"
165
+ except ImportError as e:
166
+ logger.info(e)
167
+ logger.info("attn_qat_train backend is not installed. Fall back to Flash Attention.")
168
+ elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN:
169
+ try:
170
+ from fastvideo_kernel import video_sparse_attn # noqa: F401
171
+
172
+ from fastvideo.attention.backends.video_sparse_attn import ( # noqa: F401
173
+ VideoSparseAttentionBackend)
174
+ logger.info("Using Video Sparse Attention backend.")
175
+
176
+ return "fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend"
177
+ except ImportError as e:
178
+ logger.error("Failed to import Video Sparse Attention backend: %s", str(e))
179
+ raise ImportError("The Video Sparse Attention backend is not installed. "
180
+ "To install it, please follow the instructions at: "
181
+ "https://hao-ai-lab.github.io/FastVideo/video_sparse_attention/installation ") from e
182
+ elif selected_backend == AttentionBackendEnum.SPARSE_FP4_ATTN:
183
+ try:
184
+ from fastvideo.attention.backends.sparse_fp4_attn import ( # noqa: F401
185
+ SparseFP4AttentionBackend)
186
+ logger.info("Using Sparse FP4 Attention backend (FP4 quant + VSA).")
187
+ return "fastvideo.attention.backends.sparse_fp4_attn.SparseFP4AttentionBackend"
188
+ except ImportError as e:
189
+ logger.error("Failed to import Sparse FP4 Attention backend: %s", str(e))
190
+ raise ImportError("Sparse FP4 Attention backend is not available.") from e
191
+ elif selected_backend == AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN:
192
+ try:
193
+ from fastvideo.attention.backends.sparse_fp4_ours_p_attn import ( # noqa: F401
194
+ SparseFP4OursPAttentionBackend)
195
+ logger.info(
196
+ "Using Sparse FP4 Ours-P Attention backend (group-local P quant + VSA)."
197
+ )
198
+ return "fastvideo.attention.backends.sparse_fp4_ours_p_attn.SparseFP4OursPAttentionBackend"
199
+ except ImportError as e:
200
+ logger.error("Failed to import Sparse FP4 Ours-P Attention backend: %s", str(e))
201
+ raise ImportError("Sparse FP4 Ours-P Attention backend is not available.") from e
202
+ elif selected_backend == AttentionBackendEnum.BSA_ATTN:
203
+ try:
204
+ from fastvideo.attention.backends.bsa_attn import ( # noqa: F401
205
+ BSAAttentionBackend)
206
+ logger.info("Using BSA Attention backend.")
207
+
208
+ return "fastvideo.attention.backends.bsa_attn.BSAAttentionBackend"
209
+ except ImportError as e:
210
+ logger.error("Failed to import BSA Attention backend: %s", str(e))
211
+ raise ImportError("The BSA Attention backend failed to import.") from e
212
+ elif selected_backend == AttentionBackendEnum.VMOBA_ATTN:
213
+ try:
214
+ from fastvideo_kernel import moba_attn_varlen # noqa: F401
215
+ from fastvideo.attention.backends.vmoba import ( # noqa: F401
216
+ VMOBAAttentionBackend)
217
+ logger.info("Using Video MOBA Attention backend.")
218
+
219
+ return "fastvideo.attention.backends.vmoba.VMOBAAttentionBackend"
220
+ except ImportError as e:
221
+ logger.error("Failed to import Video MoBA Attention backend: %s", str(e))
222
+ raise ImportError("Video MoBA Attention backend is not installed. ") from e
223
+ elif selected_backend == AttentionBackendEnum.SLA_ATTN:
224
+ try:
225
+ from fastvideo.attention.backends.sla import ( # noqa: F401
226
+ SLAAttentionBackend)
227
+ logger.info("Using SLA (Sparse-Linear Attention) backend.")
228
+
229
+ return "fastvideo.attention.backends.sla.SLAAttentionBackend"
230
+ except ImportError as e:
231
+ logger.error("Failed to import SLA Attention backend: %s", str(e))
232
+ raise ImportError("SLA Attention backend is not available. ") from e
233
+ elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN:
234
+ try:
235
+ from fastvideo.attention.backends.sla import ( # noqa: F401
236
+ SageSLAAttentionBackend)
237
+ logger.info("Using SageSLA (Quantized Sparse-Linear Attention) backend.")
238
+
239
+ return "fastvideo.attention.backends.sla.SageSLAAttentionBackend"
240
+ except ImportError as e:
241
+ logger.error("Failed to import SageSLA Attention backend: %s", str(e))
242
+ raise ImportError("SageSLA Attention backend requires spas_sage_attn. "
243
+ "Install with: pip install git+https://github.com/thu-ml/SpargeAttn.git") from e
244
+ elif selected_backend == AttentionBackendEnum.TORCH_SDPA:
245
+ logger.info("Using Torch SDPA backend.")
246
+ return "fastvideo.attention.backends.sdpa.SDPABackend"
247
+ elif selected_backend == AttentionBackendEnum.FLASH_ATTN or selected_backend is None:
248
+ pass
249
+ elif selected_backend:
250
+ raise ValueError(f"Invalid attention backend for {cls.device_name}")
251
+
252
+ target_backend = AttentionBackendEnum.FLASH_ATTN
253
+ if not cls.has_device_capability(80):
254
+ logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
255
+ "GPUs.")
256
+ target_backend = AttentionBackendEnum.TORCH_SDPA
257
+ elif dtype not in (torch.float16, torch.bfloat16):
258
+ logger.info("Cannot use FlashAttention-2 backend for dtype other than "
259
+ "torch.float16 or torch.bfloat16.")
260
+ target_backend = AttentionBackendEnum.TORCH_SDPA
261
+
262
+ # FlashAttn is valid for the model, checking if the package is
263
+ # installed.
264
+ if target_backend == AttentionBackendEnum.FLASH_ATTN:
265
+ try:
266
+ import flash_attn # noqa: F401
267
+
268
+ from fastvideo.attention.backends.flash_attn import ( # noqa: F401
269
+ FlashAttentionBackend)
270
+
271
+ supported_sizes = \
272
+ FlashAttentionBackend.get_supported_head_sizes()
273
+ if head_size not in supported_sizes:
274
+ logger.info("Cannot use FlashAttention-2 backend for head size %d.", head_size)
275
+ target_backend = AttentionBackendEnum.TORCH_SDPA
276
+ except ImportError:
277
+ logger.info("Cannot use FlashAttention-2 backend because the "
278
+ "flash_attn package is not found. "
279
+ "Make sure that flash_attn was built and installed "
280
+ "(on by default).")
281
+ target_backend = AttentionBackendEnum.TORCH_SDPA
282
+
283
+ if target_backend == AttentionBackendEnum.TORCH_SDPA:
284
+ logger.info("Using Torch SDPA backend.")
285
+
286
+ return "fastvideo.attention.backends.sdpa.SDPABackend"
287
+
288
+ logger.info("Using Flash Attention backend.")
289
+
290
+ return "fastvideo.attention.backends.flash_attn.FlashAttentionBackend"
291
+
292
+ @classmethod
293
+ def get_device_communicator_cls(cls) -> str:
294
+ return "fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
295
+
296
+
297
+ # NVML utils
298
+ # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
299
+ # all the related functions work on real physical device ids.
300
+ # the major benefit of using NVML is that it will not initialize CUDA
301
+ class NvmlCudaPlatform(CudaPlatformBase):
302
+
303
+ @classmethod
304
+ @lru_cache(maxsize=8)
305
+ @with_nvml_context
306
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
307
+ try:
308
+ physical_device_id = device_id_to_physical_device_id(device_id)
309
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
310
+ major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
311
+ return DeviceCapability(major=major, minor=minor)
312
+ except RuntimeError:
313
+ return None
314
+
315
+ @classmethod
316
+ @lru_cache(maxsize=8)
317
+ @with_nvml_context
318
+ def has_device_capability(
319
+ cls,
320
+ capability: tuple[int, int] | int,
321
+ device_id: int = 0,
322
+ ) -> bool:
323
+ try:
324
+ return bool(super().has_device_capability(capability, device_id))
325
+ except RuntimeError:
326
+ return False
327
+
328
+ @classmethod
329
+ @lru_cache(maxsize=8)
330
+ @with_nvml_context
331
+ def get_device_name(cls, device_id: int = 0) -> str:
332
+ physical_device_id = device_id_to_physical_device_id(device_id)
333
+ return cls._get_physical_device_name(physical_device_id)
334
+
335
+ @classmethod
336
+ @lru_cache(maxsize=8)
337
+ @with_nvml_context
338
+ def get_device_uuid(cls, device_id: int = 0) -> str:
339
+ physical_device_id = device_id_to_physical_device_id(device_id)
340
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
341
+ return str(pynvml.nvmlDeviceGetUUID(handle))
342
+
343
+ @classmethod
344
+ @lru_cache(maxsize=8)
345
+ @with_nvml_context
346
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
347
+ physical_device_id = device_id_to_physical_device_id(device_id)
348
+ handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
349
+ return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
350
+
351
+ @classmethod
352
+ @with_nvml_context
353
+ def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
354
+ """
355
+ query if the set of gpus are fully connected by nvlink (1 hop)
356
+ """
357
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
358
+ for i, handle in enumerate(handles):
359
+ for j, peer_handle in enumerate(handles):
360
+ if i < j:
361
+ try:
362
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
363
+ handle,
364
+ peer_handle,
365
+ pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
366
+ )
367
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
368
+ return False
369
+ except pynvml.NVMLError:
370
+ logger.exception("NVLink detection failed. This is normal if"
371
+ " your machine has no NVLink equipped.")
372
+ return False
373
+ return True
374
+
375
+ @classmethod
376
+ def _get_physical_device_name(cls, device_id: int = 0) -> str:
377
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
378
+ return str(pynvml.nvmlDeviceGetName(handle))
379
+
380
+ @classmethod
381
+ @with_nvml_context
382
+ def log_warnings(cls) -> None:
383
+ device_ids: int = pynvml.nvmlDeviceGetCount()
384
+ if device_ids > 1:
385
+ device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
386
+ if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
387
+ logger.warning(
388
+ "Detected different devices in the system: %s. Please"
389
+ " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
390
+ "avoid unexpected behavior.",
391
+ ", ".join(device_names),
392
+ )
393
+
394
+
395
+ class NonNvmlCudaPlatform(CudaPlatformBase):
396
+
397
+ @classmethod
398
+ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
399
+ major, minor = torch.cuda.get_device_capability(device_id)
400
+ return DeviceCapability(major=major, minor=minor)
401
+
402
+ @classmethod
403
+ def get_device_name(cls, device_id: int = 0) -> str:
404
+ return str(torch.cuda.get_device_name(device_id))
405
+
406
+ @classmethod
407
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
408
+ device_props = torch.cuda.get_device_properties(device_id)
409
+ return int(device_props.total_memory)
410
+
411
+ @classmethod
412
+ def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
413
+ logger.exception("NVLink detection not possible, as context support was"
414
+ " not found. Assuming no NVLink available.")
415
+ return False
416
+
417
+
418
+ # Autodetect either NVML-enabled or non-NVML platform
419
+ # based on whether NVML is available.
420
+ nvml_available = False
421
+ try:
422
+ try:
423
+ pynvml.nvmlInit()
424
+ nvml_available = True
425
+ except Exception:
426
+ # On Jetson, NVML is not supported.
427
+ nvml_available = False
428
+ finally:
429
+ if nvml_available:
430
+ pynvml.nvmlShutdown()
431
+
432
+ CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
433
+
434
+ try:
435
+ from sphinx.ext.autodoc.mock import _MockModule
436
+
437
+ if not isinstance(pynvml, _MockModule):
438
+ CudaPlatform.log_warnings()
439
+ except ModuleNotFoundError:
440
+ CudaPlatform.log_warnings()
backend_snapshot/fastvideo/platforms/interface.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import random
3
+ from typing import Any, NamedTuple
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from fastvideo.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ class AttentionBackendEnum(enum.Enum):
14
+ FLASH_ATTN = enum.auto()
15
+ TORCH_SDPA = enum.auto()
16
+ SAGE_ATTN = enum.auto()
17
+ SAGE_ATTN_THREE = enum.auto()
18
+ ATTN_QAT_INFER = enum.auto()
19
+ ATTN_QAT_TRAIN = enum.auto()
20
+ VIDEO_SPARSE_ATTN = enum.auto()
21
+ BSA_ATTN = enum.auto()
22
+ VMOBA_ATTN = enum.auto()
23
+ SLA_ATTN = enum.auto()
24
+ SAGE_SLA_ATTN = enum.auto()
25
+ SPARSE_FP4_ATTN = enum.auto()
26
+ SPARSE_FP4_OURS_P_ATTN = enum.auto()
27
+ NO_ATTENTION = enum.auto()
28
+
29
+
30
+ class PlatformEnum(enum.Enum):
31
+ CUDA = enum.auto()
32
+ ROCM = enum.auto()
33
+ TPU = enum.auto()
34
+ XPU = enum.auto()
35
+ CPU = enum.auto()
36
+ MPS = enum.auto()
37
+ OOT = enum.auto()
38
+ UNSPECIFIED = enum.auto()
39
+ NPU = enum.auto()
40
+
41
+
42
+ class CpuArchEnum(enum.Enum):
43
+ X86 = enum.auto()
44
+ ARM = enum.auto()
45
+ UNSPECIFIED = enum.auto()
46
+
47
+
48
+ class DeviceCapability(NamedTuple):
49
+ major: int
50
+ minor: int
51
+
52
+ def as_version_str(self) -> str:
53
+ return f"{self.major}.{self.minor}"
54
+
55
+ def to_int(self) -> int:
56
+ """
57
+ Express device capability as an integer ``<major><minor>``.
58
+
59
+ It is assumed that the minor version is always a single digit.
60
+ """
61
+ assert 0 <= self.minor < 10
62
+ return self.major * 10 + self.minor
63
+
64
+
65
+ class Platform:
66
+ _enum: PlatformEnum
67
+ device_name: str
68
+ device_type: str
69
+
70
+ dispatch_key: str = "CPU"
71
+
72
+ # platform-agnostic way to specify the device control environment variable,
73
+ # .e.g. CUDA_VISIBLE_DEVICES for CUDA.
74
+ # hint: search for "get_visible_accelerator_ids_env_var" in
75
+ # https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
76
+ device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
77
+
78
+ # available ray device keys:
79
+ # https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
80
+ # empty string means the device does not support ray
81
+ ray_device_key: str = ""
82
+ # The torch.compile backend for compiling simple and
83
+ # standalone functions. The default value is "inductor" to keep
84
+ # the same behavior as PyTorch.
85
+ # NOTE: for the forward part of the model, vLLM has another separate
86
+ # compilation strategy.
87
+ simple_compile_backend: str = "inductor"
88
+
89
+ supported_quantization: list[str] = []
90
+
91
+ additional_env_vars: list[str] = []
92
+
93
+ def is_cuda(self) -> bool:
94
+ return self._enum == PlatformEnum.CUDA
95
+
96
+ def is_rocm(self) -> bool:
97
+ return self._enum == PlatformEnum.ROCM
98
+
99
+ def is_tpu(self) -> bool:
100
+ return self._enum == PlatformEnum.TPU
101
+
102
+ def is_xpu(self) -> bool:
103
+ return self._enum == PlatformEnum.XPU
104
+
105
+ def is_cpu(self) -> bool:
106
+ return self._enum == PlatformEnum.CPU
107
+
108
+ def is_out_of_tree(self) -> bool:
109
+ return self._enum == PlatformEnum.OOT
110
+
111
+ def is_cuda_alike(self) -> bool:
112
+ """Stateless version of :func:`torch.cuda.is_available`."""
113
+ return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
114
+
115
+ def is_mps(self) -> bool:
116
+ return self._enum == PlatformEnum.MPS
117
+
118
+ def is_npu(self) -> bool:
119
+ return self._enum == PlatformEnum.NPU
120
+
121
+ @classmethod
122
+ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
123
+ dtype: torch.dtype) -> str:
124
+ """Get the attention backend class of a device."""
125
+ return ""
126
+
127
+ @classmethod
128
+ def get_device_capability(
129
+ cls,
130
+ device_id: int = 0,
131
+ ) -> DeviceCapability | None:
132
+ """Stateless version of :func:`torch.cuda.get_device_capability`."""
133
+ return None
134
+
135
+ @classmethod
136
+ def has_device_capability(
137
+ cls,
138
+ capability: tuple[int, int] | int,
139
+ device_id: int = 0,
140
+ ) -> bool:
141
+ """
142
+ Test whether this platform is compatible with a device capability.
143
+
144
+ The ``capability`` argument can either be:
145
+
146
+ - A tuple ``(major, minor)``.
147
+ - An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
148
+ """
149
+ current_capability = cls.get_device_capability(device_id=device_id)
150
+ if current_capability is None:
151
+ return False
152
+
153
+ if isinstance(capability, tuple):
154
+ return current_capability >= capability
155
+
156
+ return current_capability.to_int() >= capability
157
+
158
+ @classmethod
159
+ def get_device_name(cls, device_id: int = 0) -> str:
160
+ """Get the name of a device."""
161
+ raise NotImplementedError
162
+
163
+ @classmethod
164
+ def get_device_uuid(cls, device_id: int = 0) -> str:
165
+ """Get the uuid of a device, e.g. the PCI bus ID."""
166
+ raise NotImplementedError
167
+
168
+ @classmethod
169
+ def get_device_total_memory(cls, device_id: int = 0) -> int:
170
+ """Get the total memory of a device in bytes."""
171
+ raise NotImplementedError
172
+
173
+ @classmethod
174
+ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
175
+ """
176
+ Check if the current platform supports async output.
177
+ """
178
+ raise NotImplementedError
179
+
180
+ @classmethod
181
+ def get_torch_device(cls) -> Any:
182
+ """
183
+ Check if the current platform supports torch device.
184
+ """
185
+ raise NotImplementedError
186
+
187
+ @classmethod
188
+ def inference_mode(cls):
189
+ """A device-specific wrapper of `torch.inference_mode`.
190
+
191
+ This wrapper is recommended because some hardware backends such as TPU
192
+ do not support `torch.inference_mode`. In such a case, they will fall
193
+ back to `torch.no_grad` by overriding this method.
194
+ """
195
+ return torch.inference_mode(mode=True)
196
+
197
+ @classmethod
198
+ def seed_everything(cls, seed: int | None = None) -> None:
199
+ """
200
+ Set the seed of each random module.
201
+ `torch.manual_seed` will set seed on all devices.
202
+
203
+ Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
204
+ """
205
+ if seed is not None:
206
+ random.seed(seed)
207
+ np.random.seed(seed)
208
+ torch.manual_seed(seed)
209
+ torch.cuda.manual_seed_all(seed)
210
+
211
+ @classmethod
212
+ def verify_model_arch(cls, model_arch: str) -> None:
213
+ """
214
+ Verify whether the current platform supports the specified model
215
+ architecture.
216
+
217
+ - This will raise an Error or Warning based on the model support on
218
+ the current platform.
219
+ - By default all models are considered supported.
220
+ """
221
+ pass
222
+
223
+ @classmethod
224
+ def verify_quantization(cls, quant: str) -> None:
225
+ """
226
+ Verify whether the quantization is supported by the current platform.
227
+ """
228
+ if cls.supported_quantization and \
229
+ quant not in cls.supported_quantization:
230
+ raise ValueError(f"{quant} quantization is currently not supported in "
231
+ f"{cls.device_name}.")
232
+
233
+ @classmethod
234
+ def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
235
+ """
236
+ Return the memory usage in bytes.
237
+ """
238
+ raise NotImplementedError
239
+
240
+ @classmethod
241
+ def get_device_communicator_cls(cls) -> str:
242
+ """
243
+ Get device specific communicator class for distributed communication.
244
+ """
245
+ return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
246
+
247
+ @classmethod
248
+ def get_cpu_architecture(cls) -> CpuArchEnum:
249
+ """Get the CPU architecture of the current platform."""
250
+ return CpuArchEnum.UNSPECIFIED
251
+
252
+
253
+ class UnspecifiedPlatform(Platform):
254
+ _enum = PlatformEnum.UNSPECIFIED
255
+ device_type = ""
backend_snapshot/fastvideo/train/models/wan/wan.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Wan model plugin (per-role instance)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import copy
7
+ import gc
8
+ from typing import Any, Literal, TYPE_CHECKING
9
+
10
+ import torch
11
+
12
+ import fastvideo.envs as envs
13
+ from fastvideo.configs.sample import SamplingParam
14
+ from fastvideo.distributed import (
15
+ get_sp_group,
16
+ get_world_group,
17
+ )
18
+ from fastvideo.forward_context import set_forward_context
19
+ from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
20
+ FlowMatchEulerDiscreteScheduler, )
21
+ from fastvideo.pipelines import TrainingBatch
22
+ from fastvideo.pipelines.basic.wan.wan_pipeline import (
23
+ WanPipeline, )
24
+ from fastvideo.pipelines.pipeline_batch_info import (
25
+ ForwardBatch, )
26
+ from fastvideo.training.activation_checkpoint import (
27
+ apply_activation_checkpointing, )
28
+ from fastvideo.training.training_utils import (
29
+ compute_density_for_timestep_sampling,
30
+ get_sigmas,
31
+ normalize_dit_input,
32
+ shift_timestep,
33
+ )
34
+ from fastvideo.utils import (
35
+ is_vmoba_available,
36
+ is_vsa_available,
37
+ )
38
+
39
+ from fastvideo.train.models.base import ModelBase
40
+ from fastvideo.train.utils.module_state import (
41
+ apply_trainable, )
42
+ from fastvideo.train.utils.moduleloader import (
43
+ load_module_from_path, )
44
+
45
+ if TYPE_CHECKING:
46
+ from fastvideo.train.utils.training_config import (
47
+ TrainingConfig, )
48
+
49
+ VideoSparseAttentionMetadataBuilder: type[Any] | None
50
+ VideoMobaAttentionMetadataBuilder: type[Any] | None
51
+
52
+ try:
53
+ from fastvideo.attention.backends.video_sparse_attn import (
54
+ VideoSparseAttentionMetadataBuilder as _VideoSparseAttentionMetadataBuilder, )
55
+ from fastvideo.attention.backends.vmoba import (
56
+ VideoMobaAttentionMetadataBuilder as _VideoMobaAttentionMetadataBuilder, )
57
+ VideoSparseAttentionMetadataBuilder = _VideoSparseAttentionMetadataBuilder
58
+ VideoMobaAttentionMetadataBuilder = _VideoMobaAttentionMetadataBuilder
59
+ except Exception:
60
+ VideoSparseAttentionMetadataBuilder = None
61
+ VideoMobaAttentionMetadataBuilder = None
62
+
63
+
64
+ class WanModel(ModelBase):
65
+ """Wan per-role model: owns transformer + noise_scheduler."""
66
+
67
+ _transformer_cls_name: str = "WanTransformer3DModel"
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ init_from: str,
73
+ training_config: TrainingConfig,
74
+ trainable: bool = True,
75
+ disable_custom_init_weights: bool = False,
76
+ flow_shift: float = 3.0,
77
+ enable_gradient_checkpointing_type: str
78
+ | None = None,
79
+ transformer_override_safetensor: str
80
+ | None = None,
81
+ ) -> None:
82
+ self._init_from = str(init_from)
83
+ self._trainable = bool(trainable)
84
+
85
+ self.transformer = self._load_transformer(
86
+ init_from=self._init_from,
87
+ trainable=self._trainable,
88
+ disable_custom_init_weights=(disable_custom_init_weights),
89
+ enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
90
+ training_config=training_config,
91
+ transformer_override_safetensor=(transformer_override_safetensor),
92
+ )
93
+
94
+ self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))
95
+
96
+ # Filled by init_preprocessors (student only).
97
+ self.vae: Any = None
98
+ self.training_config: TrainingConfig = training_config
99
+ self.dataloader: Any = None
100
+ self.validator: Any = None
101
+ self.start_step: int = 0
102
+
103
+ self.world_group: Any = None
104
+ self.sp_group: Any = None
105
+
106
+ self.negative_prompt_embeds: (torch.Tensor | None) = None
107
+ self.negative_prompt_attention_mask: (torch.Tensor | None) = None
108
+
109
+ # Timestep mechanics.
110
+ self.timestep_shift: float = float(flow_shift)
111
+ self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
112
+ self.min_timestep: int = 0
113
+ self.max_timestep: int = self.num_train_timestep
114
+
115
+ def _load_transformer(
116
+ self,
117
+ *,
118
+ init_from: str,
119
+ trainable: bool,
120
+ disable_custom_init_weights: bool,
121
+ enable_gradient_checkpointing_type: str | None,
122
+ training_config: TrainingConfig,
123
+ transformer_override_safetensor: str | None = None,
124
+ ) -> torch.nn.Module:
125
+ transformer = load_module_from_path(
126
+ model_path=init_from,
127
+ module_type="transformer",
128
+ training_config=training_config,
129
+ disable_custom_init_weights=(disable_custom_init_weights),
130
+ override_transformer_cls_name=(self._transformer_cls_name),
131
+ transformer_override_safetensor=(transformer_override_safetensor),
132
+ )
133
+ transformer = apply_trainable(transformer, trainable=trainable)
134
+ # Fall back to training_config.model if not set on the
135
+ # model YAML section directly.
136
+ ckpt_type = (enable_gradient_checkpointing_type or getattr(
137
+ getattr(training_config, "model", None),
138
+ "enable_gradient_checkpointing_type",
139
+ None,
140
+ ))
141
+ if trainable and ckpt_type:
142
+ transformer = apply_activation_checkpointing(
143
+ transformer,
144
+ checkpointing_type=ckpt_type,
145
+ )
146
+ return transformer
147
+
148
+ # ------------------------------------------------------------------
149
+ # Lifecycle
150
+ # ------------------------------------------------------------------
151
+
152
+ def init_preprocessors(self, training_config: TrainingConfig) -> None:
153
+ self.vae = load_module_from_path(
154
+ model_path=str(training_config.model_path),
155
+ module_type="vae",
156
+ training_config=training_config,
157
+ )
158
+
159
+ self.world_group = get_world_group()
160
+ self.sp_group = get_sp_group()
161
+
162
+ self._init_timestep_mechanics()
163
+
164
+ from fastvideo.dataset.dataloader.schema import (
165
+ pyarrow_schema_t2v, )
166
+ from fastvideo.train.utils.dataloader import (
167
+ build_parquet_t2v_train_dataloader, )
168
+
169
+ text_len = (
170
+ training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr]
171
+ 0].arch_config.text_len)
172
+ self.dataloader = build_parquet_t2v_train_dataloader(
173
+ training_config.data,
174
+ text_len=int(text_len),
175
+ parquet_schema=pyarrow_schema_t2v,
176
+ )
177
+ self.start_step = 0
178
+
179
+ @property
180
+ def num_train_timesteps(self) -> int:
181
+ return int(self.num_train_timestep)
182
+
183
+ def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
184
+ timestep = shift_timestep(
185
+ timestep,
186
+ self.timestep_shift,
187
+ self.num_train_timestep,
188
+ )
189
+ return timestep.clamp(self.min_timestep, self.max_timestep)
190
+
191
+ def on_train_start(self) -> None:
192
+ self.ensure_negative_conditioning()
193
+
194
+ # ------------------------------------------------------------------
195
+ # Runtime primitives
196
+ # ------------------------------------------------------------------
197
+
198
+ def prepare_batch(
199
+ self,
200
+ raw_batch: dict[str, Any],
201
+ *,
202
+ generator: torch.Generator,
203
+ latents_source: Literal["data", "zeros"] = "data",
204
+ ) -> TrainingBatch:
205
+ self.ensure_negative_conditioning()
206
+ assert self.training_config is not None
207
+ tc = self.training_config
208
+
209
+ dtype = self._get_training_dtype()
210
+ device = self.device
211
+
212
+ training_batch = TrainingBatch()
213
+ encoder_hidden_states = raw_batch["text_embedding"]
214
+ encoder_attention_mask = raw_batch["text_attention_mask"]
215
+ infos = raw_batch.get("info_list")
216
+
217
+ if latents_source == "zeros":
218
+ batch_size = encoder_hidden_states.shape[0]
219
+ vae_config = (
220
+ tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr]
221
+ )
222
+ num_channels = vae_config.z_dim
223
+ spatial_compression_ratio = (vae_config.spatial_compression_ratio)
224
+ latent_height = (tc.data.num_height // spatial_compression_ratio)
225
+ latent_width = (tc.data.num_width // spatial_compression_ratio)
226
+ latents = torch.zeros(
227
+ batch_size,
228
+ num_channels,
229
+ tc.data.num_latent_t,
230
+ latent_height,
231
+ latent_width,
232
+ device=device,
233
+ dtype=dtype,
234
+ )
235
+ elif latents_source == "data":
236
+ if "vae_latent" not in raw_batch:
237
+ raise ValueError("vae_latent not found in batch "
238
+ "and latents_source='data'")
239
+ latents = raw_batch["vae_latent"]
240
+ latents = latents[:, :, :tc.data.num_latent_t]
241
+ latents = latents.to(device, dtype=dtype)
242
+ else:
243
+ raise ValueError(f"Unknown latents_source: "
244
+ f"{latents_source!r}")
245
+
246
+ training_batch.latents = latents
247
+ training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
248
+ training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
249
+ training_batch.infos = infos
250
+
251
+ training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae)
252
+ training_batch = self._prepare_dit_inputs(training_batch, generator)
253
+ training_batch = self._build_attention_metadata(training_batch)
254
+
255
+ training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata)
256
+ if training_batch.attn_metadata is not None:
257
+ training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined]
258
+
259
+ return training_batch
260
+
261
+ def add_noise(
262
+ self,
263
+ clean_latents: torch.Tensor,
264
+ noise: torch.Tensor,
265
+ timestep: torch.Tensor,
266
+ ) -> torch.Tensor:
267
+ b, t = clean_latents.shape[:2]
268
+ noisy = self.noise_scheduler.add_noise(
269
+ clean_latents.flatten(0, 1),
270
+ noise.flatten(0, 1),
271
+ timestep,
272
+ ).unflatten(0, (b, t))
273
+ return noisy
274
+
275
+ def predict_noise(
276
+ self,
277
+ noisy_latents: torch.Tensor,
278
+ timestep: torch.Tensor,
279
+ batch: TrainingBatch,
280
+ *,
281
+ conditional: bool,
282
+ cfg_uncond: dict[str, Any] | None = None,
283
+ attn_kind: Literal["dense", "vsa"] = "dense",
284
+ force_dense: bool = False,
285
+ ) -> torch.Tensor:
286
+ device_type = self.device.type
287
+ dtype = noisy_latents.dtype
288
+ if conditional:
289
+ text_dict = batch.conditional_dict
290
+ if text_dict is None:
291
+ raise RuntimeError("Missing conditional_dict in "
292
+ "TrainingBatch")
293
+ else:
294
+ text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond)
295
+
296
+ if attn_kind == "dense":
297
+ attn_metadata = batch.attn_metadata
298
+ elif attn_kind in ("vsa", "sparse_fp4"):
299
+ attn_metadata = batch.attn_metadata_vsa
300
+ else:
301
+ raise ValueError(f"Unknown attn_kind: {attn_kind!r}")
302
+
303
+ with torch.autocast(device_type, dtype=dtype), set_forward_context(
304
+ current_timestep=batch.timesteps,
305
+ attn_metadata=attn_metadata,
306
+ force_dense=force_dense,
307
+ ):
308
+ input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict))
309
+ transformer = self._get_transformer(timestep)
310
+ pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4)
311
+ return pred_noise
312
+
313
+ def backward(
314
+ self,
315
+ loss: torch.Tensor,
316
+ ctx: Any,
317
+ *,
318
+ grad_accum_rounds: int,
319
+ ) -> None:
320
+ timesteps, attn_metadata = ctx
321
+ with set_forward_context(
322
+ current_timestep=timesteps,
323
+ attn_metadata=attn_metadata,
324
+ ):
325
+ (loss / max(1, int(grad_accum_rounds))).backward()
326
+
327
+ # ------------------------------------------------------------------
328
+ # Internal helpers
329
+ # ------------------------------------------------------------------
330
+
331
+ def _get_training_dtype(self) -> torch.dtype:
332
+ return torch.bfloat16
333
+
334
+ def _init_timestep_mechanics(self) -> None:
335
+ assert self.training_config is not None
336
+ tc = self.training_config
337
+ flow_shift = tc.pipeline_config.flow_shift
338
+ self.timestep_shift = float(0.0 if flow_shift is None else flow_shift)
339
+ self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps)
340
+ # min/max timestep ratios now come from method_config;
341
+ # default to full range.
342
+ self.min_timestep = 0
343
+ self.max_timestep = self.num_train_timestep
344
+
345
+ def ensure_negative_conditioning(self) -> None:
346
+ if self.negative_prompt_embeds is not None:
347
+ return
348
+
349
+ assert self.training_config is not None
350
+ tc = self.training_config
351
+ world_group = self.world_group
352
+ device = self.device
353
+ dtype = self._get_training_dtype()
354
+
355
+ from fastvideo.train.utils.moduleloader import (
356
+ make_inference_args, )
357
+
358
+ neg_embeds: torch.Tensor | None = None
359
+ neg_mask: torch.Tensor | None = None
360
+
361
+ if world_group.rank_in_group == 0:
362
+ sampling_param = SamplingParam.from_pretrained(tc.model_path)
363
+ negative_prompt = sampling_param.negative_prompt
364
+
365
+ inference_args = make_inference_args(tc, model_path=tc.model_path)
366
+
367
+ prompt_pipeline = WanPipeline.from_pretrained(
368
+ tc.model_path,
369
+ args=inference_args,
370
+ inference_mode=True,
371
+ loaded_modules={"transformer": self.transformer},
372
+ tp_size=tc.distributed.tp_size,
373
+ sp_size=tc.distributed.sp_size,
374
+ num_gpus=tc.distributed.num_gpus,
375
+ pin_cpu_memory=(tc.distributed.pin_cpu_memory),
376
+ dit_cpu_offload=True,
377
+ )
378
+
379
+ batch_negative = ForwardBatch(
380
+ data_type="video",
381
+ prompt=negative_prompt,
382
+ prompt_embeds=[],
383
+ prompt_attention_mask=[],
384
+ )
385
+ result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
386
+ batch_negative,
387
+ inference_args,
388
+ )
389
+
390
+ neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
391
+ neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))
392
+
393
+ del prompt_pipeline
394
+ gc.collect()
395
+ if torch.cuda.is_available():
396
+ torch.cuda.empty_cache()
397
+
398
+ meta = torch.zeros((2, ), device=device, dtype=torch.int64)
399
+ if world_group.rank_in_group == 0:
400
+ assert neg_embeds is not None
401
+ assert neg_mask is not None
402
+ meta[0] = neg_embeds.ndim
403
+ meta[1] = neg_mask.ndim
404
+ world_group.broadcast(meta, src=0)
405
+ embed_ndim, mask_ndim = (
406
+ int(meta[0].item()),
407
+ int(meta[1].item()),
408
+ )
409
+
410
+ max_ndim = 8
411
+ embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
412
+ mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
413
+ if world_group.rank_in_group == 0:
414
+ assert neg_embeds is not None
415
+ assert neg_mask is not None
416
+ embed_shape[:embed_ndim] = torch.tensor(
417
+ list(neg_embeds.shape),
418
+ device=device,
419
+ dtype=torch.int64,
420
+ )
421
+ mask_shape[:mask_ndim] = torch.tensor(
422
+ list(neg_mask.shape),
423
+ device=device,
424
+ dtype=torch.int64,
425
+ )
426
+ world_group.broadcast(embed_shape, src=0)
427
+ world_group.broadcast(mask_shape, src=0)
428
+
429
+ embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
430
+ mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())
431
+
432
+ if world_group.rank_in_group != 0:
433
+ neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
434
+ neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
435
+ assert neg_embeds is not None
436
+ assert neg_mask is not None
437
+
438
+ world_group.broadcast(neg_embeds, src=0)
439
+ world_group.broadcast(neg_mask, src=0)
440
+
441
+ self.negative_prompt_embeds = neg_embeds
442
+ self.negative_prompt_attention_mask = neg_mask
443
+
444
+ def _sample_timesteps(
445
+ self,
446
+ batch_size: int,
447
+ device: torch.device,
448
+ generator: torch.Generator,
449
+ ) -> torch.Tensor:
450
+ assert self.training_config is not None
451
+ tc = self.training_config
452
+
453
+ u = compute_density_for_timestep_sampling(
454
+ weighting_scheme=tc.model.weighting_scheme,
455
+ batch_size=batch_size,
456
+ generator=generator,
457
+ device=device,
458
+ logit_mean=tc.model.logit_mean,
459
+ logit_std=tc.model.logit_std,
460
+ mode_scale=tc.model.mode_scale,
461
+ )
462
+ indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
463
+ return self.noise_scheduler.timesteps[indices.cpu()].to(device=device)
464
+
465
+ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
466
+ assert self.training_config is not None
467
+ tc = self.training_config
468
+ latents_shape = training_batch.raw_latent_shape
469
+ patch_size = (
470
+ tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr]
471
+ )
472
+ assert latents_shape is not None
473
+ assert training_batch.timesteps is not None
474
+
475
+ if envs.FASTVIDEO_ATTENTION_BACKEND in (
476
+ "VIDEO_SPARSE_ATTN", "SPARSE_FP4_ATTN", "SPARSE_FP4_OURS_P_ATTN",
477
+ ):
478
+ if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None):
479
+ raise ImportError(
480
+ f"FASTVIDEO_ATTENTION_BACKEND is "
481
+ f"{envs.FASTVIDEO_ATTENTION_BACKEND}, but "
482
+ f"fastvideo_kernel is not correctly "
483
+ f"installed or detected.")
484
+ training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc]
485
+ raw_latent_shape=latents_shape[2:5],
486
+ current_timestep=(training_batch.timesteps),
487
+ patch_size=patch_size,
488
+ VSA_sparsity=tc.vsa_sparsity,
489
+ device=self.device,
490
+ )
491
+ elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"):
492
+ if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None):
493
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is "
494
+ "VMOBA_ATTN, but fastvideo_kernel "
495
+ "(or flash_attn>=2.7.4) is not "
496
+ "correctly installed.")
497
+ moba_params = tc.model.moba_config.copy()
498
+ assert training_batch.raw_latent_shape is not None
499
+ moba_params.update({
500
+ "current_timestep": (training_batch.timesteps),
501
+ "raw_latent_shape": (training_batch.raw_latent_shape[2:5]),
502
+ "patch_size": patch_size,
503
+ "device": self.device,
504
+ })
505
+ training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**
506
+ moba_params) # type: ignore[misc]
507
+ else:
508
+ training_batch.attn_metadata = None
509
+
510
+ return training_batch
511
+
512
+ def _prepare_dit_inputs(
513
+ self,
514
+ training_batch: TrainingBatch,
515
+ generator: torch.Generator,
516
+ ) -> TrainingBatch:
517
+ assert self.training_config is not None
518
+ tc = self.training_config
519
+ latents = training_batch.latents
520
+ assert isinstance(latents, torch.Tensor)
521
+ batch_size = latents.shape[0]
522
+
523
+ noise = torch.randn(
524
+ latents.shape,
525
+ generator=generator,
526
+ device=latents.device,
527
+ dtype=latents.dtype,
528
+ )
529
+ timesteps = self._sample_timesteps(
530
+ batch_size,
531
+ latents.device,
532
+ generator,
533
+ )
534
+ if int(tc.distributed.sp_size or 1) > 1:
535
+ self.sp_group.broadcast(timesteps, src=0)
536
+
537
+ sigmas = get_sigmas(
538
+ self.noise_scheduler,
539
+ latents.device,
540
+ timesteps,
541
+ n_dim=latents.ndim,
542
+ dtype=latents.dtype,
543
+ )
544
+ noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise)
545
+
546
+ training_batch.noisy_model_input = (noisy_model_input)
547
+ training_batch.timesteps = timesteps
548
+ training_batch.sigmas = sigmas
549
+ training_batch.noise = noise
550
+ training_batch.raw_latent_shape = latents.shape
551
+
552
+ training_batch.conditional_dict = {
553
+ "encoder_hidden_states": (training_batch.encoder_hidden_states),
554
+ "encoder_attention_mask": (training_batch.encoder_attention_mask),
555
+ }
556
+
557
+ if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None):
558
+ neg_embeds = self.negative_prompt_embeds
559
+ neg_mask = (self.negative_prompt_attention_mask)
560
+ if (neg_embeds.shape[0] == 1 and batch_size > 1):
561
+ neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous()
562
+ if (neg_mask.shape[0] == 1 and batch_size > 1):
563
+ neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous()
564
+ training_batch.unconditional_dict = {
565
+ "encoder_hidden_states": neg_embeds,
566
+ "encoder_attention_mask": neg_mask,
567
+ }
568
+
569
+ training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4))
570
+ return training_batch
571
+
572
+ def _build_distill_input_kwargs(
573
+ self,
574
+ noise_input: torch.Tensor,
575
+ timestep: torch.Tensor,
576
+ text_dict: dict[str, torch.Tensor] | None,
577
+ ) -> dict[str, Any]:
578
+ if text_dict is None:
579
+ raise ValueError("text_dict cannot be None for "
580
+ "Wan distillation")
581
+ return {
582
+ "hidden_states": noise_input.permute(0, 2, 1, 3, 4),
583
+ "encoder_hidden_states": text_dict["encoder_hidden_states"],
584
+ "encoder_attention_mask": text_dict["encoder_attention_mask"],
585
+ "timestep": timestep,
586
+ "return_dict": False,
587
+ }
588
+
589
+ def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module:
590
+ return self.transformer
591
+
592
+ def _get_uncond_text_dict(
593
+ self,
594
+ batch: TrainingBatch,
595
+ *,
596
+ cfg_uncond: dict[str, Any] | None,
597
+ ) -> dict[str, torch.Tensor]:
598
+ if cfg_uncond is None:
599
+ text_dict = getattr(batch, "unconditional_dict", None)
600
+ if text_dict is None:
601
+ raise RuntimeError("Missing unconditional_dict; "
602
+ "ensure_negative_conditioning() "
603
+ "may have failed")
604
+ return text_dict
605
+
606
+ on_missing_raw = cfg_uncond.get("on_missing", "error")
607
+ if not isinstance(on_missing_raw, str):
608
+ raise ValueError("method_config.cfg_uncond.on_missing "
609
+ "must be a string, got "
610
+ f"{type(on_missing_raw).__name__}")
611
+ on_missing = on_missing_raw.strip().lower()
612
+ if on_missing not in {"error", "ignore"}:
613
+ raise ValueError("method_config.cfg_uncond.on_missing "
614
+ "must be one of {error, ignore}, got "
615
+ f"{on_missing_raw!r}")
616
+
617
+ for channel, policy_raw in cfg_uncond.items():
618
+ if channel in {"on_missing", "text"}:
619
+ continue
620
+ if policy_raw is None:
621
+ continue
622
+ if not isinstance(policy_raw, str):
623
+ raise ValueError("method_config.cfg_uncond values "
624
+ "must be strings, got "
625
+ f"{channel}="
626
+ f"{type(policy_raw).__name__}")
627
+ policy = policy_raw.strip().lower()
628
+ if policy == "keep":
629
+ continue
630
+ if on_missing == "ignore":
631
+ continue
632
+ raise ValueError("WanModel does not support "
633
+ "cfg_uncond channel "
634
+ f"{channel!r} (policy={policy!r}). "
635
+ "Set cfg_uncond.on_missing=ignore or "
636
+ "remove the channel.")
637
+
638
+ text_policy_raw = cfg_uncond.get("text", None)
639
+ if text_policy_raw is None:
640
+ text_policy = "negative_prompt"
641
+ elif not isinstance(text_policy_raw, str):
642
+ raise ValueError("method_config.cfg_uncond.text must be "
643
+ "a string, got "
644
+ f"{type(text_policy_raw).__name__}")
645
+ else:
646
+ text_policy = (text_policy_raw.strip().lower())
647
+
648
+ if text_policy in {"negative_prompt"}:
649
+ text_dict = getattr(batch, "unconditional_dict", None)
650
+ if text_dict is None:
651
+ raise RuntimeError("Missing unconditional_dict; "
652
+ "ensure_negative_conditioning() "
653
+ "may have failed")
654
+ return text_dict
655
+ if text_policy == "keep":
656
+ if batch.conditional_dict is None:
657
+ raise RuntimeError("Missing conditional_dict in "
658
+ "TrainingBatch")
659
+ return batch.conditional_dict
660
+ if text_policy == "zero":
661
+ if batch.conditional_dict is None:
662
+ raise RuntimeError("Missing conditional_dict in "
663
+ "TrainingBatch")
664
+ cond = batch.conditional_dict
665
+ enc = cond["encoder_hidden_states"]
666
+ mask = cond["encoder_attention_mask"]
667
+ if not torch.is_tensor(enc) or not torch.is_tensor(mask):
668
+ raise TypeError("conditional_dict must contain "
669
+ "tensor text inputs")
670
+ return {
671
+ "encoder_hidden_states": (torch.zeros_like(enc)),
672
+ "encoder_attention_mask": (torch.zeros_like(mask)),
673
+ }
674
+ if text_policy == "drop":
675
+ raise ValueError("cfg_uncond.text=drop is not supported "
676
+ "for Wan. Use "
677
+ "{negative_prompt, keep, zero}.")
678
+ raise ValueError("cfg_uncond.text must be one of "
679
+ "{negative_prompt, keep, zero, drop}, got "
680
+ f"{text_policy_raw!r}")
backend_snapshot/fastvideo/training/training_pipeline.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ from dataclasses import asdict
3
+ from contextlib import AbstractContextManager, nullcontext
4
+ import math
5
+ import os
6
+ import shutil
7
+ import tempfile
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from collections import deque
11
+ from collections.abc import Iterator
12
+ from typing import Any
13
+ from fastvideo.profiler import profile_region
14
+ import imageio
15
+ import numpy as np
16
+ import torch
17
+ import torch.distributed as dist
18
+ import torchvision
19
+ from einops import rearrange
20
+ from torch.utils.data import DataLoader
21
+ from torchdata.stateful_dataloader import StatefulDataLoader
22
+ from tqdm.auto import tqdm
23
+ from diffusers import FlowMatchEulerDiscreteScheduler
24
+
25
+ import fastvideo.envs as envs
26
+ try:
27
+ from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionMetadataBuilder)
28
+ from fastvideo.attention.backends.vmoba import VideoMobaAttentionMetadataBuilder
29
+ except Exception:
30
+ pass
31
+ from fastvideo.configs.sample import SamplingParam
32
+ from fastvideo.dataset import build_parquet_map_style_dataloader
33
+ from fastvideo.dataset.dataloader.schema import pyarrow_schema_t2v
34
+ from fastvideo.dataset.validation_dataset import ValidationDataset
35
+ from fastvideo.distributed import (cleanup_dist_env_and_memory, get_local_torch_device, get_sp_group, get_world_group)
36
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
37
+ from fastvideo.forward_context import set_forward_context
38
+ from fastvideo.logger import init_logger
39
+ from fastvideo.attention.selector import global_force_attn_backend_context_manager
40
+ from fastvideo.pipelines import (ComposedPipelineBase, ForwardBatch, LoRAPipeline, TrainingBatch)
41
+ from fastvideo.platforms import AttentionBackendEnum, current_platform
42
+ from fastvideo.training.activation_checkpoint import (apply_activation_checkpointing)
43
+ from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers)
44
+ from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases,
45
+ compute_density_for_timestep_sampling, count_trainable, get_scheduler,
46
+ get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint,
47
+ swap_fp4_linear, traverse_swap_module)
48
+ from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict)
49
+
50
+ try:
51
+ vsa_available = is_vsa_available()
52
+ vmoba_available = is_vmoba_available()
53
+ except Exception:
54
+ vsa_available = False
55
+ vmoba_available = False
56
+
57
+ logger = init_logger(__name__)
58
+
59
+
60
+ class TrainingPipeline(LoRAPipeline, ABC):
61
+ """
62
+ A pipeline for training a model. All training pipelines should inherit from this class.
63
+ All reusable components and code should be implemented in this class.
64
+ """
65
+ _required_config_modules = ["scheduler", "transformer"]
66
+ validation_pipeline: ComposedPipelineBase
67
+ train_dataloader: StatefulDataLoader
68
+ train_loader_iter: Iterator[dict[str, Any]]
69
+ current_epoch: int = 0
70
+ train_transformer_2: bool = False
71
+ tracker: TrackerType
72
+
73
+ def __init__(self,
74
+ model_path: str,
75
+ fastvideo_args: TrainingArgs,
76
+ required_config_modules: list[str] | None = None,
77
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
78
+ fastvideo_args.inference_mode = False
79
+ self.lora_training = fastvideo_args.lora_training
80
+ if self.lora_training and fastvideo_args.lora_rank is None:
81
+ raise ValueError("lora rank must be set when using lora training")
82
+
83
+ set_random_seed(fastvideo_args.seed) # for lora param init
84
+ super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules) # type: ignore
85
+ self.tracker = DummyTracker()
86
+
87
+ def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
88
+ raise RuntimeError("create_pipeline_stages should not be called for training pipeline")
89
+
90
+ @staticmethod
91
+ def _should_force_generator_attn_qat_train(fastvideo_args: FastVideoArgs) -> bool:
92
+ if not isinstance(fastvideo_args, TrainingArgs):
93
+ return False
94
+ return (fastvideo_args.generator_4bit_attn or envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN")
95
+
96
+ def load_modules(self,
97
+ fastvideo_args: FastVideoArgs,
98
+ loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
99
+ force_generator_qat = self._should_force_generator_attn_qat_train(fastvideo_args)
100
+ load_context: AbstractContextManager[None] = nullcontext()
101
+ if force_generator_qat:
102
+ logger.info("Forcing generator attention backend to ATTN_QAT_TRAIN during module loading")
103
+ load_context = global_force_attn_backend_context_manager(AttentionBackendEnum.ATTN_QAT_TRAIN)
104
+
105
+ with load_context:
106
+ return super().load_modules(fastvideo_args, loaded_modules)
107
+
108
+ def set_schemas(self) -> None:
109
+ self.train_dataset_schema = pyarrow_schema_t2v
110
+
111
+ def initialize_training_pipeline(self, training_args: TrainingArgs):
112
+ logger.info("Initializing training pipeline...")
113
+ self.device = get_local_torch_device()
114
+ self.training_args = training_args
115
+ world_group = get_world_group()
116
+ self.world_size = world_group.world_size
117
+ self.global_rank = world_group.rank
118
+ self.sp_group = get_sp_group()
119
+ self.rank_in_sp_group = self.sp_group.rank_in_group
120
+ self.sp_world_size = self.sp_group.world_size
121
+ self.local_rank = world_group.local_rank
122
+ self.transformer = self.get_module("transformer")
123
+ self.transformer_2 = self.get_module("transformer_2", None)
124
+ self.seed = training_args.seed
125
+ self.set_schemas()
126
+
127
+ # Set random seeds for deterministic training
128
+ assert self.seed is not None, "seed must be set"
129
+ set_random_seed(self.seed + self.global_rank)
130
+ self.transformer.train()
131
+ if training_args.enable_gradient_checkpointing_type is not None:
132
+ self.transformer = apply_activation_checkpointing(
133
+ self.transformer, checkpointing_type=training_args.enable_gradient_checkpointing_type)
134
+ if self.transformer_2 is not None:
135
+ self.transformer_2 = apply_activation_checkpointing(
136
+ self.transformer_2, checkpointing_type=training_args.enable_gradient_checkpointing_type)
137
+
138
+ if training_args.generator_4bit_linear:
139
+ num_swaps = traverse_swap_module(self.transformer, swap_fn=swap_fp4_linear)
140
+ logger.info("Swapped %s linear layers to the FP4 forward path in self.transformer", num_swaps)
141
+ noise_scheduler = self.modules["scheduler"]
142
+ self.set_trainable()
143
+ params_to_optimize = self.transformer.parameters()
144
+ params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
145
+ # Parse betas from string format "beta1,beta2"
146
+ betas_str = training_args.betas
147
+ betas = tuple(float(x.strip()) for x in betas_str.split(","))
148
+
149
+ self.optimizer = torch.optim.AdamW(
150
+ params_to_optimize,
151
+ lr=training_args.learning_rate,
152
+ betas=betas,
153
+ weight_decay=training_args.weight_decay,
154
+ eps=1e-8,
155
+ )
156
+
157
+ self.init_steps = 0
158
+ logger.info("optimizer: %s", self.optimizer)
159
+
160
+ self.lr_scheduler = get_scheduler(
161
+ training_args.lr_scheduler,
162
+ optimizer=self.optimizer,
163
+ num_warmup_steps=training_args.lr_warmup_steps,
164
+ num_training_steps=training_args.max_train_steps,
165
+ num_cycles=training_args.lr_num_cycles,
166
+ power=training_args.lr_power,
167
+ min_lr_ratio=training_args.min_lr_ratio,
168
+ last_epoch=self.init_steps - 1,
169
+ )
170
+ if self.transformer_2 is not None:
171
+ # Ensure transformer_2 has trainable parameters before creating optimizer
172
+ params_to_optimize_2 = self.transformer_2.parameters()
173
+ params_to_optimize_2 = list(filter(lambda p: p.requires_grad, params_to_optimize_2))
174
+ self.optimizer_2 = torch.optim.AdamW(
175
+ params_to_optimize_2,
176
+ lr=training_args.learning_rate,
177
+ betas=(0.9, 0.999),
178
+ weight_decay=training_args.weight_decay,
179
+ eps=1e-8,
180
+ )
181
+ self.lr_scheduler_2 = get_scheduler(
182
+ training_args.lr_scheduler,
183
+ optimizer=self.optimizer_2,
184
+ num_warmup_steps=training_args.lr_warmup_steps,
185
+ num_training_steps=training_args.max_train_steps,
186
+ num_cycles=training_args.lr_num_cycles,
187
+ power=training_args.lr_power,
188
+ min_lr_ratio=training_args.min_lr_ratio,
189
+ last_epoch=self.init_steps - 1,
190
+ )
191
+
192
+ self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
193
+ training_args.data_path,
194
+ training_args.train_batch_size,
195
+ parquet_schema=self.train_dataset_schema,
196
+ num_data_workers=training_args.dataloader_num_workers,
197
+ cfg_rate=training_args.training_cfg_rate,
198
+ drop_last=True,
199
+ text_padding_length=training_args.pipeline_config.text_encoder_configs[0].arch_config.
200
+ text_len, # type: ignore[attr-defined]
201
+ seed=self.seed)
202
+
203
+ self.noise_scheduler = noise_scheduler
204
+ if self.training_args.boundary_ratio is not None:
205
+ self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
206
+ else:
207
+ self.boundary_timestep = None
208
+
209
+ logger.info("train_dataloader length: %s", len(self.train_dataloader))
210
+ logger.info("train_sp_batch_size: %s", training_args.train_sp_batch_size)
211
+ logger.info("gradient_accumulation_steps: %s", training_args.gradient_accumulation_steps)
212
+ logger.info("sp_size: %s", training_args.sp_size)
213
+
214
+ self.num_update_steps_per_epoch = math.ceil(
215
+ len(self.train_dataloader) / training_args.gradient_accumulation_steps * training_args.sp_size /
216
+ training_args.train_sp_batch_size)
217
+ self.num_train_epochs = math.ceil(training_args.max_train_steps / self.num_update_steps_per_epoch)
218
+
219
+ # TODO(will): is there a cleaner way to track epochs?
220
+ self.current_epoch = 0
221
+
222
+ trackers = list(training_args.trackers)
223
+ if not trackers and training_args.tracker_project_name:
224
+ trackers.append(Trackers.WANDB.value)
225
+ if self.global_rank != 0:
226
+ trackers = []
227
+
228
+ tracker_log_dir = training_args.output_dir or os.getcwd()
229
+ if trackers:
230
+ tracker_log_dir = os.path.join(tracker_log_dir, "tracker")
231
+
232
+ tracker_config = asdict(training_args) if trackers else None
233
+ tracker_run_name = training_args.wandb_run_name or None
234
+ project = training_args.tracker_project_name or "fastvideo"
235
+ self.tracker = initialize_trackers(
236
+ trackers,
237
+ experiment_name=project,
238
+ config=tracker_config,
239
+ log_dir=tracker_log_dir,
240
+ run_name=tracker_run_name,
241
+ )
242
+
243
+ @abstractmethod
244
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
245
+ raise NotImplementedError("Training pipelines must implement this method")
246
+
247
+ def _prepare_training(self, training_batch: TrainingBatch) -> TrainingBatch:
248
+ self.optimizer.zero_grad()
249
+ if self.transformer_2 is not None:
250
+ self.optimizer_2.zero_grad()
251
+ training_batch.total_loss = 0.0
252
+ return training_batch
253
+
254
+ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch:
255
+ with self.tracker.timed("timing/get_next_batch"):
256
+ batch = next(self.train_loader_iter, None) # type: ignore
257
+ if batch is None:
258
+ self.current_epoch += 1
259
+ logger.info("Starting epoch %s", self.current_epoch)
260
+ # Reset iterator for next epoch
261
+ self.train_loader_iter = iter(self.train_dataloader)
262
+ # Get first batch of new epoch
263
+ batch = next(self.train_loader_iter)
264
+
265
+ latents = batch['vae_latent']
266
+ latents = latents[:, :, :self.training_args.num_latent_t]
267
+ encoder_hidden_states = batch['text_embedding']
268
+ encoder_attention_mask = batch['text_attention_mask']
269
+ infos = batch['info_list']
270
+
271
+ training_batch.latents = latents.to(
272
+ get_local_torch_device(),
273
+ dtype=torch.bfloat16,
274
+ non_blocking=True,
275
+ )
276
+ training_batch.encoder_hidden_states = (encoder_hidden_states.to(
277
+ get_local_torch_device(),
278
+ dtype=torch.bfloat16,
279
+ non_blocking=True,
280
+ ))
281
+ training_batch.encoder_attention_mask = (encoder_attention_mask.to(
282
+ get_local_torch_device(),
283
+ dtype=torch.bfloat16,
284
+ non_blocking=True,
285
+ ))
286
+ training_batch.infos = infos
287
+
288
+ return training_batch
289
+
290
+ def _normalize_dit_input(self, training_batch: TrainingBatch) -> TrainingBatch:
291
+ # TODO(will): support other models
292
+ with self.tracker.timed("timing/normalize_input"):
293
+ training_batch.latents = normalize_dit_input(
294
+ 'wan',
295
+ training_batch.latents,
296
+ self.get_module("vae"),
297
+ )
298
+ return training_batch
299
+
300
+ def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch:
301
+ assert self.training_args is not None, "training_args must be set"
302
+ with self.tracker.timed("timing/prepare_dit_inputs"):
303
+ latents = training_batch.latents
304
+ batch_size = latents.shape[0]
305
+ noise = torch.randn(latents.shape,
306
+ generator=self.noise_gen_cuda,
307
+ device=latents.device,
308
+ dtype=latents.dtype)
309
+ timesteps = self._sample_timesteps(batch_size, latents.device)
310
+
311
+ if self.training_args.sp_size > 1:
312
+ # Make sure that the timesteps are the same across all sp processes.
313
+ sp_group = get_sp_group()
314
+ sp_group.broadcast(timesteps, src=0)
315
+ sp_group.broadcast(noise, src=0)
316
+ sigmas = get_sigmas(
317
+ self.noise_scheduler,
318
+ latents.device,
319
+ timesteps,
320
+ n_dim=latents.ndim,
321
+ dtype=latents.dtype,
322
+ )
323
+ noisy_model_input = (1.0 - sigmas) * training_batch.latents + sigmas * noise
324
+
325
+ training_batch.noisy_model_input = noisy_model_input
326
+ training_batch.timesteps = timesteps
327
+ training_batch.sigmas = sigmas
328
+ training_batch.noise = noise
329
+ training_batch.raw_latent_shape = training_batch.latents.shape
330
+
331
+ return training_batch
332
+
333
+ def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
334
+ # Determine which model to train based on the boundary timestep
335
+ if (self.transformer_2 is not None and self.boundary_timestep is not None
336
+ and torch.rand(1, generator=self.noise_random_generator).item() <= self.training_args.boundary_ratio):
337
+ self.train_transformer_2 = True
338
+ else:
339
+ self.train_transformer_2 = False
340
+
341
+ # Broadcast the decision to all processes
342
+ decision = torch.tensor(1.0 if self.train_transformer_2 else 0.0, device=self.device)
343
+ dist.broadcast(decision, src=0)
344
+ self.train_transformer_2 = decision.item() == 1.0
345
+
346
+ # Sample u from the appropriate range
347
+ u = compute_density_for_timestep_sampling(
348
+ weighting_scheme=self.training_args.weighting_scheme,
349
+ batch_size=batch_size,
350
+ generator=self.noise_random_generator,
351
+ logit_mean=self.training_args.logit_mean,
352
+ logit_std=self.training_args.logit_std,
353
+ mode_scale=self.training_args.mode_scale,
354
+ )
355
+
356
+ boundary_ratio = self.training_args.boundary_ratio
357
+ if self.train_transformer_2:
358
+ u = (1 - boundary_ratio) + u * boundary_ratio # min: 1 - boundary_ratio, max: 1
359
+ # elif self.transformer_2 is not None:
360
+ # u = u * (1 - boundary_ratio) # min: 0, max: 1 - boundary_ratio
361
+ # else: # patch for now to align with non-MoE timestep logic
362
+ # pass
363
+
364
+ indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
365
+ return self.noise_scheduler.timesteps[indices].to(device=device)
366
+
367
+ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
368
+ latents_shape = training_batch.raw_latent_shape
369
+ patch_size = self.training_args.pipeline_config.dit_config.patch_size
370
+ current_vsa_sparsity = training_batch.current_vsa_sparsity
371
+ assert latents_shape is not None
372
+ assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
373
+ assert training_batch.timesteps is not None
374
+ if envs.FASTVIDEO_ATTENTION_BACKEND in (
375
+ "VIDEO_SPARSE_ATTN",
376
+ "SPARSE_FP4_ATTN",
377
+ "SPARSE_FP4_OURS_P_ATTN",
378
+ ):
379
+ if not vsa_available:
380
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VIDEO_SPARSE_ATTN, "
381
+ "but fastvideo_kernel is not correctly installed or detected. "
382
+ "Please ensure fastvideo-kernel is installed.")
383
+ training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder( # type: ignore
384
+ ).build( # type: ignore
385
+ raw_latent_shape=latents_shape[2:5],
386
+ current_timestep=training_batch.timesteps,
387
+ patch_size=patch_size,
388
+ VSA_sparsity=current_vsa_sparsity,
389
+ device=get_local_torch_device())
390
+ elif envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
391
+ if not vmoba_available:
392
+ raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VMOBA_ATTN, "
393
+ "but fastvideo_kernel (or flash_attn>=2.7.4) is not correctly installed.")
394
+ moba_params = self.training_args.moba_config.copy()
395
+ moba_params.update({
396
+ "current_timestep": training_batch.timesteps,
397
+ "raw_latent_shape": latents_shape[2:5],
398
+ "patch_size": self.training_args.pipeline_config.dit_config.patch_size,
399
+ "device": get_local_torch_device(),
400
+ })
401
+ training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**moba_params)
402
+ else:
403
+ training_batch.attn_metadata = None
404
+
405
+ return training_batch
406
+
407
+ def _build_input_kwargs(self, training_batch: TrainingBatch) -> TrainingBatch:
408
+ training_batch.input_kwargs = {
409
+ "hidden_states": training_batch.noisy_model_input,
410
+ "encoder_hidden_states": training_batch.encoder_hidden_states,
411
+ "timestep": training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16),
412
+ "encoder_attention_mask": training_batch.encoder_attention_mask,
413
+ "return_dict": False,
414
+ }
415
+ return training_batch
416
+
417
+ def _transformer_forward_and_compute_loss(self, training_batch: TrainingBatch) -> TrainingBatch:
418
+ if vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND in (
419
+ "VIDEO_SPARSE_ATTN",
420
+ "SPARSE_FP4_ATTN",
421
+ "SPARSE_FP4_OURS_P_ATTN",
422
+ ) or vmoba_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
423
+ assert training_batch.attn_metadata is not None
424
+ else:
425
+ assert training_batch.attn_metadata is None
426
+ input_kwargs = training_batch.input_kwargs
427
+
428
+ # if 'hunyuan' in self.training_args.model_type:
429
+ # input_kwargs["guidance"] = torch.tensor(
430
+ # [1000.0],
431
+ # device=training_batch.noisy_model_input.device,
432
+ # dtype=torch.bfloat16)
433
+ current_model = self.transformer_2 if self.train_transformer_2 else self.transformer
434
+
435
+ with self.tracker.timed("timing/forward_backward"), set_forward_context(
436
+ current_timestep=training_batch.current_timestep, attn_metadata=training_batch.attn_metadata):
437
+ model_pred = current_model(**input_kwargs)
438
+ if self.training_args.precondition_outputs:
439
+ assert training_batch.sigmas is not None
440
+ model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
441
+ assert training_batch.latents is not None
442
+ assert training_batch.noise is not None
443
+ target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
444
+
445
+ # make sure no implicit broadcasting happens
446
+ assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
447
+
448
+ loss = (torch.mean(
449
+ (model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps)
450
+
451
+ loss.backward()
452
+
453
+ avg_loss = loss.detach().clone()
454
+
455
+ # Reduce across ranks without forcing a CPU sync
456
+ with self.tracker.timed("timing/reduce_loss"):
457
+ world_group = get_world_group()
458
+ avg_loss = world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
459
+ # Accumulate on GPU; materialize to CPU only once after
460
+ # all gradient-accumulation iterations (see train_one_step).
461
+ training_batch.total_loss += avg_loss
462
+
463
+ return training_batch
464
+
465
+ def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch:
466
+ max_grad_norm = self.training_args.max_grad_norm
467
+
468
+ # TODO(will): perhaps move this into transformer api so that we can do
469
+ # the following:
470
+ # grad_norm = transformer.clip_grad_norm_(max_grad_norm)
471
+ if max_grad_norm is not None:
472
+ with self.tracker.timed("timing/clip_grad_norm"):
473
+ # Only clip gradients for the model that is currently training
474
+ if self.train_transformer_2 and self.transformer_2 is not None:
475
+ model_parts = [self.transformer_2]
476
+ else:
477
+ model_parts = [self.transformer]
478
+
479
+ grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases(
480
+ [p for m in model_parts for p in m.parameters()],
481
+ max_grad_norm,
482
+ foreach=None,
483
+ )
484
+ assert grad_norm is not float('nan') or grad_norm is not float('inf')
485
+ grad_norm = grad_norm.item() if grad_norm is not None else 0.0
486
+ else:
487
+ grad_norm = 0.0
488
+ training_batch.grad_norm = grad_norm
489
+ return training_batch
490
+
491
+ @profile_region("profiler_region_training_train_one_step")
492
+ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
493
+ training_batch = self._prepare_training(training_batch)
494
+
495
+ for _ in range(self.training_args.gradient_accumulation_steps):
496
+ training_batch = self._get_next_batch(training_batch)
497
+
498
+ # Normalize DIT input
499
+ training_batch = self._normalize_dit_input(training_batch)
500
+ # Create noisy model input
501
+ training_batch = self._prepare_dit_inputs(training_batch)
502
+ assert training_batch.latents is not None
503
+ assert training_batch.noisy_model_input is not None
504
+ assert training_batch.noise is not None
505
+
506
+ # old sharding code, need to shard latents and noise but not input
507
+ # Shard latents across sp groups
508
+ training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t]
509
+ # shard noisy_model_input to match
510
+ training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t]
511
+ # shard noise to match latents
512
+ training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t]
513
+
514
+ training_batch = self._build_attention_metadata(training_batch)
515
+ training_batch = self._build_input_kwargs(training_batch)
516
+
517
+ training_batch = self._transformer_forward_and_compute_loss(training_batch)
518
+
519
+ training_batch = self._clip_grad_norm(training_batch)
520
+
521
+ # Only step the optimizer and scheduler for the model that is currently training
522
+ with self.tracker.timed("timing/optimizer_step"):
523
+ if self.train_transformer_2 and self.transformer_2 is not None:
524
+ self.optimizer_2.step()
525
+ self.lr_scheduler_2.step()
526
+ else:
527
+ self.optimizer.step()
528
+ self.lr_scheduler.step()
529
+
530
+ return training_batch
531
+
532
+ def _compute_current_sparsity(self, step: int) -> float:
533
+ """Compute the VSA sparsity for a given step using the decay schedule."""
534
+ vsa_sparsity = self.training_args.VSA_sparsity
535
+ vsa_decay_rate = self.training_args.VSA_decay_rate
536
+ vsa_decay_interval = self.training_args.VSA_decay_interval_steps
537
+ vsa_init = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
538
+ vsa_warmup = getattr(self.training_args, 'VSA_warmup_steps', 0)
539
+ if step <= vsa_warmup:
540
+ return vsa_init
541
+ ramp_step = step - vsa_warmup
542
+ max_times = int((vsa_sparsity - vsa_init) / vsa_decay_rate) if vsa_decay_rate > 0 else 0
543
+ times = min(ramp_step // vsa_decay_interval, max_times)
544
+ return vsa_init + times * vsa_decay_rate
545
+
546
+ def _resolve_checkpoint_path(self, path: str) -> str | None:
547
+ """Resolve 'latest' to the most recent checkpoint in output_dir."""
548
+ import glob
549
+ if path == "latest":
550
+ output_dir = self.training_args.output_dir
551
+ ckpt_dirs = sorted(
552
+ glob.glob(os.path.join(output_dir, "checkpoint-*")),
553
+ key=lambda d: int(d.split("-")[-1]) if d.split("-")[-1].isdigit() else 0,
554
+ )
555
+ if ckpt_dirs:
556
+ latest = ckpt_dirs[-1]
557
+ logger.info("Auto-resolved 'latest' to %s", latest)
558
+ return latest
559
+ logger.info("No checkpoints found in %s, starting from scratch", output_dir)
560
+ return None
561
+ return path
562
+
563
+ def _resume_from_checkpoint(self) -> None:
564
+ ckpt_path = self._resolve_checkpoint_path(self.training_args.resume_from_checkpoint)
565
+ if ckpt_path is None:
566
+ logger.info("No checkpoint to resume from, starting from step 0")
567
+ return
568
+
569
+ safetensors_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model.safetensors")
570
+ step = int(os.path.basename(os.path.normpath(ckpt_path)).split('-')[-1])
571
+
572
+ resumed_step = load_checkpoint(self.transformer, self.global_rank, ckpt_path,
573
+ self.optimizer, self.train_dataloader,
574
+ self.lr_scheduler, self.noise_random_generator)
575
+ if resumed_step > 0 or step == 0:
576
+ self.init_steps = resumed_step
577
+ logger.info("Successfully resumed full training state from step %s", resumed_step)
578
+ return
579
+
580
+ if os.path.exists(safetensors_path):
581
+ self.init_steps = step
582
+ logger.warning("Distributed checkpoint resume failed; falling back to safetensors weights at step %s",
583
+ step)
584
+ return
585
+
586
+ logger.warning("No usable checkpoint state found at %s; starting from step 0", ckpt_path)
587
+ self.init_steps = 0
588
+
589
+ @profile_region("profiler_region_training_train")
590
+ def train(self) -> None:
591
+ assert self.seed is not None, "seed must be set"
592
+ assert self.training_args is not None, "training_args must be set"
593
+ set_random_seed(self.seed + self.global_rank)
594
+ logger.info('rank: %s: start training', self.global_rank, local_main_process_only=False)
595
+ if not self.post_init_called:
596
+ self.post_init()
597
+ num_trainable_params = count_trainable(self.transformer)
598
+ logger.info("Starting training with %s B trainable parameters", round(num_trainable_params / 1e9, 3))
599
+
600
+ if getattr(self, "transformer_2", None) is not None:
601
+ num_trainable_params = count_trainable(self.transformer_2)
602
+ logger.info("Transformer 2: Starting training with %s B trainable parameters",
603
+ round(num_trainable_params / 1e9, 3))
604
+
605
+ # Set random seeds for deterministic training
606
+ self.noise_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
607
+ self.noise_gen_cuda = torch.Generator(device=current_platform.device_name).manual_seed(self.seed +
608
+ self.global_rank)
609
+ self.validation_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
610
+ logger.info("Initialized random seeds with seed: %s", self.seed + self.global_rank)
611
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler()
612
+
613
+ if self.training_args.resume_from_checkpoint:
614
+ self._resume_from_checkpoint()
615
+
616
+ self.train_loader_iter = iter(self.train_dataloader)
617
+
618
+ step_times: deque[float] = deque(maxlen=100)
619
+
620
+ self._log_training_info()
621
+
622
+ # Validation at init uses the sparsity corresponding to init_steps
623
+ saved_sparsity = self.training_args.VSA_sparsity
624
+ self.training_args.VSA_sparsity = self._compute_current_sparsity(self.init_steps)
625
+ self._log_validation(self.transformer, self.training_args, self.init_steps)
626
+ self.training_args.VSA_sparsity = saved_sparsity
627
+
628
+ # Train!
629
+ progress_bar = tqdm(
630
+ range(0, self.training_args.max_train_steps),
631
+ initial=self.init_steps,
632
+ desc="Steps",
633
+ # Only show the progress bar once on each machine.
634
+ disable=self.local_rank > 0,
635
+ )
636
+ for step in range(self.init_steps + 1, self.training_args.max_train_steps + 1):
637
+ start_time = time.perf_counter()
638
+ if vsa_available:
639
+ vsa_sparsity = self.training_args.VSA_sparsity
640
+ vsa_decay_rate = self.training_args.VSA_decay_rate
641
+ vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
642
+ vsa_init_sparsity = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
643
+ vsa_warmup_steps = getattr(self.training_args, 'VSA_warmup_steps', 0)
644
+ if step <= vsa_warmup_steps:
645
+ current_vsa_sparsity = vsa_init_sparsity
646
+ else:
647
+ ramp_step = step - vsa_warmup_steps
648
+ max_decay_times = int((vsa_sparsity - vsa_init_sparsity) / vsa_decay_rate)
649
+ current_decay_times = min(ramp_step // vsa_decay_interval_steps, max_decay_times)
650
+ current_vsa_sparsity = vsa_init_sparsity + current_decay_times * vsa_decay_rate
651
+ elif vmoba_available:
652
+ #TODO: add vmoba sparsity scheduling here
653
+ current_vsa_sparsity = 0.0
654
+ else:
655
+ current_vsa_sparsity = 0.0
656
+
657
+ training_batch = TrainingBatch()
658
+ training_batch.current_timestep = step
659
+ training_batch.current_vsa_sparsity = current_vsa_sparsity
660
+ training_batch = self.train_one_step(training_batch)
661
+
662
+ loss = float(training_batch.total_loss)
663
+ grad_norm = training_batch.grad_norm
664
+
665
+ step_time = time.perf_counter() - start_time
666
+ step_times.append(step_time)
667
+ avg_step_time = sum(step_times) / len(step_times)
668
+
669
+ progress_bar.set_postfix({
670
+ "loss": f"{loss:.4f}",
671
+ "step_time": f"{step_time:.2f}s",
672
+ "grad_norm": grad_norm,
673
+ })
674
+ progress_bar.update(1)
675
+ if self.global_rank == 0:
676
+ metrics = {
677
+ "train_loss": loss,
678
+ "learning_rate": self.lr_scheduler.get_last_lr()[0],
679
+ "step_time": step_time,
680
+ "avg_step_time": avg_step_time,
681
+ "grad_norm": grad_norm,
682
+ "vsa_sparsity": current_vsa_sparsity,
683
+ }
684
+ try:
685
+ assert training_batch.raw_latent_shape is not None
686
+ metrics["batch_size"] = int(training_batch.raw_latent_shape[0])
687
+
688
+ patch_size = self.training_args.pipeline_config.dit_config.patch_size
689
+ assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
690
+ patch_t, patch_h, patch_w = patch_size
691
+ seq_len = (training_batch.raw_latent_shape[2] // patch_t) * (
692
+ training_batch.raw_latent_shape[3] // patch_h) * (training_batch.raw_latent_shape[4] // patch_w)
693
+ if training_batch.encoder_hidden_states is not None:
694
+ context_len = int(training_batch.encoder_hidden_states.shape[1])
695
+ else:
696
+ context_len = 0
697
+
698
+ metrics["dit_seq_len"] = int(seq_len)
699
+ metrics["context_len"] = context_len
700
+
701
+ arch_config = self.training_args.pipeline_config.dit_config.arch_config
702
+
703
+ metrics["hidden_dim"] = arch_config.hidden_size
704
+ metrics["num_layers"] = arch_config.num_layers
705
+ metrics["ffn_dim"] = arch_config.ffn_dim
706
+ except Exception:
707
+ pass
708
+
709
+ self.tracker.log(metrics, step)
710
+ if step % self.training_args.training_state_checkpointing_steps == 0:
711
+ with self.profiler_controller.region("profiler_region_training_save_checkpoint"):
712
+ save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, step,
713
+ self.optimizer, self.train_dataloader, self.lr_scheduler,
714
+ self.noise_random_generator,
715
+ self.training_args.checkpoints_total_limit)
716
+ self.transformer.train()
717
+ self.sp_group.barrier()
718
+
719
+ if self.training_args.log_visualization and step % self.training_args.visualization_steps == 0:
720
+ self.visualize_intermediate_latents(training_batch, self.training_args, step)
721
+
722
+ if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
723
+ with self.profiler_controller.region("profiler_region_training_validation"):
724
+ saved_sparsity = self.training_args.VSA_sparsity
725
+ self.training_args.VSA_sparsity = current_vsa_sparsity
726
+ self._log_validation(self.transformer, self.training_args, step)
727
+ self.training_args.VSA_sparsity = saved_sparsity
728
+ gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
729
+ trainable_params = round(count_trainable(self.transformer) / 1e9, 3)
730
+ logger.info("GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage,
731
+ trainable_params)
732
+
733
+ self.tracker.finish()
734
+ save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
735
+ self.training_args.max_train_steps, self.optimizer, self.train_dataloader, self.lr_scheduler,
736
+ self.noise_random_generator, self.training_args.checkpoints_total_limit)
737
+
738
+ if envs.FASTVIDEO_TORCH_PROFILER_DIR:
739
+ logger.info("Stopping profiler...")
740
+ self.profiler_controller.stop()
741
+ logger.info("Profiler stopped.")
742
+
743
+ if get_sp_group():
744
+ cleanup_dist_env_and_memory()
745
+
746
+ def _log_training_info(self) -> None:
747
+ assert self.training_args is not None, "training_args must be set"
748
+ total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps /
749
+ self.training_args.sp_size * self.training_args.train_sp_batch_size)
750
+ logger.info("***** Running training *****")
751
+ logger.info(" Num examples = %s", len(self.train_dataset))
752
+ logger.info(" Dataloader size = %s", len(self.train_dataloader))
753
+ logger.info(" Num Epochs = %s", self.num_train_epochs)
754
+ logger.info(" Resume training from step %s", self.init_steps) # type: ignore
755
+ logger.info(" Instantaneous batch size per device = %s", self.training_args.train_batch_size)
756
+ logger.info(" Total train batch size (w. data & sequence parallel, accumulation) = %s", total_batch_size)
757
+ logger.info(" Gradient Accumulation steps = %s", self.training_args.gradient_accumulation_steps)
758
+ logger.info(" Total optimization steps = %s", self.training_args.max_train_steps)
759
+ logger.info(" Total training parameters per FSDP shard = %s B",
760
+ round(count_trainable(self.transformer) / 1e9, 3))
761
+ # print dtype
762
+ logger.info(" Master weight dtype: %s", self.transformer.parameters().__next__().dtype)
763
+
764
+ gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
765
+ logger.info("GPU memory usage before train_one_step: %s MB", gpu_memory_usage)
766
+ logger.info("VSA validation sparsity: %s", self.training_args.VSA_sparsity)
767
+
768
+ def _prepare_validation_batch(self, sampling_param: SamplingParam, training_args: TrainingArgs,
769
+ validation_batch: dict[str, Any], num_inference_steps: int) -> ForwardBatch:
770
+ sampling_param.prompt = validation_batch['prompt']
771
+ sampling_param.height = training_args.num_height
772
+ sampling_param.width = training_args.num_width
773
+ sampling_param.num_inference_steps = num_inference_steps
774
+ sampling_param.data_type = "video"
775
+ if training_args.validation_guidance_scale:
776
+ sampling_param.guidance_scale = float(training_args.validation_guidance_scale)
777
+ assert self.seed is not None
778
+ sampling_param.seed = self.seed
779
+
780
+ latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
781
+ n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
782
+ temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
783
+ num_frames = (training_args.num_latent_t - 1) * temporal_compression_factor + 1
784
+ sampling_param.num_frames = num_frames
785
+ batch = ForwardBatch(
786
+ **shallow_asdict(sampling_param),
787
+ latents=None,
788
+ generator=self.validation_random_generator,
789
+ n_tokens=n_tokens,
790
+ eta=0.0,
791
+ VSA_sparsity=training_args.VSA_sparsity,
792
+ )
793
+
794
+ return batch
795
+
796
+ @torch.no_grad()
797
+ def _log_validation(self, transformer, training_args, global_step) -> None:
798
+ """
799
+ Generate a validation video and log it to the configured tracker to check the quality during training.
800
+ """
801
+ training_args.inference_mode = True
802
+ training_args.dit_cpu_offload = False
803
+ if not training_args.log_validation:
804
+ return
805
+ if self.validation_pipeline is None:
806
+ raise ValueError("Validation pipeline is not set")
807
+
808
+ logger.info("Starting validation")
809
+
810
+ # Create sampling parameters if not provided
811
+ sampling_param = SamplingParam.from_pretrained(training_args.model_path)
812
+
813
+ # Prepare validation prompts
814
+ logger.info('rank: %s: fastvideo_args.validation_dataset_file: %s',
815
+ self.global_rank,
816
+ training_args.validation_dataset_file,
817
+ local_main_process_only=False)
818
+ validation_dataset = ValidationDataset(training_args.validation_dataset_file)
819
+ validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0)
820
+
821
+ self.transformer.eval()
822
+ if getattr(self, "transformer_2", None) is not None:
823
+ self.transformer_2.eval()
824
+
825
+ validation_steps = training_args.validation_sampling_steps.split(",")
826
+ validation_steps = [int(step) for step in validation_steps]
827
+ validation_steps = [step for step in validation_steps if step > 0]
828
+ # Log validation results for this step
829
+ world_group = get_world_group()
830
+ num_sp_groups = world_group.world_size // self.sp_group.world_size
831
+ one_prompt_per_rank = os.environ.get(
832
+ "FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK",
833
+ "",
834
+ ).lower() in {"1", "true", "yes", "on"}
835
+
836
+ # Process each validation prompt for each validation step
837
+ for num_inference_steps in validation_steps:
838
+ logger.info("rank: %s: num_inference_steps: %s",
839
+ self.global_rank,
840
+ num_inference_steps,
841
+ local_main_process_only=False)
842
+ step_videos: list[np.ndarray] = []
843
+ step_captions: list[str] = []
844
+
845
+ step_audio: list[np.ndarray | None] = []
846
+ step_sample_rates: list[int | None] = []
847
+
848
+ for prompt_idx, validation_batch in enumerate(validation_dataloader):
849
+ if one_prompt_per_rank and prompt_idx > 0:
850
+ continue
851
+
852
+ batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch,
853
+ num_inference_steps)
854
+ logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s",
855
+ self.global_rank,
856
+ self.rank_in_sp_group,
857
+ batch.prompt,
858
+ local_main_process_only=False)
859
+
860
+ assert batch.prompt is not None and isinstance(batch.prompt, str)
861
+ step_captions.append(batch.prompt)
862
+
863
+ # Run validation inference
864
+ output_batch = self.validation_pipeline.forward(batch, training_args)
865
+ samples = output_batch.output.cpu()
866
+
867
+ # Capture audio if available
868
+ audio = output_batch.extra.get("audio")
869
+ sample_rate = output_batch.extra.get("audio_sample_rate")
870
+
871
+ if audio is not None and torch.is_tensor(audio):
872
+ audio = audio.detach().cpu().float().numpy()
873
+
874
+ step_audio.append(audio)
875
+ step_sample_rates.append(sample_rate)
876
+
877
+ if self.rank_in_sp_group != 0:
878
+ continue
879
+
880
+ # Process outputs
881
+ video = rearrange(samples, "b c t h w -> t b c h w")
882
+ frames = []
883
+ for x in video:
884
+ x = torchvision.utils.make_grid(x, nrow=6)
885
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
886
+ frames.append((x * 255).numpy().astype(np.uint8))
887
+ step_videos.append(frames)
888
+
889
+ # Only sp_group leaders (rank_in_sp_group == 0) need to send their
890
+ # results to global rank 0
891
+ if self.rank_in_sp_group == 0 and self.global_rank == 0:
892
+ # Global rank 0 collects results from all sp_group leaders
893
+ all_videos = step_videos # Start with own results
894
+ all_captions = step_captions
895
+ all_audios = step_audio
896
+ all_sample_rates = step_sample_rates
897
+
898
+ # Receive from other sp_group leaders
899
+ for sp_group_idx in range(1, num_sp_groups):
900
+ src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders
901
+ recv_videos = world_group.recv_object(src=src_rank)
902
+ recv_captions = world_group.recv_object(src=src_rank)
903
+ recv_audios = world_group.recv_object(src=src_rank)
904
+ recv_sample_rates = world_group.recv_object(src=src_rank)
905
+
906
+ all_videos.extend(recv_videos)
907
+ all_captions.extend(recv_captions)
908
+ all_audios.extend(recv_audios)
909
+ all_sample_rates.extend(recv_sample_rates)
910
+
911
+ video_filenames = []
912
+ for i, (video, caption, audio, sample_rate) in enumerate(
913
+ zip(all_videos, all_captions, all_audios, all_sample_rates, strict=True)):
914
+ os.makedirs(training_args.output_dir, exist_ok=True)
915
+ filename = os.path.join(
916
+ training_args.output_dir,
917
+ f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4")
918
+ imageio.mimsave(filename, video, fps=sampling_param.fps)
919
+ # Mux audio if available
920
+ if (audio is not None and sample_rate is not None and not self._mux_audio(
921
+ filename,
922
+ audio,
923
+ sample_rate,
924
+ )):
925
+ logger.warning("Audio mux failed for validation video %s; saved video without audio.", filename)
926
+ video_filenames.append(filename)
927
+
928
+ artifacts = []
929
+ for filename, caption in zip(video_filenames, all_captions, strict=True):
930
+ video_artifact = self.tracker.video(filename, caption=caption)
931
+ if video_artifact is not None:
932
+ artifacts.append(video_artifact)
933
+ if artifacts:
934
+ logs = {f"validation_videos_{num_inference_steps}_steps": artifacts}
935
+ self.tracker.log_artifacts(logs, global_step)
936
+ elif self.rank_in_sp_group == 0:
937
+ # Other sp_group leaders send their results to global rank 0
938
+ world_group.send_object(step_videos, dst=0)
939
+ world_group.send_object(step_captions, dst=0)
940
+ world_group.send_object(step_audio, dst=0)
941
+ world_group.send_object(step_sample_rates, dst=0)
942
+
943
+ world_group.barrier()
944
+
945
+ # Re-enable gradients for training
946
+ training_args.inference_mode = False
947
+ self.transformer.train()
948
+ if getattr(self, "transformer_2", None) is not None:
949
+ self.transformer_2.train()
950
+
951
+ @staticmethod
952
+ def _mux_audio(
953
+ video_path: str,
954
+ audio: torch.Tensor | np.ndarray,
955
+ sample_rate: int,
956
+ ) -> bool:
957
+ """Mux audio into video using PyAV."""
958
+ try:
959
+ import av
960
+ except ImportError:
961
+ logger.warning("PyAV not installed; cannot mux audio. "
962
+ "Install with: pip install av")
963
+ return False
964
+
965
+ if torch.is_tensor(audio):
966
+ audio_np = audio.detach().cpu().float().numpy()
967
+ else:
968
+ audio_np = np.asarray(audio, dtype=np.float32)
969
+
970
+ if audio_np.ndim == 1:
971
+ audio_np = audio_np[:, None]
972
+ elif audio_np.ndim == 2:
973
+ if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
974
+ audio_np = audio_np.T
975
+ else:
976
+ logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
977
+ return False
978
+
979
+ audio_np = np.clip(audio_np, -1.0, 1.0)
980
+ audio_int16 = (audio_np * 32767.0).astype(np.int16)
981
+ num_channels = audio_int16.shape[1]
982
+ layout = "stereo" if num_channels == 2 else "mono"
983
+
984
+ try:
985
+ import wave
986
+ with tempfile.TemporaryDirectory() as tmpdir:
987
+ out_path = os.path.join(tmpdir, "muxed.mp4")
988
+ wav_path = os.path.join(tmpdir, "audio.wav")
989
+
990
+ # Write audio to WAV file
991
+ with wave.open(wav_path, "wb") as wav_file:
992
+ wav_file.setnchannels(num_channels)
993
+ wav_file.setsampwidth(2)
994
+ wav_file.setframerate(sample_rate)
995
+ wav_file.writeframes(audio_int16.tobytes())
996
+
997
+ # Open input video and audio
998
+ input_video = av.open(video_path)
999
+ input_audio = av.open(wav_path)
1000
+
1001
+ # Create output with both streams
1002
+ output = av.open(out_path, mode="w")
1003
+
1004
+ # Add video stream (copy codec from input)
1005
+ in_video_stream = input_video.streams.video[0]
1006
+ out_video_stream = output.add_stream(
1007
+ codec_name=in_video_stream.codec_context.name,
1008
+ rate=in_video_stream.average_rate,
1009
+ )
1010
+ out_video_stream.width = in_video_stream.width
1011
+ out_video_stream.height = in_video_stream.height
1012
+ out_video_stream.pix_fmt = in_video_stream.pix_fmt
1013
+
1014
+ # Add audio stream (AAC)
1015
+ out_audio_stream = output.add_stream("aac", rate=sample_rate)
1016
+ out_audio_stream.layout = layout
1017
+
1018
+ # Remux video (decode and re-encode to be safe)
1019
+ for frame in input_video.decode(video=0):
1020
+ for packet in out_video_stream.encode(frame):
1021
+ output.mux(packet)
1022
+ for packet in out_video_stream.encode():
1023
+ output.mux(packet)
1024
+
1025
+ # Encode audio
1026
+ for frame in input_audio.decode(audio=0):
1027
+ frame.pts = None # Let encoder assign PTS
1028
+ for packet in out_audio_stream.encode(frame):
1029
+ output.mux(packet)
1030
+ for packet in out_audio_stream.encode():
1031
+ output.mux(packet)
1032
+
1033
+ input_video.close()
1034
+ input_audio.close()
1035
+ output.close()
1036
+ shutil.move(out_path, video_path)
1037
+ return True
1038
+ except Exception as e:
1039
+ logger.warning("Audio mux failed: %s", e)
1040
+ return False
1041
+
1042
+ def visualize_intermediate_latents(self, training_batch: TrainingBatch, training_args: TrainingArgs, step: int):
1043
+ """Add visualization data to tracker logging and save frames to disk."""
1044
+ raise NotImplementedError("Visualize intermediate latents is not implemented for training pipeline")
backend_snapshot/fastvideo/training/wan_training_pipeline.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import sys
3
+ from copy import deepcopy
4
+
5
+ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
6
+ from fastvideo.logger import init_logger
7
+ from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
8
+ from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline
9
+ from fastvideo.training.training_pipeline import TrainingPipeline
10
+ from fastvideo.utils import is_vsa_available
11
+
12
+ try:
13
+ vsa_available = is_vsa_available()
14
+ except Exception:
15
+ vsa_available = False
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class WanTrainingPipeline(TrainingPipeline):
21
+ """
22
+ A training pipeline for Wan.
23
+ """
24
+ _required_config_modules = ["scheduler", "transformer", "vae"]
25
+
26
+ def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
27
+ self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
28
+
29
+ def create_training_stages(self, training_args: TrainingArgs):
30
+ """
31
+ May be used in future refactors.
32
+ """
33
+ pass
34
+
35
+ def initialize_validation_pipeline(self, training_args: TrainingArgs):
36
+ logger.info("Initializing validation pipeline...")
37
+ args_copy = deepcopy(training_args)
38
+
39
+ args_copy.inference_mode = True
40
+ validation_pipeline = WanPipeline.from_pretrained(
41
+ training_args.model_path,
42
+ args=args_copy, # type: ignore
43
+ inference_mode=True,
44
+ loaded_modules={
45
+ "transformer": self.get_module("transformer"),
46
+ },
47
+ tp_size=training_args.tp_size,
48
+ sp_size=training_args.sp_size,
49
+ num_gpus=training_args.num_gpus,
50
+ pin_cpu_memory=training_args.pin_cpu_memory,
51
+ dit_cpu_offload=True)
52
+
53
+ self.validation_pipeline = validation_pipeline
54
+
55
+
56
+ def main(args) -> None:
57
+ logger.info("Starting training pipeline...")
58
+
59
+ pipeline = WanTrainingPipeline.from_pretrained(args.pretrained_model_name_or_path, args=args)
60
+ args = pipeline.training_args
61
+ pipeline.train()
62
+ logger.info("Training pipeline done")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ argv = sys.argv
67
+ from fastvideo.fastvideo_args import TrainingArgs
68
+ from fastvideo.utils import FlexibleArgumentParser
69
+ parser = FlexibleArgumentParser()
70
+ parser = TrainingArgs.add_cli_args(parser)
71
+ parser = FastVideoArgs.add_cli_args(parser)
72
+ args = parser.parse_args()
73
+ args.dit_cpu_offload = False
74
+ main(args)
backend_snapshot/manifest.sha256 ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 45ff4b677a84fad92bd2ff596bf432cb1b9386c5923b6c0824f896074e7cfbc6 ./README.md
2
+ 9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
3
+ 211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
4
+ 3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
5
+ 56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
6
+ 2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
7
+ a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
8
+ 79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
9
+ ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
10
+ 6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
11
+ 489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
12
+ c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
13
+ 2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
14
+ bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
15
+ 1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
16
+ 5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
17
+ 5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
backend_snapshot/scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=sfp4-s09-oursp-i2050
3
+ #SBATCH --account=nvr_elm_llm
4
+ #SBATCH --partition=interactive
5
+ #SBATCH --nodes=1
6
+ #SBATCH --gres=gpu:8
7
+ #SBATCH --ntasks-per-node=1
8
+ #SBATCH --cpus-per-task=128
9
+ #SBATCH --mem=1440G
10
+ #SBATCH --time=02:00:00
11
+ #SBATCH --output=slurm_logs/sfp4_sparse09_ours_p_init2050_1n_interactive_%j.out
12
+ #SBATCH --error=slurm_logs/sfp4_sparse09_ours_p_init2050_1n_interactive_%j.err
13
+
14
+ export RUN_NAME="sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive"
15
+ export WANDB_RUN_ID="sfp4v4-sparse09-hpo-on-ours-p-init2050-1n-interactive"
16
+ export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
17
+ export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
18
+ export CHECKPOINT_LIMIT=5
19
+ export SAVE_STEPS=50
20
+ export EVAL_STEPS=50
21
+ export VALIDATION_SAMPLING_STEPS=50
22
+ export USE_SRUN=0
23
+
24
+ export VSA_SPARSITY=0.9
25
+ export VSA_INIT_SPARSITY=0.9
26
+ export VSA_WARMUP_STEPS=0
27
+ export VSA_DECAY_RATE=0.03
28
+ export VSA_DECAY_INTERVAL_STEPS=50
29
+
30
+ export INIT_WEIGHTS_FROM_SAFETENSORS="checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors"
31
+
32
+ exec bash scripts/training/run_sparse_fp4_train_v4_common.sh
backend_snapshot/scripts/training/run_sparse_fp4_train_v4_common.sh ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -euo pipefail
4
+ set -x
5
+
6
+ cd /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo
7
+ source .venv/bin/activate
8
+
9
+ : "${RUN_NAME:?RUN_NAME must be set by the Slurm wrapper}"
10
+ : "${WANDB_RUN_ID:?WANDB_RUN_ID must be set by the Slurm wrapper}"
11
+
12
+ export PYTHONPATH=fastvideo-kernel/python:fastvideo-kernel:${PYTHONPATH:-}
13
+ export FASTVIDEO_ATTENTION_BACKEND="${FASTVIDEO_ATTENTION_BACKEND:-SPARSE_FP4_ATTN}"
14
+ export FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK="${FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK:-1}"
15
+ export WANDB_MODE=online
16
+ export WANDB_BASE_URL="https://api.wandb.ai"
17
+ export WANDB_RESUME=allow
18
+ export WANDB_NAME="${RUN_NAME}"
19
+ export TOKENIZERS_PARALLELISM=false
20
+ export NCCL_P2P_DISABLE=1
21
+ export TORCH_NCCL_ENABLE_MONITORING=0
22
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
23
+ export TRITON_CACHE_DIR="/tmp/triton_cache_${SLURM_JOB_ID:-manual}"
24
+
25
+ if [[ -n "${SLURM_JOB_NODELIST:-}" ]]; then
26
+ MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1)
27
+ else
28
+ MASTER_ADDR=127.0.0.1
29
+ fi
30
+ MASTER_PORT=$((20000 + (${SLURM_JOB_ID:-0} % 20000)))
31
+ export MASTER_ADDR MASTER_PORT
32
+
33
+ NUM_GPUS_PER_NODE=8
34
+ NNODES=${SLURM_NNODES:-1}
35
+ TOTAL_GPUS=$((NNODES * NUM_GPUS_PER_NODE))
36
+ OUTPUT_DIR="checkpoints/${RUN_NAME}"
37
+ CHECKPOINT_LIMIT="${CHECKPOINT_LIMIT:-5}"
38
+ MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
39
+ DATA_DIR="data/Wan-Syn_77x448x832_600k"
40
+ VALIDATION_DATASET_FILE="examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_64.json"
41
+ SAVE_STEPS="${SAVE_STEPS:-50}"
42
+ EVAL_STEPS="${EVAL_STEPS:-50}"
43
+ VALIDATION_SAMPLING_STEPS="${VALIDATION_SAMPLING_STEPS:-50}"
44
+ MAX_TRAIN_STEPS="${MAX_TRAIN_STEPS:-100000}"
45
+ VSA_SPARSITY="${VSA_SPARSITY:-0.9}"
46
+ VSA_INIT_SPARSITY="${VSA_INIT_SPARSITY:-0.6}"
47
+ VSA_WARMUP_STEPS="${VSA_WARMUP_STEPS:-0}"
48
+ VSA_DECAY_RATE="${VSA_DECAY_RATE:-0.03}"
49
+ VSA_DECAY_INTERVAL_STEPS="${VSA_DECAY_INTERVAL_STEPS:-50}"
50
+ INIT_WEIGHTS_FROM_SAFETENSORS="${INIT_WEIGHTS_FROM_SAFETENSORS:-}"
51
+
52
+ mkdir -p slurm_logs "${OUTPUT_DIR}"
53
+
54
+ find_latest_checkpoint() {
55
+ if [[ ! -d "${OUTPUT_DIR}" ]]; then
56
+ return 1
57
+ fi
58
+
59
+ mapfile -t checkpoint_steps < <(find "${OUTPUT_DIR}" -maxdepth 1 -type d -name 'checkpoint-*' -printf '%f\n' \
60
+ | sed 's/checkpoint-//' \
61
+ | sort -nr)
62
+
63
+ local step
64
+ for step in "${checkpoint_steps[@]}"; do
65
+ if [[ -f "${OUTPUT_DIR}/checkpoint-${step}/transformer/diffusion_pytorch_model.safetensors" ]]; then
66
+ echo "${OUTPUT_DIR}/checkpoint-${step}"
67
+ return 0
68
+ fi
69
+ done
70
+
71
+ return 1
72
+ }
73
+
74
+ prune_checkpoints() {
75
+ if [[ ! -d "${OUTPUT_DIR}" ]]; then
76
+ return 0
77
+ fi
78
+
79
+ mapfile -t checkpoint_steps < <(find "${OUTPUT_DIR}" -maxdepth 1 -type d -name 'checkpoint-*' -printf '%f\n' \
80
+ | sed 's/checkpoint-//' \
81
+ | sort -n)
82
+
83
+ local count=${#checkpoint_steps[@]}
84
+ if (( count <= CHECKPOINT_LIMIT )); then
85
+ return 0
86
+ fi
87
+
88
+ local remove_count=$((count - CHECKPOINT_LIMIT))
89
+ local step
90
+ for step in "${checkpoint_steps[@]:0:remove_count}"; do
91
+ rm -rf "${OUTPUT_DIR}/checkpoint-${step}"
92
+ done
93
+ }
94
+
95
+ RESUME_ARGS=()
96
+ refresh_resume_args() {
97
+ RESUME_ARGS=()
98
+
99
+ local latest_ckpt
100
+ if latest_ckpt=$(find_latest_checkpoint); then
101
+ RESUME_ARGS=(
102
+ --resume_from_checkpoint latest
103
+ --init_weights_from_safetensors "${latest_ckpt}/transformer/diffusion_pytorch_model.safetensors"
104
+ )
105
+ echo "=== Resuming from ${latest_ckpt} ==="
106
+ fi
107
+ }
108
+
109
+ COMMON_ARGS=(
110
+ --tracker_project_name "wan_t2v_sparse_fp4"
111
+ --wandb_run_name "${RUN_NAME}"
112
+ --output_dir "${OUTPUT_DIR}"
113
+ --train_batch_size 1
114
+ --train_sp_batch_size 1
115
+ --gradient_accumulation_steps 1
116
+ --num_latent_t 20
117
+ --num_height 448
118
+ --num_width 832
119
+ --num_frames 77
120
+ --enable_gradient_checkpointing_type "full"
121
+ --num_gpus "${TOTAL_GPUS}"
122
+ --sp_size 1
123
+ --tp_size 1
124
+ --hsdp_replicate_dim "${TOTAL_GPUS}"
125
+ --hsdp_shard_dim 1
126
+ --model_path "${MODEL_PATH}"
127
+ --pretrained_model_name_or_path "${MODEL_PATH}"
128
+ --data_path "${DATA_DIR}"
129
+ --dataloader_num_workers 4
130
+ --log_validation
131
+ --validation_dataset_file "${VALIDATION_DATASET_FILE}"
132
+ --validation_steps "${EVAL_STEPS}"
133
+ --validation_sampling_steps "${VALIDATION_SAMPLING_STEPS}"
134
+ --validation_guidance_scale "5.0"
135
+ --learning_rate 1e-6
136
+ --mixed_precision "bf16"
137
+ --weight_only_checkpointing_steps "${SAVE_STEPS}"
138
+ --training_state_checkpointing_steps "${SAVE_STEPS}"
139
+ --weight_decay 0.01
140
+ --max_grad_norm 1.0
141
+ --inference_mode False
142
+ --checkpoints_total_limit "${CHECKPOINT_LIMIT}"
143
+ --training_cfg_rate 0.1
144
+ --dit_precision "fp32"
145
+ --ema_start_step 0
146
+ --flow_shift 1
147
+ --seed 1000
148
+ --VSA-sparsity "${VSA_SPARSITY}"
149
+ --VSA-init-sparsity "${VSA_INIT_SPARSITY}"
150
+ --VSA-warmup-steps "${VSA_WARMUP_STEPS}"
151
+ --VSA-decay-rate "${VSA_DECAY_RATE}"
152
+ --VSA-decay-interval-steps "${VSA_DECAY_INTERVAL_STEPS}"
153
+ )
154
+
155
+ if [[ -n "${INIT_WEIGHTS_FROM_SAFETENSORS}" ]]; then
156
+ COMMON_ARGS+=(
157
+ --init_weights_from_safetensors "${INIT_WEIGHTS_FROM_SAFETENSORS}"
158
+ )
159
+ fi
160
+
161
+ run_training() {
162
+ local max_steps=$1
163
+
164
+ local torchrun_cmd=(
165
+ torchrun
166
+ --nnodes="${NNODES}" \
167
+ --nproc_per_node="${NUM_GPUS_PER_NODE}" \
168
+ --rdzv_backend=c10d \
169
+ --rdzv_endpoint="${MASTER_ADDR}:${MASTER_PORT}" \
170
+ fastvideo/training/wan_training_pipeline.py \
171
+ --max_train_steps "${max_steps}" \
172
+ "${COMMON_ARGS[@]}" \
173
+ "${RESUME_ARGS[@]}"
174
+ )
175
+
176
+ if [[ "${USE_SRUN:-1}" == "1" ]]; then
177
+ srun "${torchrun_cmd[@]}"
178
+ else
179
+ "${torchrun_cmd[@]}"
180
+ fi
181
+ }
182
+
183
+ echo "=== ${RUN_NAME} array=${SLURM_ARRAY_TASK_ID:-0} nodes=${NNODES} gpus=${TOTAL_GPUS} ==="
184
+ echo "=== save_steps=${SAVE_STEPS} eval_steps=${EVAL_STEPS} checkpoint_limit=${CHECKPOINT_LIMIT} ==="
185
+ echo "=== master=${MASTER_ADDR}:${MASTER_PORT} validation_one_prompt_per_rank=1 ==="
186
+
187
+ if [[ ! -f "${OUTPUT_DIR}/checkpoint-0/transformer/diffusion_pytorch_model.safetensors" ]]; then
188
+ if ! find_latest_checkpoint >/dev/null; then
189
+ echo "=== Creating step-0 validation and checkpoint ==="
190
+ run_training 0
191
+ prune_checkpoints
192
+ fi
193
+ fi
194
+
195
+ refresh_resume_args
196
+ run_training "${MAX_TRAIN_STEPS}"
197
+ prune_checkpoints
198
+
199
+ echo "=== Done arr=${SLURM_ARRAY_TASK_ID:-0} ==="