Upload backend snapshot for sfp4 checkpoint-750
Browse files- backend_snapshot/README.md +35 -0
- backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py +270 -0
- backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py +1155 -0
- backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py +250 -0
- backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py +80 -0
- backend_snapshot/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py +192 -0
- backend_snapshot/fastvideo/attention/backends/video_sparse_attn.py +262 -0
- backend_snapshot/fastvideo/configs/models/dits/base.py +79 -0
- backend_snapshot/fastvideo/forward_context.py +100 -0
- backend_snapshot/fastvideo/pipelines/stages/denoising.py +1184 -0
- backend_snapshot/fastvideo/platforms/cuda.py +440 -0
- backend_snapshot/fastvideo/platforms/interface.py +255 -0
- backend_snapshot/fastvideo/train/models/wan/wan.py +680 -0
- backend_snapshot/fastvideo/training/training_pipeline.py +1044 -0
- backend_snapshot/fastvideo/training/wan_training_pipeline.py +74 -0
- backend_snapshot/manifest.sha256 +17 -0
- backend_snapshot/scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh +32 -0
- backend_snapshot/scripts/training/run_sparse_fp4_train_v4_common.sh +199 -0
backend_snapshot/README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend snapshot for checkpoint-750
|
| 2 |
+
|
| 3 |
+
This directory is the code snapshot for the training backend used by:
|
| 4 |
+
|
| 5 |
+
`sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive/checkpoint-750`
|
| 6 |
+
|
| 7 |
+
Key runtime settings:
|
| 8 |
+
|
| 9 |
+
- `FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN`
|
| 10 |
+
- `FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1`
|
| 11 |
+
- `VSA_SPARSITY=0.9`
|
| 12 |
+
- `VSA_INIT_SPARSITY=0.9`
|
| 13 |
+
- `VSA_WARMUP_STEPS=0`
|
| 14 |
+
- tile size: `4 x 4 x 4 = 64` video tokens
|
| 15 |
+
|
| 16 |
+
Important files:
|
| 17 |
+
|
| 18 |
+
- `fastvideo/attention/backends/sparse_fp4_ours_p_attn.py`: Python attention backend, Q/K/V fake quantization, top-k block map, tile mean setup.
|
| 19 |
+
- `fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py`: PyTorch custom op and autograd wrapper.
|
| 20 |
+
- `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py`: Triton forward/backward kernel.
|
| 21 |
+
- `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py`: FP4 quant/dequant utilities used by the kernel.
|
| 22 |
+
- `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py`: Q/K/V fake quant kernels.
|
| 23 |
+
- `fastvideo/attention/backends/video_sparse_attn.py`: VSA metadata and tile-size helper.
|
| 24 |
+
- `fastvideo/platforms/interface.py` and `fastvideo/platforms/cuda.py`: backend enum and CUDA backend selection wiring.
|
| 25 |
+
- `fastvideo/training/training_pipeline.py` and `fastvideo/training/wan_training_pipeline.py`: legacy SFT training path used by the launch script.
|
| 26 |
+
- `scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh`: exact Slurm wrapper for this run.
|
| 27 |
+
- `scripts/training/run_sparse_fp4_train_v4_common.sh`: common SFT launch/resume script.
|
| 28 |
+
|
| 29 |
+
Source repo HEAD when staged:
|
| 30 |
+
|
| 31 |
+
`3f818d0fc532ec6494b465967d5f485150917d0c`
|
| 32 |
+
|
| 33 |
+
Note: several backend files were uncommitted or locally modified when this
|
| 34 |
+
snapshot was staged, so the files here are the authoritative copy for this
|
| 35 |
+
checkpoint rather than the clean git commit alone.
|
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _use_high_prec_output_for_backward() -> bool:
|
| 9 |
+
value = os.environ.get("FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O", "1")
|
| 10 |
+
return value.lower() not in ("0", "false", "no", "off")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _map_to_index(block_map: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 14 |
+
if block_map.dim() == 3:
|
| 15 |
+
block_map = block_map.unsqueeze(0)
|
| 16 |
+
if block_map.dim() != 4:
|
| 17 |
+
raise ValueError(
|
| 18 |
+
f"block_map must be [B,H,Q,KV] or [H,Q,KV], got {tuple(block_map.shape)}"
|
| 19 |
+
)
|
| 20 |
+
if block_map.dtype != torch.bool:
|
| 21 |
+
block_map = block_map.to(torch.bool)
|
| 22 |
+
if not block_map.is_cuda:
|
| 23 |
+
raise RuntimeError("block_map must be a CUDA tensor.")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from fastvideo_kernel.triton_kernels.index import map_to_index as triton_map_to_index
|
| 27 |
+
except Exception as e:
|
| 28 |
+
raise ImportError("Triton map_to_index is required for ours-P Sparse FP4.") from e
|
| 29 |
+
return triton_map_to_index(block_map)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.library.custom_op(
|
| 33 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_triton",
|
| 34 |
+
mutates_args=(),
|
| 35 |
+
device_types="cuda",
|
| 36 |
+
)
|
| 37 |
+
def block_sparse_attn_ours_p_triton(
|
| 38 |
+
q: torch.Tensor,
|
| 39 |
+
k: torch.Tensor,
|
| 40 |
+
v: torch.Tensor,
|
| 41 |
+
block_map: torch.Tensor,
|
| 42 |
+
variable_block_sizes: torch.Tensor,
|
| 43 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 44 |
+
q = q.contiguous()
|
| 45 |
+
k = k.contiguous()
|
| 46 |
+
v = v.contiguous()
|
| 47 |
+
block_map = block_map.to(torch.bool)
|
| 48 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 49 |
+
|
| 50 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 51 |
+
triton_block_sparse_attn_forward,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return triton_block_sparse_attn_forward(
|
| 55 |
+
q, k, v, q2k_idx, q2k_num, variable_block_sizes, is_qat=True
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@torch.library.register_fake("fastvideo_kernel::block_sparse_attn_ours_p_triton")
|
| 60 |
+
def _block_sparse_attn_ours_p_triton_fake(
|
| 61 |
+
q: torch.Tensor,
|
| 62 |
+
k: torch.Tensor,
|
| 63 |
+
v: torch.Tensor,
|
| 64 |
+
block_map: torch.Tensor,
|
| 65 |
+
variable_block_sizes: torch.Tensor,
|
| 66 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 67 |
+
o = torch.empty_like(q)
|
| 68 |
+
high_prec_o = torch.empty_like(q)
|
| 69 |
+
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
| 70 |
+
return o, M, high_prec_o
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@torch.library.custom_op(
|
| 74 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_backward_triton",
|
| 75 |
+
mutates_args=(),
|
| 76 |
+
device_types="cuda",
|
| 77 |
+
)
|
| 78 |
+
def block_sparse_attn_ours_p_backward_triton(
|
| 79 |
+
grad_output: torch.Tensor,
|
| 80 |
+
q: torch.Tensor,
|
| 81 |
+
k: torch.Tensor,
|
| 82 |
+
v: torch.Tensor,
|
| 83 |
+
o: torch.Tensor,
|
| 84 |
+
M: torch.Tensor,
|
| 85 |
+
block_map: torch.Tensor,
|
| 86 |
+
variable_block_sizes: torch.Tensor,
|
| 87 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 88 |
+
grad_output = grad_output.contiguous()
|
| 89 |
+
block_map = block_map.to(torch.bool)
|
| 90 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 91 |
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
| 92 |
+
|
| 93 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 94 |
+
triton_block_sparse_attn_backward,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return triton_block_sparse_attn_backward(
|
| 98 |
+
grad_output,
|
| 99 |
+
q,
|
| 100 |
+
k,
|
| 101 |
+
v,
|
| 102 |
+
o,
|
| 103 |
+
M,
|
| 104 |
+
q2k_idx,
|
| 105 |
+
q2k_num,
|
| 106 |
+
k2q_idx,
|
| 107 |
+
k2q_num,
|
| 108 |
+
variable_block_sizes,
|
| 109 |
+
is_qat=True,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@torch.library.register_fake(
|
| 114 |
+
"fastvideo_kernel::block_sparse_attn_ours_p_backward_triton"
|
| 115 |
+
)
|
| 116 |
+
def _block_sparse_attn_ours_p_backward_triton_fake(
|
| 117 |
+
grad_output: torch.Tensor,
|
| 118 |
+
q: torch.Tensor,
|
| 119 |
+
k: torch.Tensor,
|
| 120 |
+
v: torch.Tensor,
|
| 121 |
+
o: torch.Tensor,
|
| 122 |
+
M: torch.Tensor,
|
| 123 |
+
block_map: torch.Tensor,
|
| 124 |
+
variable_block_sizes: torch.Tensor,
|
| 125 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 126 |
+
return torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _backward_triton(ctx, grad_o, grad_M, grad_high_prec_o):
|
| 130 |
+
q, k, v, o_for_bwd, M, block_map, variable_block_sizes = ctx.saved_tensors
|
| 131 |
+
dq, dk, dv = block_sparse_attn_ours_p_backward_triton(
|
| 132 |
+
grad_o, q, k, v, o_for_bwd, M, block_map, variable_block_sizes
|
| 133 |
+
)
|
| 134 |
+
return dq, dk, dv, None, None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _setup_context_triton(ctx, inputs, output):
|
| 138 |
+
q, k, v, block_map, variable_block_sizes = inputs
|
| 139 |
+
o, M, high_prec_o = output
|
| 140 |
+
o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
|
| 141 |
+
ctx.save_for_backward(q, k, v, o_for_bwd, M, block_map, variable_block_sizes)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
block_sparse_attn_ours_p_triton.register_autograd(
|
| 145 |
+
_backward_triton, setup_context=_setup_context_triton
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class _BlockSparseAttnOursPTileComp(torch.autograd.Function):
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def forward(ctx, q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes):
|
| 153 |
+
q = q.contiguous()
|
| 154 |
+
k = k.contiguous()
|
| 155 |
+
v = v.contiguous()
|
| 156 |
+
q_mean = q_mean.contiguous()
|
| 157 |
+
k_mean = k_mean.contiguous()
|
| 158 |
+
v_mean = v_mean.contiguous()
|
| 159 |
+
block_map = block_map.to(torch.bool)
|
| 160 |
+
dropped_block_map = torch.logical_not(block_map)
|
| 161 |
+
|
| 162 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 163 |
+
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
|
| 164 |
+
|
| 165 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 166 |
+
triton_block_sparse_attn_forward,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
o, M, high_prec_o = triton_block_sparse_attn_forward(
|
| 170 |
+
q,
|
| 171 |
+
k,
|
| 172 |
+
v,
|
| 173 |
+
q2k_idx,
|
| 174 |
+
q2k_num,
|
| 175 |
+
variable_block_sizes,
|
| 176 |
+
is_qat=True,
|
| 177 |
+
q_mean=q_mean,
|
| 178 |
+
k_mean=k_mean,
|
| 179 |
+
v_mean=v_mean,
|
| 180 |
+
dropped_q2k_index=dropped_q2k_idx,
|
| 181 |
+
dropped_q2k_num=dropped_q2k_num,
|
| 182 |
+
)
|
| 183 |
+
o_for_bwd = high_prec_o if _use_high_prec_output_for_backward() else o
|
| 184 |
+
ctx.save_for_backward(
|
| 185 |
+
q,
|
| 186 |
+
k,
|
| 187 |
+
v,
|
| 188 |
+
q_mean,
|
| 189 |
+
k_mean,
|
| 190 |
+
v_mean,
|
| 191 |
+
o_for_bwd,
|
| 192 |
+
M,
|
| 193 |
+
block_map,
|
| 194 |
+
dropped_block_map,
|
| 195 |
+
variable_block_sizes,
|
| 196 |
+
)
|
| 197 |
+
return o, M
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def backward(ctx, grad_o, grad_M):
|
| 201 |
+
(
|
| 202 |
+
q,
|
| 203 |
+
k,
|
| 204 |
+
v,
|
| 205 |
+
q_mean,
|
| 206 |
+
k_mean,
|
| 207 |
+
v_mean,
|
| 208 |
+
o_for_bwd,
|
| 209 |
+
M,
|
| 210 |
+
block_map,
|
| 211 |
+
dropped_block_map,
|
| 212 |
+
variable_block_sizes,
|
| 213 |
+
) = ctx.saved_tensors
|
| 214 |
+
|
| 215 |
+
q2k_idx, q2k_num = _map_to_index(block_map)
|
| 216 |
+
k2q_idx, k2q_num = _map_to_index(block_map.transpose(-1, -2).contiguous())
|
| 217 |
+
dropped_q2k_idx, dropped_q2k_num = _map_to_index(dropped_block_map)
|
| 218 |
+
dropped_k2q_idx, dropped_k2q_num = _map_to_index(
|
| 219 |
+
dropped_block_map.transpose(-1, -2).contiguous()
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
from fastvideo_kernel.triton_kernels.block_sparse_attn_triton_ours_p import (
|
| 223 |
+
triton_block_sparse_attn_backward,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
dq, dk, dv = triton_block_sparse_attn_backward(
|
| 227 |
+
grad_o.contiguous(),
|
| 228 |
+
q,
|
| 229 |
+
k,
|
| 230 |
+
v,
|
| 231 |
+
o_for_bwd,
|
| 232 |
+
M,
|
| 233 |
+
q2k_idx,
|
| 234 |
+
q2k_num,
|
| 235 |
+
k2q_idx,
|
| 236 |
+
k2q_num,
|
| 237 |
+
variable_block_sizes,
|
| 238 |
+
is_qat=True,
|
| 239 |
+
q_mean=q_mean,
|
| 240 |
+
k_mean=k_mean,
|
| 241 |
+
v_mean=v_mean,
|
| 242 |
+
dropped_q2k_index=dropped_q2k_idx,
|
| 243 |
+
dropped_q2k_num=dropped_q2k_num,
|
| 244 |
+
dropped_k2q_index=dropped_k2q_idx,
|
| 245 |
+
dropped_k2q_num=dropped_k2q_num,
|
| 246 |
+
)
|
| 247 |
+
return dq, dk, dv, None, None, None, None, None
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def block_sparse_attn_ours_p(
|
| 251 |
+
q: torch.Tensor,
|
| 252 |
+
k: torch.Tensor,
|
| 253 |
+
v: torch.Tensor,
|
| 254 |
+
block_map: torch.Tensor,
|
| 255 |
+
variable_block_sizes: torch.Tensor,
|
| 256 |
+
q_mean: torch.Tensor | None = None,
|
| 257 |
+
k_mean: torch.Tensor | None = None,
|
| 258 |
+
v_mean: torch.Tensor | None = None,
|
| 259 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
if (q_mean is not None) or (k_mean is not None) or (v_mean is not None):
|
| 261 |
+
if q_mean is None or k_mean is None or v_mean is None:
|
| 262 |
+
raise ValueError("q_mean, k_mean, and v_mean must be provided together")
|
| 263 |
+
return _BlockSparseAttnOursPTileComp.apply(
|
| 264 |
+
q, k, v, q_mean, k_mean, v_mean, block_map, variable_block_sizes
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
o, M, _ = block_sparse_attn_ours_p_triton(
|
| 268 |
+
q, k, v, block_map, variable_block_sizes
|
| 269 |
+
)
|
| 270 |
+
return o, M
|
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
ADDED
|
@@ -0,0 +1,1155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fused Attention
|
| 3 |
+
===============
|
| 4 |
+
|
| 5 |
+
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
| 6 |
+
(https://tridao.me/publications/flash2/flash2.pdf)
|
| 7 |
+
|
| 8 |
+
Credits: OpenAI kernel team
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import triton
|
| 13 |
+
import triton.language as tl
|
| 14 |
+
from .quant_utils import fake_quantize
|
| 15 |
+
|
| 16 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 17 |
+
import math # small utility needed by the sparse wrapper
|
| 18 |
+
# ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
|
| 19 |
+
|
| 20 |
+
# We don't run auto-tuning every time to keep the tutorial fast. Keeping
|
| 21 |
+
# the code below and commenting out the equivalent parameters is convenient for
|
| 22 |
+
# re-tuning.
|
| 23 |
+
configs = [
|
| 24 |
+
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
|
| 25 |
+
for BM in [64]\
|
| 26 |
+
for BN in [64]\
|
| 27 |
+
for s in [3, 4, 7]\
|
| 28 |
+
for w in [4, 8]\
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 33 |
+
@triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
|
| 34 |
+
@triton.jit
|
| 35 |
+
def _attn_fwd_sparse(
|
| 36 |
+
Q,
|
| 37 |
+
K,
|
| 38 |
+
V,
|
| 39 |
+
QMean,
|
| 40 |
+
KMean,
|
| 41 |
+
VMean,
|
| 42 |
+
sm_scale, #
|
| 43 |
+
q2k_index,
|
| 44 |
+
q2k_num,
|
| 45 |
+
max_kv_blks, #
|
| 46 |
+
dropped_q2k_index,
|
| 47 |
+
dropped_q2k_num,
|
| 48 |
+
max_dropped_kv_blks, #
|
| 49 |
+
variable_block_sizes,
|
| 50 |
+
M,
|
| 51 |
+
Out, #
|
| 52 |
+
HighPrecOut, #
|
| 53 |
+
stride_qz,
|
| 54 |
+
stride_qh,
|
| 55 |
+
stride_qm,
|
| 56 |
+
stride_qk,
|
| 57 |
+
stride_kz,
|
| 58 |
+
stride_kh,
|
| 59 |
+
stride_kn,
|
| 60 |
+
stride_kk,
|
| 61 |
+
stride_vz,
|
| 62 |
+
stride_vh,
|
| 63 |
+
stride_vk,
|
| 64 |
+
stride_vn,
|
| 65 |
+
stride_oz,
|
| 66 |
+
stride_oh,
|
| 67 |
+
stride_om,
|
| 68 |
+
stride_on,
|
| 69 |
+
Z,
|
| 70 |
+
H,
|
| 71 |
+
N_CTX_Q, #
|
| 72 |
+
N_CTX_KV, #
|
| 73 |
+
HEAD_DIM: tl.constexpr, #
|
| 74 |
+
BLOCK_M: tl.constexpr,
|
| 75 |
+
BLOCK_N: tl.constexpr,
|
| 76 |
+
STAGE: tl.constexpr,
|
| 77 |
+
IS_QAT: tl.constexpr = False,
|
| 78 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 79 |
+
"""
|
| 80 |
+
64x64 block-sparse forward kernel for the independent "ours P quant" path.
|
| 81 |
+
|
| 82 |
+
P quantization is group-local: each selected KV tile quantizes
|
| 83 |
+
exp2(logit - tile_row_max), then applies exp2(tile_row_max - online_max)
|
| 84 |
+
after the FP4 PV GEMM. This intentionally differs from the QAT-style
|
| 85 |
+
backend, which quantizes exp2(logit - online_max) directly.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
# ----- program-id mapping -----
|
| 89 |
+
q_blk = tl.program_id(0) # Q-tile index
|
| 90 |
+
off_hz = tl.program_id(1) # fused (batch, head)
|
| 91 |
+
b = off_hz // H
|
| 92 |
+
h = off_hz % H
|
| 93 |
+
q_tiles = N_CTX_Q // BLOCK_M
|
| 94 |
+
meta_base = ((b * H + h) * q_tiles + q_blk)
|
| 95 |
+
|
| 96 |
+
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
| 97 |
+
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
| 98 |
+
dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
|
| 99 |
+
dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
|
| 100 |
+
|
| 101 |
+
# ----- base pointers -----
|
| 102 |
+
q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 103 |
+
k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 104 |
+
v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 105 |
+
o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
|
| 106 |
+
|
| 107 |
+
Q_ptr = tl.make_block_ptr(base=Q + q_off,
|
| 108 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 109 |
+
strides=(stride_qm, stride_qk),
|
| 110 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 111 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 112 |
+
order=(1, 0))
|
| 113 |
+
|
| 114 |
+
K_base = tl.make_block_ptr(base=K + k_off,
|
| 115 |
+
shape=(HEAD_DIM, N_CTX_KV),
|
| 116 |
+
strides=(stride_kk, stride_kn),
|
| 117 |
+
offsets=(0, 0),
|
| 118 |
+
block_shape=(HEAD_DIM, BLOCK_N),
|
| 119 |
+
order=(0, 1))
|
| 120 |
+
|
| 121 |
+
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
|
| 122 |
+
0)
|
| 123 |
+
V_base = tl.make_block_ptr(base=V + v_off,
|
| 124 |
+
shape=(N_CTX_KV, HEAD_DIM),
|
| 125 |
+
strides=(stride_vk, stride_vn),
|
| 126 |
+
offsets=(0, 0),
|
| 127 |
+
block_shape=(BLOCK_N, HEAD_DIM),
|
| 128 |
+
order=v_order)
|
| 129 |
+
|
| 130 |
+
O_ptr = tl.make_block_ptr(base=Out + o_off,
|
| 131 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 132 |
+
strides=(stride_om, stride_on),
|
| 133 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 134 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 135 |
+
order=(1, 0))
|
| 136 |
+
HPO_ptr = tl.make_block_ptr(base=HighPrecOut + o_off,
|
| 137 |
+
shape=(N_CTX_Q, HEAD_DIM),
|
| 138 |
+
strides=(stride_om, stride_on),
|
| 139 |
+
offsets=(q_blk * BLOCK_M, 0),
|
| 140 |
+
block_shape=(BLOCK_M, HEAD_DIM),
|
| 141 |
+
order=(1, 0))
|
| 142 |
+
|
| 143 |
+
# ----- accumulators -----
|
| 144 |
+
offs_m = q_blk * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 145 |
+
m_i = tl.full([BLOCK_M], -float("inf"), tl.float32)
|
| 146 |
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
| 147 |
+
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
| 148 |
+
high_prec_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
| 149 |
+
qk_scale = sm_scale * 1.44269504 # 1/ln2
|
| 150 |
+
q = tl.load(Q_ptr)
|
| 151 |
+
offs_d = tl.arange(0, HEAD_DIM)
|
| 152 |
+
|
| 153 |
+
# ----- sparse loop over valid K/V tiles -----
|
| 154 |
+
for i in range(0, kv_blocks):
|
| 155 |
+
kv_idx = tl.load(kv_ptr + i).to(tl.int32)
|
| 156 |
+
block_size = tl.load(variable_block_sizes + kv_idx)
|
| 157 |
+
K_ptr = tl.advance(K_base, (0, kv_idx * BLOCK_N))
|
| 158 |
+
V_ptr = tl.advance(V_base, (kv_idx * BLOCK_N, 0))
|
| 159 |
+
|
| 160 |
+
k = tl.load(K_ptr)
|
| 161 |
+
mask = tl.arange(0, BLOCK_N) < block_size
|
| 162 |
+
qk = tl.dot(q, k) * qk_scale
|
| 163 |
+
# mask out invalid columns
|
| 164 |
+
qk = tl.where(mask[None, :], qk, -float("inf"))
|
| 165 |
+
group_m = tl.max(qk, 1)
|
| 166 |
+
m_ij = tl.maximum(m_i, group_m)
|
| 167 |
+
|
| 168 |
+
p_local = tl.math.exp2(qk - group_m[:, None])
|
| 169 |
+
p_local = tl.where(mask[None, :], p_local, 0.0)
|
| 170 |
+
p_comp = tl.math.exp2(group_m - m_ij)
|
| 171 |
+
p_valid = mask[None, :] & (
|
| 172 |
+
tl.full(shape=p_local.shape, value=1.0,
|
| 173 |
+
dtype=p_local.dtype) == 1.0
|
| 174 |
+
)
|
| 175 |
+
p_quant, high_prec_p = fake_quantize(
|
| 176 |
+
src_tensor=p_local, valid_src_mask=p_valid,
|
| 177 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=BLOCK_N,
|
| 178 |
+
dst_dtype=tl.bfloat16, use_global_sf=False,
|
| 179 |
+
)
|
| 180 |
+
l_ij = tl.sum(high_prec_p, 1) * p_comp
|
| 181 |
+
|
| 182 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
| 183 |
+
l_i = l_i * alpha + l_ij
|
| 184 |
+
acc = acc * alpha[:, None]
|
| 185 |
+
high_prec_acc = high_prec_acc * alpha[:, None]
|
| 186 |
+
|
| 187 |
+
v = tl.load(V_ptr)
|
| 188 |
+
acc = acc + tl.dot(
|
| 189 |
+
p_quant.to(tl.bfloat16),
|
| 190 |
+
v.to(tl.bfloat16),
|
| 191 |
+
) * p_comp[:, None]
|
| 192 |
+
high_prec_acc = high_prec_acc + tl.dot(
|
| 193 |
+
high_prec_p.to(tl.bfloat16),
|
| 194 |
+
v.to(tl.bfloat16),
|
| 195 |
+
) * p_comp[:, None]
|
| 196 |
+
m_i = m_ij
|
| 197 |
+
|
| 198 |
+
if USE_TILE_COMP:
|
| 199 |
+
q_mean_base = (off_hz * q_tiles + q_blk).to(tl.int64) * HEAD_DIM
|
| 200 |
+
q_mean = tl.load(QMean + q_mean_base + offs_d).to(tl.float32)
|
| 201 |
+
kv_tiles = N_CTX_KV // BLOCK_N
|
| 202 |
+
|
| 203 |
+
for i in range(0, dropped_kv_blocks):
|
| 204 |
+
kv_idx = tl.load(dropped_kv_ptr + i).to(tl.int32)
|
| 205 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
|
| 206 |
+
kv_mean_base = (off_hz * kv_tiles + kv_idx).to(tl.int64) * HEAD_DIM
|
| 207 |
+
k_mean = tl.load(KMean + kv_mean_base + offs_d).to(tl.float32)
|
| 208 |
+
v_mean = tl.load(VMean + kv_mean_base + offs_d).to(tl.float32)
|
| 209 |
+
|
| 210 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 211 |
+
m_ij = tl.maximum(m_i, score)
|
| 212 |
+
alpha = tl.math.exp2(m_i - m_ij)
|
| 213 |
+
beta = tl.math.exp2(score - m_ij)
|
| 214 |
+
|
| 215 |
+
l_i = l_i * alpha + block_size * beta
|
| 216 |
+
comp = (block_size * beta)[:, None] * v_mean[None, :]
|
| 217 |
+
acc = acc * alpha[:, None] + comp
|
| 218 |
+
high_prec_acc = high_prec_acc * alpha[:, None] + comp
|
| 219 |
+
m_i = m_ij
|
| 220 |
+
|
| 221 |
+
# ----- epilogue -----
|
| 222 |
+
m_i += tl.math.log2(l_i)
|
| 223 |
+
acc = acc / l_i[:, None]
|
| 224 |
+
high_prec_acc = high_prec_acc / l_i[:, None]
|
| 225 |
+
tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
|
| 226 |
+
tl.store(O_ptr, acc.to(Out.type.element_ty))
|
| 227 |
+
tl.store(HPO_ptr, high_prec_acc.to(HighPrecOut.type.element_ty))
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ──────────────────────────── SPARSE ADDITION END ─────────────────────────────
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@triton.jit
|
| 234 |
+
def _attn_bwd_preprocess(
|
| 235 |
+
O,
|
| 236 |
+
DO, #
|
| 237 |
+
Delta, #
|
| 238 |
+
Z,
|
| 239 |
+
H,
|
| 240 |
+
N_CTX, #
|
| 241 |
+
BLOCK_M: tl.constexpr,
|
| 242 |
+
HEAD_DIM: tl.constexpr #
|
| 243 |
+
):
|
| 244 |
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 245 |
+
off_hz = tl.program_id(1)
|
| 246 |
+
off_n = tl.arange(0, HEAD_DIM)
|
| 247 |
+
# load
|
| 248 |
+
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
|
| 249 |
+
off_n[None, :])
|
| 250 |
+
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM +
|
| 251 |
+
off_n[None, :]).to(tl.float32)
|
| 252 |
+
delta = tl.sum(o * do, axis=1)
|
| 253 |
+
# write-back
|
| 254 |
+
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# The main inner-loop logic for computing dK and dV.
|
| 258 |
+
@triton.jit
|
| 259 |
+
def _attn_bwd_dkdv(
|
| 260 |
+
dk,
|
| 261 |
+
dv, #
|
| 262 |
+
Q,
|
| 263 |
+
k,
|
| 264 |
+
v,
|
| 265 |
+
QMean,
|
| 266 |
+
KMean,
|
| 267 |
+
VMean,
|
| 268 |
+
sm_scale, #
|
| 269 |
+
DO, #
|
| 270 |
+
M,
|
| 271 |
+
D, #
|
| 272 |
+
k2q_index,
|
| 273 |
+
k2q_num,
|
| 274 |
+
max_q_blks,
|
| 275 |
+
dropped_k2q_index,
|
| 276 |
+
dropped_k2q_num,
|
| 277 |
+
max_dropped_q_blks,
|
| 278 |
+
variable_block_sizes,
|
| 279 |
+
# shared by Q/K/V/DO.
|
| 280 |
+
stride_tok,
|
| 281 |
+
stride_d, #
|
| 282 |
+
H,
|
| 283 |
+
N_CTX_KV,
|
| 284 |
+
BLOCK_M1: tl.constexpr, #
|
| 285 |
+
BLOCK_N1: tl.constexpr, #
|
| 286 |
+
HEAD_DIM: tl.constexpr, #
|
| 287 |
+
# Filled in by the wrapper.
|
| 288 |
+
start_n,
|
| 289 |
+
start_m,
|
| 290 |
+
num_steps,
|
| 291 |
+
IS_QAT: tl.constexpr = False,
|
| 292 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 293 |
+
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
| 294 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 295 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 296 |
+
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 297 |
+
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 298 |
+
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
| 299 |
+
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
| 300 |
+
step_m = BLOCK_M1
|
| 301 |
+
kv_blk = tl.program_id(0) # Q-tile index
|
| 302 |
+
off_hz = tl.program_id(2) # fused (batch, head)
|
| 303 |
+
b = off_hz // H
|
| 304 |
+
h = off_hz % H
|
| 305 |
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
| 306 |
+
meta_base = ((b * H + h) * kv_tiles + kv_blk)
|
| 307 |
+
|
| 308 |
+
q_blocks = tl.load(k2q_num + meta_base) # int32
|
| 309 |
+
q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
|
| 310 |
+
dropped_q_blocks = tl.load(dropped_k2q_num + meta_base)
|
| 311 |
+
dropped_q_ptr = dropped_k2q_index + meta_base * max_dropped_q_blks
|
| 312 |
+
block_size = tl.load(variable_block_sizes + kv_blk)
|
| 313 |
+
block_size_f = block_size.to(tl.float32)
|
| 314 |
+
|
| 315 |
+
for blk_idx in range(q_blocks * 2):
|
| 316 |
+
block_sparse_offset = (tl.load(q_ptr + blk_idx // 2).to(tl.int32) * 2 +
|
| 317 |
+
blk_idx % 2) * step_m
|
| 318 |
+
qT = tl.load(qT_ptrs + block_sparse_offset * stride_tok)
|
| 319 |
+
# Load m before computing qk to reduce pipeline stall.
|
| 320 |
+
offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
|
| 321 |
+
m = tl.load(M + offs_m)
|
| 322 |
+
qkT = tl.dot(k.to(tl.bfloat16), qT)
|
| 323 |
+
qkT = qkT * sm_scale * 1.44269504
|
| 324 |
+
mask = tl.arange(0, BLOCK_N1) < block_size
|
| 325 |
+
qkT = tl.where(mask[:, None], qkT, -float("inf"))
|
| 326 |
+
group_m = tl.max(qkT, 0)
|
| 327 |
+
pT = tl.math.exp2(qkT - m[None, :])
|
| 328 |
+
pT = tl.where(mask[:, None], pT, 0.0)
|
| 329 |
+
|
| 330 |
+
do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
|
| 331 |
+
# Compute dV with group-local P quantization:
|
| 332 |
+
# quantize exp2(logit - tile_col_max), then multiply dO by
|
| 333 |
+
# exp2(tile_col_max - final_lse) to recover the final softmax scale.
|
| 334 |
+
p_local_T = tl.math.exp2(qkT - group_m[None, :])
|
| 335 |
+
p_local_T = tl.where(mask[:, None], p_local_T, 0.0)
|
| 336 |
+
p_comp = tl.math.exp2(group_m - m)
|
| 337 |
+
p_for_quant = tl.trans(p_local_T)
|
| 338 |
+
p_valid = mask[None, :] & (
|
| 339 |
+
tl.full(
|
| 340 |
+
shape=p_for_quant.shape,
|
| 341 |
+
value=1.0,
|
| 342 |
+
dtype=p_for_quant.dtype,
|
| 343 |
+
) == 1.0
|
| 344 |
+
)
|
| 345 |
+
p_quant, _ = fake_quantize(
|
| 346 |
+
src_tensor=p_for_quant, valid_src_mask=p_valid,
|
| 347 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_M1, BLOCK_SIZE_QUANT_DIM=BLOCK_N1,
|
| 348 |
+
dst_dtype=p_for_quant.dtype, use_global_sf=False,
|
| 349 |
+
)
|
| 350 |
+
dv += tl.dot(
|
| 351 |
+
tl.trans(p_quant.to(tl.bfloat16)),
|
| 352 |
+
(do * p_comp[:, None]).to(tl.bfloat16),
|
| 353 |
+
)
|
| 354 |
+
# D (= delta) is pre-divided by ds_scale.
|
| 355 |
+
Di = tl.load(D + offs_m)
|
| 356 |
+
# Compute dP and dS.
|
| 357 |
+
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
|
| 358 |
+
dsT = pT * (dpT - Di[None, :])
|
| 359 |
+
dsT = dsT.to(tl.bfloat16)
|
| 360 |
+
dk += tl.dot(dsT, tl.trans(qT))
|
| 361 |
+
# Increment pointers.
|
| 362 |
+
|
| 363 |
+
if USE_TILE_COMP:
|
| 364 |
+
k_mean = tl.load(KMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 365 |
+
v_mean = tl.load(VMean + kv_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 366 |
+
qk_scale = sm_scale * 1.44269504
|
| 367 |
+
|
| 368 |
+
for blk_idx in range(dropped_q_blocks * 2):
|
| 369 |
+
q_blk_idx = tl.load(dropped_q_ptr + blk_idx // 2).to(tl.int32)
|
| 370 |
+
half = (blk_idx % 2).to(tl.int32)
|
| 371 |
+
block_sparse_offset = (q_blk_idx * 2 + half) * step_m
|
| 372 |
+
offs_m = start_m + block_sparse_offset + tl.arange(0, BLOCK_M1)
|
| 373 |
+
q_mean = tl.load(QMean + q_blk_idx * HEAD_DIM +
|
| 374 |
+
offs_k).to(tl.float32)
|
| 375 |
+
m = tl.load(M + offs_m)
|
| 376 |
+
do = tl.load(do_ptrs + block_sparse_offset * stride_tok)
|
| 377 |
+
Di = tl.load(D + offs_m)
|
| 378 |
+
q_block_size = tl.load(variable_block_sizes +
|
| 379 |
+
q_blk_idx).to(tl.float32)
|
| 380 |
+
|
| 381 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 382 |
+
p = tl.math.exp2(score - m)
|
| 383 |
+
dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
|
| 384 |
+
ds = block_size_f * p * (dp - Di)
|
| 385 |
+
|
| 386 |
+
dk_mean = tl.sum(ds[:, None] * q_mean[None, :],
|
| 387 |
+
axis=0) / block_size_f
|
| 388 |
+
dv_mean = tl.sum(p[:, None] * do.to(tl.float32), axis=0)
|
| 389 |
+
dk += dk_mean[None, :]
|
| 390 |
+
dv += dv_mean[None, :]
|
| 391 |
+
return dk, dv
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# the main inner-loop logic for computing dQ
|
| 395 |
+
@triton.jit
|
| 396 |
+
def _attn_bwd_dq(
|
| 397 |
+
dq,
|
| 398 |
+
q,
|
| 399 |
+
K,
|
| 400 |
+
V, #
|
| 401 |
+
QMean,
|
| 402 |
+
KMean,
|
| 403 |
+
VMean,
|
| 404 |
+
do,
|
| 405 |
+
m,
|
| 406 |
+
m_vec,
|
| 407 |
+
D,
|
| 408 |
+
# shared by Q/K/V/DO.
|
| 409 |
+
q2k_index,
|
| 410 |
+
q2k_num,
|
| 411 |
+
max_kv_blks,
|
| 412 |
+
dropped_q2k_index,
|
| 413 |
+
dropped_q2k_num,
|
| 414 |
+
max_dropped_kv_blks,
|
| 415 |
+
variable_block_sizes,
|
| 416 |
+
stride_tok,
|
| 417 |
+
stride_d, #
|
| 418 |
+
H,
|
| 419 |
+
N_CTX, #
|
| 420 |
+
BLOCK_M2: tl.constexpr, #
|
| 421 |
+
BLOCK_N2: tl.constexpr, #
|
| 422 |
+
HEAD_DIM: tl.constexpr,
|
| 423 |
+
# Filled in by the wrapper.
|
| 424 |
+
start_m,
|
| 425 |
+
start_n,
|
| 426 |
+
num_steps,
|
| 427 |
+
sm_scale=1.0,
|
| 428 |
+
IS_QAT: tl.constexpr = False,
|
| 429 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 430 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 431 |
+
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
| 432 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 433 |
+
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 434 |
+
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
|
| 435 |
+
# D (= delta) is pre-divided by ds_scale.
|
| 436 |
+
Di = tl.load(D + offs_m)
|
| 437 |
+
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
| 438 |
+
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
| 439 |
+
step_n = BLOCK_N2
|
| 440 |
+
|
| 441 |
+
q_blk = tl.program_id(0) # Q-tile index
|
| 442 |
+
off_hz = tl.program_id(2) # fused (batch, head)
|
| 443 |
+
b = off_hz // H
|
| 444 |
+
h = off_hz % H
|
| 445 |
+
q_tiles = N_CTX // BLOCK_M2
|
| 446 |
+
meta_base = ((b * H + h) * q_tiles + q_blk)
|
| 447 |
+
|
| 448 |
+
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
| 449 |
+
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
| 450 |
+
dropped_kv_blocks = tl.load(dropped_q2k_num + meta_base)
|
| 451 |
+
dropped_kv_ptr = dropped_q2k_index + meta_base * max_dropped_kv_blks
|
| 452 |
+
|
| 453 |
+
for blk_idx in range(kv_blocks * 2):
|
| 454 |
+
kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
|
| 455 |
+
# variable_block_sizes is defined per KV block (tile). Mask must therefore
|
| 456 |
+
# use kv_idx (not q_blk). Also, because we split each 64-token block into
|
| 457 |
+
# two 32-token halves, the mask must account for the half-block offset.
|
| 458 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
|
| 459 |
+
half = (blk_idx % 2).to(tl.int32)
|
| 460 |
+
block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
|
| 461 |
+
kT = tl.load(kT_ptrs + block_sparse_offset)
|
| 462 |
+
vT = tl.load(vT_ptrs + block_sparse_offset)
|
| 463 |
+
qk = tl.dot(q, kT)
|
| 464 |
+
qk = qk * sm_scale * 1.44269504
|
| 465 |
+
p = tl.math.exp2(qk - m)
|
| 466 |
+
offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
|
| 467 |
+
mask = offs_in_block < block_size
|
| 468 |
+
p = tl.where(mask[None, :], p, 0.0)
|
| 469 |
+
# Compute dP and dS.
|
| 470 |
+
dp = tl.dot(do, vT).to(tl.float32)
|
| 471 |
+
ds = p * (dp - Di[:, None])
|
| 472 |
+
ds = ds.to(tl.bfloat16)
|
| 473 |
+
# Compute dQ.
|
| 474 |
+
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
| 475 |
+
dq += tl.dot(ds, tl.trans(kT))
|
| 476 |
+
# Increment pointers.
|
| 477 |
+
|
| 478 |
+
if USE_TILE_COMP:
|
| 479 |
+
q_mean = tl.load(QMean + q_blk * HEAD_DIM + offs_k).to(tl.float32)
|
| 480 |
+
q_block_size = tl.load(variable_block_sizes + q_blk).to(tl.float32)
|
| 481 |
+
qk_scale = sm_scale * 1.44269504
|
| 482 |
+
dq_mean = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
| 483 |
+
|
| 484 |
+
for blk_idx in range(dropped_kv_blocks):
|
| 485 |
+
kv_idx = tl.load(dropped_kv_ptr + blk_idx).to(tl.int32)
|
| 486 |
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.float32)
|
| 487 |
+
k_mean = tl.load(KMean + kv_idx * HEAD_DIM +
|
| 488 |
+
offs_k).to(tl.float32)
|
| 489 |
+
v_mean = tl.load(VMean + kv_idx * HEAD_DIM +
|
| 490 |
+
offs_k).to(tl.float32)
|
| 491 |
+
|
| 492 |
+
score = tl.sum(q_mean * k_mean, axis=0) * qk_scale
|
| 493 |
+
p = tl.math.exp2(score - m_vec)
|
| 494 |
+
dp = tl.sum(do.to(tl.float32) * v_mean[None, :], axis=1)
|
| 495 |
+
ds = block_size * p * (dp - Di)
|
| 496 |
+
dq_mean = dq_mean + tl.sum(ds, axis=0) * k_mean
|
| 497 |
+
|
| 498 |
+
dq += dq_mean[None, :] / q_block_size
|
| 499 |
+
return dq
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@triton.jit
|
| 503 |
+
def _attn_bwd(
|
| 504 |
+
Q,
|
| 505 |
+
K,
|
| 506 |
+
V,
|
| 507 |
+
sm_scale, #
|
| 508 |
+
DO, #
|
| 509 |
+
DQ,
|
| 510 |
+
DK,
|
| 511 |
+
DV, #
|
| 512 |
+
M,
|
| 513 |
+
D,
|
| 514 |
+
q2k_index,
|
| 515 |
+
q2k_num,
|
| 516 |
+
max_kv_blks,
|
| 517 |
+
k2q_index,
|
| 518 |
+
k2q_num,
|
| 519 |
+
max_q_blks,
|
| 520 |
+
variable_block_sizes,
|
| 521 |
+
# shared by Q/K/V/DO.
|
| 522 |
+
stride_z,
|
| 523 |
+
stride_h,
|
| 524 |
+
stride_tok,
|
| 525 |
+
stride_d, #
|
| 526 |
+
H,
|
| 527 |
+
N_CTX, #
|
| 528 |
+
BLOCK_M1: tl.constexpr, #
|
| 529 |
+
BLOCK_N1: tl.constexpr, #
|
| 530 |
+
BLOCK_M2: tl.constexpr, #
|
| 531 |
+
BLOCK_N2: tl.constexpr, #
|
| 532 |
+
HEAD_DIM: tl.constexpr,
|
| 533 |
+
IS_QAT: tl.constexpr = False):
|
| 534 |
+
LN2 = 0.6931471824645996 # = ln(2)
|
| 535 |
+
|
| 536 |
+
bhid = tl.program_id(2)
|
| 537 |
+
off_chz = (bhid * N_CTX).to(tl.int64)
|
| 538 |
+
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
| 539 |
+
pid = tl.program_id(0)
|
| 540 |
+
|
| 541 |
+
# offset pointers for batch/head
|
| 542 |
+
Q += adj
|
| 543 |
+
K += adj
|
| 544 |
+
V += adj
|
| 545 |
+
DO += adj
|
| 546 |
+
DQ += adj
|
| 547 |
+
DK += adj
|
| 548 |
+
DV += adj
|
| 549 |
+
M += off_chz
|
| 550 |
+
D += off_chz
|
| 551 |
+
|
| 552 |
+
# load scales
|
| 553 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 554 |
+
|
| 555 |
+
start_n = pid * BLOCK_N1
|
| 556 |
+
start_m = 0
|
| 557 |
+
|
| 558 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 559 |
+
|
| 560 |
+
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 561 |
+
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 562 |
+
|
| 563 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 564 |
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 565 |
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 566 |
+
|
| 567 |
+
num_steps = N_CTX // BLOCK_M1
|
| 568 |
+
|
| 569 |
+
dk, dv = _attn_bwd_dkdv( #
|
| 570 |
+
dk,
|
| 571 |
+
dv, #
|
| 572 |
+
Q,
|
| 573 |
+
k,
|
| 574 |
+
v,
|
| 575 |
+
Q,
|
| 576 |
+
K,
|
| 577 |
+
V,
|
| 578 |
+
sm_scale, #
|
| 579 |
+
DO, #
|
| 580 |
+
M,
|
| 581 |
+
D, #
|
| 582 |
+
k2q_index,
|
| 583 |
+
k2q_num,
|
| 584 |
+
max_q_blks,
|
| 585 |
+
k2q_index,
|
| 586 |
+
k2q_num,
|
| 587 |
+
max_q_blks,
|
| 588 |
+
variable_block_sizes,
|
| 589 |
+
stride_tok,
|
| 590 |
+
stride_d, #
|
| 591 |
+
H,
|
| 592 |
+
N_CTX, #
|
| 593 |
+
BLOCK_M1,
|
| 594 |
+
BLOCK_N1,
|
| 595 |
+
HEAD_DIM, #
|
| 596 |
+
start_n,
|
| 597 |
+
start_m,
|
| 598 |
+
num_steps, #
|
| 599 |
+
IS_QAT=IS_QAT,
|
| 600 |
+
USE_TILE_COMP=False,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 604 |
+
tl.store(dv_ptrs, dv)
|
| 605 |
+
|
| 606 |
+
# Write back dK.
|
| 607 |
+
dk *= sm_scale
|
| 608 |
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 609 |
+
tl.store(dk_ptrs, dk)
|
| 610 |
+
|
| 611 |
+
# THIS BLOCK DOES DQ:
|
| 612 |
+
start_m = pid * BLOCK_M2
|
| 613 |
+
end_n = 0
|
| 614 |
+
|
| 615 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 616 |
+
|
| 617 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 618 |
+
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
| 619 |
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 620 |
+
|
| 621 |
+
m_vec = tl.load(M + offs_m)
|
| 622 |
+
m = m_vec[:, None]
|
| 623 |
+
|
| 624 |
+
num_steps = N_CTX // BLOCK_N2
|
| 625 |
+
dq = _attn_bwd_dq(
|
| 626 |
+
dq,
|
| 627 |
+
q,
|
| 628 |
+
K,
|
| 629 |
+
V, #
|
| 630 |
+
Q,
|
| 631 |
+
K,
|
| 632 |
+
V,
|
| 633 |
+
do,
|
| 634 |
+
m,
|
| 635 |
+
m_vec,
|
| 636 |
+
D, #
|
| 637 |
+
q2k_index,
|
| 638 |
+
q2k_num,
|
| 639 |
+
max_kv_blks,
|
| 640 |
+
q2k_index,
|
| 641 |
+
q2k_num,
|
| 642 |
+
max_kv_blks,
|
| 643 |
+
variable_block_sizes,
|
| 644 |
+
stride_tok,
|
| 645 |
+
stride_d, #
|
| 646 |
+
H,
|
| 647 |
+
N_CTX, #
|
| 648 |
+
BLOCK_M2,
|
| 649 |
+
BLOCK_N2,
|
| 650 |
+
HEAD_DIM, #
|
| 651 |
+
start_m,
|
| 652 |
+
end_n,
|
| 653 |
+
num_steps, #
|
| 654 |
+
sm_scale=sm_scale,
|
| 655 |
+
IS_QAT=IS_QAT,
|
| 656 |
+
USE_TILE_COMP=False,
|
| 657 |
+
)
|
| 658 |
+
# Write back dQ.
|
| 659 |
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 660 |
+
dq *= sm_scale
|
| 661 |
+
tl.store(dq_ptrs, dq)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
@triton.jit
|
| 665 |
+
def _attn_bwd_dkdv_kernel(
|
| 666 |
+
Q,
|
| 667 |
+
K,
|
| 668 |
+
V,
|
| 669 |
+
QMean,
|
| 670 |
+
KMean,
|
| 671 |
+
VMean,
|
| 672 |
+
sm_scale, #
|
| 673 |
+
DO, #
|
| 674 |
+
DK,
|
| 675 |
+
DV, #
|
| 676 |
+
M,
|
| 677 |
+
D,
|
| 678 |
+
k2q_index,
|
| 679 |
+
k2q_num,
|
| 680 |
+
max_q_blks,
|
| 681 |
+
dropped_k2q_index,
|
| 682 |
+
dropped_k2q_num,
|
| 683 |
+
max_dropped_q_blks,
|
| 684 |
+
variable_block_sizes,
|
| 685 |
+
# shared token/dim strides (assumed contiguous along token and dim)
|
| 686 |
+
stride_tok,
|
| 687 |
+
stride_d, #
|
| 688 |
+
# batch/head strides (may differ between Q and KV)
|
| 689 |
+
stride_qz,
|
| 690 |
+
stride_qh,
|
| 691 |
+
stride_kz,
|
| 692 |
+
stride_kh,
|
| 693 |
+
stride_vz,
|
| 694 |
+
stride_vh,
|
| 695 |
+
stride_doz,
|
| 696 |
+
stride_doh,
|
| 697 |
+
stride_dkz,
|
| 698 |
+
stride_dkh,
|
| 699 |
+
stride_dvz,
|
| 700 |
+
stride_dvh,
|
| 701 |
+
H,
|
| 702 |
+
N_CTX_Q,
|
| 703 |
+
N_CTX_KV,
|
| 704 |
+
BLOCK_M1: tl.constexpr, #
|
| 705 |
+
BLOCK_N1: tl.constexpr, #
|
| 706 |
+
HEAD_DIM: tl.constexpr,
|
| 707 |
+
IS_QAT: tl.constexpr = False,
|
| 708 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 709 |
+
"""
|
| 710 |
+
Backward kernel that computes dK and dV for each KV block (64 tokens).
|
| 711 |
+
Grid:
|
| 712 |
+
pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
|
| 713 |
+
pid2: fused (batch, head) in [0, B*H)
|
| 714 |
+
"""
|
| 715 |
+
bhid = tl.program_id(2)
|
| 716 |
+
b = bhid // H
|
| 717 |
+
h = bhid % H
|
| 718 |
+
kv_blk = tl.program_id(0)
|
| 719 |
+
|
| 720 |
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 721 |
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 722 |
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 723 |
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
| 724 |
+
dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
|
| 725 |
+
dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
|
| 726 |
+
|
| 727 |
+
Q = Q + q_adj
|
| 728 |
+
K = K + kv_adj_k
|
| 729 |
+
V = V + kv_adj_v
|
| 730 |
+
DO = DO + do_adj
|
| 731 |
+
DK = DK + dk_adj
|
| 732 |
+
DV = DV + dv_adj
|
| 733 |
+
|
| 734 |
+
q_tiles = N_CTX_Q // BLOCK_M1 // 2
|
| 735 |
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
| 736 |
+
mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
|
| 737 |
+
mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
|
| 738 |
+
QMean = QMean + mean_q_adj
|
| 739 |
+
KMean = KMean + mean_kv_adj
|
| 740 |
+
VMean = VMean + mean_kv_adj
|
| 741 |
+
|
| 742 |
+
# M and D (delta) are always sized by Q length.
|
| 743 |
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
| 744 |
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
| 745 |
+
|
| 746 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 747 |
+
start_n = kv_blk * BLOCK_N1
|
| 748 |
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
| 749 |
+
|
| 750 |
+
# load K and V: they stay in SRAM throughout the inner loop.
|
| 751 |
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 752 |
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 753 |
+
|
| 754 |
+
dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 755 |
+
dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
| 756 |
+
|
| 757 |
+
num_steps = N_CTX_Q // BLOCK_M1
|
| 758 |
+
dk_acc, dv_acc = _attn_bwd_dkdv(
|
| 759 |
+
dk_acc,
|
| 760 |
+
dv_acc,
|
| 761 |
+
Q,
|
| 762 |
+
k,
|
| 763 |
+
v,
|
| 764 |
+
QMean,
|
| 765 |
+
KMean,
|
| 766 |
+
VMean,
|
| 767 |
+
sm_scale,
|
| 768 |
+
DO,
|
| 769 |
+
M,
|
| 770 |
+
D,
|
| 771 |
+
k2q_index,
|
| 772 |
+
k2q_num,
|
| 773 |
+
max_q_blks,
|
| 774 |
+
dropped_k2q_index,
|
| 775 |
+
dropped_k2q_num,
|
| 776 |
+
max_dropped_q_blks,
|
| 777 |
+
variable_block_sizes,
|
| 778 |
+
stride_tok,
|
| 779 |
+
stride_d,
|
| 780 |
+
H,
|
| 781 |
+
N_CTX_KV,
|
| 782 |
+
BLOCK_M1=BLOCK_M1,
|
| 783 |
+
BLOCK_N1=BLOCK_N1,
|
| 784 |
+
HEAD_DIM=HEAD_DIM,
|
| 785 |
+
start_n=start_n,
|
| 786 |
+
start_m=0,
|
| 787 |
+
num_steps=num_steps,
|
| 788 |
+
IS_QAT=IS_QAT,
|
| 789 |
+
USE_TILE_COMP=USE_TILE_COMP,
|
| 790 |
+
)
|
| 791 |
+
|
| 792 |
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 793 |
+
tl.store(dv_ptrs, dv_acc)
|
| 794 |
+
|
| 795 |
+
dk_acc *= sm_scale
|
| 796 |
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 797 |
+
tl.store(dk_ptrs, dk_acc)
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
@triton.jit
|
| 801 |
+
def _attn_bwd_dq_kernel(
|
| 802 |
+
Q,
|
| 803 |
+
K,
|
| 804 |
+
V,
|
| 805 |
+
QMean,
|
| 806 |
+
KMean,
|
| 807 |
+
VMean,
|
| 808 |
+
DO, #
|
| 809 |
+
DQ,
|
| 810 |
+
M,
|
| 811 |
+
D,
|
| 812 |
+
q2k_index,
|
| 813 |
+
q2k_num,
|
| 814 |
+
max_kv_blks,
|
| 815 |
+
dropped_q2k_index,
|
| 816 |
+
dropped_q2k_num,
|
| 817 |
+
max_dropped_kv_blks,
|
| 818 |
+
variable_block_sizes,
|
| 819 |
+
# shared token/dim strides (assumed contiguous along token and dim)
|
| 820 |
+
stride_tok,
|
| 821 |
+
stride_d, #
|
| 822 |
+
# batch/head strides (may differ between Q and KV)
|
| 823 |
+
stride_qz,
|
| 824 |
+
stride_qh,
|
| 825 |
+
stride_kz,
|
| 826 |
+
stride_kh,
|
| 827 |
+
stride_vz,
|
| 828 |
+
stride_vh,
|
| 829 |
+
stride_doz,
|
| 830 |
+
stride_doh,
|
| 831 |
+
stride_dqz,
|
| 832 |
+
stride_dqh,
|
| 833 |
+
H,
|
| 834 |
+
N_CTX_Q,
|
| 835 |
+
sm_scale,
|
| 836 |
+
BLOCK_M2: tl.constexpr, #
|
| 837 |
+
BLOCK_N2: tl.constexpr, #
|
| 838 |
+
HEAD_DIM: tl.constexpr,
|
| 839 |
+
IS_QAT: tl.constexpr = False,
|
| 840 |
+
USE_TILE_COMP: tl.constexpr = False):
|
| 841 |
+
"""
|
| 842 |
+
Backward kernel that computes dQ for each Q block (64 tokens).
|
| 843 |
+
Grid:
|
| 844 |
+
pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
|
| 845 |
+
pid2: fused (batch, head) in [0, B*H)
|
| 846 |
+
"""
|
| 847 |
+
LN2 = 0.6931471824645996 # = ln(2)
|
| 848 |
+
bhid = tl.program_id(2)
|
| 849 |
+
b = bhid // H
|
| 850 |
+
h = bhid % H
|
| 851 |
+
q_blk = tl.program_id(0)
|
| 852 |
+
|
| 853 |
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
| 854 |
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
| 855 |
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
| 856 |
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
| 857 |
+
dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
|
| 858 |
+
|
| 859 |
+
Q = Q + q_adj
|
| 860 |
+
K = K + kv_adj_k
|
| 861 |
+
V = V + kv_adj_v
|
| 862 |
+
DO = DO + do_adj
|
| 863 |
+
DQ = DQ + dq_adj
|
| 864 |
+
|
| 865 |
+
q_tiles = N_CTX_Q // BLOCK_M2
|
| 866 |
+
kv_tiles = N_CTX_Q // 64
|
| 867 |
+
mean_q_adj = (bhid * q_tiles * HEAD_DIM).to(tl.int64)
|
| 868 |
+
mean_kv_adj = (bhid * kv_tiles * HEAD_DIM).to(tl.int64)
|
| 869 |
+
QMean = QMean + mean_q_adj
|
| 870 |
+
KMean = KMean + mean_kv_adj
|
| 871 |
+
VMean = VMean + mean_kv_adj
|
| 872 |
+
|
| 873 |
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
| 874 |
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
| 875 |
+
|
| 876 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 877 |
+
start_m = q_blk * BLOCK_M2
|
| 878 |
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
| 879 |
+
|
| 880 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 881 |
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
| 882 |
+
m_vec = tl.load(M + offs_m)
|
| 883 |
+
m = m_vec[:, None]
|
| 884 |
+
|
| 885 |
+
dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
| 886 |
+
num_steps = 0 # unused in _attn_bwd_dq
|
| 887 |
+
dq_acc = _attn_bwd_dq(
|
| 888 |
+
dq_acc,
|
| 889 |
+
q,
|
| 890 |
+
K,
|
| 891 |
+
V,
|
| 892 |
+
QMean,
|
| 893 |
+
KMean,
|
| 894 |
+
VMean,
|
| 895 |
+
do,
|
| 896 |
+
m,
|
| 897 |
+
m_vec,
|
| 898 |
+
D,
|
| 899 |
+
q2k_index,
|
| 900 |
+
q2k_num,
|
| 901 |
+
max_kv_blks,
|
| 902 |
+
dropped_q2k_index,
|
| 903 |
+
dropped_q2k_num,
|
| 904 |
+
max_dropped_kv_blks,
|
| 905 |
+
variable_block_sizes,
|
| 906 |
+
stride_tok,
|
| 907 |
+
stride_d,
|
| 908 |
+
H,
|
| 909 |
+
N_CTX_Q,
|
| 910 |
+
BLOCK_M2=BLOCK_M2,
|
| 911 |
+
BLOCK_N2=BLOCK_N2,
|
| 912 |
+
HEAD_DIM=HEAD_DIM,
|
| 913 |
+
start_m=start_m,
|
| 914 |
+
start_n=0,
|
| 915 |
+
num_steps=num_steps,
|
| 916 |
+
sm_scale=sm_scale,
|
| 917 |
+
IS_QAT=IS_QAT,
|
| 918 |
+
USE_TILE_COMP=USE_TILE_COMP,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
| 922 |
+
dq_acc *= sm_scale
|
| 923 |
+
tl.store(dq_ptrs, dq_acc)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
| 927 |
+
def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
| 928 |
+
variable_block_sizes, is_qat=False,
|
| 929 |
+
q_mean=None, k_mean=None, v_mean=None,
|
| 930 |
+
dropped_q2k_index=None,
|
| 931 |
+
dropped_q2k_num=None):
|
| 932 |
+
B, H, Tq, D = q.shape
|
| 933 |
+
Tkv = k.shape[2]
|
| 934 |
+
sm_scale = 1.0 / math.sqrt(D)
|
| 935 |
+
max_kv_blks = q2k_index.shape[-1]
|
| 936 |
+
use_tile_comp = q_mean is not None
|
| 937 |
+
if use_tile_comp:
|
| 938 |
+
assert k_mean is not None and v_mean is not None
|
| 939 |
+
assert dropped_q2k_index is not None and dropped_q2k_num is not None
|
| 940 |
+
q_mean = q_mean.contiguous()
|
| 941 |
+
k_mean = k_mean.contiguous()
|
| 942 |
+
v_mean = v_mean.contiguous()
|
| 943 |
+
max_dropped_kv_blks = dropped_q2k_index.shape[-1]
|
| 944 |
+
else:
|
| 945 |
+
q_mean = q
|
| 946 |
+
k_mean = k
|
| 947 |
+
v_mean = v
|
| 948 |
+
dropped_q2k_index = q2k_index
|
| 949 |
+
dropped_q2k_num = q2k_num
|
| 950 |
+
max_dropped_kv_blks = max_kv_blks
|
| 951 |
+
assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
|
| 952 |
+
assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
|
| 953 |
+
assert q2k_num.shape[
|
| 954 |
+
-1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
|
| 955 |
+
assert variable_block_sizes.numel() == Tkv // 64, (
|
| 956 |
+
f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
|
| 957 |
+
f"got {variable_block_sizes.numel()}"
|
| 958 |
+
)
|
| 959 |
+
o = torch.empty_like(q)
|
| 960 |
+
high_prec_o = torch.empty_like(q)
|
| 961 |
+
M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
|
| 962 |
+
|
| 963 |
+
grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
|
| 964 |
+
_attn_fwd_sparse[grid](q,
|
| 965 |
+
k,
|
| 966 |
+
v,
|
| 967 |
+
q_mean,
|
| 968 |
+
k_mean,
|
| 969 |
+
v_mean,
|
| 970 |
+
sm_scale,
|
| 971 |
+
q2k_index,
|
| 972 |
+
q2k_num,
|
| 973 |
+
max_kv_blks,
|
| 974 |
+
dropped_q2k_index,
|
| 975 |
+
dropped_q2k_num,
|
| 976 |
+
max_dropped_kv_blks,
|
| 977 |
+
variable_block_sizes,
|
| 978 |
+
M,
|
| 979 |
+
o,
|
| 980 |
+
high_prec_o,
|
| 981 |
+
q.stride(0),
|
| 982 |
+
q.stride(1),
|
| 983 |
+
q.stride(2),
|
| 984 |
+
q.stride(3),
|
| 985 |
+
k.stride(0),
|
| 986 |
+
k.stride(1),
|
| 987 |
+
k.stride(2),
|
| 988 |
+
k.stride(3),
|
| 989 |
+
v.stride(0),
|
| 990 |
+
v.stride(1),
|
| 991 |
+
v.stride(2),
|
| 992 |
+
v.stride(3),
|
| 993 |
+
o.stride(0),
|
| 994 |
+
o.stride(1),
|
| 995 |
+
o.stride(2),
|
| 996 |
+
o.stride(3),
|
| 997 |
+
B,
|
| 998 |
+
H,
|
| 999 |
+
Tq,
|
| 1000 |
+
Tkv,
|
| 1001 |
+
HEAD_DIM=D,
|
| 1002 |
+
STAGE=3,
|
| 1003 |
+
IS_QAT=is_qat,
|
| 1004 |
+
USE_TILE_COMP=use_tile_comp)
|
| 1005 |
+
|
| 1006 |
+
return o, M, high_prec_o
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
| 1010 |
+
k2q_index, k2q_num, variable_block_sizes,
|
| 1011 |
+
is_qat=False, q_mean=None, k_mean=None,
|
| 1012 |
+
v_mean=None, dropped_q2k_index=None,
|
| 1013 |
+
dropped_q2k_num=None,
|
| 1014 |
+
dropped_k2q_index=None,
|
| 1015 |
+
dropped_k2q_num=None):
|
| 1016 |
+
assert do.is_contiguous()
|
| 1017 |
+
|
| 1018 |
+
B, H, Tq, D = q.shape
|
| 1019 |
+
Tkv = k.shape[2]
|
| 1020 |
+
sm_scale = 1.0 / math.sqrt(D)
|
| 1021 |
+
dq = torch.empty_like(q)
|
| 1022 |
+
dk = torch.empty_like(k)
|
| 1023 |
+
dv = torch.empty_like(v)
|
| 1024 |
+
BATCH, N_HEAD = q.shape[:2]
|
| 1025 |
+
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
| 1026 |
+
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
| 1027 |
+
# Ours-P mode keeps K unscaled and applies sm_scale inside the bwd kernels.
|
| 1028 |
+
arg_k = k
|
| 1029 |
+
PRE_BLOCK = 64
|
| 1030 |
+
assert Tq % PRE_BLOCK == 0
|
| 1031 |
+
pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
|
| 1032 |
+
delta = torch.empty_like(M)
|
| 1033 |
+
_attn_bwd_preprocess[pre_grid](
|
| 1034 |
+
o,
|
| 1035 |
+
do, #
|
| 1036 |
+
delta, #
|
| 1037 |
+
BATCH,
|
| 1038 |
+
N_HEAD,
|
| 1039 |
+
Tq, #
|
| 1040 |
+
BLOCK_M=PRE_BLOCK,
|
| 1041 |
+
HEAD_DIM=D #
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
max_q_blks = k2q_index.shape[-1]
|
| 1045 |
+
max_kv_blks = q2k_index.shape[-1]
|
| 1046 |
+
use_tile_comp = q_mean is not None
|
| 1047 |
+
if use_tile_comp:
|
| 1048 |
+
assert k_mean is not None and v_mean is not None
|
| 1049 |
+
assert dropped_q2k_index is not None and dropped_q2k_num is not None
|
| 1050 |
+
assert dropped_k2q_index is not None and dropped_k2q_num is not None
|
| 1051 |
+
q_mean = q_mean.contiguous()
|
| 1052 |
+
k_mean = k_mean.contiguous()
|
| 1053 |
+
v_mean = v_mean.contiguous()
|
| 1054 |
+
max_dropped_kv_blks = dropped_q2k_index.shape[-1]
|
| 1055 |
+
max_dropped_q_blks = dropped_k2q_index.shape[-1]
|
| 1056 |
+
else:
|
| 1057 |
+
q_mean = q
|
| 1058 |
+
k_mean = k
|
| 1059 |
+
v_mean = v
|
| 1060 |
+
dropped_q2k_index = q2k_index
|
| 1061 |
+
dropped_q2k_num = q2k_num
|
| 1062 |
+
dropped_k2q_index = k2q_index
|
| 1063 |
+
dropped_k2q_num = k2q_num
|
| 1064 |
+
max_dropped_kv_blks = max_kv_blks
|
| 1065 |
+
max_dropped_q_blks = max_q_blks
|
| 1066 |
+
|
| 1067 |
+
# dK/dV kernel: grid over KV blocks
|
| 1068 |
+
grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
|
| 1069 |
+
_attn_bwd_dkdv_kernel[grid_kv](
|
| 1070 |
+
q,
|
| 1071 |
+
arg_k,
|
| 1072 |
+
v,
|
| 1073 |
+
q_mean,
|
| 1074 |
+
k_mean,
|
| 1075 |
+
v_mean,
|
| 1076 |
+
sm_scale,
|
| 1077 |
+
do,
|
| 1078 |
+
dk,
|
| 1079 |
+
dv,
|
| 1080 |
+
M,
|
| 1081 |
+
delta,
|
| 1082 |
+
k2q_index,
|
| 1083 |
+
k2q_num,
|
| 1084 |
+
max_q_blks,
|
| 1085 |
+
dropped_k2q_index,
|
| 1086 |
+
dropped_k2q_num,
|
| 1087 |
+
max_dropped_q_blks,
|
| 1088 |
+
variable_block_sizes,
|
| 1089 |
+
q.stride(2),
|
| 1090 |
+
q.stride(3),
|
| 1091 |
+
q.stride(0),
|
| 1092 |
+
q.stride(1),
|
| 1093 |
+
arg_k.stride(0),
|
| 1094 |
+
arg_k.stride(1),
|
| 1095 |
+
v.stride(0),
|
| 1096 |
+
v.stride(1),
|
| 1097 |
+
do.stride(0),
|
| 1098 |
+
do.stride(1),
|
| 1099 |
+
dk.stride(0),
|
| 1100 |
+
dk.stride(1),
|
| 1101 |
+
dv.stride(0),
|
| 1102 |
+
dv.stride(1),
|
| 1103 |
+
N_HEAD,
|
| 1104 |
+
Tq,
|
| 1105 |
+
Tkv,
|
| 1106 |
+
BLOCK_M1=BLOCK_M1,
|
| 1107 |
+
BLOCK_N1=BLOCK_N1,
|
| 1108 |
+
HEAD_DIM=D,
|
| 1109 |
+
IS_QAT=is_qat,
|
| 1110 |
+
USE_TILE_COMP=use_tile_comp,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# dQ kernel: grid over Q blocks
|
| 1114 |
+
grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
|
| 1115 |
+
_attn_bwd_dq_kernel[grid_q](
|
| 1116 |
+
q,
|
| 1117 |
+
arg_k,
|
| 1118 |
+
v,
|
| 1119 |
+
q_mean,
|
| 1120 |
+
k_mean,
|
| 1121 |
+
v_mean,
|
| 1122 |
+
do,
|
| 1123 |
+
dq,
|
| 1124 |
+
M,
|
| 1125 |
+
delta,
|
| 1126 |
+
q2k_index,
|
| 1127 |
+
q2k_num,
|
| 1128 |
+
max_kv_blks,
|
| 1129 |
+
dropped_q2k_index,
|
| 1130 |
+
dropped_q2k_num,
|
| 1131 |
+
max_dropped_kv_blks,
|
| 1132 |
+
variable_block_sizes,
|
| 1133 |
+
q.stride(2),
|
| 1134 |
+
q.stride(3),
|
| 1135 |
+
q.stride(0),
|
| 1136 |
+
q.stride(1),
|
| 1137 |
+
arg_k.stride(0),
|
| 1138 |
+
arg_k.stride(1),
|
| 1139 |
+
v.stride(0),
|
| 1140 |
+
v.stride(1),
|
| 1141 |
+
do.stride(0),
|
| 1142 |
+
do.stride(1),
|
| 1143 |
+
dq.stride(0),
|
| 1144 |
+
dq.stride(1),
|
| 1145 |
+
N_HEAD,
|
| 1146 |
+
Tq,
|
| 1147 |
+
sm_scale,
|
| 1148 |
+
BLOCK_M2=BLOCK_M2,
|
| 1149 |
+
BLOCK_N2=BLOCK_N2,
|
| 1150 |
+
HEAD_DIM=D,
|
| 1151 |
+
IS_QAT=is_qat,
|
| 1152 |
+
USE_TILE_COMP=use_tile_comp,
|
| 1153 |
+
)
|
| 1154 |
+
|
| 1155 |
+
return dq, dk, dv
|
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
|
| 3 |
+
# and https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
|
| 4 |
+
|
| 5 |
+
import triton
|
| 6 |
+
import triton.language as tl
|
| 7 |
+
try:
|
| 8 |
+
from triton.language.target_info import cuda_capability_geq
|
| 9 |
+
_HAS_CAPABILITY_CHECK = True
|
| 10 |
+
except ImportError:
|
| 11 |
+
cuda_capability_geq = None
|
| 12 |
+
_HAS_CAPABILITY_CHECK = False
|
| 13 |
+
|
| 14 |
+
MXFP_BLOCK_SIZE = tl.constexpr(16)
|
| 15 |
+
|
| 16 |
+
@triton.jit
|
| 17 |
+
def _compute_quant_and_scale(
|
| 18 |
+
src_tensor,
|
| 19 |
+
valid_src_mask,
|
| 20 |
+
mx_tensor_dtype: tl.constexpr = tl.uint8,
|
| 21 |
+
use_global_sf=True,
|
| 22 |
+
two_level_quant_P=False,
|
| 23 |
+
IS_BLACKWELL: tl.constexpr = False,
|
| 24 |
+
):
|
| 25 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
| 26 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
| 27 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
| 28 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 29 |
+
|
| 30 |
+
is_fp8e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
|
| 31 |
+
is_fp8e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
|
| 32 |
+
tl.static_assert(
|
| 33 |
+
is_fp4 or (is_fp8e4 or is_fp8e5),
|
| 34 |
+
"mx_tensor_dtype must be uint8, float8e4nv, or float8e5",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
| 38 |
+
f32_tensor = src_tensor.to(tl.float32)
|
| 39 |
+
abs_tensor = tl.abs(f32_tensor)
|
| 40 |
+
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
| 41 |
+
|
| 42 |
+
if two_level_quant_P:
|
| 43 |
+
# row max from SageAttn3 paper
|
| 44 |
+
global_max_val = tl.max(f32_tensor, axis=1, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, 1)
|
| 45 |
+
global_max_val = tl.maximum(global_max_val, 1e-8)
|
| 46 |
+
s_enc = ((6 * 448) / global_max_val).reshape([BLOCK_SIZE_OUT_DIM, 1, 1])
|
| 47 |
+
s_dec = (1 / s_enc)
|
| 48 |
+
|
| 49 |
+
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 50 |
+
|
| 51 |
+
if use_global_sf and not two_level_quant_P:
|
| 52 |
+
global_max_val = tl.max(abs_tensor)
|
| 53 |
+
# Avoid division by zero: if all values are padding (max is 0), use a default scale
|
| 54 |
+
global_max_val = tl.maximum(global_max_val, 1e-8)
|
| 55 |
+
s_enc = (6 * 448) / global_max_val
|
| 56 |
+
s_dec = (1 / s_enc)
|
| 57 |
+
elif not two_level_quant_P and not use_global_sf:
|
| 58 |
+
s_dec = 1.0
|
| 59 |
+
s_enc = 1.0
|
| 60 |
+
|
| 61 |
+
max_val = tl.max(abs_tensor, axis=2, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) # per block maxima
|
| 62 |
+
s_dec_b = max_val / 6 # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 63 |
+
s_dec_b_e4m3 = (s_dec_b * s_enc).to(tl.float8e4nv) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 64 |
+
s_enc_b = 1 / (s_dec_b_e4m3.to(tl.float32) * s_dec) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1)
|
| 65 |
+
|
| 66 |
+
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 67 |
+
quant_tensor = f32_tensor * s_enc_b
|
| 68 |
+
|
| 69 |
+
# Reshape the tensors after scaling
|
| 70 |
+
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 71 |
+
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
| 72 |
+
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0.0)
|
| 73 |
+
dequant_scale = s_dec_b_e4m3.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
| 74 |
+
|
| 75 |
+
if is_fp4 and IS_BLACKWELL:
|
| 76 |
+
# Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands.
|
| 77 |
+
pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
| 78 |
+
lo_f, hi_f = tl.split(pairs)
|
| 79 |
+
lo_f32 = lo_f.to(tl.float32)
|
| 80 |
+
hi_f32 = hi_f.to(tl.float32)
|
| 81 |
+
|
| 82 |
+
# Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2.
|
| 83 |
+
out_tensor = tl.inline_asm_elementwise(
|
| 84 |
+
"""
|
| 85 |
+
{
|
| 86 |
+
.reg .b8 r;
|
| 87 |
+
cvt.rn.satfinite.e2m1x2.f32 r, $1, $2;
|
| 88 |
+
mov.b32 $0, {r, r, r, r};
|
| 89 |
+
}
|
| 90 |
+
""",
|
| 91 |
+
constraints="=r,f,f",
|
| 92 |
+
args=[hi_f32, lo_f32],
|
| 93 |
+
dtype=tl.uint8,
|
| 94 |
+
is_pure=True,
|
| 95 |
+
pack=1,
|
| 96 |
+
)
|
| 97 |
+
elif is_fp4:
|
| 98 |
+
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
| 99 |
+
signs = quant_tensor & 0x80000000
|
| 100 |
+
exponents = (quant_tensor >> 23) & 0xFF
|
| 101 |
+
mantissas_orig = (quant_tensor & 0x7FFFFF)
|
| 102 |
+
|
| 103 |
+
# For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0
|
| 104 |
+
E8_BIAS = 127
|
| 105 |
+
E2_BIAS = 1
|
| 106 |
+
# Move implicit bit 1 at the beginning to mantissa for denormals
|
| 107 |
+
is_subnormal = exponents < E8_BIAS
|
| 108 |
+
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
|
| 109 |
+
mantissas_pre = (0x400000 | (mantissas_orig >> 1))
|
| 110 |
+
mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig)
|
| 111 |
+
|
| 112 |
+
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
| 113 |
+
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
| 114 |
+
|
| 115 |
+
# Combine sign, exponent, and mantissa, while saturating
|
| 116 |
+
# Round to nearest, ties to even (RTNE): use guard/sticky and LSB to decide increment
|
| 117 |
+
m2bits = mantissas >> 21
|
| 118 |
+
lsb_keep = (m2bits >> 1) & 0x1
|
| 119 |
+
guard = m2bits & 0x1
|
| 120 |
+
IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32
|
| 121 |
+
if IS_SRC_FP32:
|
| 122 |
+
bit0_dropped = (mantissas_orig & 0x1) != 0
|
| 123 |
+
mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1
|
| 124 |
+
dropped_post = (mantissas_pre & mask) != 0
|
| 125 |
+
sticky = is_subnormal & (bit0_dropped | dropped_post)
|
| 126 |
+
sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
|
| 127 |
+
else:
|
| 128 |
+
sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
|
| 129 |
+
round_inc = guard & (sticky | lsb_keep)
|
| 130 |
+
e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7)
|
| 131 |
+
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
| 132 |
+
|
| 133 |
+
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
| 134 |
+
evens, odds = tl.split(e2m1_value)
|
| 135 |
+
out_tensor = evens | (odds << 4)
|
| 136 |
+
else:
|
| 137 |
+
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
| 138 |
+
|
| 139 |
+
return out_tensor, dequant_scale, s_dec
|
| 140 |
+
|
| 141 |
+
@triton.jit
|
| 142 |
+
def _compute_dequant(
|
| 143 |
+
mx_tensor,
|
| 144 |
+
scale,
|
| 145 |
+
s_dec,
|
| 146 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr,
|
| 147 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
| 148 |
+
dst_dtype: tl.constexpr,
|
| 149 |
+
IS_BLACKWELL: tl.constexpr = False,
|
| 150 |
+
):
|
| 151 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"Block size along quantization block must be a multiple of {MXFP_BLOCK_SIZE=}")
|
| 152 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
| 153 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor.dtype
|
| 154 |
+
_is_f16: tl.constexpr = dst_dtype == tl.float16
|
| 155 |
+
_is_bf16: tl.constexpr = dst_dtype == tl.bfloat16
|
| 156 |
+
_is_f32: tl.constexpr = dst_dtype == tl.float32
|
| 157 |
+
tl.static_assert(_is_f16 or (_is_bf16 or _is_f32))
|
| 158 |
+
_is_u8: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 159 |
+
_is_e4: tl.constexpr = mx_tensor_dtype == tl.float8e4nv
|
| 160 |
+
_is_e5: tl.constexpr = mx_tensor_dtype == tl.float8e5
|
| 161 |
+
_is_dst: tl.constexpr = mx_tensor_dtype == dst_dtype
|
| 162 |
+
tl.static_assert(
|
| 163 |
+
_is_u8 or ((_is_e4 or _is_e5) or _is_dst),
|
| 164 |
+
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
| 165 |
+
tl.static_assert(scale.dtype == tl.float8e4nv, "scale must be float8e4nv")
|
| 166 |
+
|
| 167 |
+
# Determine if we are dealing with fp8 types.
|
| 168 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
| 169 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
| 170 |
+
|
| 171 |
+
# Upcast the scale to the destination type.
|
| 172 |
+
if dst_dtype == tl.bfloat16:
|
| 173 |
+
dst_scale = scale.to(tl.bfloat16)
|
| 174 |
+
else:
|
| 175 |
+
dst_scale = scale.to(tl.float32)
|
| 176 |
+
if dst_dtype == tl.float16:
|
| 177 |
+
dst_scale = dst_scale.to(tl.float16)
|
| 178 |
+
|
| 179 |
+
# Now upcast the tensor.
|
| 180 |
+
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
|
| 181 |
+
if IS_BLACKWELL:
|
| 182 |
+
assert is_fp4
|
| 183 |
+
packed_u32 = tl.inline_asm_elementwise(
|
| 184 |
+
asm="""
|
| 185 |
+
{
|
| 186 |
+
.reg .b8 in_8;
|
| 187 |
+
.reg .f16x2 out;
|
| 188 |
+
cvt.u8.u32 in_8, $1;
|
| 189 |
+
cvt.rn.f16x2.e2m1x2 out, in_8;
|
| 190 |
+
mov.b32 $0, out;
|
| 191 |
+
}
|
| 192 |
+
""",
|
| 193 |
+
constraints="=r,r",
|
| 194 |
+
args=[mx_tensor], # tl.uint8 passed in as a 32-bit reg with value in low 8 bits
|
| 195 |
+
dtype=tl.uint32,
|
| 196 |
+
is_pure=True,
|
| 197 |
+
pack=1,
|
| 198 |
+
)
|
| 199 |
+
lo_u16 = (packed_u32 & 0xFFFF).to(tl.uint16)
|
| 200 |
+
hi_u16 = (packed_u32 >> 16).to(tl.uint16)
|
| 201 |
+
lo_f16 = lo_u16.to(tl.float16, bitcast=True)
|
| 202 |
+
hi_f16 = hi_u16.to(tl.float16, bitcast=True)
|
| 203 |
+
|
| 204 |
+
if intermediate_dtype == tl.float16:
|
| 205 |
+
x0, x1 = lo_f16, hi_f16
|
| 206 |
+
else:
|
| 207 |
+
x0 = lo_f16.to(intermediate_dtype)
|
| 208 |
+
x1 = hi_f16.to(intermediate_dtype)
|
| 209 |
+
|
| 210 |
+
dst_tensor = tl.interleave(x0, x1)
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
assert is_fp4
|
| 214 |
+
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 # exponent bias
|
| 215 |
+
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
|
| 216 |
+
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 # mantissa bits
|
| 217 |
+
# e2m1
|
| 218 |
+
em0 = mx_tensor & 0x07
|
| 219 |
+
em1 = mx_tensor & 0x70
|
| 220 |
+
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((mx_tensor & 0x08).to(tl.uint16) << 12)
|
| 221 |
+
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((mx_tensor & 0x80).to(tl.uint16) << 8)
|
| 222 |
+
# Three cases:
|
| 223 |
+
# 1) x is normal and non-zero: Correct bias
|
| 224 |
+
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
| 225 |
+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
| 226 |
+
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
| 227 |
+
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
| 228 |
+
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
| 229 |
+
# 3) x is zero, do nothing
|
| 230 |
+
dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
|
| 231 |
+
|
| 232 |
+
dst_tensor = dst_tensor.to(dst_dtype)
|
| 233 |
+
|
| 234 |
+
# Reshape for proper broadcasting: the scale was stored with a 16‐sized “inner” grouping.
|
| 235 |
+
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
| 236 |
+
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
| 237 |
+
scale = scale.reshape(dst_scale.shape)
|
| 238 |
+
|
| 239 |
+
out_tensor = dst_tensor * dst_scale * s_dec # NVFP4 has the additional global scale factor
|
| 240 |
+
if dst_dtype == tl.float32:
|
| 241 |
+
max_fin = 3.4028234663852886e+38
|
| 242 |
+
elif dst_dtype == tl.bfloat16:
|
| 243 |
+
max_fin = 3.3895313892515355e+38
|
| 244 |
+
else:
|
| 245 |
+
tl.static_assert(dst_dtype == tl.float16)
|
| 246 |
+
max_fin = 65504
|
| 247 |
+
out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin)
|
| 248 |
+
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
| 249 |
+
out_tensor = out_tensor.to(dst_dtype)
|
| 250 |
+
return out_tensor
|
backend_snapshot/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import triton
|
| 2 |
+
import triton.language as tl
|
| 3 |
+
|
| 4 |
+
from .nvfp4_utils import _compute_quant_and_scale, _compute_dequant
|
| 5 |
+
|
| 6 |
+
@triton.jit
|
| 7 |
+
def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr,
|
| 8 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
| 9 |
+
dst_dtype: tl.constexpr,
|
| 10 |
+
mx_tensor_dtype: tl.constexpr = tl.uint8,
|
| 11 |
+
use_global_sf: tl.constexpr = True,
|
| 12 |
+
two_level_quant_P: tl.constexpr = False):
|
| 13 |
+
high_prec_src_tensor = src_tensor
|
| 14 |
+
src_tensor, src_scale, src_s_dec = _compute_quant_and_scale(src_tensor=src_tensor,
|
| 15 |
+
valid_src_mask=valid_src_mask,
|
| 16 |
+
mx_tensor_dtype=mx_tensor_dtype,
|
| 17 |
+
use_global_sf=use_global_sf,
|
| 18 |
+
two_level_quant_P=two_level_quant_P)
|
| 19 |
+
src_tensor = _compute_dequant(mx_tensor=src_tensor,
|
| 20 |
+
scale=src_scale,
|
| 21 |
+
s_dec=src_s_dec,
|
| 22 |
+
BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM,
|
| 23 |
+
BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM,
|
| 24 |
+
dst_dtype=dst_dtype)
|
| 25 |
+
return src_tensor, high_prec_src_tensor.to(src_tensor.dtype)
|
| 26 |
+
|
| 27 |
+
@triton.jit
|
| 28 |
+
def fake_quantize_q(Q, fake_Q, stride_z_q, stride_h_q,
|
| 29 |
+
stride_tok_q, stride_d_q,
|
| 30 |
+
fake_stride_z_q, fake_stride_h_q,
|
| 31 |
+
fake_stride_tok_q, fake_stride_d_q,
|
| 32 |
+
H, N_CTX_Q,
|
| 33 |
+
BLOCK_M: tl.constexpr,
|
| 34 |
+
HEAD_DIM: tl.constexpr,
|
| 35 |
+
use_global_sf: tl.constexpr = True):
|
| 36 |
+
bhid = tl.program_id(1)
|
| 37 |
+
adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H))
|
| 38 |
+
fake_adj_q = (fake_stride_h_q * (bhid % H) + fake_stride_z_q * (bhid // H))
|
| 39 |
+
Q += adj_q
|
| 40 |
+
fake_Q += fake_adj_q
|
| 41 |
+
|
| 42 |
+
pid = tl.program_id(0)
|
| 43 |
+
start_m = pid * BLOCK_M
|
| 44 |
+
offs_m = start_m + tl.arange(0, BLOCK_M)
|
| 45 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 46 |
+
|
| 47 |
+
q_valid = offs_m < N_CTX_Q
|
| 48 |
+
q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None], other=0.0)
|
| 49 |
+
q, _ = fake_quantize(src_tensor=q, valid_src_mask=q_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=q.dtype, use_global_sf=use_global_sf)
|
| 50 |
+
tl.store(fake_Q + offs_m[:, None] * fake_stride_tok_q + offs_k[None, :] * fake_stride_d_q, q, mask=q_valid[:, None])
|
| 51 |
+
|
| 52 |
+
@triton.jit
|
| 53 |
+
def fake_quantize_kv(K, V, fake_K, fake_V, stride_z_kv, stride_h_kv,
|
| 54 |
+
stride_tok_kv, stride_d_kv,
|
| 55 |
+
fake_stride_z_kv, fake_stride_h_kv,
|
| 56 |
+
fake_stride_tok_kv, fake_stride_d_kv,
|
| 57 |
+
H, N_CTX_KV,
|
| 58 |
+
BLOCK_N: tl.constexpr,
|
| 59 |
+
HEAD_DIM: tl.constexpr,
|
| 60 |
+
use_global_sf: tl.constexpr = True):
|
| 61 |
+
bhid = tl.program_id(1)
|
| 62 |
+
adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H))
|
| 63 |
+
fake_adj_kv = (fake_stride_h_kv * (bhid % H) + fake_stride_z_kv * (bhid // H))
|
| 64 |
+
K += adj_kv
|
| 65 |
+
V += adj_kv
|
| 66 |
+
fake_K += fake_adj_kv
|
| 67 |
+
fake_V += fake_adj_kv
|
| 68 |
+
|
| 69 |
+
pid = tl.program_id(0)
|
| 70 |
+
start_n = pid * BLOCK_N
|
| 71 |
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
| 72 |
+
offs_k = tl.arange(0, HEAD_DIM)
|
| 73 |
+
|
| 74 |
+
kv_valid = offs_n < N_CTX_KV
|
| 75 |
+
k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
|
| 76 |
+
v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0)
|
| 77 |
+
k, _ = fake_quantize(src_tensor=k_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=k_block.dtype, use_global_sf=use_global_sf)
|
| 78 |
+
v, _ = fake_quantize(src_tensor=v_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=v_block.dtype, use_global_sf=use_global_sf)
|
| 79 |
+
tl.store(fake_K + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, k, mask=kv_valid[:, None])
|
| 80 |
+
tl.store(fake_V + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, v, mask=kv_valid[:, None])
|
backend_snapshot/fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Sparse FP4 Attention backend with the independent ours-P quant kernel."""
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import triton
|
| 9 |
+
|
| 10 |
+
from fastvideo_kernel.triton_kernels.quant_utils import (
|
| 11 |
+
fake_quantize_q,
|
| 12 |
+
fake_quantize_kv,
|
| 13 |
+
)
|
| 14 |
+
from fastvideo_kernel.block_sparse_attn_ours_p import block_sparse_attn_ours_p
|
| 15 |
+
from fastvideo.forward_context import get_forward_context
|
| 16 |
+
|
| 17 |
+
from fastvideo.attention.backends.abstract import (
|
| 18 |
+
AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder,
|
| 19 |
+
)
|
| 20 |
+
from fastvideo.attention.backends.video_sparse_attn import (
|
| 21 |
+
VideoSparseAttentionMetadata,
|
| 22 |
+
VideoSparseAttentionMetadataBuilder,
|
| 23 |
+
VSA_TILE_SIZE,
|
| 24 |
+
)
|
| 25 |
+
from fastvideo.distributed import get_sp_group
|
| 26 |
+
from fastvideo.logger import init_logger
|
| 27 |
+
|
| 28 |
+
logger = init_logger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _dense_sdpa_blhd(query, key, value):
|
| 32 |
+
q = query.transpose(1, 2)
|
| 33 |
+
k = key.transpose(1, 2)
|
| 34 |
+
v = value.transpose(1, 2)
|
| 35 |
+
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
| 36 |
+
return out.transpose(1, 2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _quantize_qkv_bhld(q, k, v):
|
| 40 |
+
"""FP4 fake quantize Q/K/V in BHLD layout, same as attn_qat_train."""
|
| 41 |
+
H = q.shape[1]
|
| 42 |
+
N_Q = q.shape[2]
|
| 43 |
+
N_KV = k.shape[2]
|
| 44 |
+
D = q.shape[3]
|
| 45 |
+
BLOCK = 32
|
| 46 |
+
|
| 47 |
+
fake_q = torch.empty_like(q)
|
| 48 |
+
fake_k = torch.empty_like(k)
|
| 49 |
+
fake_v = torch.empty_like(v)
|
| 50 |
+
|
| 51 |
+
grid_q = (triton.cdiv(N_Q, BLOCK), q.shape[0] * H, 1)
|
| 52 |
+
grid_kv = (triton.cdiv(N_KV, BLOCK), q.shape[0] * H, 1)
|
| 53 |
+
|
| 54 |
+
fake_quantize_q[grid_q](
|
| 55 |
+
q, fake_q,
|
| 56 |
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
| 57 |
+
fake_q.stride(0), fake_q.stride(1), fake_q.stride(2), fake_q.stride(3),
|
| 58 |
+
H, N_Q, BLOCK_M=BLOCK, HEAD_DIM=D, use_global_sf=False,
|
| 59 |
+
)
|
| 60 |
+
fake_quantize_kv[grid_kv](
|
| 61 |
+
k, v, fake_k, fake_v,
|
| 62 |
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
| 63 |
+
fake_k.stride(0), fake_k.stride(1), fake_k.stride(2), fake_k.stride(3),
|
| 64 |
+
H, N_KV, BLOCK_N=BLOCK, HEAD_DIM=D, use_global_sf=False,
|
| 65 |
+
)
|
| 66 |
+
return fake_q, fake_k, fake_v
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SparseFP4OursPAttentionBackend(AttentionBackend):
|
| 70 |
+
accept_output_buffer: bool = True
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def get_supported_head_sizes() -> list[int]:
|
| 74 |
+
return [64, 96, 128, 160, 192, 224, 256]
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def get_name() -> str:
|
| 78 |
+
return "SPARSE_FP4_OURS_P_ATTN"
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def get_impl_cls() -> type["SparseFP4OursPAttentionImpl"]:
|
| 82 |
+
return SparseFP4OursPAttentionImpl
|
| 83 |
+
|
| 84 |
+
@staticmethod
|
| 85 |
+
def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
|
| 86 |
+
return VideoSparseAttentionMetadata
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
|
| 90 |
+
return VideoSparseAttentionMetadataBuilder
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SparseFP4OursPAttentionImpl(AttentionImpl):
|
| 94 |
+
|
| 95 |
+
def __init__(self, num_heads, head_size, causal, softmax_scale,
|
| 96 |
+
num_kv_heads=None, prefix="", **extra):
|
| 97 |
+
self.prefix = prefix
|
| 98 |
+
self.sp_size = get_sp_group().world_size
|
| 99 |
+
|
| 100 |
+
def tile(self, x, num_tiles, tile_partition_indices, non_pad_index):
|
| 101 |
+
t_p = num_tiles[0] * VSA_TILE_SIZE[0]
|
| 102 |
+
h_p = num_tiles[1] * VSA_TILE_SIZE[1]
|
| 103 |
+
w_p = num_tiles[2] * VSA_TILE_SIZE[2]
|
| 104 |
+
out = torch.zeros(
|
| 105 |
+
(x.shape[0], t_p * h_p * w_p, x.shape[-2], x.shape[-1]),
|
| 106 |
+
device=x.device, dtype=x.dtype,
|
| 107 |
+
)
|
| 108 |
+
out[:, non_pad_index] = x[:, tile_partition_indices]
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
def untile(self, x, reverse_tile_partition_indices, non_pad_index):
|
| 112 |
+
return x[:, non_pad_index][:, reverse_tile_partition_indices]
|
| 113 |
+
|
| 114 |
+
def _is_force_dense(self) -> bool:
|
| 115 |
+
ctx = get_forward_context()
|
| 116 |
+
return ctx.force_dense
|
| 117 |
+
|
| 118 |
+
def preprocess_qkv(self, qkv, attn_metadata):
|
| 119 |
+
if attn_metadata is None or self._is_force_dense():
|
| 120 |
+
return qkv
|
| 121 |
+
return self.tile(qkv, attn_metadata.num_tiles,
|
| 122 |
+
attn_metadata.tile_partition_indices,
|
| 123 |
+
attn_metadata.non_pad_index)
|
| 124 |
+
|
| 125 |
+
def postprocess_output(self, output, attn_metadata):
|
| 126 |
+
if attn_metadata is None or self._is_force_dense():
|
| 127 |
+
return output
|
| 128 |
+
return self.untile(output,
|
| 129 |
+
attn_metadata.reverse_tile_partition_indices,
|
| 130 |
+
attn_metadata.non_pad_index)
|
| 131 |
+
|
| 132 |
+
def forward(self, query, key, value,
|
| 133 |
+
gate_compress_or_metadata=None, attn_metadata=None):
|
| 134 |
+
# Handle both call conventions
|
| 135 |
+
if attn_metadata is None and isinstance(
|
| 136 |
+
gate_compress_or_metadata, (VideoSparseAttentionMetadata, type(None))):
|
| 137 |
+
attn_metadata = gate_compress_or_metadata
|
| 138 |
+
|
| 139 |
+
# ── force_dense: true dense BF16 SDPA (for teacher in distillation) ──
|
| 140 |
+
ctx = get_forward_context()
|
| 141 |
+
if ctx.force_dense:
|
| 142 |
+
return _dense_sdpa_blhd(query, key, value)
|
| 143 |
+
|
| 144 |
+
is_cross = query.shape[1] != key.shape[1]
|
| 145 |
+
|
| 146 |
+
# ── Cross-attention/no metadata: keep dense. The sparse VSA metadata only
|
| 147 |
+
# applies to tiled video self-attention.
|
| 148 |
+
if attn_metadata is None or is_cross:
|
| 149 |
+
return _dense_sdpa_blhd(query, key, value)
|
| 150 |
+
|
| 151 |
+
# ── Self-attention: FP4 quant Q/K/V + block-sparse attention ──
|
| 152 |
+
# BLHD → BHLD
|
| 153 |
+
q = query.transpose(1, 2).contiguous()
|
| 154 |
+
k = key.transpose(1, 2).contiguous()
|
| 155 |
+
v = value.transpose(1, 2).contiguous()
|
| 156 |
+
|
| 157 |
+
# Step 1: FP4 fake quantize Q/K/V with STE (straight-through estimator)
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
fq, fk, fv = _quantize_qkv_bhld(q, k, v)
|
| 160 |
+
# STE: forward uses quantized values, backward passes gradient through as-is
|
| 161 |
+
fq = q + (fq - q).detach()
|
| 162 |
+
fk = k + (fk - k).detach()
|
| 163 |
+
fv = v + (fv - v).detach()
|
| 164 |
+
|
| 165 |
+
# Step 2: Build sparse block map
|
| 166 |
+
B, H, S, D = fq.shape
|
| 167 |
+
block_elements = math.prod(VSA_TILE_SIZE)
|
| 168 |
+
num_blocks = S // block_elements
|
| 169 |
+
|
| 170 |
+
VSA_sparsity = attn_metadata.VSA_sparsity
|
| 171 |
+
cur_topk = max(1, math.ceil((1 - VSA_sparsity) * num_blocks))
|
| 172 |
+
logger.info(f"[SFP4] S={S} num_blocks={num_blocks} sparsity={VSA_sparsity} topk={cur_topk}/{num_blocks}")
|
| 173 |
+
|
| 174 |
+
block_sizes = attn_metadata.variable_block_sizes.to(
|
| 175 |
+
device=fq.device, dtype=torch.float32).clamp_min(1)
|
| 176 |
+
block_sizes = block_sizes.view(1, 1, num_blocks, 1)
|
| 177 |
+
q_c = (fq.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 178 |
+
block_sizes).to(fq.dtype)
|
| 179 |
+
k_c = (fk.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 180 |
+
block_sizes).to(fk.dtype)
|
| 181 |
+
v_c = (fv.view(B, H, num_blocks, block_elements, D).float().sum(3) /
|
| 182 |
+
block_sizes).to(fv.dtype)
|
| 183 |
+
scores = torch.matmul(q_c, k_c.transpose(-2, -1)) / (D ** 0.5)
|
| 184 |
+
topk_idx = torch.topk(scores, cur_topk, dim=-1).indices
|
| 185 |
+
block_map = torch.zeros_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, True)
|
| 186 |
+
|
| 187 |
+
# Step 3: Block-sparse attention with independent group-local P quant.
|
| 188 |
+
out, _ = block_sparse_attn_ours_p(fq, fk, fv, block_map,
|
| 189 |
+
attn_metadata.variable_block_sizes,
|
| 190 |
+
q_c, k_c, v_c)
|
| 191 |
+
|
| 192 |
+
return out.transpose(1, 2) # BHLD → BLHD
|
backend_snapshot/fastvideo/attention/backends/video_sparse_attn.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
import functools
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from fastvideo_kernel import video_sparse_attn
|
| 10 |
+
except ImportError:
|
| 11 |
+
video_sparse_attn = None
|
| 12 |
+
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
from fastvideo.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata,
|
| 16 |
+
AttentionMetadataBuilder)
|
| 17 |
+
from fastvideo.distributed import get_sp_group
|
| 18 |
+
from fastvideo.logger import init_logger
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
VSA_TILE_SIZE = (4, 4, 4)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@functools.lru_cache(maxsize=10)
|
| 25 |
+
def get_tile_partition_indices(
|
| 26 |
+
dit_seq_shape: tuple[int, int, int],
|
| 27 |
+
tile_size: tuple[int, int, int],
|
| 28 |
+
device: torch.device,
|
| 29 |
+
) -> torch.LongTensor:
|
| 30 |
+
T, H, W = dit_seq_shape
|
| 31 |
+
ts, hs, ws = tile_size
|
| 32 |
+
indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W)
|
| 33 |
+
ls = []
|
| 34 |
+
for t in range(math.ceil(T / ts)):
|
| 35 |
+
for h in range(math.ceil(H / hs)):
|
| 36 |
+
for w in range(math.ceil(W / ws)):
|
| 37 |
+
ls.append(indices[t * ts:min(t * ts + ts, T), h * hs:min(h * hs + hs, H),
|
| 38 |
+
w * ws:min(w * ws + ws, W)].flatten())
|
| 39 |
+
index = torch.cat(ls, dim=0)
|
| 40 |
+
return index
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@functools.lru_cache(maxsize=10)
|
| 44 |
+
def get_reverse_tile_partition_indices(
|
| 45 |
+
dit_seq_shape: tuple[int, int, int],
|
| 46 |
+
tile_size: tuple[int, int, int],
|
| 47 |
+
device: torch.device,
|
| 48 |
+
) -> torch.LongTensor:
|
| 49 |
+
return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@functools.lru_cache(maxsize=10)
|
| 53 |
+
def construct_variable_block_sizes(
|
| 54 |
+
dit_seq_shape: tuple[int, int, int],
|
| 55 |
+
num_tiles: tuple[int, int, int],
|
| 56 |
+
device: torch.device,
|
| 57 |
+
) -> torch.LongTensor:
|
| 58 |
+
"""
|
| 59 |
+
Compute the number of valid (non‑padded) tokens inside every
|
| 60 |
+
(ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order
|
| 61 |
+
(t‑tile, h‑tile, w‑tile) that `rearrange` uses.
|
| 62 |
+
|
| 63 |
+
Returns
|
| 64 |
+
-------
|
| 65 |
+
torch.LongTensor # shape: [∏ full_window_size]
|
| 66 |
+
"""
|
| 67 |
+
# unpack
|
| 68 |
+
t, h, w = dit_seq_shape
|
| 69 |
+
ts_t, ts_h, ts_w = VSA_TILE_SIZE
|
| 70 |
+
n_t, n_h, n_w = num_tiles
|
| 71 |
+
|
| 72 |
+
def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor:
|
| 73 |
+
"""Vector with the size of each tile along one dimension."""
|
| 74 |
+
sizes = torch.full((n_tiles, ), tile, dtype=torch.int, device=device)
|
| 75 |
+
# size of last (possibly partial) tile
|
| 76 |
+
remainder = dim_len - (n_tiles - 1) * tile
|
| 77 |
+
sizes[-1] = remainder if remainder > 0 else tile
|
| 78 |
+
return sizes
|
| 79 |
+
|
| 80 |
+
t_sizes = _sizes(t, ts_t, n_t) # [n_t]
|
| 81 |
+
h_sizes = _sizes(h, ts_h, n_h) # [n_h]
|
| 82 |
+
w_sizes = _sizes(w, ts_w, n_w) # [n_w]
|
| 83 |
+
|
| 84 |
+
# broadcast‑multiply to get voxels per tile, then flatten
|
| 85 |
+
block_sizes = (
|
| 86 |
+
t_sizes[:, None, None] # [n_t, 1, 1]
|
| 87 |
+
* h_sizes[None, :, None] # [1, n_h, 1]
|
| 88 |
+
* w_sizes[None, None, :] # [1, 1, n_w]
|
| 89 |
+
).reshape(-1) # [n_t * n_h * n_w]
|
| 90 |
+
|
| 91 |
+
return block_sizes
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@functools.lru_cache(maxsize=10)
|
| 95 |
+
def get_non_pad_index(
|
| 96 |
+
variable_block_sizes: torch.LongTensor,
|
| 97 |
+
max_block_size: int,
|
| 98 |
+
):
|
| 99 |
+
n_win = variable_block_sizes.shape[0]
|
| 100 |
+
device = variable_block_sizes.device
|
| 101 |
+
starts_pad = torch.arange(n_win, device=device) * max_block_size
|
| 102 |
+
index_pad = starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :]
|
| 103 |
+
index_mask = torch.arange(max_block_size, device=device)[None, :] < variable_block_sizes[:, None]
|
| 104 |
+
return index_pad[index_mask]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class VideoSparseAttentionBackend(AttentionBackend):
|
| 108 |
+
|
| 109 |
+
accept_output_buffer: bool = True
|
| 110 |
+
|
| 111 |
+
@staticmethod
|
| 112 |
+
def get_supported_head_sizes() -> list[int]:
|
| 113 |
+
return [64, 128]
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def get_name() -> str:
|
| 117 |
+
return "VIDEO_SPARSE_ATTN"
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def get_impl_cls() -> type["VideoSparseAttentionImpl"]:
|
| 121 |
+
return VideoSparseAttentionImpl
|
| 122 |
+
|
| 123 |
+
@staticmethod
|
| 124 |
+
def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]:
|
| 125 |
+
return VideoSparseAttentionMetadata
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]:
|
| 129 |
+
return VideoSparseAttentionMetadataBuilder
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class VideoSparseAttentionMetadata(AttentionMetadata):
|
| 134 |
+
current_timestep: int
|
| 135 |
+
dit_seq_shape: list[int]
|
| 136 |
+
num_tiles: list[int]
|
| 137 |
+
total_seq_length: int
|
| 138 |
+
tile_partition_indices: torch.LongTensor
|
| 139 |
+
reverse_tile_partition_indices: torch.LongTensor
|
| 140 |
+
variable_block_sizes: torch.LongTensor
|
| 141 |
+
non_pad_index: torch.LongTensor
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
|
| 145 |
+
|
| 146 |
+
def __init__(self) -> None:
|
| 147 |
+
pass
|
| 148 |
+
|
| 149 |
+
def prepare(self) -> None:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
def build( # type: ignore
|
| 153 |
+
self,
|
| 154 |
+
current_timestep: int,
|
| 155 |
+
raw_latent_shape: tuple[int, int, int],
|
| 156 |
+
patch_size: tuple[int, int, int],
|
| 157 |
+
VSA_sparsity: float,
|
| 158 |
+
device: torch.device,
|
| 159 |
+
**kwargs: dict[str, Any],
|
| 160 |
+
) -> VideoSparseAttentionMetadata:
|
| 161 |
+
patch_size = patch_size
|
| 162 |
+
dit_seq_shape = (raw_latent_shape[0] // patch_size[0], raw_latent_shape[1] // patch_size[1],
|
| 163 |
+
raw_latent_shape[2] // patch_size[2])
|
| 164 |
+
|
| 165 |
+
num_tiles = (math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]),
|
| 166 |
+
math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]))
|
| 167 |
+
total_seq_length = math.prod(dit_seq_shape)
|
| 168 |
+
|
| 169 |
+
tile_partition_indices = get_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
| 170 |
+
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
|
| 171 |
+
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
|
| 172 |
+
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
|
| 173 |
+
|
| 174 |
+
return VideoSparseAttentionMetadata(
|
| 175 |
+
current_timestep=current_timestep,
|
| 176 |
+
dit_seq_shape=dit_seq_shape, # type: ignore
|
| 177 |
+
VSA_sparsity=VSA_sparsity, # type: ignore
|
| 178 |
+
num_tiles=num_tiles, # type: ignore
|
| 179 |
+
total_seq_length=total_seq_length, # type: ignore
|
| 180 |
+
tile_partition_indices=tile_partition_indices, # type: ignore
|
| 181 |
+
reverse_tile_partition_indices=reverse_tile_partition_indices,
|
| 182 |
+
variable_block_sizes=variable_block_sizes,
|
| 183 |
+
non_pad_index=non_pad_index)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class VideoSparseAttentionImpl(AttentionImpl):
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
num_heads: int,
|
| 191 |
+
head_size: int,
|
| 192 |
+
causal: bool,
|
| 193 |
+
softmax_scale: float,
|
| 194 |
+
num_kv_heads: int | None = None,
|
| 195 |
+
prefix: str = "",
|
| 196 |
+
**extra_impl_args,
|
| 197 |
+
) -> None:
|
| 198 |
+
self.prefix = prefix
|
| 199 |
+
sp_group = get_sp_group()
|
| 200 |
+
self.sp_size = sp_group.world_size
|
| 201 |
+
|
| 202 |
+
def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor,
|
| 203 |
+
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
| 204 |
+
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
|
| 205 |
+
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
|
| 206 |
+
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
|
| 207 |
+
|
| 208 |
+
x_padded = torch.zeros((x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
|
| 209 |
+
device=x.device,
|
| 210 |
+
dtype=x.dtype)
|
| 211 |
+
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
|
| 212 |
+
return x_padded
|
| 213 |
+
|
| 214 |
+
def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor,
|
| 215 |
+
non_pad_index: torch.LongTensor) -> torch.Tensor:
|
| 216 |
+
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
def preprocess_qkv(
|
| 220 |
+
self,
|
| 221 |
+
qkv: torch.Tensor,
|
| 222 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 223 |
+
) -> torch.Tensor:
|
| 224 |
+
return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices,
|
| 225 |
+
attn_metadata.non_pad_index)
|
| 226 |
+
|
| 227 |
+
def postprocess_output(
|
| 228 |
+
self,
|
| 229 |
+
output: torch.Tensor,
|
| 230 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 231 |
+
) -> torch.Tensor:
|
| 232 |
+
return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index)
|
| 233 |
+
|
| 234 |
+
def forward( # type: ignore[override]
|
| 235 |
+
self,
|
| 236 |
+
query: torch.Tensor,
|
| 237 |
+
key: torch.Tensor,
|
| 238 |
+
value: torch.Tensor,
|
| 239 |
+
gate_compress: torch.Tensor,
|
| 240 |
+
attn_metadata: VideoSparseAttentionMetadata,
|
| 241 |
+
) -> torch.Tensor:
|
| 242 |
+
query = query.transpose(1, 2).contiguous()
|
| 243 |
+
key = key.transpose(1, 2).contiguous()
|
| 244 |
+
value = value.transpose(1, 2).contiguous()
|
| 245 |
+
gate_compress = gate_compress.transpose(1, 2).contiguous()
|
| 246 |
+
|
| 247 |
+
VSA_sparsity = attn_metadata.VSA_sparsity
|
| 248 |
+
|
| 249 |
+
cur_topk = math.ceil((1 - VSA_sparsity) * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)))
|
| 250 |
+
|
| 251 |
+
if video_sparse_attn is None:
|
| 252 |
+
raise NotImplementedError("video_sparse_attn is not installed")
|
| 253 |
+
hidden_states = video_sparse_attn(query,
|
| 254 |
+
key,
|
| 255 |
+
value,
|
| 256 |
+
attn_metadata.variable_block_sizes,
|
| 257 |
+
attn_metadata.variable_block_sizes,
|
| 258 |
+
cur_topk,
|
| 259 |
+
block_size=VSA_TILE_SIZE,
|
| 260 |
+
compress_attn_weight=gate_compress).transpose(1, 2)
|
| 261 |
+
|
| 262 |
+
return hidden_states
|
backend_snapshot/fastvideo/configs/models/dits/base.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from fastvideo.configs.models.base import ArchConfig, ModelConfig
|
| 6 |
+
from fastvideo.layers.quantization import QuantizationConfig
|
| 7 |
+
from fastvideo.platforms import AttentionBackendEnum
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class DiTArchConfig(ArchConfig):
|
| 12 |
+
_fsdp_shard_conditions: list = field(default_factory=list)
|
| 13 |
+
_compile_conditions: list = field(default_factory=list)
|
| 14 |
+
param_names_mapping: dict = field(default_factory=dict)
|
| 15 |
+
reverse_param_names_mapping: dict = field(default_factory=dict)
|
| 16 |
+
lora_param_names_mapping: dict = field(default_factory=dict)
|
| 17 |
+
_supported_attention_backends: tuple[AttentionBackendEnum,
|
| 18 |
+
...] = (AttentionBackendEnum.SAGE_ATTN, AttentionBackendEnum.FLASH_ATTN,
|
| 19 |
+
AttentionBackendEnum.TORCH_SDPA,
|
| 20 |
+
AttentionBackendEnum.VIDEO_SPARSE_ATTN,
|
| 21 |
+
AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.SAGE_ATTN_THREE,
|
| 22 |
+
AttentionBackendEnum.ATTN_QAT_INFER,
|
| 23 |
+
AttentionBackendEnum.ATTN_QAT_TRAIN, AttentionBackendEnum.SLA_ATTN,
|
| 24 |
+
AttentionBackendEnum.SAGE_SLA_ATTN,
|
| 25 |
+
AttentionBackendEnum.SPARSE_FP4_ATTN,
|
| 26 |
+
AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN)
|
| 27 |
+
|
| 28 |
+
hidden_size: int = 0
|
| 29 |
+
num_attention_heads: int = 0
|
| 30 |
+
num_channels_latents: int = 0
|
| 31 |
+
in_channels: int | None = 0
|
| 32 |
+
out_channels: int | None = 0
|
| 33 |
+
patch_size: int | tuple[int, int, int] | None = None
|
| 34 |
+
expand_timesteps: bool = False
|
| 35 |
+
num_layers: int = 0
|
| 36 |
+
ffn_dim: int = 0
|
| 37 |
+
exclude_lora_layers: list[str] = field(default_factory=list)
|
| 38 |
+
boundary_ratio: float | None = None
|
| 39 |
+
|
| 40 |
+
def __post_init__(self) -> None:
|
| 41 |
+
if not self._compile_conditions:
|
| 42 |
+
self._compile_conditions = self._fsdp_shard_conditions.copy()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class DiTConfig(ModelConfig):
|
| 47 |
+
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
|
| 48 |
+
|
| 49 |
+
# FastVideoDiT-specific parameters
|
| 50 |
+
prefix: str = ""
|
| 51 |
+
quant_config: QuantizationConfig | None = None
|
| 52 |
+
expand_timesteps: bool = False
|
| 53 |
+
boundary_ratio: float | None = None
|
| 54 |
+
|
| 55 |
+
def __post_init__(self) -> None:
|
| 56 |
+
super().__post_init__()
|
| 57 |
+
self.arch_config.expand_timesteps = self.expand_timesteps
|
| 58 |
+
self.arch_config.boundary_ratio = self.boundary_ratio
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any:
|
| 62 |
+
"""Add CLI arguments for DiTConfig fields"""
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
f"--{prefix}.prefix",
|
| 65 |
+
type=str,
|
| 66 |
+
dest=f"{prefix.replace('-', '_')}.prefix",
|
| 67 |
+
default=DiTConfig.prefix,
|
| 68 |
+
help="Prefix for the DiT model",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
f"--{prefix}.quant-config",
|
| 73 |
+
type=str,
|
| 74 |
+
dest=f"{prefix.replace('-', '_')}.quant_config",
|
| 75 |
+
default=None,
|
| 76 |
+
help="Quantization configuration for the DiT model",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return parser
|
backend_snapshot/fastvideo/forward_context.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import TYPE_CHECKING, Optional
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from fastvideo.logger import init_logger
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from fastvideo.attention import AttentionMetadata
|
| 16 |
+
from fastvideo.pipelines import ForwardBatch
|
| 17 |
+
|
| 18 |
+
logger = init_logger(__name__)
|
| 19 |
+
|
| 20 |
+
# TODO(will): check if this is needed
|
| 21 |
+
# track_batchsize: bool = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL >= 0
|
| 22 |
+
track_batchsize: bool = False
|
| 23 |
+
last_logging_time: float = 0
|
| 24 |
+
forward_start_time: float = 0
|
| 25 |
+
# batchsize_logging_interval: float = envs.FASTVIDEO_LOG_BATCHSIZE_INTERVAL
|
| 26 |
+
batchsize_logging_interval: float = 1000
|
| 27 |
+
batchsize_forward_time: defaultdict = defaultdict(list)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
#
|
| 31 |
+
@dataclass
|
| 32 |
+
class ForwardContext:
|
| 33 |
+
current_timestep: int
|
| 34 |
+
# TODO(will): check this arg
|
| 35 |
+
# copy from vllm_config.compilation_config.static_forward_context
|
| 36 |
+
# attn_layers: Dict[str, Any]
|
| 37 |
+
# TODO: extend to support per-layer dynamic forward context
|
| 38 |
+
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
| 39 |
+
forward_batch: Optional["ForwardBatch"] = None
|
| 40 |
+
force_dense: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_forward_context: Optional["ForwardContext"] = None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_forward_context() -> "ForwardContext":
|
| 47 |
+
"""Get the current forward context."""
|
| 48 |
+
assert _forward_context is not None, ("Forward context is not set. "
|
| 49 |
+
"Please use `set_forward_context` to set the forward context.")
|
| 50 |
+
return _forward_context
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# TODO(will): finalize the interface
|
| 54 |
+
@contextmanager
|
| 55 |
+
def set_forward_context(current_timestep, attn_metadata, forward_batch: Optional["ForwardBatch"] = None, force_dense: bool = False):
|
| 56 |
+
"""A context manager that stores the current forward context,
|
| 57 |
+
can be attention metadata, etc.
|
| 58 |
+
Here we can inject common logic for every model forward pass.
|
| 59 |
+
"""
|
| 60 |
+
global forward_start_time
|
| 61 |
+
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
| 62 |
+
if need_to_track_batchsize:
|
| 63 |
+
forward_start_time = time.perf_counter()
|
| 64 |
+
global _forward_context
|
| 65 |
+
prev_context = _forward_context
|
| 66 |
+
_forward_context = ForwardContext(current_timestep=current_timestep,
|
| 67 |
+
attn_metadata=attn_metadata,
|
| 68 |
+
forward_batch=forward_batch,
|
| 69 |
+
force_dense=force_dense)
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
yield
|
| 73 |
+
finally:
|
| 74 |
+
global last_logging_time, batchsize_logging_interval
|
| 75 |
+
if need_to_track_batchsize:
|
| 76 |
+
if hasattr(attn_metadata, "num_prefill_tokens"):
|
| 77 |
+
# for v0 attention backends
|
| 78 |
+
batchsize = attn_metadata.num_prefill_tokens + \
|
| 79 |
+
attn_metadata.num_decode_tokens
|
| 80 |
+
else:
|
| 81 |
+
# for v1 attention backends
|
| 82 |
+
batchsize = attn_metadata.num_input_tokens
|
| 83 |
+
now = time.perf_counter()
|
| 84 |
+
# time measurement is in milliseconds
|
| 85 |
+
batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
|
| 86 |
+
if now - last_logging_time > batchsize_logging_interval:
|
| 87 |
+
last_logging_time = now
|
| 88 |
+
forward_stats = []
|
| 89 |
+
for bs, times in batchsize_forward_time.items():
|
| 90 |
+
if len(times) <= 1:
|
| 91 |
+
# can be cudagraph / profiling run
|
| 92 |
+
continue
|
| 93 |
+
medium = torch.quantile(torch.tensor(times), q=0.5).item()
|
| 94 |
+
medium = round(medium, 2)
|
| 95 |
+
forward_stats.append((bs, len(times), medium))
|
| 96 |
+
forward_stats.sort(key=lambda x: x[1], reverse=True)
|
| 97 |
+
if forward_stats:
|
| 98 |
+
logger.info(("Batchsize forward time stats "
|
| 99 |
+
"(batchsize, count, median_time(ms)): %s"), forward_stats)
|
| 100 |
+
_forward_context = prev_context
|
backend_snapshot/fastvideo/pipelines/stages/denoising.py
ADDED
|
@@ -0,0 +1,1184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""
|
| 3 |
+
Denoising stage for diffusion pipelines.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import inspect
|
| 7 |
+
import weakref
|
| 8 |
+
from collections.abc import Iterable
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from tqdm.auto import tqdm
|
| 13 |
+
|
| 14 |
+
from fastvideo.attention import get_attn_backend
|
| 15 |
+
from fastvideo.distributed import (get_local_torch_device, get_world_group)
|
| 16 |
+
from fastvideo.fastvideo_args import FastVideoArgs
|
| 17 |
+
from fastvideo.forward_context import set_forward_context
|
| 18 |
+
from fastvideo.logger import init_logger
|
| 19 |
+
from fastvideo.models.loader.component_loader import TransformerLoader
|
| 20 |
+
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (FlowMatchEulerDiscreteScheduler)
|
| 21 |
+
from fastvideo.models.utils import pred_noise_to_pred_video
|
| 22 |
+
from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
|
| 23 |
+
from fastvideo.pipelines.stages.base import PipelineStage
|
| 24 |
+
from fastvideo.pipelines.stages.validators import StageValidators as V
|
| 25 |
+
from fastvideo.pipelines.stages.validators import VerificationResult
|
| 26 |
+
from fastvideo.platforms import AttentionBackendEnum
|
| 27 |
+
from fastvideo.utils import dict_to_3d_list, masks_like
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from fastvideo.attention.backends.vmoba import VMOBAAttentionBackend
|
| 31 |
+
from fastvideo.utils import is_vmoba_available
|
| 32 |
+
vmoba_attn_available = is_vmoba_available()
|
| 33 |
+
except ImportError:
|
| 34 |
+
vmoba_attn_available = False
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionBackend)
|
| 38 |
+
vsa_available = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
vsa_available = False
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from fastvideo.attention.backends.sparse_fp4_attn import (SparseFP4AttentionBackend)
|
| 44 |
+
except ImportError:
|
| 45 |
+
SparseFP4AttentionBackend = None # type: ignore[assignment]
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from fastvideo.attention.backends.sparse_fp4_ours_p_attn import (SparseFP4OursPAttentionBackend)
|
| 49 |
+
except ImportError:
|
| 50 |
+
SparseFP4OursPAttentionBackend = None # type: ignore[assignment]
|
| 51 |
+
|
| 52 |
+
sparse_fp4_backends = tuple(
|
| 53 |
+
backend for backend in (
|
| 54 |
+
SparseFP4AttentionBackend,
|
| 55 |
+
SparseFP4OursPAttentionBackend,
|
| 56 |
+
) if backend is not None)
|
| 57 |
+
sparse_fp4_available = bool(sparse_fp4_backends)
|
| 58 |
+
|
| 59 |
+
logger = init_logger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DenoisingStage(PipelineStage):
|
| 63 |
+
"""
|
| 64 |
+
Stage for running the denoising loop in diffusion pipelines.
|
| 65 |
+
|
| 66 |
+
This stage handles the iterative denoising process that transforms
|
| 67 |
+
the initial noise into the final output.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None) -> None:
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.transformer = transformer
|
| 73 |
+
self.transformer_2 = transformer_2
|
| 74 |
+
self.scheduler = scheduler
|
| 75 |
+
self.vae = vae
|
| 76 |
+
self.pipeline = weakref.ref(pipeline) if pipeline else None
|
| 77 |
+
attn_head_size = self.transformer.hidden_size // self.transformer.num_attention_heads
|
| 78 |
+
self.attn_backend = get_attn_backend(
|
| 79 |
+
head_size=attn_head_size,
|
| 80 |
+
dtype=torch.float16, # TODO(will): hack
|
| 81 |
+
supported_attention_backends=(
|
| 82 |
+
AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, AttentionBackendEnum.VMOBA_ATTN,
|
| 83 |
+
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE,
|
| 84 |
+
AttentionBackendEnum.ATTN_QAT_INFER, AttentionBackendEnum.ATTN_QAT_TRAIN,
|
| 85 |
+
AttentionBackendEnum.SPARSE_FP4_ATTN, AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN) # hack
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(
|
| 89 |
+
self,
|
| 90 |
+
batch: ForwardBatch,
|
| 91 |
+
fastvideo_args: FastVideoArgs,
|
| 92 |
+
) -> ForwardBatch:
|
| 93 |
+
"""
|
| 94 |
+
Run the denoising loop.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
batch: The current batch information.
|
| 98 |
+
fastvideo_args: The inference arguments.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
The batch with denoised latents.
|
| 102 |
+
"""
|
| 103 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 104 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 105 |
+
loader = TransformerLoader()
|
| 106 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 107 |
+
if pipeline:
|
| 108 |
+
pipeline.add_module("transformer", self.transformer)
|
| 109 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 110 |
+
|
| 111 |
+
# Prepare extra step kwargs for scheduler
|
| 112 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 113 |
+
self.scheduler.step,
|
| 114 |
+
{
|
| 115 |
+
"generator": batch.generator,
|
| 116 |
+
"eta": batch.eta
|
| 117 |
+
},
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Setup precision and autocast settings
|
| 121 |
+
# TODO(will): make the precision configurable for inference
|
| 122 |
+
# target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
|
| 123 |
+
target_dtype = torch.bfloat16
|
| 124 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 125 |
+
|
| 126 |
+
# Get timesteps and calculate warmup steps
|
| 127 |
+
timesteps = batch.timesteps
|
| 128 |
+
# TODO(will): remove this once we add input/output validation for stages
|
| 129 |
+
if timesteps is None:
|
| 130 |
+
raise ValueError("Timesteps must be provided")
|
| 131 |
+
num_inference_steps = batch.num_inference_steps
|
| 132 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 133 |
+
|
| 134 |
+
# Prepare image latents and embeddings for I2V generation
|
| 135 |
+
image_embeds = batch.image_embeds
|
| 136 |
+
if len(image_embeds) > 0:
|
| 137 |
+
assert not torch.isnan(image_embeds[0]).any(), "image_embeds contains nan"
|
| 138 |
+
image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
|
| 139 |
+
|
| 140 |
+
image_kwargs = self.prepare_extra_func_kwargs(
|
| 141 |
+
self.transformer.forward,
|
| 142 |
+
{
|
| 143 |
+
"encoder_hidden_states_image": image_embeds,
|
| 144 |
+
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
|
| 145 |
+
},
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
pos_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 149 |
+
self.transformer.forward,
|
| 150 |
+
{
|
| 151 |
+
"encoder_hidden_states_2": batch.clip_embedding_pos,
|
| 152 |
+
"encoder_attention_mask": batch.prompt_attention_mask,
|
| 153 |
+
},
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
neg_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 157 |
+
self.transformer.forward,
|
| 158 |
+
{
|
| 159 |
+
"encoder_hidden_states_2": batch.clip_embedding_neg,
|
| 160 |
+
"encoder_attention_mask": batch.negative_attention_mask,
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
action_kwargs = self.prepare_extra_func_kwargs(
|
| 165 |
+
self.transformer.forward,
|
| 166 |
+
{
|
| 167 |
+
"mouse_cond": batch.mouse_cond,
|
| 168 |
+
"keyboard_cond": batch.keyboard_cond,
|
| 169 |
+
"c2ws_plucker_emb": batch.c2ws_plucker_emb,
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
camera_kwargs = self.prepare_extra_func_kwargs(
|
| 174 |
+
self.transformer.forward,
|
| 175 |
+
{
|
| 176 |
+
"camera_states": batch.camera_states,
|
| 177 |
+
},
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Get latents and embeddings
|
| 181 |
+
latents = batch.latents
|
| 182 |
+
prompt_embeds = batch.prompt_embeds
|
| 183 |
+
assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
|
| 184 |
+
if batch.do_classifier_free_guidance:
|
| 185 |
+
neg_prompt_embeds = batch.negative_prompt_embeds
|
| 186 |
+
assert neg_prompt_embeds is not None
|
| 187 |
+
assert not torch.isnan(neg_prompt_embeds[0]).any(), "neg_prompt_embeds contains nan"
|
| 188 |
+
|
| 189 |
+
# (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert
|
| 190 |
+
boundary_ratio = fastvideo_args.pipeline_config.dit_config.boundary_ratio
|
| 191 |
+
if batch.boundary_ratio is not None:
|
| 192 |
+
logger.info("Overriding boundary ratio from %s to %s", boundary_ratio, batch.boundary_ratio)
|
| 193 |
+
boundary_ratio = batch.boundary_ratio
|
| 194 |
+
|
| 195 |
+
boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps if boundary_ratio is not None else None
|
| 196 |
+
latent_model_input = latents.to(target_dtype)
|
| 197 |
+
assert latent_model_input.shape[0] == 1, "only support batch size 1"
|
| 198 |
+
|
| 199 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 200 |
+
# TI2V directly replaces the first frame of the latent with
|
| 201 |
+
# the image latent instead of appending along the channel dim
|
| 202 |
+
assert batch.image_latent is None, "TI2V task should not have image latents"
|
| 203 |
+
assert self.vae is not None, "VAE is not provided for TI2V task"
|
| 204 |
+
z = self.vae.encode(batch.pil_image).mean.float()
|
| 205 |
+
if (hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None):
|
| 206 |
+
if isinstance(self.vae.shift_factor, torch.Tensor):
|
| 207 |
+
z -= self.vae.shift_factor.to(z.device, z.dtype)
|
| 208 |
+
else:
|
| 209 |
+
z -= self.vae.shift_factor
|
| 210 |
+
|
| 211 |
+
if isinstance(self.vae.scaling_factor, torch.Tensor):
|
| 212 |
+
z = z * self.vae.scaling_factor.to(z.device, z.dtype)
|
| 213 |
+
else:
|
| 214 |
+
z = z * self.vae.scaling_factor
|
| 215 |
+
|
| 216 |
+
latent_model_input = latent_model_input.squeeze(0)
|
| 217 |
+
_, mask2 = masks_like([latent_model_input], zero=True)
|
| 218 |
+
|
| 219 |
+
latent_model_input = (1. - mask2[0]) * z + mask2[0] * latent_model_input
|
| 220 |
+
# latent_model_input = latent_model_input.unsqueeze(0)
|
| 221 |
+
latent_model_input = latent_model_input.to(get_local_torch_device())
|
| 222 |
+
latents = latent_model_input
|
| 223 |
+
F = batch.num_frames
|
| 224 |
+
temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal
|
| 225 |
+
spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial
|
| 226 |
+
patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size
|
| 227 |
+
if not isinstance(patch_size, tuple):
|
| 228 |
+
raise ValueError(f"Expected 3D patch_size tuple for denoising, got {patch_size!r}")
|
| 229 |
+
seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * (
|
| 230 |
+
batch.width // spatial_scale) // (patch_size[1] * patch_size[2])
|
| 231 |
+
|
| 232 |
+
# Initialize lists for ODE trajectory
|
| 233 |
+
trajectory_timesteps: list[torch.Tensor] = []
|
| 234 |
+
trajectory_latents: list[torch.Tensor] = []
|
| 235 |
+
|
| 236 |
+
# Run denoising loop
|
| 237 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 238 |
+
for i, t in enumerate(timesteps):
|
| 239 |
+
# Skip if interrupted
|
| 240 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 241 |
+
continue
|
| 242 |
+
|
| 243 |
+
if boundary_timestep is None or t >= boundary_timestep:
|
| 244 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 245 |
+
and self.transformer_2 is not None
|
| 246 |
+
and next(self.transformer_2.parameters()).device.type == 'cuda'):
|
| 247 |
+
self.transformer_2.to('cpu')
|
| 248 |
+
current_model = self.transformer
|
| 249 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 250 |
+
and not fastvideo_args.use_fsdp_inference and current_model is not None):
|
| 251 |
+
transformer_device = next(current_model.parameters()).device.type
|
| 252 |
+
if transformer_device == 'cpu':
|
| 253 |
+
current_model.to(get_local_torch_device())
|
| 254 |
+
current_guidance_scale = batch.guidance_scale
|
| 255 |
+
else:
|
| 256 |
+
# low-noise stage in wan2.2
|
| 257 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 258 |
+
and next(self.transformer.parameters()).device.type == 'cuda'):
|
| 259 |
+
self.transformer.to('cpu')
|
| 260 |
+
current_model = self.transformer_2
|
| 261 |
+
if (fastvideo_args.dit_cpu_offload and not fastvideo_args.dit_layerwise_offload
|
| 262 |
+
and not fastvideo_args.use_fsdp_inference and current_model is not None):
|
| 263 |
+
transformer_2_device = next(current_model.parameters()).device.type
|
| 264 |
+
if transformer_2_device == 'cpu':
|
| 265 |
+
current_model.to(get_local_torch_device())
|
| 266 |
+
current_guidance_scale = batch.guidance_scale_2
|
| 267 |
+
assert current_model is not None, "current_model is None"
|
| 268 |
+
|
| 269 |
+
# Expand latents for V2V/I2V
|
| 270 |
+
latent_model_input = latents.to(target_dtype)
|
| 271 |
+
if batch.video_latent is not None:
|
| 272 |
+
latent_model_input = torch.cat([latent_model_input, batch.video_latent,
|
| 273 |
+
torch.zeros_like(latents)],
|
| 274 |
+
dim=1).to(target_dtype)
|
| 275 |
+
elif batch.image_latent is not None:
|
| 276 |
+
assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
|
| 277 |
+
latent_model_input = torch.cat([latent_model_input, batch.image_latent], dim=1).to(target_dtype)
|
| 278 |
+
|
| 279 |
+
assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
|
| 280 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 281 |
+
timestep = torch.stack([t]).to(get_local_torch_device())
|
| 282 |
+
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
|
| 283 |
+
temp_ts = torch.cat([temp_ts, temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep])
|
| 284 |
+
timestep = temp_ts.unsqueeze(0)
|
| 285 |
+
t_expand = timestep.repeat(latent_model_input.shape[0], 1)
|
| 286 |
+
else:
|
| 287 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 288 |
+
t_expand = t_expand.to(get_local_torch_device())
|
| 289 |
+
|
| 290 |
+
use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
|
| 291 |
+
if use_meanflow:
|
| 292 |
+
if i == len(timesteps) - 1:
|
| 293 |
+
timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
|
| 294 |
+
else:
|
| 295 |
+
timesteps_r = timesteps[i + 1]
|
| 296 |
+
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
|
| 297 |
+
else:
|
| 298 |
+
timesteps_r = None
|
| 299 |
+
|
| 300 |
+
timesteps_r_kwarg = self.prepare_extra_func_kwargs(
|
| 301 |
+
self.transformer.forward,
|
| 302 |
+
{
|
| 303 |
+
"timestep_r": timesteps_r,
|
| 304 |
+
},
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 308 |
+
|
| 309 |
+
# Prepare inputs for transformer
|
| 310 |
+
guidance_expand = (torch.tensor(
|
| 311 |
+
[fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
|
| 312 |
+
dtype=torch.float32,
|
| 313 |
+
device=get_local_torch_device(),
|
| 314 |
+
).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
|
| 315 |
+
|
| 316 |
+
# Predict noise residual
|
| 317 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 318 |
+
if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
|
| 319 |
+
(sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
|
| 320 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 321 |
+
|
| 322 |
+
if self.attn_metadata_builder_cls is not None:
|
| 323 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 324 |
+
# TODO(will): clean this up
|
| 325 |
+
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
| 326 |
+
current_timestep=i, # type: ignore
|
| 327 |
+
raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
|
| 328 |
+
patch_size=fastvideo_args.pipeline_config. # type: ignore
|
| 329 |
+
dit_config.patch_size, # type: ignore
|
| 330 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
|
| 331 |
+
device=get_local_torch_device(),
|
| 332 |
+
)
|
| 333 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 334 |
+
else:
|
| 335 |
+
attn_metadata = None
|
| 336 |
+
elif (vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend):
|
| 337 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 338 |
+
if self.attn_metadata_builder_cls is not None:
|
| 339 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 340 |
+
# Prepare V-MoBA parameters from config
|
| 341 |
+
moba_params = fastvideo_args.moba_config.copy()
|
| 342 |
+
assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA"
|
| 343 |
+
moba_params.update({
|
| 344 |
+
"current_timestep": i,
|
| 345 |
+
"raw_latent_shape": batch.raw_latent_shape[2:5],
|
| 346 |
+
"patch_size": fastvideo_args.pipeline_config.dit_config.patch_size,
|
| 347 |
+
"device": get_local_torch_device(),
|
| 348 |
+
})
|
| 349 |
+
attn_metadata = self.attn_metadata_builder.build(**moba_params)
|
| 350 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 351 |
+
else:
|
| 352 |
+
attn_metadata = None
|
| 353 |
+
else:
|
| 354 |
+
attn_metadata = None
|
| 355 |
+
# TODO(will): finalize the interface. vLLM uses this to
|
| 356 |
+
# support torch dynamo compilation. They pass in
|
| 357 |
+
# attn_metadata, vllm_config, and num_tokens. We can pass in
|
| 358 |
+
# fastvideo_args or training_args, and attn_metadata.
|
| 359 |
+
batch.is_cfg_negative = False
|
| 360 |
+
with set_forward_context(
|
| 361 |
+
current_timestep=i,
|
| 362 |
+
attn_metadata=attn_metadata,
|
| 363 |
+
forward_batch=batch,
|
| 364 |
+
# fastvideo_args=fastvideo_args
|
| 365 |
+
):
|
| 366 |
+
# Run transformer
|
| 367 |
+
noise_pred = current_model(
|
| 368 |
+
latent_model_input,
|
| 369 |
+
prompt_embeds,
|
| 370 |
+
t_expand,
|
| 371 |
+
guidance=guidance_expand,
|
| 372 |
+
**image_kwargs,
|
| 373 |
+
**pos_cond_kwargs,
|
| 374 |
+
**action_kwargs,
|
| 375 |
+
**camera_kwargs,
|
| 376 |
+
**timesteps_r_kwarg,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
if batch.do_classifier_free_guidance:
|
| 380 |
+
batch.is_cfg_negative = True
|
| 381 |
+
with set_forward_context(
|
| 382 |
+
current_timestep=i,
|
| 383 |
+
attn_metadata=attn_metadata,
|
| 384 |
+
forward_batch=batch,
|
| 385 |
+
):
|
| 386 |
+
noise_pred_uncond = current_model(
|
| 387 |
+
latent_model_input,
|
| 388 |
+
neg_prompt_embeds,
|
| 389 |
+
t_expand,
|
| 390 |
+
guidance=guidance_expand,
|
| 391 |
+
**image_kwargs,
|
| 392 |
+
**neg_cond_kwargs,
|
| 393 |
+
**action_kwargs,
|
| 394 |
+
**camera_kwargs,
|
| 395 |
+
**timesteps_r_kwarg,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
noise_pred_text = noise_pred
|
| 399 |
+
noise_pred = noise_pred_uncond + current_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 400 |
+
|
| 401 |
+
# Apply guidance rescale if needed
|
| 402 |
+
if batch.guidance_rescale > 0.0:
|
| 403 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 404 |
+
noise_pred = self.rescale_noise_cfg(
|
| 405 |
+
noise_pred,
|
| 406 |
+
noise_pred_text,
|
| 407 |
+
guidance_rescale=batch.guidance_rescale,
|
| 408 |
+
)
|
| 409 |
+
# Compute the previous noisy sample
|
| 410 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 411 |
+
if fastvideo_args.pipeline_config.ti2v_task and batch.pil_image is not None:
|
| 412 |
+
latents = latents.squeeze(0)
|
| 413 |
+
latents = (1. - mask2[0]) * z + mask2[0] * latents
|
| 414 |
+
# latents = latents.unsqueeze(0)
|
| 415 |
+
|
| 416 |
+
# save trajectory latents if needed
|
| 417 |
+
if batch.return_trajectory_latents:
|
| 418 |
+
trajectory_timesteps.append(t)
|
| 419 |
+
trajectory_latents.append(latents)
|
| 420 |
+
|
| 421 |
+
# Update progress bar
|
| 422 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
|
| 423 |
+
(i + 1) % self.scheduler.order == 0 and progress_bar is not None):
|
| 424 |
+
progress_bar.update()
|
| 425 |
+
|
| 426 |
+
trajectory_tensor: torch.Tensor | None = None
|
| 427 |
+
if trajectory_latents:
|
| 428 |
+
trajectory_tensor = torch.stack(trajectory_latents, dim=1)
|
| 429 |
+
trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0)
|
| 430 |
+
else:
|
| 431 |
+
trajectory_tensor = None
|
| 432 |
+
trajectory_timesteps_tensor = None
|
| 433 |
+
|
| 434 |
+
if trajectory_tensor is not None and trajectory_timesteps_tensor is not None:
|
| 435 |
+
batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu()
|
| 436 |
+
batch.trajectory_latents = trajectory_tensor.cpu()
|
| 437 |
+
|
| 438 |
+
# Update batch with final latents
|
| 439 |
+
batch.latents = latents
|
| 440 |
+
|
| 441 |
+
if fastvideo_args.dit_layerwise_offload:
|
| 442 |
+
mgr = getattr(self.transformer, "_layerwise_offload_manager", None)
|
| 443 |
+
if mgr is not None and getattr(mgr, "enabled", False):
|
| 444 |
+
mgr.release_all()
|
| 445 |
+
if self.transformer_2 is not None:
|
| 446 |
+
mgr2 = getattr(self.transformer_2, "_layerwise_offload_manager", None)
|
| 447 |
+
if mgr2 is not None and getattr(mgr2, "enabled", False):
|
| 448 |
+
mgr2.release_all()
|
| 449 |
+
|
| 450 |
+
# deallocate transformer if on mps
|
| 451 |
+
if torch.backends.mps.is_available():
|
| 452 |
+
logger.info("Memory before deallocating transformer: %s", torch.mps.current_allocated_memory())
|
| 453 |
+
del self.transformer
|
| 454 |
+
if pipeline is not None and "transformer" in pipeline.modules:
|
| 455 |
+
del pipeline.modules["transformer"]
|
| 456 |
+
fastvideo_args.model_loaded["transformer"] = False
|
| 457 |
+
logger.info("Memory after deallocating transformer: %s", torch.mps.current_allocated_memory())
|
| 458 |
+
|
| 459 |
+
return batch
|
| 460 |
+
|
| 461 |
+
def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]:
|
| 462 |
+
"""
|
| 463 |
+
Prepare extra kwargs for the scheduler step / denoise step.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
func: The function to prepare kwargs for.
|
| 467 |
+
kwargs: The kwargs to prepare.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
The prepared kwargs.
|
| 471 |
+
"""
|
| 472 |
+
extra_step_kwargs = {}
|
| 473 |
+
for k, v in kwargs.items():
|
| 474 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
| 475 |
+
if accepts:
|
| 476 |
+
extra_step_kwargs[k] = v
|
| 477 |
+
return extra_step_kwargs
|
| 478 |
+
|
| 479 |
+
def progress_bar(self, iterable: Iterable | None = None, total: int | None = None) -> tqdm:
|
| 480 |
+
"""
|
| 481 |
+
Create a progress bar for the denoising process.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
iterable: The iterable to iterate over.
|
| 485 |
+
total: The total number of items.
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
A tqdm progress bar.
|
| 489 |
+
"""
|
| 490 |
+
local_rank = get_world_group().local_rank
|
| 491 |
+
if local_rank == 0:
|
| 492 |
+
return tqdm(iterable=iterable, total=total)
|
| 493 |
+
else:
|
| 494 |
+
return tqdm(iterable=iterable, total=total, disable=True)
|
| 495 |
+
|
| 496 |
+
def rescale_noise_cfg(self, noise_cfg, noise_pred_text, guidance_rescale=0.0) -> torch.Tensor:
|
| 497 |
+
"""
|
| 498 |
+
Rescale noise prediction according to guidance_rescale.
|
| 499 |
+
|
| 500 |
+
Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed"
|
| 501 |
+
(https://arxiv.org/pdf/2305.08891.pdf), Section 3.4.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
noise_cfg: The noise prediction with guidance.
|
| 505 |
+
noise_pred_text: The text-conditioned noise prediction.
|
| 506 |
+
guidance_rescale: The guidance rescale factor.
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
The rescaled noise prediction.
|
| 510 |
+
"""
|
| 511 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 512 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 513 |
+
# Rescale the results from guidance (fixes overexposure)
|
| 514 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 515 |
+
# Mix with the original results from guidance by factor guidance_rescale
|
| 516 |
+
noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
|
| 517 |
+
return noise_cfg
|
| 518 |
+
|
| 519 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 520 |
+
"""Verify denoising stage inputs."""
|
| 521 |
+
result = VerificationResult()
|
| 522 |
+
result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)])
|
| 523 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 524 |
+
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
|
| 525 |
+
result.add_check("image_embeds", batch.image_embeds, V.is_list)
|
| 526 |
+
result.add_check("image_latent", batch.image_latent, V.none_or_tensor_with_dims(5))
|
| 527 |
+
result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
|
| 528 |
+
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
|
| 529 |
+
result.add_check("eta", batch.eta, V.non_negative_float)
|
| 530 |
+
result.add_check("generator", batch.generator, V.generator_or_list_generators)
|
| 531 |
+
result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
|
| 532 |
+
result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
|
| 533 |
+
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
|
| 534 |
+
return result
|
| 535 |
+
|
| 536 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 537 |
+
"""Verify denoising stage outputs."""
|
| 538 |
+
result = VerificationResult()
|
| 539 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 540 |
+
return result
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class CosmosDenoisingStage(DenoisingStage):
|
| 544 |
+
"""
|
| 545 |
+
Denoising stage for Cosmos models using FlowMatchEulerDiscreteScheduler.
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
def __init__(self, transformer, scheduler, pipeline=None) -> None:
|
| 549 |
+
super().__init__(transformer, scheduler, pipeline)
|
| 550 |
+
|
| 551 |
+
def forward(
|
| 552 |
+
self,
|
| 553 |
+
batch: ForwardBatch,
|
| 554 |
+
fastvideo_args: FastVideoArgs,
|
| 555 |
+
) -> ForwardBatch:
|
| 556 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 557 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 558 |
+
loader = TransformerLoader()
|
| 559 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 560 |
+
if pipeline:
|
| 561 |
+
pipeline.add_module("transformer", self.transformer)
|
| 562 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 563 |
+
|
| 564 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 565 |
+
self.scheduler.step,
|
| 566 |
+
{
|
| 567 |
+
"generator": batch.generator,
|
| 568 |
+
"eta": batch.eta
|
| 569 |
+
},
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
if hasattr(self.transformer, 'module'):
|
| 573 |
+
transformer_dtype = next(self.transformer.module.parameters()).dtype
|
| 574 |
+
else:
|
| 575 |
+
transformer_dtype = next(self.transformer.parameters()).dtype
|
| 576 |
+
target_dtype = transformer_dtype
|
| 577 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 578 |
+
|
| 579 |
+
latents = batch.latents
|
| 580 |
+
num_inference_steps = batch.num_inference_steps
|
| 581 |
+
guidance_scale = batch.guidance_scale
|
| 582 |
+
|
| 583 |
+
sigma_max = 80.0
|
| 584 |
+
sigma_min = 0.002
|
| 585 |
+
sigma_data = 1.0
|
| 586 |
+
final_sigmas_type = "sigma_min"
|
| 587 |
+
|
| 588 |
+
if self.scheduler is not None:
|
| 589 |
+
self.scheduler.register_to_config(
|
| 590 |
+
sigma_max=sigma_max,
|
| 591 |
+
sigma_min=sigma_min,
|
| 592 |
+
sigma_data=sigma_data,
|
| 593 |
+
final_sigmas_type=final_sigmas_type,
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
|
| 597 |
+
timesteps = self.scheduler.timesteps
|
| 598 |
+
|
| 599 |
+
if (hasattr(self.scheduler.config, 'final_sigmas_type')
|
| 600 |
+
and self.scheduler.config.final_sigmas_type == "sigma_min" and len(self.scheduler.sigmas) > 1):
|
| 601 |
+
self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
|
| 602 |
+
|
| 603 |
+
conditioning_latents = getattr(batch, 'conditioning_latents', None)
|
| 604 |
+
unconditioning_latents = conditioning_latents
|
| 605 |
+
|
| 606 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 607 |
+
for i, t in enumerate(timesteps):
|
| 608 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 609 |
+
continue
|
| 610 |
+
|
| 611 |
+
current_sigma = self.scheduler.sigmas[i]
|
| 612 |
+
current_t = current_sigma / (current_sigma + 1)
|
| 613 |
+
c_in = 1 - current_t
|
| 614 |
+
c_skip = 1 - current_t
|
| 615 |
+
c_out = -current_t
|
| 616 |
+
|
| 617 |
+
timestep = current_t.view(1, 1, 1, 1, 1).expand(latents.size(0), -1, latents.size(2), -1,
|
| 618 |
+
-1) # [B, 1, T, 1, 1]
|
| 619 |
+
|
| 620 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 621 |
+
|
| 622 |
+
cond_latent = latents * c_in
|
| 623 |
+
|
| 624 |
+
if hasattr(
|
| 625 |
+
batch,
|
| 626 |
+
'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
|
| 627 |
+
cond_latent = batch.cond_indicator * conditioning_latents + (1 -
|
| 628 |
+
batch.cond_indicator) * cond_latent
|
| 629 |
+
else:
|
| 630 |
+
logger.warning(
|
| 631 |
+
"Step %s: Missing conditioning data - cond_indicator: %s, conditioning_latents: %s", i,
|
| 632 |
+
hasattr(batch, 'cond_indicator'), conditioning_latents is not None)
|
| 633 |
+
|
| 634 |
+
cond_latent = cond_latent.to(target_dtype)
|
| 635 |
+
|
| 636 |
+
cond_timestep = timestep
|
| 637 |
+
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
|
| 638 |
+
sigma_conditioning = 0.0001
|
| 639 |
+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
|
| 640 |
+
cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
|
| 641 |
+
cond_timestep = cond_timestep.to(target_dtype)
|
| 642 |
+
|
| 643 |
+
with set_forward_context(
|
| 644 |
+
current_timestep=i,
|
| 645 |
+
attn_metadata=None,
|
| 646 |
+
forward_batch=batch,
|
| 647 |
+
):
|
| 648 |
+
# Use conditioning masks from CosmosLatentPreparationStage
|
| 649 |
+
condition_mask = batch.cond_mask.to(target_dtype) if hasattr(batch, 'cond_mask') else None
|
| 650 |
+
padding_mask = torch.zeros(1,
|
| 651 |
+
1,
|
| 652 |
+
batch.height,
|
| 653 |
+
batch.width,
|
| 654 |
+
device=cond_latent.device,
|
| 655 |
+
dtype=target_dtype)
|
| 656 |
+
|
| 657 |
+
# Fallback if masks not available
|
| 658 |
+
if condition_mask is None:
|
| 659 |
+
batch_size, num_channels, num_frames, height, width = cond_latent.shape
|
| 660 |
+
condition_mask = torch.zeros(batch_size,
|
| 661 |
+
1,
|
| 662 |
+
num_frames,
|
| 663 |
+
height,
|
| 664 |
+
width,
|
| 665 |
+
device=cond_latent.device,
|
| 666 |
+
dtype=target_dtype)
|
| 667 |
+
|
| 668 |
+
noise_pred = self.transformer(
|
| 669 |
+
hidden_states=cond_latent,
|
| 670 |
+
timestep=cond_timestep.to(target_dtype),
|
| 671 |
+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
|
| 672 |
+
fps=24, # TODO: get fps from batch or config
|
| 673 |
+
condition_mask=condition_mask,
|
| 674 |
+
padding_mask=padding_mask,
|
| 675 |
+
return_dict=False,
|
| 676 |
+
)[0]
|
| 677 |
+
|
| 678 |
+
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
|
| 679 |
+
|
| 680 |
+
if hasattr(
|
| 681 |
+
batch,
|
| 682 |
+
'cond_indicator') and batch.cond_indicator is not None and conditioning_latents is not None:
|
| 683 |
+
cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
|
| 684 |
+
|
| 685 |
+
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
|
| 686 |
+
uncond_latent = latents * c_in
|
| 687 |
+
|
| 688 |
+
if hasattr(batch, 'uncond_indicator'
|
| 689 |
+
) and batch.uncond_indicator is not None and unconditioning_latents is not None:
|
| 690 |
+
uncond_latent = batch.uncond_indicator * unconditioning_latents + (
|
| 691 |
+
1 - batch.uncond_indicator) * uncond_latent
|
| 692 |
+
|
| 693 |
+
with set_forward_context(
|
| 694 |
+
current_timestep=i,
|
| 695 |
+
attn_metadata=None,
|
| 696 |
+
forward_batch=batch,
|
| 697 |
+
):
|
| 698 |
+
uncond_condition_mask = batch.uncond_mask.to(target_dtype) if hasattr(
|
| 699 |
+
batch, 'uncond_mask') and batch.uncond_mask is not None else condition_mask
|
| 700 |
+
|
| 701 |
+
uncond_timestep = timestep
|
| 702 |
+
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
|
| 703 |
+
sigma_conditioning = 0.0001
|
| 704 |
+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
|
| 705 |
+
uncond_timestep = batch.uncond_indicator * t_conditioning + (
|
| 706 |
+
1 - batch.uncond_indicator) * timestep
|
| 707 |
+
uncond_timestep = uncond_timestep.to(target_dtype)
|
| 708 |
+
|
| 709 |
+
noise_pred_uncond = self.transformer(
|
| 710 |
+
hidden_states=uncond_latent.to(target_dtype),
|
| 711 |
+
timestep=uncond_timestep.to(target_dtype),
|
| 712 |
+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
|
| 713 |
+
fps=24, # TODO: get fps from batch or config
|
| 714 |
+
condition_mask=uncond_condition_mask,
|
| 715 |
+
padding_mask=padding_mask,
|
| 716 |
+
return_dict=False,
|
| 717 |
+
)[0]
|
| 718 |
+
|
| 719 |
+
uncond_pred = (c_skip * latents + c_out * noise_pred_uncond.float()).to(target_dtype)
|
| 720 |
+
|
| 721 |
+
if hasattr(batch, 'uncond_indicator'
|
| 722 |
+
) and batch.uncond_indicator is not None and unconditioning_latents is not None:
|
| 723 |
+
uncond_pred = batch.uncond_indicator * unconditioning_latents + (
|
| 724 |
+
1 - batch.uncond_indicator) * uncond_pred
|
| 725 |
+
|
| 726 |
+
guidance_diff = cond_pred - uncond_pred
|
| 727 |
+
final_pred = cond_pred + guidance_scale * guidance_diff
|
| 728 |
+
else:
|
| 729 |
+
final_pred = cond_pred
|
| 730 |
+
|
| 731 |
+
# Convert to noise for scheduler step
|
| 732 |
+
if current_sigma > 1e-8:
|
| 733 |
+
noise_for_scheduler = (latents - final_pred) / current_sigma
|
| 734 |
+
else:
|
| 735 |
+
logger.warning("Step %s: current_sigma too small (%s), using final_pred directly", i, current_sigma)
|
| 736 |
+
noise_for_scheduler = final_pred
|
| 737 |
+
|
| 738 |
+
if torch.isnan(noise_for_scheduler).sum() > 0:
|
| 739 |
+
logger.error("Step %s: NaN detected in noise_for_scheduler, sum: %s", i,
|
| 740 |
+
noise_for_scheduler.float().sum().item())
|
| 741 |
+
logger.error("Step %s: latents sum: %s, final_pred sum: %s, current_sigma: %s", i,
|
| 742 |
+
latents.float().sum().item(),
|
| 743 |
+
final_pred.float().sum().item(), current_sigma)
|
| 744 |
+
|
| 745 |
+
latents = self.scheduler.step(noise_for_scheduler, t, latents, **extra_step_kwargs,
|
| 746 |
+
return_dict=False)[0]
|
| 747 |
+
|
| 748 |
+
progress_bar.update()
|
| 749 |
+
|
| 750 |
+
batch.latents = latents
|
| 751 |
+
|
| 752 |
+
return batch
|
| 753 |
+
|
| 754 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 755 |
+
"""Verify Cosmos denoising stage inputs."""
|
| 756 |
+
result = VerificationResult()
|
| 757 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 758 |
+
result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty)
|
| 759 |
+
result.add_check("num_inference_steps", batch.num_inference_steps, V.positive_int)
|
| 760 |
+
result.add_check("guidance_scale", batch.guidance_scale, V.positive_float)
|
| 761 |
+
result.add_check("do_classifier_free_guidance", batch.do_classifier_free_guidance, V.bool_value)
|
| 762 |
+
result.add_check("negative_prompt_embeds", batch.negative_prompt_embeds,
|
| 763 |
+
lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x))
|
| 764 |
+
return result
|
| 765 |
+
|
| 766 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 767 |
+
"""Verify Cosmos denoising stage outputs."""
|
| 768 |
+
result = VerificationResult()
|
| 769 |
+
result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)])
|
| 770 |
+
return result
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class Cosmos25DenoisingStage(CosmosDenoisingStage):
|
| 774 |
+
"""Denoising stage for Cosmos 2.5 DiT (expects 1D/2D timestep, not 5D)."""
|
| 775 |
+
|
| 776 |
+
def forward(
|
| 777 |
+
self,
|
| 778 |
+
batch: ForwardBatch,
|
| 779 |
+
fastvideo_args: FastVideoArgs,
|
| 780 |
+
) -> ForwardBatch:
|
| 781 |
+
pipeline = self.pipeline() if self.pipeline else None
|
| 782 |
+
if not fastvideo_args.model_loaded["transformer"]:
|
| 783 |
+
loader = TransformerLoader()
|
| 784 |
+
self.transformer = loader.load(fastvideo_args.model_paths["transformer"], fastvideo_args)
|
| 785 |
+
if pipeline:
|
| 786 |
+
pipeline.add_module("transformer", self.transformer)
|
| 787 |
+
fastvideo_args.model_loaded["transformer"] = True
|
| 788 |
+
|
| 789 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 790 |
+
self.scheduler.step,
|
| 791 |
+
{
|
| 792 |
+
"generator": batch.generator,
|
| 793 |
+
"eta": batch.eta
|
| 794 |
+
},
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if hasattr(self.transformer, 'module'):
|
| 798 |
+
transformer_dtype = next(self.transformer.module.parameters()).dtype
|
| 799 |
+
else:
|
| 800 |
+
transformer_dtype = next(self.transformer.parameters()).dtype
|
| 801 |
+
target_dtype = transformer_dtype
|
| 802 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 803 |
+
|
| 804 |
+
latents = batch.latents
|
| 805 |
+
if latents is None:
|
| 806 |
+
raise ValueError("latents must be provided for Cosmos25DenoisingStage")
|
| 807 |
+
guidance_scale = batch.guidance_scale
|
| 808 |
+
|
| 809 |
+
if batch.timesteps is None:
|
| 810 |
+
self.scheduler.set_timesteps(batch.num_inference_steps, device=latents.device)
|
| 811 |
+
timesteps = self.scheduler.timesteps
|
| 812 |
+
else:
|
| 813 |
+
timesteps = batch.timesteps.to(latents.device)
|
| 814 |
+
|
| 815 |
+
cfg = fastvideo_args.pipeline_config
|
| 816 |
+
|
| 817 |
+
if batch.fps is None:
|
| 818 |
+
gen = batch.generator
|
| 819 |
+
if isinstance(gen, list) and len(gen) > 0:
|
| 820 |
+
gen = gen[0]
|
| 821 |
+
fps_tensor = torch.randint(
|
| 822 |
+
16,
|
| 823 |
+
32,
|
| 824 |
+
(1, ),
|
| 825 |
+
generator=gen if isinstance(gen, torch.Generator) else None,
|
| 826 |
+
device=latents.device,
|
| 827 |
+
).float().to(dtype=target_dtype)
|
| 828 |
+
else:
|
| 829 |
+
fps_val = batch.fps
|
| 830 |
+
fps_tensor = torch.tensor(
|
| 831 |
+
[fps_val],
|
| 832 |
+
device=latents.device,
|
| 833 |
+
dtype=target_dtype,
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
latents_4d = latents[0]
|
| 837 |
+
|
| 838 |
+
# Masks are optional for T2W.
|
| 839 |
+
cond_mask = getattr(batch, "cond_mask", None)
|
| 840 |
+
condition_mask = cond_mask.to(target_dtype) if isinstance(cond_mask, torch.Tensor) else None
|
| 841 |
+
pad_mask = getattr(batch, "padding_mask", None)
|
| 842 |
+
padding_mask = pad_mask.to(target_dtype) if isinstance(pad_mask, torch.Tensor) else None
|
| 843 |
+
|
| 844 |
+
# Conditioning fields are attached by latent preparation stage.
|
| 845 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 846 |
+
cond_indicator = getattr(batch, "cond_indicator", None)
|
| 847 |
+
# Infer whether this is a conditioned run (V2W/I2W) purely from the presence
|
| 848 |
+
# of conditioning latents. Avoid carrying explicit mode flags on the batch.
|
| 849 |
+
is_conditioned = (conditioning_latents is not None)
|
| 850 |
+
|
| 851 |
+
init_noise_4d = latents_4d.clone()
|
| 852 |
+
if condition_mask is None:
|
| 853 |
+
_, t, h, w = latents_4d.shape
|
| 854 |
+
condition_mask = torch.zeros(1, 1, t, h, w, device=latents.device, dtype=target_dtype)
|
| 855 |
+
if padding_mask is None:
|
| 856 |
+
_, _, h, w = latents_4d.shape
|
| 857 |
+
padding_default = 0.0 if is_conditioned else 1.0
|
| 858 |
+
padding_mask = torch.full(
|
| 859 |
+
(1, 1, h, w),
|
| 860 |
+
float(padding_default),
|
| 861 |
+
device=latents.device,
|
| 862 |
+
dtype=target_dtype,
|
| 863 |
+
)
|
| 864 |
+
|
| 865 |
+
timestep_scale = 0.001
|
| 866 |
+
|
| 867 |
+
state_dtype = torch.float32
|
| 868 |
+
|
| 869 |
+
conditional_frame_timestep = 0.1
|
| 870 |
+
latents_4d = latents_4d.to(state_dtype)
|
| 871 |
+
init_noise_4d = init_noise_4d.to(state_dtype)
|
| 872 |
+
|
| 873 |
+
clamp_every_step = bool(getattr(cfg, "cosmos25_clamp_every_step", True)) if is_conditioned else False
|
| 874 |
+
|
| 875 |
+
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
| 876 |
+
for i, t in enumerate(timesteps):
|
| 877 |
+
t_val = float(t)
|
| 878 |
+
if is_conditioned:
|
| 879 |
+
t_frames = int(latents_4d.shape[1])
|
| 880 |
+
timestep = torch.full(
|
| 881 |
+
(1, t_frames),
|
| 882 |
+
float(t_val * timestep_scale),
|
| 883 |
+
device=latents.device,
|
| 884 |
+
dtype=torch.float32,
|
| 885 |
+
)
|
| 886 |
+
if cond_indicator is not None and t_frames > 0:
|
| 887 |
+
cond_t = cond_indicator[0, 0, :t_frames, 0, 0]
|
| 888 |
+
cond_mask_t = (cond_t > 0.5)
|
| 889 |
+
if bool(cond_mask_t.any().item()):
|
| 890 |
+
timestep[0, cond_mask_t] = float(conditional_frame_timestep)
|
| 891 |
+
else:
|
| 892 |
+
timestep_val = t_val * timestep_scale
|
| 893 |
+
timestep = torch.tensor(
|
| 894 |
+
[[float(timestep_val)]],
|
| 895 |
+
device=latents.device,
|
| 896 |
+
dtype=target_dtype,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Conditioned runs: replace x_t with GT x0 on the conditioned frames.
|
| 900 |
+
if (is_conditioned and cond_indicator is not None and conditioning_latents is not None
|
| 901 |
+
and (clamp_every_step or i == 0)):
|
| 902 |
+
cond_ind_4d = cond_indicator[0].to(state_dtype)
|
| 903 |
+
gt_x0 = conditioning_latents[0].to(state_dtype)
|
| 904 |
+
latents_4d = gt_x0 * cond_ind_4d + latents_4d * (1 - cond_ind_4d)
|
| 905 |
+
|
| 906 |
+
model_hidden_states = latents_4d.unsqueeze(0)
|
| 907 |
+
|
| 908 |
+
with (
|
| 909 |
+
set_forward_context(current_timestep=int(t_val), attn_metadata=None, forward_batch=batch),
|
| 910 |
+
torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled),
|
| 911 |
+
):
|
| 912 |
+
cond_v = self.transformer(
|
| 913 |
+
hidden_states=model_hidden_states.to(target_dtype),
|
| 914 |
+
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
|
| 915 |
+
timestep=timestep,
|
| 916 |
+
fps=fps_tensor,
|
| 917 |
+
condition_mask=condition_mask,
|
| 918 |
+
padding_mask=padding_mask,
|
| 919 |
+
return_dict=False,
|
| 920 |
+
)[0]
|
| 921 |
+
|
| 922 |
+
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds:
|
| 923 |
+
uncond_v = self.transformer(
|
| 924 |
+
hidden_states=model_hidden_states.to(target_dtype),
|
| 925 |
+
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
|
| 926 |
+
timestep=timestep,
|
| 927 |
+
fps=fps_tensor,
|
| 928 |
+
condition_mask=condition_mask,
|
| 929 |
+
padding_mask=padding_mask,
|
| 930 |
+
return_dict=False,
|
| 931 |
+
)[0]
|
| 932 |
+
if is_conditioned:
|
| 933 |
+
v = cond_v + guidance_scale * (cond_v - uncond_v)
|
| 934 |
+
else:
|
| 935 |
+
v = uncond_v + guidance_scale * (cond_v - uncond_v)
|
| 936 |
+
else:
|
| 937 |
+
v = cond_v
|
| 938 |
+
|
| 939 |
+
# Conditioned runs: replace velocity on conditioned frames with GT velocity.
|
| 940 |
+
if (is_conditioned and cond_indicator is not None and conditioning_latents is not None):
|
| 941 |
+
cond_ind_4d = cond_indicator[0].to(state_dtype)
|
| 942 |
+
gt_x0 = conditioning_latents[0].to(state_dtype)
|
| 943 |
+
gt_v = init_noise_4d.to(state_dtype) - gt_x0
|
| 944 |
+
v = cond_ind_4d * gt_v + (1 - cond_ind_4d) * v.to(state_dtype)
|
| 945 |
+
|
| 946 |
+
prev = self.scheduler.step(v.unsqueeze(0),
|
| 947 |
+
t,
|
| 948 |
+
latents_4d.unsqueeze(0),
|
| 949 |
+
**extra_step_kwargs,
|
| 950 |
+
return_dict=False)[0]
|
| 951 |
+
latents_4d = prev.squeeze(0)
|
| 952 |
+
|
| 953 |
+
progress_bar.update()
|
| 954 |
+
|
| 955 |
+
batch.latents = latents_4d.to(target_dtype).unsqueeze(0)
|
| 956 |
+
return batch
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
class Cosmos25T2WDenoisingStage(Cosmos25DenoisingStage):
|
| 960 |
+
"""Cosmos 2.5 Text2World denoising stage."""
|
| 961 |
+
|
| 962 |
+
_CONDITIONING_FIELDS = (
|
| 963 |
+
"conditioning_latents",
|
| 964 |
+
"cond_indicator",
|
| 965 |
+
"uncond_indicator",
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
def forward(
|
| 969 |
+
self,
|
| 970 |
+
batch: ForwardBatch,
|
| 971 |
+
fastvideo_args: FastVideoArgs,
|
| 972 |
+
) -> ForwardBatch:
|
| 973 |
+
for name in self._CONDITIONING_FIELDS:
|
| 974 |
+
if hasattr(batch, name):
|
| 975 |
+
setattr(batch, name, None)
|
| 976 |
+
return super().forward(batch, fastvideo_args)
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
class Cosmos25V2WDenoisingStage(Cosmos25DenoisingStage):
|
| 980 |
+
"""Cosmos 2.5 Video2World denoising stage."""
|
| 981 |
+
|
| 982 |
+
def forward(
|
| 983 |
+
self,
|
| 984 |
+
batch: ForwardBatch,
|
| 985 |
+
fastvideo_args: FastVideoArgs,
|
| 986 |
+
) -> ForwardBatch:
|
| 987 |
+
return super().forward(batch, fastvideo_args)
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
class Cosmos25AutoDenoisingStage(PipelineStage):
|
| 991 |
+
"""Route Cosmos 2.5 denoising to T2W vs V2W/I2W."""
|
| 992 |
+
|
| 993 |
+
def __init__(self, transformer, scheduler) -> None:
|
| 994 |
+
super().__init__()
|
| 995 |
+
self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, scheduler=scheduler)
|
| 996 |
+
self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, scheduler=scheduler)
|
| 997 |
+
|
| 998 |
+
def pipeline(self):
|
| 999 |
+
return self._v2w.pipeline() if self._v2w.pipeline else None
|
| 1000 |
+
|
| 1001 |
+
def forward(
|
| 1002 |
+
self,
|
| 1003 |
+
batch: ForwardBatch,
|
| 1004 |
+
fastvideo_args: FastVideoArgs,
|
| 1005 |
+
) -> ForwardBatch:
|
| 1006 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1007 |
+
if conditioning_latents is not None:
|
| 1008 |
+
return self._v2w.forward(batch, fastvideo_args)
|
| 1009 |
+
return self._t2w.forward(batch, fastvideo_args)
|
| 1010 |
+
|
| 1011 |
+
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 1012 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1013 |
+
if conditioning_latents is not None:
|
| 1014 |
+
return self._v2w.verify_input(batch, fastvideo_args)
|
| 1015 |
+
return self._t2w.verify_input(batch, fastvideo_args)
|
| 1016 |
+
|
| 1017 |
+
def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:
|
| 1018 |
+
conditioning_latents = getattr(batch, "conditioning_latents", None)
|
| 1019 |
+
if conditioning_latents is not None:
|
| 1020 |
+
return self._v2w.verify_output(batch, fastvideo_args)
|
| 1021 |
+
return self._t2w.verify_output(batch, fastvideo_args)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class DmdDenoisingStage(DenoisingStage):
|
| 1025 |
+
"""
|
| 1026 |
+
Denoising stage for DMD.
|
| 1027 |
+
"""
|
| 1028 |
+
|
| 1029 |
+
def __init__(self, transformer, scheduler) -> None:
|
| 1030 |
+
super().__init__(transformer, scheduler)
|
| 1031 |
+
self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0)
|
| 1032 |
+
|
| 1033 |
+
def forward(
|
| 1034 |
+
self,
|
| 1035 |
+
batch: ForwardBatch,
|
| 1036 |
+
fastvideo_args: FastVideoArgs,
|
| 1037 |
+
) -> ForwardBatch:
|
| 1038 |
+
"""
|
| 1039 |
+
Run the denoising loop.
|
| 1040 |
+
|
| 1041 |
+
Args:
|
| 1042 |
+
batch: The current batch information.
|
| 1043 |
+
fastvideo_args: The inference arguments.
|
| 1044 |
+
|
| 1045 |
+
Returns:
|
| 1046 |
+
The batch with denoised latents.
|
| 1047 |
+
"""
|
| 1048 |
+
# Setup precision and autocast settings
|
| 1049 |
+
# TODO(will): make the precision configurable for inference
|
| 1050 |
+
# target_dtype = PRECISION_TO_TYPE[fastvideo_args.precision]
|
| 1051 |
+
target_dtype = torch.bfloat16
|
| 1052 |
+
autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast
|
| 1053 |
+
|
| 1054 |
+
# Get timesteps and calculate warmup steps
|
| 1055 |
+
timesteps = batch.timesteps
|
| 1056 |
+
|
| 1057 |
+
# TODO(will): remove this once we add input/output validation for stages
|
| 1058 |
+
if timesteps is None:
|
| 1059 |
+
raise ValueError("Timesteps must be provided")
|
| 1060 |
+
num_inference_steps = batch.num_inference_steps
|
| 1061 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1062 |
+
|
| 1063 |
+
# Prepare image latents and embeddings for I2V generation
|
| 1064 |
+
image_embeds = batch.image_embeds
|
| 1065 |
+
if len(image_embeds) > 0:
|
| 1066 |
+
assert torch.isnan(image_embeds[0]).sum() == 0
|
| 1067 |
+
image_embeds = [image_embed.to(target_dtype) for image_embed in image_embeds]
|
| 1068 |
+
|
| 1069 |
+
image_kwargs = self.prepare_extra_func_kwargs(
|
| 1070 |
+
self.transformer.forward,
|
| 1071 |
+
{
|
| 1072 |
+
"encoder_hidden_states_image": image_embeds,
|
| 1073 |
+
"mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24)
|
| 1074 |
+
},
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
pos_cond_kwargs = self.prepare_extra_func_kwargs(
|
| 1078 |
+
self.transformer.forward,
|
| 1079 |
+
{
|
| 1080 |
+
"encoder_hidden_states_2": batch.clip_embedding_pos,
|
| 1081 |
+
"encoder_attention_mask": batch.prompt_attention_mask,
|
| 1082 |
+
},
|
| 1083 |
+
)
|
| 1084 |
+
|
| 1085 |
+
# Get latents and embeddings
|
| 1086 |
+
assert batch.latents is not None, "latents must be provided"
|
| 1087 |
+
latents = batch.latents
|
| 1088 |
+
|
| 1089 |
+
video_raw_latent_shape = latents.shape
|
| 1090 |
+
prompt_embeds = batch.prompt_embeds
|
| 1091 |
+
assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan"
|
| 1092 |
+
timesteps = torch.tensor(fastvideo_args.pipeline_config.dmd_denoising_steps,
|
| 1093 |
+
dtype=torch.long,
|
| 1094 |
+
device=get_local_torch_device())
|
| 1095 |
+
|
| 1096 |
+
# Run denoising loop
|
| 1097 |
+
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
| 1098 |
+
for i, t in enumerate(timesteps):
|
| 1099 |
+
# Skip if interrupted
|
| 1100 |
+
if hasattr(self, 'interrupt') and self.interrupt:
|
| 1101 |
+
continue
|
| 1102 |
+
# Expand latents for I2V
|
| 1103 |
+
noise_latents = latents.clone()
|
| 1104 |
+
latent_model_input = latents.to(target_dtype)
|
| 1105 |
+
|
| 1106 |
+
if batch.image_latent is not None:
|
| 1107 |
+
latent_model_input = torch.cat(
|
| 1108 |
+
[latent_model_input, batch.image_latent.permute(0, 2, 1, 3, 4)], dim=2).to(target_dtype)
|
| 1109 |
+
assert not torch.isnan(latent_model_input).any(), "latent_model_input contains nan"
|
| 1110 |
+
|
| 1111 |
+
# Prepare inputs for transformer
|
| 1112 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 1113 |
+
guidance_expand = (torch.tensor(
|
| 1114 |
+
[fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
|
| 1115 |
+
dtype=torch.float32,
|
| 1116 |
+
device=get_local_torch_device(),
|
| 1117 |
+
).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
|
| 1118 |
+
|
| 1119 |
+
# Predict noise residual
|
| 1120 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
| 1121 |
+
if (vsa_available and self.attn_backend == VideoSparseAttentionBackend) or \
|
| 1122 |
+
(sparse_fp4_available and self.attn_backend in sparse_fp4_backends):
|
| 1123 |
+
self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls()
|
| 1124 |
+
|
| 1125 |
+
if self.attn_metadata_builder_cls is not None:
|
| 1126 |
+
self.attn_metadata_builder = self.attn_metadata_builder_cls()
|
| 1127 |
+
# TODO(will): clean this up
|
| 1128 |
+
attn_metadata = self.attn_metadata_builder.build( # type: ignore
|
| 1129 |
+
current_timestep=i, # type: ignore
|
| 1130 |
+
raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore
|
| 1131 |
+
patch_size=fastvideo_args.pipeline_config. # type: ignore
|
| 1132 |
+
dit_config.patch_size, # type: ignore
|
| 1133 |
+
VSA_sparsity=fastvideo_args.VSA_sparsity, # type: ignore
|
| 1134 |
+
device=get_local_torch_device(), # type: ignore
|
| 1135 |
+
) # type: ignore
|
| 1136 |
+
assert attn_metadata is not None, "attn_metadata cannot be None"
|
| 1137 |
+
else:
|
| 1138 |
+
attn_metadata = None
|
| 1139 |
+
else:
|
| 1140 |
+
attn_metadata = None
|
| 1141 |
+
|
| 1142 |
+
batch.is_cfg_negative = False
|
| 1143 |
+
with set_forward_context(
|
| 1144 |
+
current_timestep=i,
|
| 1145 |
+
attn_metadata=attn_metadata,
|
| 1146 |
+
forward_batch=batch,
|
| 1147 |
+
# fastvideo_args=fastvideo_args
|
| 1148 |
+
):
|
| 1149 |
+
# Run transformer
|
| 1150 |
+
pred_noise = self.transformer(
|
| 1151 |
+
latent_model_input.permute(0, 2, 1, 3, 4),
|
| 1152 |
+
prompt_embeds,
|
| 1153 |
+
t_expand,
|
| 1154 |
+
guidance=guidance_expand,
|
| 1155 |
+
**image_kwargs,
|
| 1156 |
+
**pos_cond_kwargs,
|
| 1157 |
+
).permute(0, 2, 1, 3, 4)
|
| 1158 |
+
|
| 1159 |
+
pred_video = pred_noise_to_pred_video(pred_noise=pred_noise.flatten(0, 1),
|
| 1160 |
+
noise_input_latent=noise_latents.flatten(0, 1),
|
| 1161 |
+
timestep=t_expand,
|
| 1162 |
+
scheduler=self.scheduler).unflatten(0, pred_noise.shape[:2])
|
| 1163 |
+
|
| 1164 |
+
if i < len(timesteps) - 1:
|
| 1165 |
+
next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device)
|
| 1166 |
+
noise_generator = batch.generator[0] if isinstance(batch.generator, list) else batch.generator
|
| 1167 |
+
noise = torch.randn(video_raw_latent_shape, dtype=pred_video.dtype,
|
| 1168 |
+
generator=noise_generator).to(self.device)
|
| 1169 |
+
latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1),
|
| 1170 |
+
next_timestep).unflatten(0, pred_video.shape[:2])
|
| 1171 |
+
else:
|
| 1172 |
+
latents = pred_video
|
| 1173 |
+
|
| 1174 |
+
# Update progress bar
|
| 1175 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
|
| 1176 |
+
(i + 1) % self.scheduler.order == 0 and progress_bar is not None):
|
| 1177 |
+
progress_bar.update()
|
| 1178 |
+
|
| 1179 |
+
# Gather results if using sequence parallelism
|
| 1180 |
+
latents = latents.permute(0, 2, 1, 3, 4)
|
| 1181 |
+
# Update batch with final latents
|
| 1182 |
+
batch.latents = latents
|
| 1183 |
+
|
| 1184 |
+
return batch
|
backend_snapshot/fastvideo/platforms/cuda.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py
|
| 3 |
+
"""Code inside this file can safely assume cuda platform, e.g. importing
|
| 4 |
+
pynvml. However, it should not initialize cuda context.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
from functools import lru_cache, wraps
|
| 10 |
+
from typing import TypeVar
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing_extensions import ParamSpec
|
| 14 |
+
|
| 15 |
+
import fastvideo.envs as envs
|
| 16 |
+
from fastvideo.logger import init_logger
|
| 17 |
+
from fastvideo.platforms.interface import (AttentionBackendEnum, DeviceCapability, Platform, PlatformEnum)
|
| 18 |
+
from fastvideo.utils import import_pynvml
|
| 19 |
+
|
| 20 |
+
logger = init_logger(__name__)
|
| 21 |
+
|
| 22 |
+
_P = ParamSpec("_P")
|
| 23 |
+
_R = TypeVar("_R")
|
| 24 |
+
|
| 25 |
+
pynvml = import_pynvml() # type: ignore[no-untyped-call]
|
| 26 |
+
|
| 27 |
+
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
| 28 |
+
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
| 29 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def device_id_to_physical_device_id(device_id: int) -> int:
|
| 33 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
| 34 |
+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
| 35 |
+
if device_ids == [""]:
|
| 36 |
+
msg = ("CUDA_VISIBLE_DEVICES is set to empty string, which means"
|
| 37 |
+
" GPU support is disabled. If you are using ray, please unset"
|
| 38 |
+
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
|
| 39 |
+
" worker/actor. "
|
| 40 |
+
"Check https://github.com/vllm-project/vllm/issues/8402 for"
|
| 41 |
+
" more information.")
|
| 42 |
+
raise RuntimeError(msg)
|
| 43 |
+
physical_device_id = device_ids[device_id]
|
| 44 |
+
return int(physical_device_id)
|
| 45 |
+
else:
|
| 46 |
+
return device_id
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
| 50 |
+
|
| 51 |
+
@wraps(fn)
|
| 52 |
+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
| 53 |
+
pynvml.nvmlInit()
|
| 54 |
+
try:
|
| 55 |
+
return fn(*args, **kwargs)
|
| 56 |
+
finally:
|
| 57 |
+
pynvml.nvmlShutdown()
|
| 58 |
+
|
| 59 |
+
return wrapper
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CudaPlatformBase(Platform):
|
| 63 |
+
_enum = PlatformEnum.CUDA
|
| 64 |
+
device_name: str = "cuda"
|
| 65 |
+
device_type: str = "cuda"
|
| 66 |
+
dispatch_key: str = "CUDA"
|
| 67 |
+
ray_device_key: str = "GPU"
|
| 68 |
+
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
|
| 84 |
+
if enforce_eager:
|
| 85 |
+
logger.warning("To see benefits of async output processing, enable CUDA "
|
| 86 |
+
"graph. Since, enforce-eager is enabled, async output "
|
| 87 |
+
"processor cannot be used")
|
| 88 |
+
return False
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def is_full_nvlink(cls, device_ids: list[int]) -> bool:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def log_warnings(cls) -> None:
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
@classmethod
|
| 100 |
+
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
| 101 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 102 |
+
return float(torch.cuda.max_memory_allocated(device))
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def get_torch_device(cls) -> object:
|
| 106 |
+
"""
|
| 107 |
+
Return torch.cuda
|
| 108 |
+
"""
|
| 109 |
+
return torch.cuda
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
|
| 113 |
+
dtype: torch.dtype) -> str:
|
| 114 |
+
# TODO(will): maybe come up with a more general interface for local attention
|
| 115 |
+
# if distributed is False, we always try to use Flash attn
|
| 116 |
+
|
| 117 |
+
logger.info("Trying FASTVIDEO_ATTENTION_BACKEND=%s", envs.FASTVIDEO_ATTENTION_BACKEND)
|
| 118 |
+
logger.info("Selected backend: %s", selected_backend)
|
| 119 |
+
if selected_backend == AttentionBackendEnum.SAGE_ATTN:
|
| 120 |
+
try:
|
| 121 |
+
from sageattention import sageattn # noqa: F401
|
| 122 |
+
|
| 123 |
+
from fastvideo.attention.backends.sage_attn import ( # noqa: F401
|
| 124 |
+
SageAttentionBackend)
|
| 125 |
+
logger.info("Using Sage Attention backend.")
|
| 126 |
+
|
| 127 |
+
return "fastvideo.attention.backends.sage_attn.SageAttentionBackend"
|
| 128 |
+
except ImportError as e:
|
| 129 |
+
logger.info(e)
|
| 130 |
+
logger.info("Sage Attention backend is not installed. Fall back to Flash Attention.")
|
| 131 |
+
elif selected_backend == AttentionBackendEnum.SAGE_ATTN_THREE:
|
| 132 |
+
try:
|
| 133 |
+
from sageattn3 import sageattn3_blackwell # noqa: F401
|
| 134 |
+
|
| 135 |
+
from fastvideo.attention.backends.sage_attn3 import ( # noqa: F401
|
| 136 |
+
SageAttention3Backend)
|
| 137 |
+
logger.info("Using Sage Attention 3 backend.")
|
| 138 |
+
|
| 139 |
+
return "fastvideo.attention.backends.sage_attn3.SageAttention3Backend"
|
| 140 |
+
except ImportError as e:
|
| 141 |
+
logger.info(e)
|
| 142 |
+
logger.info("Sage Attention 3 backend is not installed. Fall back to Flash Attention.")
|
| 143 |
+
elif selected_backend == AttentionBackendEnum.ATTN_QAT_INFER:
|
| 144 |
+
try:
|
| 145 |
+
from fastvideo.attention.backends.attn_qat_infer import ( # noqa: F401
|
| 146 |
+
AttnQatInferBackend, is_attn_qat_infer_available,
|
| 147 |
+
)
|
| 148 |
+
if not is_attn_qat_infer_available():
|
| 149 |
+
raise ImportError("attn_qat_infer could not be imported.")
|
| 150 |
+
logger.info("Using attn_qat_infer backend.")
|
| 151 |
+
|
| 152 |
+
return "fastvideo.attention.backends.attn_qat_infer.AttnQatInferBackend"
|
| 153 |
+
except ImportError as e:
|
| 154 |
+
logger.info(e)
|
| 155 |
+
logger.info("attn_qat_infer backend is not installed. Fall back to Flash Attention.")
|
| 156 |
+
elif selected_backend == AttentionBackendEnum.ATTN_QAT_TRAIN:
|
| 157 |
+
try:
|
| 158 |
+
from fastvideo_kernel.triton_kernels.attn_qat_train import attention # noqa: F401
|
| 159 |
+
|
| 160 |
+
from fastvideo.attention.backends.attn_qat_train import ( # noqa: F401
|
| 161 |
+
AttnQatTrainBackend)
|
| 162 |
+
logger.info("Using attn_qat_train backend.")
|
| 163 |
+
|
| 164 |
+
return "fastvideo.attention.backends.attn_qat_train.AttnQatTrainBackend"
|
| 165 |
+
except ImportError as e:
|
| 166 |
+
logger.info(e)
|
| 167 |
+
logger.info("attn_qat_train backend is not installed. Fall back to Flash Attention.")
|
| 168 |
+
elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN:
|
| 169 |
+
try:
|
| 170 |
+
from fastvideo_kernel import video_sparse_attn # noqa: F401
|
| 171 |
+
|
| 172 |
+
from fastvideo.attention.backends.video_sparse_attn import ( # noqa: F401
|
| 173 |
+
VideoSparseAttentionBackend)
|
| 174 |
+
logger.info("Using Video Sparse Attention backend.")
|
| 175 |
+
|
| 176 |
+
return "fastvideo.attention.backends.video_sparse_attn.VideoSparseAttentionBackend"
|
| 177 |
+
except ImportError as e:
|
| 178 |
+
logger.error("Failed to import Video Sparse Attention backend: %s", str(e))
|
| 179 |
+
raise ImportError("The Video Sparse Attention backend is not installed. "
|
| 180 |
+
"To install it, please follow the instructions at: "
|
| 181 |
+
"https://hao-ai-lab.github.io/FastVideo/video_sparse_attention/installation ") from e
|
| 182 |
+
elif selected_backend == AttentionBackendEnum.SPARSE_FP4_ATTN:
|
| 183 |
+
try:
|
| 184 |
+
from fastvideo.attention.backends.sparse_fp4_attn import ( # noqa: F401
|
| 185 |
+
SparseFP4AttentionBackend)
|
| 186 |
+
logger.info("Using Sparse FP4 Attention backend (FP4 quant + VSA).")
|
| 187 |
+
return "fastvideo.attention.backends.sparse_fp4_attn.SparseFP4AttentionBackend"
|
| 188 |
+
except ImportError as e:
|
| 189 |
+
logger.error("Failed to import Sparse FP4 Attention backend: %s", str(e))
|
| 190 |
+
raise ImportError("Sparse FP4 Attention backend is not available.") from e
|
| 191 |
+
elif selected_backend == AttentionBackendEnum.SPARSE_FP4_OURS_P_ATTN:
|
| 192 |
+
try:
|
| 193 |
+
from fastvideo.attention.backends.sparse_fp4_ours_p_attn import ( # noqa: F401
|
| 194 |
+
SparseFP4OursPAttentionBackend)
|
| 195 |
+
logger.info(
|
| 196 |
+
"Using Sparse FP4 Ours-P Attention backend (group-local P quant + VSA)."
|
| 197 |
+
)
|
| 198 |
+
return "fastvideo.attention.backends.sparse_fp4_ours_p_attn.SparseFP4OursPAttentionBackend"
|
| 199 |
+
except ImportError as e:
|
| 200 |
+
logger.error("Failed to import Sparse FP4 Ours-P Attention backend: %s", str(e))
|
| 201 |
+
raise ImportError("Sparse FP4 Ours-P Attention backend is not available.") from e
|
| 202 |
+
elif selected_backend == AttentionBackendEnum.BSA_ATTN:
|
| 203 |
+
try:
|
| 204 |
+
from fastvideo.attention.backends.bsa_attn import ( # noqa: F401
|
| 205 |
+
BSAAttentionBackend)
|
| 206 |
+
logger.info("Using BSA Attention backend.")
|
| 207 |
+
|
| 208 |
+
return "fastvideo.attention.backends.bsa_attn.BSAAttentionBackend"
|
| 209 |
+
except ImportError as e:
|
| 210 |
+
logger.error("Failed to import BSA Attention backend: %s", str(e))
|
| 211 |
+
raise ImportError("The BSA Attention backend failed to import.") from e
|
| 212 |
+
elif selected_backend == AttentionBackendEnum.VMOBA_ATTN:
|
| 213 |
+
try:
|
| 214 |
+
from fastvideo_kernel import moba_attn_varlen # noqa: F401
|
| 215 |
+
from fastvideo.attention.backends.vmoba import ( # noqa: F401
|
| 216 |
+
VMOBAAttentionBackend)
|
| 217 |
+
logger.info("Using Video MOBA Attention backend.")
|
| 218 |
+
|
| 219 |
+
return "fastvideo.attention.backends.vmoba.VMOBAAttentionBackend"
|
| 220 |
+
except ImportError as e:
|
| 221 |
+
logger.error("Failed to import Video MoBA Attention backend: %s", str(e))
|
| 222 |
+
raise ImportError("Video MoBA Attention backend is not installed. ") from e
|
| 223 |
+
elif selected_backend == AttentionBackendEnum.SLA_ATTN:
|
| 224 |
+
try:
|
| 225 |
+
from fastvideo.attention.backends.sla import ( # noqa: F401
|
| 226 |
+
SLAAttentionBackend)
|
| 227 |
+
logger.info("Using SLA (Sparse-Linear Attention) backend.")
|
| 228 |
+
|
| 229 |
+
return "fastvideo.attention.backends.sla.SLAAttentionBackend"
|
| 230 |
+
except ImportError as e:
|
| 231 |
+
logger.error("Failed to import SLA Attention backend: %s", str(e))
|
| 232 |
+
raise ImportError("SLA Attention backend is not available. ") from e
|
| 233 |
+
elif selected_backend == AttentionBackendEnum.SAGE_SLA_ATTN:
|
| 234 |
+
try:
|
| 235 |
+
from fastvideo.attention.backends.sla import ( # noqa: F401
|
| 236 |
+
SageSLAAttentionBackend)
|
| 237 |
+
logger.info("Using SageSLA (Quantized Sparse-Linear Attention) backend.")
|
| 238 |
+
|
| 239 |
+
return "fastvideo.attention.backends.sla.SageSLAAttentionBackend"
|
| 240 |
+
except ImportError as e:
|
| 241 |
+
logger.error("Failed to import SageSLA Attention backend: %s", str(e))
|
| 242 |
+
raise ImportError("SageSLA Attention backend requires spas_sage_attn. "
|
| 243 |
+
"Install with: pip install git+https://github.com/thu-ml/SpargeAttn.git") from e
|
| 244 |
+
elif selected_backend == AttentionBackendEnum.TORCH_SDPA:
|
| 245 |
+
logger.info("Using Torch SDPA backend.")
|
| 246 |
+
return "fastvideo.attention.backends.sdpa.SDPABackend"
|
| 247 |
+
elif selected_backend == AttentionBackendEnum.FLASH_ATTN or selected_backend is None:
|
| 248 |
+
pass
|
| 249 |
+
elif selected_backend:
|
| 250 |
+
raise ValueError(f"Invalid attention backend for {cls.device_name}")
|
| 251 |
+
|
| 252 |
+
target_backend = AttentionBackendEnum.FLASH_ATTN
|
| 253 |
+
if not cls.has_device_capability(80):
|
| 254 |
+
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
|
| 255 |
+
"GPUs.")
|
| 256 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 257 |
+
elif dtype not in (torch.float16, torch.bfloat16):
|
| 258 |
+
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
|
| 259 |
+
"torch.float16 or torch.bfloat16.")
|
| 260 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 261 |
+
|
| 262 |
+
# FlashAttn is valid for the model, checking if the package is
|
| 263 |
+
# installed.
|
| 264 |
+
if target_backend == AttentionBackendEnum.FLASH_ATTN:
|
| 265 |
+
try:
|
| 266 |
+
import flash_attn # noqa: F401
|
| 267 |
+
|
| 268 |
+
from fastvideo.attention.backends.flash_attn import ( # noqa: F401
|
| 269 |
+
FlashAttentionBackend)
|
| 270 |
+
|
| 271 |
+
supported_sizes = \
|
| 272 |
+
FlashAttentionBackend.get_supported_head_sizes()
|
| 273 |
+
if head_size not in supported_sizes:
|
| 274 |
+
logger.info("Cannot use FlashAttention-2 backend for head size %d.", head_size)
|
| 275 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 276 |
+
except ImportError:
|
| 277 |
+
logger.info("Cannot use FlashAttention-2 backend because the "
|
| 278 |
+
"flash_attn package is not found. "
|
| 279 |
+
"Make sure that flash_attn was built and installed "
|
| 280 |
+
"(on by default).")
|
| 281 |
+
target_backend = AttentionBackendEnum.TORCH_SDPA
|
| 282 |
+
|
| 283 |
+
if target_backend == AttentionBackendEnum.TORCH_SDPA:
|
| 284 |
+
logger.info("Using Torch SDPA backend.")
|
| 285 |
+
|
| 286 |
+
return "fastvideo.attention.backends.sdpa.SDPABackend"
|
| 287 |
+
|
| 288 |
+
logger.info("Using Flash Attention backend.")
|
| 289 |
+
|
| 290 |
+
return "fastvideo.attention.backends.flash_attn.FlashAttentionBackend"
|
| 291 |
+
|
| 292 |
+
@classmethod
|
| 293 |
+
def get_device_communicator_cls(cls) -> str:
|
| 294 |
+
return "fastvideo.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# NVML utils
|
| 298 |
+
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
| 299 |
+
# all the related functions work on real physical device ids.
|
| 300 |
+
# the major benefit of using NVML is that it will not initialize CUDA
|
| 301 |
+
class NvmlCudaPlatform(CudaPlatformBase):
|
| 302 |
+
|
| 303 |
+
@classmethod
|
| 304 |
+
@lru_cache(maxsize=8)
|
| 305 |
+
@with_nvml_context
|
| 306 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
| 307 |
+
try:
|
| 308 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 309 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 310 |
+
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
| 311 |
+
return DeviceCapability(major=major, minor=minor)
|
| 312 |
+
except RuntimeError:
|
| 313 |
+
return None
|
| 314 |
+
|
| 315 |
+
@classmethod
|
| 316 |
+
@lru_cache(maxsize=8)
|
| 317 |
+
@with_nvml_context
|
| 318 |
+
def has_device_capability(
|
| 319 |
+
cls,
|
| 320 |
+
capability: tuple[int, int] | int,
|
| 321 |
+
device_id: int = 0,
|
| 322 |
+
) -> bool:
|
| 323 |
+
try:
|
| 324 |
+
return bool(super().has_device_capability(capability, device_id))
|
| 325 |
+
except RuntimeError:
|
| 326 |
+
return False
|
| 327 |
+
|
| 328 |
+
@classmethod
|
| 329 |
+
@lru_cache(maxsize=8)
|
| 330 |
+
@with_nvml_context
|
| 331 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 332 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 333 |
+
return cls._get_physical_device_name(physical_device_id)
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
@lru_cache(maxsize=8)
|
| 337 |
+
@with_nvml_context
|
| 338 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 339 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 340 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 341 |
+
return str(pynvml.nvmlDeviceGetUUID(handle))
|
| 342 |
+
|
| 343 |
+
@classmethod
|
| 344 |
+
@lru_cache(maxsize=8)
|
| 345 |
+
@with_nvml_context
|
| 346 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 347 |
+
physical_device_id = device_id_to_physical_device_id(device_id)
|
| 348 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
| 349 |
+
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
| 350 |
+
|
| 351 |
+
@classmethod
|
| 352 |
+
@with_nvml_context
|
| 353 |
+
def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
|
| 354 |
+
"""
|
| 355 |
+
query if the set of gpus are fully connected by nvlink (1 hop)
|
| 356 |
+
"""
|
| 357 |
+
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
|
| 358 |
+
for i, handle in enumerate(handles):
|
| 359 |
+
for j, peer_handle in enumerate(handles):
|
| 360 |
+
if i < j:
|
| 361 |
+
try:
|
| 362 |
+
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
| 363 |
+
handle,
|
| 364 |
+
peer_handle,
|
| 365 |
+
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
| 366 |
+
)
|
| 367 |
+
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
| 368 |
+
return False
|
| 369 |
+
except pynvml.NVMLError:
|
| 370 |
+
logger.exception("NVLink detection failed. This is normal if"
|
| 371 |
+
" your machine has no NVLink equipped.")
|
| 372 |
+
return False
|
| 373 |
+
return True
|
| 374 |
+
|
| 375 |
+
@classmethod
|
| 376 |
+
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
| 377 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
| 378 |
+
return str(pynvml.nvmlDeviceGetName(handle))
|
| 379 |
+
|
| 380 |
+
@classmethod
|
| 381 |
+
@with_nvml_context
|
| 382 |
+
def log_warnings(cls) -> None:
|
| 383 |
+
device_ids: int = pynvml.nvmlDeviceGetCount()
|
| 384 |
+
if device_ids > 1:
|
| 385 |
+
device_names = [cls._get_physical_device_name(i) for i in range(device_ids)]
|
| 386 |
+
if (len(set(device_names)) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
| 387 |
+
logger.warning(
|
| 388 |
+
"Detected different devices in the system: %s. Please"
|
| 389 |
+
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
| 390 |
+
"avoid unexpected behavior.",
|
| 391 |
+
", ".join(device_names),
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class NonNvmlCudaPlatform(CudaPlatformBase):
|
| 396 |
+
|
| 397 |
+
@classmethod
|
| 398 |
+
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
| 399 |
+
major, minor = torch.cuda.get_device_capability(device_id)
|
| 400 |
+
return DeviceCapability(major=major, minor=minor)
|
| 401 |
+
|
| 402 |
+
@classmethod
|
| 403 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 404 |
+
return str(torch.cuda.get_device_name(device_id))
|
| 405 |
+
|
| 406 |
+
@classmethod
|
| 407 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 408 |
+
device_props = torch.cuda.get_device_properties(device_id)
|
| 409 |
+
return int(device_props.total_memory)
|
| 410 |
+
|
| 411 |
+
@classmethod
|
| 412 |
+
def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool:
|
| 413 |
+
logger.exception("NVLink detection not possible, as context support was"
|
| 414 |
+
" not found. Assuming no NVLink available.")
|
| 415 |
+
return False
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Autodetect either NVML-enabled or non-NVML platform
|
| 419 |
+
# based on whether NVML is available.
|
| 420 |
+
nvml_available = False
|
| 421 |
+
try:
|
| 422 |
+
try:
|
| 423 |
+
pynvml.nvmlInit()
|
| 424 |
+
nvml_available = True
|
| 425 |
+
except Exception:
|
| 426 |
+
# On Jetson, NVML is not supported.
|
| 427 |
+
nvml_available = False
|
| 428 |
+
finally:
|
| 429 |
+
if nvml_available:
|
| 430 |
+
pynvml.nvmlShutdown()
|
| 431 |
+
|
| 432 |
+
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
| 433 |
+
|
| 434 |
+
try:
|
| 435 |
+
from sphinx.ext.autodoc.mock import _MockModule
|
| 436 |
+
|
| 437 |
+
if not isinstance(pynvml, _MockModule):
|
| 438 |
+
CudaPlatform.log_warnings()
|
| 439 |
+
except ModuleNotFoundError:
|
| 440 |
+
CudaPlatform.log_warnings()
|
backend_snapshot/fastvideo/platforms/interface.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import random
|
| 3 |
+
from typing import Any, NamedTuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from fastvideo.logger import init_logger
|
| 9 |
+
|
| 10 |
+
logger = init_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AttentionBackendEnum(enum.Enum):
|
| 14 |
+
FLASH_ATTN = enum.auto()
|
| 15 |
+
TORCH_SDPA = enum.auto()
|
| 16 |
+
SAGE_ATTN = enum.auto()
|
| 17 |
+
SAGE_ATTN_THREE = enum.auto()
|
| 18 |
+
ATTN_QAT_INFER = enum.auto()
|
| 19 |
+
ATTN_QAT_TRAIN = enum.auto()
|
| 20 |
+
VIDEO_SPARSE_ATTN = enum.auto()
|
| 21 |
+
BSA_ATTN = enum.auto()
|
| 22 |
+
VMOBA_ATTN = enum.auto()
|
| 23 |
+
SLA_ATTN = enum.auto()
|
| 24 |
+
SAGE_SLA_ATTN = enum.auto()
|
| 25 |
+
SPARSE_FP4_ATTN = enum.auto()
|
| 26 |
+
SPARSE_FP4_OURS_P_ATTN = enum.auto()
|
| 27 |
+
NO_ATTENTION = enum.auto()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PlatformEnum(enum.Enum):
|
| 31 |
+
CUDA = enum.auto()
|
| 32 |
+
ROCM = enum.auto()
|
| 33 |
+
TPU = enum.auto()
|
| 34 |
+
XPU = enum.auto()
|
| 35 |
+
CPU = enum.auto()
|
| 36 |
+
MPS = enum.auto()
|
| 37 |
+
OOT = enum.auto()
|
| 38 |
+
UNSPECIFIED = enum.auto()
|
| 39 |
+
NPU = enum.auto()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CpuArchEnum(enum.Enum):
|
| 43 |
+
X86 = enum.auto()
|
| 44 |
+
ARM = enum.auto()
|
| 45 |
+
UNSPECIFIED = enum.auto()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DeviceCapability(NamedTuple):
|
| 49 |
+
major: int
|
| 50 |
+
minor: int
|
| 51 |
+
|
| 52 |
+
def as_version_str(self) -> str:
|
| 53 |
+
return f"{self.major}.{self.minor}"
|
| 54 |
+
|
| 55 |
+
def to_int(self) -> int:
|
| 56 |
+
"""
|
| 57 |
+
Express device capability as an integer ``<major><minor>``.
|
| 58 |
+
|
| 59 |
+
It is assumed that the minor version is always a single digit.
|
| 60 |
+
"""
|
| 61 |
+
assert 0 <= self.minor < 10
|
| 62 |
+
return self.major * 10 + self.minor
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Platform:
|
| 66 |
+
_enum: PlatformEnum
|
| 67 |
+
device_name: str
|
| 68 |
+
device_type: str
|
| 69 |
+
|
| 70 |
+
dispatch_key: str = "CPU"
|
| 71 |
+
|
| 72 |
+
# platform-agnostic way to specify the device control environment variable,
|
| 73 |
+
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
| 74 |
+
# hint: search for "get_visible_accelerator_ids_env_var" in
|
| 75 |
+
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
| 76 |
+
device_control_env_var: str = "FASTVIDEO_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
| 77 |
+
|
| 78 |
+
# available ray device keys:
|
| 79 |
+
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
| 80 |
+
# empty string means the device does not support ray
|
| 81 |
+
ray_device_key: str = ""
|
| 82 |
+
# The torch.compile backend for compiling simple and
|
| 83 |
+
# standalone functions. The default value is "inductor" to keep
|
| 84 |
+
# the same behavior as PyTorch.
|
| 85 |
+
# NOTE: for the forward part of the model, vLLM has another separate
|
| 86 |
+
# compilation strategy.
|
| 87 |
+
simple_compile_backend: str = "inductor"
|
| 88 |
+
|
| 89 |
+
supported_quantization: list[str] = []
|
| 90 |
+
|
| 91 |
+
additional_env_vars: list[str] = []
|
| 92 |
+
|
| 93 |
+
def is_cuda(self) -> bool:
|
| 94 |
+
return self._enum == PlatformEnum.CUDA
|
| 95 |
+
|
| 96 |
+
def is_rocm(self) -> bool:
|
| 97 |
+
return self._enum == PlatformEnum.ROCM
|
| 98 |
+
|
| 99 |
+
def is_tpu(self) -> bool:
|
| 100 |
+
return self._enum == PlatformEnum.TPU
|
| 101 |
+
|
| 102 |
+
def is_xpu(self) -> bool:
|
| 103 |
+
return self._enum == PlatformEnum.XPU
|
| 104 |
+
|
| 105 |
+
def is_cpu(self) -> bool:
|
| 106 |
+
return self._enum == PlatformEnum.CPU
|
| 107 |
+
|
| 108 |
+
def is_out_of_tree(self) -> bool:
|
| 109 |
+
return self._enum == PlatformEnum.OOT
|
| 110 |
+
|
| 111 |
+
def is_cuda_alike(self) -> bool:
|
| 112 |
+
"""Stateless version of :func:`torch.cuda.is_available`."""
|
| 113 |
+
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
| 114 |
+
|
| 115 |
+
def is_mps(self) -> bool:
|
| 116 |
+
return self._enum == PlatformEnum.MPS
|
| 117 |
+
|
| 118 |
+
def is_npu(self) -> bool:
|
| 119 |
+
return self._enum == PlatformEnum.NPU
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, head_size: int,
|
| 123 |
+
dtype: torch.dtype) -> str:
|
| 124 |
+
"""Get the attention backend class of a device."""
|
| 125 |
+
return ""
|
| 126 |
+
|
| 127 |
+
@classmethod
|
| 128 |
+
def get_device_capability(
|
| 129 |
+
cls,
|
| 130 |
+
device_id: int = 0,
|
| 131 |
+
) -> DeviceCapability | None:
|
| 132 |
+
"""Stateless version of :func:`torch.cuda.get_device_capability`."""
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
@classmethod
|
| 136 |
+
def has_device_capability(
|
| 137 |
+
cls,
|
| 138 |
+
capability: tuple[int, int] | int,
|
| 139 |
+
device_id: int = 0,
|
| 140 |
+
) -> bool:
|
| 141 |
+
"""
|
| 142 |
+
Test whether this platform is compatible with a device capability.
|
| 143 |
+
|
| 144 |
+
The ``capability`` argument can either be:
|
| 145 |
+
|
| 146 |
+
- A tuple ``(major, minor)``.
|
| 147 |
+
- An integer ``<major><minor>``. (See :meth:`DeviceCapability.to_int`)
|
| 148 |
+
"""
|
| 149 |
+
current_capability = cls.get_device_capability(device_id=device_id)
|
| 150 |
+
if current_capability is None:
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
if isinstance(capability, tuple):
|
| 154 |
+
return current_capability >= capability
|
| 155 |
+
|
| 156 |
+
return current_capability.to_int() >= capability
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def get_device_name(cls, device_id: int = 0) -> str:
|
| 160 |
+
"""Get the name of a device."""
|
| 161 |
+
raise NotImplementedError
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def get_device_uuid(cls, device_id: int = 0) -> str:
|
| 165 |
+
"""Get the uuid of a device, e.g. the PCI bus ID."""
|
| 166 |
+
raise NotImplementedError
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
| 170 |
+
"""Get the total memory of a device in bytes."""
|
| 171 |
+
raise NotImplementedError
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
|
| 175 |
+
"""
|
| 176 |
+
Check if the current platform supports async output.
|
| 177 |
+
"""
|
| 178 |
+
raise NotImplementedError
|
| 179 |
+
|
| 180 |
+
@classmethod
|
| 181 |
+
def get_torch_device(cls) -> Any:
|
| 182 |
+
"""
|
| 183 |
+
Check if the current platform supports torch device.
|
| 184 |
+
"""
|
| 185 |
+
raise NotImplementedError
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def inference_mode(cls):
|
| 189 |
+
"""A device-specific wrapper of `torch.inference_mode`.
|
| 190 |
+
|
| 191 |
+
This wrapper is recommended because some hardware backends such as TPU
|
| 192 |
+
do not support `torch.inference_mode`. In such a case, they will fall
|
| 193 |
+
back to `torch.no_grad` by overriding this method.
|
| 194 |
+
"""
|
| 195 |
+
return torch.inference_mode(mode=True)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def seed_everything(cls, seed: int | None = None) -> None:
|
| 199 |
+
"""
|
| 200 |
+
Set the seed of each random module.
|
| 201 |
+
`torch.manual_seed` will set seed on all devices.
|
| 202 |
+
|
| 203 |
+
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
|
| 204 |
+
"""
|
| 205 |
+
if seed is not None:
|
| 206 |
+
random.seed(seed)
|
| 207 |
+
np.random.seed(seed)
|
| 208 |
+
torch.manual_seed(seed)
|
| 209 |
+
torch.cuda.manual_seed_all(seed)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def verify_model_arch(cls, model_arch: str) -> None:
|
| 213 |
+
"""
|
| 214 |
+
Verify whether the current platform supports the specified model
|
| 215 |
+
architecture.
|
| 216 |
+
|
| 217 |
+
- This will raise an Error or Warning based on the model support on
|
| 218 |
+
the current platform.
|
| 219 |
+
- By default all models are considered supported.
|
| 220 |
+
"""
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def verify_quantization(cls, quant: str) -> None:
|
| 225 |
+
"""
|
| 226 |
+
Verify whether the quantization is supported by the current platform.
|
| 227 |
+
"""
|
| 228 |
+
if cls.supported_quantization and \
|
| 229 |
+
quant not in cls.supported_quantization:
|
| 230 |
+
raise ValueError(f"{quant} quantization is currently not supported in "
|
| 231 |
+
f"{cls.device_name}.")
|
| 232 |
+
|
| 233 |
+
@classmethod
|
| 234 |
+
def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> float:
|
| 235 |
+
"""
|
| 236 |
+
Return the memory usage in bytes.
|
| 237 |
+
"""
|
| 238 |
+
raise NotImplementedError
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def get_device_communicator_cls(cls) -> str:
|
| 242 |
+
"""
|
| 243 |
+
Get device specific communicator class for distributed communication.
|
| 244 |
+
"""
|
| 245 |
+
return "fastvideo.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def get_cpu_architecture(cls) -> CpuArchEnum:
|
| 249 |
+
"""Get the CPU architecture of the current platform."""
|
| 250 |
+
return CpuArchEnum.UNSPECIFIED
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class UnspecifiedPlatform(Platform):
|
| 254 |
+
_enum = PlatformEnum.UNSPECIFIED
|
| 255 |
+
device_type = ""
|
backend_snapshot/fastvideo/train/models/wan/wan.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
"""Wan model plugin (per-role instance)."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import gc
|
| 8 |
+
from typing import Any, Literal, TYPE_CHECKING
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import fastvideo.envs as envs
|
| 13 |
+
from fastvideo.configs.sample import SamplingParam
|
| 14 |
+
from fastvideo.distributed import (
|
| 15 |
+
get_sp_group,
|
| 16 |
+
get_world_group,
|
| 17 |
+
)
|
| 18 |
+
from fastvideo.forward_context import set_forward_context
|
| 19 |
+
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
|
| 20 |
+
FlowMatchEulerDiscreteScheduler, )
|
| 21 |
+
from fastvideo.pipelines import TrainingBatch
|
| 22 |
+
from fastvideo.pipelines.basic.wan.wan_pipeline import (
|
| 23 |
+
WanPipeline, )
|
| 24 |
+
from fastvideo.pipelines.pipeline_batch_info import (
|
| 25 |
+
ForwardBatch, )
|
| 26 |
+
from fastvideo.training.activation_checkpoint import (
|
| 27 |
+
apply_activation_checkpointing, )
|
| 28 |
+
from fastvideo.training.training_utils import (
|
| 29 |
+
compute_density_for_timestep_sampling,
|
| 30 |
+
get_sigmas,
|
| 31 |
+
normalize_dit_input,
|
| 32 |
+
shift_timestep,
|
| 33 |
+
)
|
| 34 |
+
from fastvideo.utils import (
|
| 35 |
+
is_vmoba_available,
|
| 36 |
+
is_vsa_available,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from fastvideo.train.models.base import ModelBase
|
| 40 |
+
from fastvideo.train.utils.module_state import (
|
| 41 |
+
apply_trainable, )
|
| 42 |
+
from fastvideo.train.utils.moduleloader import (
|
| 43 |
+
load_module_from_path, )
|
| 44 |
+
|
| 45 |
+
if TYPE_CHECKING:
|
| 46 |
+
from fastvideo.train.utils.training_config import (
|
| 47 |
+
TrainingConfig, )
|
| 48 |
+
|
| 49 |
+
VideoSparseAttentionMetadataBuilder: type[Any] | None
|
| 50 |
+
VideoMobaAttentionMetadataBuilder: type[Any] | None
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from fastvideo.attention.backends.video_sparse_attn import (
|
| 54 |
+
VideoSparseAttentionMetadataBuilder as _VideoSparseAttentionMetadataBuilder, )
|
| 55 |
+
from fastvideo.attention.backends.vmoba import (
|
| 56 |
+
VideoMobaAttentionMetadataBuilder as _VideoMobaAttentionMetadataBuilder, )
|
| 57 |
+
VideoSparseAttentionMetadataBuilder = _VideoSparseAttentionMetadataBuilder
|
| 58 |
+
VideoMobaAttentionMetadataBuilder = _VideoMobaAttentionMetadataBuilder
|
| 59 |
+
except Exception:
|
| 60 |
+
VideoSparseAttentionMetadataBuilder = None
|
| 61 |
+
VideoMobaAttentionMetadataBuilder = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class WanModel(ModelBase):
|
| 65 |
+
"""Wan per-role model: owns transformer + noise_scheduler."""
|
| 66 |
+
|
| 67 |
+
_transformer_cls_name: str = "WanTransformer3DModel"
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
*,
|
| 72 |
+
init_from: str,
|
| 73 |
+
training_config: TrainingConfig,
|
| 74 |
+
trainable: bool = True,
|
| 75 |
+
disable_custom_init_weights: bool = False,
|
| 76 |
+
flow_shift: float = 3.0,
|
| 77 |
+
enable_gradient_checkpointing_type: str
|
| 78 |
+
| None = None,
|
| 79 |
+
transformer_override_safetensor: str
|
| 80 |
+
| None = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
self._init_from = str(init_from)
|
| 83 |
+
self._trainable = bool(trainable)
|
| 84 |
+
|
| 85 |
+
self.transformer = self._load_transformer(
|
| 86 |
+
init_from=self._init_from,
|
| 87 |
+
trainable=self._trainable,
|
| 88 |
+
disable_custom_init_weights=(disable_custom_init_weights),
|
| 89 |
+
enable_gradient_checkpointing_type=(enable_gradient_checkpointing_type),
|
| 90 |
+
training_config=training_config,
|
| 91 |
+
transformer_override_safetensor=(transformer_override_safetensor),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
self.noise_scheduler = (FlowMatchEulerDiscreteScheduler(shift=float(flow_shift)))
|
| 95 |
+
|
| 96 |
+
# Filled by init_preprocessors (student only).
|
| 97 |
+
self.vae: Any = None
|
| 98 |
+
self.training_config: TrainingConfig = training_config
|
| 99 |
+
self.dataloader: Any = None
|
| 100 |
+
self.validator: Any = None
|
| 101 |
+
self.start_step: int = 0
|
| 102 |
+
|
| 103 |
+
self.world_group: Any = None
|
| 104 |
+
self.sp_group: Any = None
|
| 105 |
+
|
| 106 |
+
self.negative_prompt_embeds: (torch.Tensor | None) = None
|
| 107 |
+
self.negative_prompt_attention_mask: (torch.Tensor | None) = None
|
| 108 |
+
|
| 109 |
+
# Timestep mechanics.
|
| 110 |
+
self.timestep_shift: float = float(flow_shift)
|
| 111 |
+
self.num_train_timestep: int = int(self.noise_scheduler.num_train_timesteps)
|
| 112 |
+
self.min_timestep: int = 0
|
| 113 |
+
self.max_timestep: int = self.num_train_timestep
|
| 114 |
+
|
| 115 |
+
def _load_transformer(
|
| 116 |
+
self,
|
| 117 |
+
*,
|
| 118 |
+
init_from: str,
|
| 119 |
+
trainable: bool,
|
| 120 |
+
disable_custom_init_weights: bool,
|
| 121 |
+
enable_gradient_checkpointing_type: str | None,
|
| 122 |
+
training_config: TrainingConfig,
|
| 123 |
+
transformer_override_safetensor: str | None = None,
|
| 124 |
+
) -> torch.nn.Module:
|
| 125 |
+
transformer = load_module_from_path(
|
| 126 |
+
model_path=init_from,
|
| 127 |
+
module_type="transformer",
|
| 128 |
+
training_config=training_config,
|
| 129 |
+
disable_custom_init_weights=(disable_custom_init_weights),
|
| 130 |
+
override_transformer_cls_name=(self._transformer_cls_name),
|
| 131 |
+
transformer_override_safetensor=(transformer_override_safetensor),
|
| 132 |
+
)
|
| 133 |
+
transformer = apply_trainable(transformer, trainable=trainable)
|
| 134 |
+
# Fall back to training_config.model if not set on the
|
| 135 |
+
# model YAML section directly.
|
| 136 |
+
ckpt_type = (enable_gradient_checkpointing_type or getattr(
|
| 137 |
+
getattr(training_config, "model", None),
|
| 138 |
+
"enable_gradient_checkpointing_type",
|
| 139 |
+
None,
|
| 140 |
+
))
|
| 141 |
+
if trainable and ckpt_type:
|
| 142 |
+
transformer = apply_activation_checkpointing(
|
| 143 |
+
transformer,
|
| 144 |
+
checkpointing_type=ckpt_type,
|
| 145 |
+
)
|
| 146 |
+
return transformer
|
| 147 |
+
|
| 148 |
+
# ------------------------------------------------------------------
|
| 149 |
+
# Lifecycle
|
| 150 |
+
# ------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
def init_preprocessors(self, training_config: TrainingConfig) -> None:
|
| 153 |
+
self.vae = load_module_from_path(
|
| 154 |
+
model_path=str(training_config.model_path),
|
| 155 |
+
module_type="vae",
|
| 156 |
+
training_config=training_config,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.world_group = get_world_group()
|
| 160 |
+
self.sp_group = get_sp_group()
|
| 161 |
+
|
| 162 |
+
self._init_timestep_mechanics()
|
| 163 |
+
|
| 164 |
+
from fastvideo.dataset.dataloader.schema import (
|
| 165 |
+
pyarrow_schema_t2v, )
|
| 166 |
+
from fastvideo.train.utils.dataloader import (
|
| 167 |
+
build_parquet_t2v_train_dataloader, )
|
| 168 |
+
|
| 169 |
+
text_len = (
|
| 170 |
+
training_config.pipeline_config.text_encoder_configs[ # type: ignore[union-attr]
|
| 171 |
+
0].arch_config.text_len)
|
| 172 |
+
self.dataloader = build_parquet_t2v_train_dataloader(
|
| 173 |
+
training_config.data,
|
| 174 |
+
text_len=int(text_len),
|
| 175 |
+
parquet_schema=pyarrow_schema_t2v,
|
| 176 |
+
)
|
| 177 |
+
self.start_step = 0
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def num_train_timesteps(self) -> int:
|
| 181 |
+
return int(self.num_train_timestep)
|
| 182 |
+
|
| 183 |
+
def shift_and_clamp_timestep(self, timestep: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
timestep = shift_timestep(
|
| 185 |
+
timestep,
|
| 186 |
+
self.timestep_shift,
|
| 187 |
+
self.num_train_timestep,
|
| 188 |
+
)
|
| 189 |
+
return timestep.clamp(self.min_timestep, self.max_timestep)
|
| 190 |
+
|
| 191 |
+
def on_train_start(self) -> None:
|
| 192 |
+
self.ensure_negative_conditioning()
|
| 193 |
+
|
| 194 |
+
# ------------------------------------------------------------------
|
| 195 |
+
# Runtime primitives
|
| 196 |
+
# ------------------------------------------------------------------
|
| 197 |
+
|
| 198 |
+
def prepare_batch(
|
| 199 |
+
self,
|
| 200 |
+
raw_batch: dict[str, Any],
|
| 201 |
+
*,
|
| 202 |
+
generator: torch.Generator,
|
| 203 |
+
latents_source: Literal["data", "zeros"] = "data",
|
| 204 |
+
) -> TrainingBatch:
|
| 205 |
+
self.ensure_negative_conditioning()
|
| 206 |
+
assert self.training_config is not None
|
| 207 |
+
tc = self.training_config
|
| 208 |
+
|
| 209 |
+
dtype = self._get_training_dtype()
|
| 210 |
+
device = self.device
|
| 211 |
+
|
| 212 |
+
training_batch = TrainingBatch()
|
| 213 |
+
encoder_hidden_states = raw_batch["text_embedding"]
|
| 214 |
+
encoder_attention_mask = raw_batch["text_attention_mask"]
|
| 215 |
+
infos = raw_batch.get("info_list")
|
| 216 |
+
|
| 217 |
+
if latents_source == "zeros":
|
| 218 |
+
batch_size = encoder_hidden_states.shape[0]
|
| 219 |
+
vae_config = (
|
| 220 |
+
tc.pipeline_config.vae_config.arch_config # type: ignore[union-attr]
|
| 221 |
+
)
|
| 222 |
+
num_channels = vae_config.z_dim
|
| 223 |
+
spatial_compression_ratio = (vae_config.spatial_compression_ratio)
|
| 224 |
+
latent_height = (tc.data.num_height // spatial_compression_ratio)
|
| 225 |
+
latent_width = (tc.data.num_width // spatial_compression_ratio)
|
| 226 |
+
latents = torch.zeros(
|
| 227 |
+
batch_size,
|
| 228 |
+
num_channels,
|
| 229 |
+
tc.data.num_latent_t,
|
| 230 |
+
latent_height,
|
| 231 |
+
latent_width,
|
| 232 |
+
device=device,
|
| 233 |
+
dtype=dtype,
|
| 234 |
+
)
|
| 235 |
+
elif latents_source == "data":
|
| 236 |
+
if "vae_latent" not in raw_batch:
|
| 237 |
+
raise ValueError("vae_latent not found in batch "
|
| 238 |
+
"and latents_source='data'")
|
| 239 |
+
latents = raw_batch["vae_latent"]
|
| 240 |
+
latents = latents[:, :, :tc.data.num_latent_t]
|
| 241 |
+
latents = latents.to(device, dtype=dtype)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unknown latents_source: "
|
| 244 |
+
f"{latents_source!r}")
|
| 245 |
+
|
| 246 |
+
training_batch.latents = latents
|
| 247 |
+
training_batch.encoder_hidden_states = (encoder_hidden_states.to(device, dtype=dtype))
|
| 248 |
+
training_batch.encoder_attention_mask = (encoder_attention_mask.to(device, dtype=dtype))
|
| 249 |
+
training_batch.infos = infos
|
| 250 |
+
|
| 251 |
+
training_batch.latents = normalize_dit_input("wan", training_batch.latents, self.vae)
|
| 252 |
+
training_batch = self._prepare_dit_inputs(training_batch, generator)
|
| 253 |
+
training_batch = self._build_attention_metadata(training_batch)
|
| 254 |
+
|
| 255 |
+
training_batch.attn_metadata_vsa = copy.deepcopy(training_batch.attn_metadata)
|
| 256 |
+
if training_batch.attn_metadata is not None:
|
| 257 |
+
training_batch.attn_metadata.VSA_sparsity = 0.0 # type: ignore[attr-defined]
|
| 258 |
+
|
| 259 |
+
return training_batch
|
| 260 |
+
|
| 261 |
+
def add_noise(
|
| 262 |
+
self,
|
| 263 |
+
clean_latents: torch.Tensor,
|
| 264 |
+
noise: torch.Tensor,
|
| 265 |
+
timestep: torch.Tensor,
|
| 266 |
+
) -> torch.Tensor:
|
| 267 |
+
b, t = clean_latents.shape[:2]
|
| 268 |
+
noisy = self.noise_scheduler.add_noise(
|
| 269 |
+
clean_latents.flatten(0, 1),
|
| 270 |
+
noise.flatten(0, 1),
|
| 271 |
+
timestep,
|
| 272 |
+
).unflatten(0, (b, t))
|
| 273 |
+
return noisy
|
| 274 |
+
|
| 275 |
+
def predict_noise(
|
| 276 |
+
self,
|
| 277 |
+
noisy_latents: torch.Tensor,
|
| 278 |
+
timestep: torch.Tensor,
|
| 279 |
+
batch: TrainingBatch,
|
| 280 |
+
*,
|
| 281 |
+
conditional: bool,
|
| 282 |
+
cfg_uncond: dict[str, Any] | None = None,
|
| 283 |
+
attn_kind: Literal["dense", "vsa"] = "dense",
|
| 284 |
+
force_dense: bool = False,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
device_type = self.device.type
|
| 287 |
+
dtype = noisy_latents.dtype
|
| 288 |
+
if conditional:
|
| 289 |
+
text_dict = batch.conditional_dict
|
| 290 |
+
if text_dict is None:
|
| 291 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 292 |
+
"TrainingBatch")
|
| 293 |
+
else:
|
| 294 |
+
text_dict = self._get_uncond_text_dict(batch, cfg_uncond=cfg_uncond)
|
| 295 |
+
|
| 296 |
+
if attn_kind == "dense":
|
| 297 |
+
attn_metadata = batch.attn_metadata
|
| 298 |
+
elif attn_kind in ("vsa", "sparse_fp4"):
|
| 299 |
+
attn_metadata = batch.attn_metadata_vsa
|
| 300 |
+
else:
|
| 301 |
+
raise ValueError(f"Unknown attn_kind: {attn_kind!r}")
|
| 302 |
+
|
| 303 |
+
with torch.autocast(device_type, dtype=dtype), set_forward_context(
|
| 304 |
+
current_timestep=batch.timesteps,
|
| 305 |
+
attn_metadata=attn_metadata,
|
| 306 |
+
force_dense=force_dense,
|
| 307 |
+
):
|
| 308 |
+
input_kwargs = (self._build_distill_input_kwargs(noisy_latents, timestep, text_dict))
|
| 309 |
+
transformer = self._get_transformer(timestep)
|
| 310 |
+
pred_noise = transformer(**input_kwargs).permute(0, 2, 1, 3, 4)
|
| 311 |
+
return pred_noise
|
| 312 |
+
|
| 313 |
+
def backward(
|
| 314 |
+
self,
|
| 315 |
+
loss: torch.Tensor,
|
| 316 |
+
ctx: Any,
|
| 317 |
+
*,
|
| 318 |
+
grad_accum_rounds: int,
|
| 319 |
+
) -> None:
|
| 320 |
+
timesteps, attn_metadata = ctx
|
| 321 |
+
with set_forward_context(
|
| 322 |
+
current_timestep=timesteps,
|
| 323 |
+
attn_metadata=attn_metadata,
|
| 324 |
+
):
|
| 325 |
+
(loss / max(1, int(grad_accum_rounds))).backward()
|
| 326 |
+
|
| 327 |
+
# ------------------------------------------------------------------
|
| 328 |
+
# Internal helpers
|
| 329 |
+
# ------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
def _get_training_dtype(self) -> torch.dtype:
|
| 332 |
+
return torch.bfloat16
|
| 333 |
+
|
| 334 |
+
def _init_timestep_mechanics(self) -> None:
|
| 335 |
+
assert self.training_config is not None
|
| 336 |
+
tc = self.training_config
|
| 337 |
+
flow_shift = tc.pipeline_config.flow_shift
|
| 338 |
+
self.timestep_shift = float(0.0 if flow_shift is None else flow_shift)
|
| 339 |
+
self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps)
|
| 340 |
+
# min/max timestep ratios now come from method_config;
|
| 341 |
+
# default to full range.
|
| 342 |
+
self.min_timestep = 0
|
| 343 |
+
self.max_timestep = self.num_train_timestep
|
| 344 |
+
|
| 345 |
+
def ensure_negative_conditioning(self) -> None:
|
| 346 |
+
if self.negative_prompt_embeds is not None:
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
assert self.training_config is not None
|
| 350 |
+
tc = self.training_config
|
| 351 |
+
world_group = self.world_group
|
| 352 |
+
device = self.device
|
| 353 |
+
dtype = self._get_training_dtype()
|
| 354 |
+
|
| 355 |
+
from fastvideo.train.utils.moduleloader import (
|
| 356 |
+
make_inference_args, )
|
| 357 |
+
|
| 358 |
+
neg_embeds: torch.Tensor | None = None
|
| 359 |
+
neg_mask: torch.Tensor | None = None
|
| 360 |
+
|
| 361 |
+
if world_group.rank_in_group == 0:
|
| 362 |
+
sampling_param = SamplingParam.from_pretrained(tc.model_path)
|
| 363 |
+
negative_prompt = sampling_param.negative_prompt
|
| 364 |
+
|
| 365 |
+
inference_args = make_inference_args(tc, model_path=tc.model_path)
|
| 366 |
+
|
| 367 |
+
prompt_pipeline = WanPipeline.from_pretrained(
|
| 368 |
+
tc.model_path,
|
| 369 |
+
args=inference_args,
|
| 370 |
+
inference_mode=True,
|
| 371 |
+
loaded_modules={"transformer": self.transformer},
|
| 372 |
+
tp_size=tc.distributed.tp_size,
|
| 373 |
+
sp_size=tc.distributed.sp_size,
|
| 374 |
+
num_gpus=tc.distributed.num_gpus,
|
| 375 |
+
pin_cpu_memory=(tc.distributed.pin_cpu_memory),
|
| 376 |
+
dit_cpu_offload=True,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
batch_negative = ForwardBatch(
|
| 380 |
+
data_type="video",
|
| 381 |
+
prompt=negative_prompt,
|
| 382 |
+
prompt_embeds=[],
|
| 383 |
+
prompt_attention_mask=[],
|
| 384 |
+
)
|
| 385 |
+
result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
|
| 386 |
+
batch_negative,
|
| 387 |
+
inference_args,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
|
| 391 |
+
neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))
|
| 392 |
+
|
| 393 |
+
del prompt_pipeline
|
| 394 |
+
gc.collect()
|
| 395 |
+
if torch.cuda.is_available():
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
|
| 398 |
+
meta = torch.zeros((2, ), device=device, dtype=torch.int64)
|
| 399 |
+
if world_group.rank_in_group == 0:
|
| 400 |
+
assert neg_embeds is not None
|
| 401 |
+
assert neg_mask is not None
|
| 402 |
+
meta[0] = neg_embeds.ndim
|
| 403 |
+
meta[1] = neg_mask.ndim
|
| 404 |
+
world_group.broadcast(meta, src=0)
|
| 405 |
+
embed_ndim, mask_ndim = (
|
| 406 |
+
int(meta[0].item()),
|
| 407 |
+
int(meta[1].item()),
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
max_ndim = 8
|
| 411 |
+
embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
|
| 412 |
+
mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
|
| 413 |
+
if world_group.rank_in_group == 0:
|
| 414 |
+
assert neg_embeds is not None
|
| 415 |
+
assert neg_mask is not None
|
| 416 |
+
embed_shape[:embed_ndim] = torch.tensor(
|
| 417 |
+
list(neg_embeds.shape),
|
| 418 |
+
device=device,
|
| 419 |
+
dtype=torch.int64,
|
| 420 |
+
)
|
| 421 |
+
mask_shape[:mask_ndim] = torch.tensor(
|
| 422 |
+
list(neg_mask.shape),
|
| 423 |
+
device=device,
|
| 424 |
+
dtype=torch.int64,
|
| 425 |
+
)
|
| 426 |
+
world_group.broadcast(embed_shape, src=0)
|
| 427 |
+
world_group.broadcast(mask_shape, src=0)
|
| 428 |
+
|
| 429 |
+
embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
|
| 430 |
+
mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())
|
| 431 |
+
|
| 432 |
+
if world_group.rank_in_group != 0:
|
| 433 |
+
neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
|
| 434 |
+
neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
|
| 435 |
+
assert neg_embeds is not None
|
| 436 |
+
assert neg_mask is not None
|
| 437 |
+
|
| 438 |
+
world_group.broadcast(neg_embeds, src=0)
|
| 439 |
+
world_group.broadcast(neg_mask, src=0)
|
| 440 |
+
|
| 441 |
+
self.negative_prompt_embeds = neg_embeds
|
| 442 |
+
self.negative_prompt_attention_mask = neg_mask
|
| 443 |
+
|
| 444 |
+
def _sample_timesteps(
|
| 445 |
+
self,
|
| 446 |
+
batch_size: int,
|
| 447 |
+
device: torch.device,
|
| 448 |
+
generator: torch.Generator,
|
| 449 |
+
) -> torch.Tensor:
|
| 450 |
+
assert self.training_config is not None
|
| 451 |
+
tc = self.training_config
|
| 452 |
+
|
| 453 |
+
u = compute_density_for_timestep_sampling(
|
| 454 |
+
weighting_scheme=tc.model.weighting_scheme,
|
| 455 |
+
batch_size=batch_size,
|
| 456 |
+
generator=generator,
|
| 457 |
+
device=device,
|
| 458 |
+
logit_mean=tc.model.logit_mean,
|
| 459 |
+
logit_std=tc.model.logit_std,
|
| 460 |
+
mode_scale=tc.model.mode_scale,
|
| 461 |
+
)
|
| 462 |
+
indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
|
| 463 |
+
return self.noise_scheduler.timesteps[indices.cpu()].to(device=device)
|
| 464 |
+
|
| 465 |
+
def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 466 |
+
assert self.training_config is not None
|
| 467 |
+
tc = self.training_config
|
| 468 |
+
latents_shape = training_batch.raw_latent_shape
|
| 469 |
+
patch_size = (
|
| 470 |
+
tc.pipeline_config.dit_config.patch_size # type: ignore[union-attr]
|
| 471 |
+
)
|
| 472 |
+
assert latents_shape is not None
|
| 473 |
+
assert training_batch.timesteps is not None
|
| 474 |
+
|
| 475 |
+
if envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 476 |
+
"VIDEO_SPARSE_ATTN", "SPARSE_FP4_ATTN", "SPARSE_FP4_OURS_P_ATTN",
|
| 477 |
+
):
|
| 478 |
+
if (not is_vsa_available() or VideoSparseAttentionMetadataBuilder is None):
|
| 479 |
+
raise ImportError(
|
| 480 |
+
f"FASTVIDEO_ATTENTION_BACKEND is "
|
| 481 |
+
f"{envs.FASTVIDEO_ATTENTION_BACKEND}, but "
|
| 482 |
+
f"fastvideo_kernel is not correctly "
|
| 483 |
+
f"installed or detected.")
|
| 484 |
+
training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder().build( # type: ignore[misc]
|
| 485 |
+
raw_latent_shape=latents_shape[2:5],
|
| 486 |
+
current_timestep=(training_batch.timesteps),
|
| 487 |
+
patch_size=patch_size,
|
| 488 |
+
VSA_sparsity=tc.vsa_sparsity,
|
| 489 |
+
device=self.device,
|
| 490 |
+
)
|
| 491 |
+
elif (envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN"):
|
| 492 |
+
if (not is_vmoba_available() or VideoMobaAttentionMetadataBuilder is None):
|
| 493 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is "
|
| 494 |
+
"VMOBA_ATTN, but fastvideo_kernel "
|
| 495 |
+
"(or flash_attn>=2.7.4) is not "
|
| 496 |
+
"correctly installed.")
|
| 497 |
+
moba_params = tc.model.moba_config.copy()
|
| 498 |
+
assert training_batch.raw_latent_shape is not None
|
| 499 |
+
moba_params.update({
|
| 500 |
+
"current_timestep": (training_batch.timesteps),
|
| 501 |
+
"raw_latent_shape": (training_batch.raw_latent_shape[2:5]),
|
| 502 |
+
"patch_size": patch_size,
|
| 503 |
+
"device": self.device,
|
| 504 |
+
})
|
| 505 |
+
training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**
|
| 506 |
+
moba_params) # type: ignore[misc]
|
| 507 |
+
else:
|
| 508 |
+
training_batch.attn_metadata = None
|
| 509 |
+
|
| 510 |
+
return training_batch
|
| 511 |
+
|
| 512 |
+
def _prepare_dit_inputs(
|
| 513 |
+
self,
|
| 514 |
+
training_batch: TrainingBatch,
|
| 515 |
+
generator: torch.Generator,
|
| 516 |
+
) -> TrainingBatch:
|
| 517 |
+
assert self.training_config is not None
|
| 518 |
+
tc = self.training_config
|
| 519 |
+
latents = training_batch.latents
|
| 520 |
+
assert isinstance(latents, torch.Tensor)
|
| 521 |
+
batch_size = latents.shape[0]
|
| 522 |
+
|
| 523 |
+
noise = torch.randn(
|
| 524 |
+
latents.shape,
|
| 525 |
+
generator=generator,
|
| 526 |
+
device=latents.device,
|
| 527 |
+
dtype=latents.dtype,
|
| 528 |
+
)
|
| 529 |
+
timesteps = self._sample_timesteps(
|
| 530 |
+
batch_size,
|
| 531 |
+
latents.device,
|
| 532 |
+
generator,
|
| 533 |
+
)
|
| 534 |
+
if int(tc.distributed.sp_size or 1) > 1:
|
| 535 |
+
self.sp_group.broadcast(timesteps, src=0)
|
| 536 |
+
|
| 537 |
+
sigmas = get_sigmas(
|
| 538 |
+
self.noise_scheduler,
|
| 539 |
+
latents.device,
|
| 540 |
+
timesteps,
|
| 541 |
+
n_dim=latents.ndim,
|
| 542 |
+
dtype=latents.dtype,
|
| 543 |
+
)
|
| 544 |
+
noisy_model_input = ((1.0 - sigmas) * latents + sigmas * noise)
|
| 545 |
+
|
| 546 |
+
training_batch.noisy_model_input = (noisy_model_input)
|
| 547 |
+
training_batch.timesteps = timesteps
|
| 548 |
+
training_batch.sigmas = sigmas
|
| 549 |
+
training_batch.noise = noise
|
| 550 |
+
training_batch.raw_latent_shape = latents.shape
|
| 551 |
+
|
| 552 |
+
training_batch.conditional_dict = {
|
| 553 |
+
"encoder_hidden_states": (training_batch.encoder_hidden_states),
|
| 554 |
+
"encoder_attention_mask": (training_batch.encoder_attention_mask),
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
if (self.negative_prompt_embeds is not None and self.negative_prompt_attention_mask is not None):
|
| 558 |
+
neg_embeds = self.negative_prompt_embeds
|
| 559 |
+
neg_mask = (self.negative_prompt_attention_mask)
|
| 560 |
+
if (neg_embeds.shape[0] == 1 and batch_size > 1):
|
| 561 |
+
neg_embeds = neg_embeds.expand(batch_size, *neg_embeds.shape[1:]).contiguous()
|
| 562 |
+
if (neg_mask.shape[0] == 1 and batch_size > 1):
|
| 563 |
+
neg_mask = neg_mask.expand(batch_size, *neg_mask.shape[1:]).contiguous()
|
| 564 |
+
training_batch.unconditional_dict = {
|
| 565 |
+
"encoder_hidden_states": neg_embeds,
|
| 566 |
+
"encoder_attention_mask": neg_mask,
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
training_batch.latents = (training_batch.latents.permute(0, 2, 1, 3, 4))
|
| 570 |
+
return training_batch
|
| 571 |
+
|
| 572 |
+
def _build_distill_input_kwargs(
|
| 573 |
+
self,
|
| 574 |
+
noise_input: torch.Tensor,
|
| 575 |
+
timestep: torch.Tensor,
|
| 576 |
+
text_dict: dict[str, torch.Tensor] | None,
|
| 577 |
+
) -> dict[str, Any]:
|
| 578 |
+
if text_dict is None:
|
| 579 |
+
raise ValueError("text_dict cannot be None for "
|
| 580 |
+
"Wan distillation")
|
| 581 |
+
return {
|
| 582 |
+
"hidden_states": noise_input.permute(0, 2, 1, 3, 4),
|
| 583 |
+
"encoder_hidden_states": text_dict["encoder_hidden_states"],
|
| 584 |
+
"encoder_attention_mask": text_dict["encoder_attention_mask"],
|
| 585 |
+
"timestep": timestep,
|
| 586 |
+
"return_dict": False,
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
def _get_transformer(self, timestep: torch.Tensor) -> torch.nn.Module:
|
| 590 |
+
return self.transformer
|
| 591 |
+
|
| 592 |
+
def _get_uncond_text_dict(
|
| 593 |
+
self,
|
| 594 |
+
batch: TrainingBatch,
|
| 595 |
+
*,
|
| 596 |
+
cfg_uncond: dict[str, Any] | None,
|
| 597 |
+
) -> dict[str, torch.Tensor]:
|
| 598 |
+
if cfg_uncond is None:
|
| 599 |
+
text_dict = getattr(batch, "unconditional_dict", None)
|
| 600 |
+
if text_dict is None:
|
| 601 |
+
raise RuntimeError("Missing unconditional_dict; "
|
| 602 |
+
"ensure_negative_conditioning() "
|
| 603 |
+
"may have failed")
|
| 604 |
+
return text_dict
|
| 605 |
+
|
| 606 |
+
on_missing_raw = cfg_uncond.get("on_missing", "error")
|
| 607 |
+
if not isinstance(on_missing_raw, str):
|
| 608 |
+
raise ValueError("method_config.cfg_uncond.on_missing "
|
| 609 |
+
"must be a string, got "
|
| 610 |
+
f"{type(on_missing_raw).__name__}")
|
| 611 |
+
on_missing = on_missing_raw.strip().lower()
|
| 612 |
+
if on_missing not in {"error", "ignore"}:
|
| 613 |
+
raise ValueError("method_config.cfg_uncond.on_missing "
|
| 614 |
+
"must be one of {error, ignore}, got "
|
| 615 |
+
f"{on_missing_raw!r}")
|
| 616 |
+
|
| 617 |
+
for channel, policy_raw in cfg_uncond.items():
|
| 618 |
+
if channel in {"on_missing", "text"}:
|
| 619 |
+
continue
|
| 620 |
+
if policy_raw is None:
|
| 621 |
+
continue
|
| 622 |
+
if not isinstance(policy_raw, str):
|
| 623 |
+
raise ValueError("method_config.cfg_uncond values "
|
| 624 |
+
"must be strings, got "
|
| 625 |
+
f"{channel}="
|
| 626 |
+
f"{type(policy_raw).__name__}")
|
| 627 |
+
policy = policy_raw.strip().lower()
|
| 628 |
+
if policy == "keep":
|
| 629 |
+
continue
|
| 630 |
+
if on_missing == "ignore":
|
| 631 |
+
continue
|
| 632 |
+
raise ValueError("WanModel does not support "
|
| 633 |
+
"cfg_uncond channel "
|
| 634 |
+
f"{channel!r} (policy={policy!r}). "
|
| 635 |
+
"Set cfg_uncond.on_missing=ignore or "
|
| 636 |
+
"remove the channel.")
|
| 637 |
+
|
| 638 |
+
text_policy_raw = cfg_uncond.get("text", None)
|
| 639 |
+
if text_policy_raw is None:
|
| 640 |
+
text_policy = "negative_prompt"
|
| 641 |
+
elif not isinstance(text_policy_raw, str):
|
| 642 |
+
raise ValueError("method_config.cfg_uncond.text must be "
|
| 643 |
+
"a string, got "
|
| 644 |
+
f"{type(text_policy_raw).__name__}")
|
| 645 |
+
else:
|
| 646 |
+
text_policy = (text_policy_raw.strip().lower())
|
| 647 |
+
|
| 648 |
+
if text_policy in {"negative_prompt"}:
|
| 649 |
+
text_dict = getattr(batch, "unconditional_dict", None)
|
| 650 |
+
if text_dict is None:
|
| 651 |
+
raise RuntimeError("Missing unconditional_dict; "
|
| 652 |
+
"ensure_negative_conditioning() "
|
| 653 |
+
"may have failed")
|
| 654 |
+
return text_dict
|
| 655 |
+
if text_policy == "keep":
|
| 656 |
+
if batch.conditional_dict is None:
|
| 657 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 658 |
+
"TrainingBatch")
|
| 659 |
+
return batch.conditional_dict
|
| 660 |
+
if text_policy == "zero":
|
| 661 |
+
if batch.conditional_dict is None:
|
| 662 |
+
raise RuntimeError("Missing conditional_dict in "
|
| 663 |
+
"TrainingBatch")
|
| 664 |
+
cond = batch.conditional_dict
|
| 665 |
+
enc = cond["encoder_hidden_states"]
|
| 666 |
+
mask = cond["encoder_attention_mask"]
|
| 667 |
+
if not torch.is_tensor(enc) or not torch.is_tensor(mask):
|
| 668 |
+
raise TypeError("conditional_dict must contain "
|
| 669 |
+
"tensor text inputs")
|
| 670 |
+
return {
|
| 671 |
+
"encoder_hidden_states": (torch.zeros_like(enc)),
|
| 672 |
+
"encoder_attention_mask": (torch.zeros_like(mask)),
|
| 673 |
+
}
|
| 674 |
+
if text_policy == "drop":
|
| 675 |
+
raise ValueError("cfg_uncond.text=drop is not supported "
|
| 676 |
+
"for Wan. Use "
|
| 677 |
+
"{negative_prompt, keep, zero}.")
|
| 678 |
+
raise ValueError("cfg_uncond.text must be one of "
|
| 679 |
+
"{negative_prompt, keep, zero, drop}, got "
|
| 680 |
+
f"{text_policy_raw!r}")
|
backend_snapshot/fastvideo/training/training_pipeline.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
from dataclasses import asdict
|
| 3 |
+
from contextlib import AbstractContextManager, nullcontext
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
from collections import deque
|
| 11 |
+
from collections.abc import Iterator
|
| 12 |
+
from typing import Any
|
| 13 |
+
from fastvideo.profiler import profile_region
|
| 14 |
+
import imageio
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
import torchvision
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
from torchdata.stateful_dataloader import StatefulDataLoader
|
| 22 |
+
from tqdm.auto import tqdm
|
| 23 |
+
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 24 |
+
|
| 25 |
+
import fastvideo.envs as envs
|
| 26 |
+
try:
|
| 27 |
+
from fastvideo.attention.backends.video_sparse_attn import (VideoSparseAttentionMetadataBuilder)
|
| 28 |
+
from fastvideo.attention.backends.vmoba import VideoMobaAttentionMetadataBuilder
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
from fastvideo.configs.sample import SamplingParam
|
| 32 |
+
from fastvideo.dataset import build_parquet_map_style_dataloader
|
| 33 |
+
from fastvideo.dataset.dataloader.schema import pyarrow_schema_t2v
|
| 34 |
+
from fastvideo.dataset.validation_dataset import ValidationDataset
|
| 35 |
+
from fastvideo.distributed import (cleanup_dist_env_and_memory, get_local_torch_device, get_sp_group, get_world_group)
|
| 36 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 37 |
+
from fastvideo.forward_context import set_forward_context
|
| 38 |
+
from fastvideo.logger import init_logger
|
| 39 |
+
from fastvideo.attention.selector import global_force_attn_backend_context_manager
|
| 40 |
+
from fastvideo.pipelines import (ComposedPipelineBase, ForwardBatch, LoRAPipeline, TrainingBatch)
|
| 41 |
+
from fastvideo.platforms import AttentionBackendEnum, current_platform
|
| 42 |
+
from fastvideo.training.activation_checkpoint import (apply_activation_checkpointing)
|
| 43 |
+
from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers)
|
| 44 |
+
from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases,
|
| 45 |
+
compute_density_for_timestep_sampling, count_trainable, get_scheduler,
|
| 46 |
+
get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint,
|
| 47 |
+
swap_fp4_linear, traverse_swap_module)
|
| 48 |
+
from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
vsa_available = is_vsa_available()
|
| 52 |
+
vmoba_available = is_vmoba_available()
|
| 53 |
+
except Exception:
|
| 54 |
+
vsa_available = False
|
| 55 |
+
vmoba_available = False
|
| 56 |
+
|
| 57 |
+
logger = init_logger(__name__)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TrainingPipeline(LoRAPipeline, ABC):
|
| 61 |
+
"""
|
| 62 |
+
A pipeline for training a model. All training pipelines should inherit from this class.
|
| 63 |
+
All reusable components and code should be implemented in this class.
|
| 64 |
+
"""
|
| 65 |
+
_required_config_modules = ["scheduler", "transformer"]
|
| 66 |
+
validation_pipeline: ComposedPipelineBase
|
| 67 |
+
train_dataloader: StatefulDataLoader
|
| 68 |
+
train_loader_iter: Iterator[dict[str, Any]]
|
| 69 |
+
current_epoch: int = 0
|
| 70 |
+
train_transformer_2: bool = False
|
| 71 |
+
tracker: TrackerType
|
| 72 |
+
|
| 73 |
+
def __init__(self,
|
| 74 |
+
model_path: str,
|
| 75 |
+
fastvideo_args: TrainingArgs,
|
| 76 |
+
required_config_modules: list[str] | None = None,
|
| 77 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> None:
|
| 78 |
+
fastvideo_args.inference_mode = False
|
| 79 |
+
self.lora_training = fastvideo_args.lora_training
|
| 80 |
+
if self.lora_training and fastvideo_args.lora_rank is None:
|
| 81 |
+
raise ValueError("lora rank must be set when using lora training")
|
| 82 |
+
|
| 83 |
+
set_random_seed(fastvideo_args.seed) # for lora param init
|
| 84 |
+
super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules) # type: ignore
|
| 85 |
+
self.tracker = DummyTracker()
|
| 86 |
+
|
| 87 |
+
def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
|
| 88 |
+
raise RuntimeError("create_pipeline_stages should not be called for training pipeline")
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def _should_force_generator_attn_qat_train(fastvideo_args: FastVideoArgs) -> bool:
|
| 92 |
+
if not isinstance(fastvideo_args, TrainingArgs):
|
| 93 |
+
return False
|
| 94 |
+
return (fastvideo_args.generator_4bit_attn or envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN")
|
| 95 |
+
|
| 96 |
+
def load_modules(self,
|
| 97 |
+
fastvideo_args: FastVideoArgs,
|
| 98 |
+
loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]:
|
| 99 |
+
force_generator_qat = self._should_force_generator_attn_qat_train(fastvideo_args)
|
| 100 |
+
load_context: AbstractContextManager[None] = nullcontext()
|
| 101 |
+
if force_generator_qat:
|
| 102 |
+
logger.info("Forcing generator attention backend to ATTN_QAT_TRAIN during module loading")
|
| 103 |
+
load_context = global_force_attn_backend_context_manager(AttentionBackendEnum.ATTN_QAT_TRAIN)
|
| 104 |
+
|
| 105 |
+
with load_context:
|
| 106 |
+
return super().load_modules(fastvideo_args, loaded_modules)
|
| 107 |
+
|
| 108 |
+
def set_schemas(self) -> None:
|
| 109 |
+
self.train_dataset_schema = pyarrow_schema_t2v
|
| 110 |
+
|
| 111 |
+
def initialize_training_pipeline(self, training_args: TrainingArgs):
|
| 112 |
+
logger.info("Initializing training pipeline...")
|
| 113 |
+
self.device = get_local_torch_device()
|
| 114 |
+
self.training_args = training_args
|
| 115 |
+
world_group = get_world_group()
|
| 116 |
+
self.world_size = world_group.world_size
|
| 117 |
+
self.global_rank = world_group.rank
|
| 118 |
+
self.sp_group = get_sp_group()
|
| 119 |
+
self.rank_in_sp_group = self.sp_group.rank_in_group
|
| 120 |
+
self.sp_world_size = self.sp_group.world_size
|
| 121 |
+
self.local_rank = world_group.local_rank
|
| 122 |
+
self.transformer = self.get_module("transformer")
|
| 123 |
+
self.transformer_2 = self.get_module("transformer_2", None)
|
| 124 |
+
self.seed = training_args.seed
|
| 125 |
+
self.set_schemas()
|
| 126 |
+
|
| 127 |
+
# Set random seeds for deterministic training
|
| 128 |
+
assert self.seed is not None, "seed must be set"
|
| 129 |
+
set_random_seed(self.seed + self.global_rank)
|
| 130 |
+
self.transformer.train()
|
| 131 |
+
if training_args.enable_gradient_checkpointing_type is not None:
|
| 132 |
+
self.transformer = apply_activation_checkpointing(
|
| 133 |
+
self.transformer, checkpointing_type=training_args.enable_gradient_checkpointing_type)
|
| 134 |
+
if self.transformer_2 is not None:
|
| 135 |
+
self.transformer_2 = apply_activation_checkpointing(
|
| 136 |
+
self.transformer_2, checkpointing_type=training_args.enable_gradient_checkpointing_type)
|
| 137 |
+
|
| 138 |
+
if training_args.generator_4bit_linear:
|
| 139 |
+
num_swaps = traverse_swap_module(self.transformer, swap_fn=swap_fp4_linear)
|
| 140 |
+
logger.info("Swapped %s linear layers to the FP4 forward path in self.transformer", num_swaps)
|
| 141 |
+
noise_scheduler = self.modules["scheduler"]
|
| 142 |
+
self.set_trainable()
|
| 143 |
+
params_to_optimize = self.transformer.parameters()
|
| 144 |
+
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
|
| 145 |
+
# Parse betas from string format "beta1,beta2"
|
| 146 |
+
betas_str = training_args.betas
|
| 147 |
+
betas = tuple(float(x.strip()) for x in betas_str.split(","))
|
| 148 |
+
|
| 149 |
+
self.optimizer = torch.optim.AdamW(
|
| 150 |
+
params_to_optimize,
|
| 151 |
+
lr=training_args.learning_rate,
|
| 152 |
+
betas=betas,
|
| 153 |
+
weight_decay=training_args.weight_decay,
|
| 154 |
+
eps=1e-8,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.init_steps = 0
|
| 158 |
+
logger.info("optimizer: %s", self.optimizer)
|
| 159 |
+
|
| 160 |
+
self.lr_scheduler = get_scheduler(
|
| 161 |
+
training_args.lr_scheduler,
|
| 162 |
+
optimizer=self.optimizer,
|
| 163 |
+
num_warmup_steps=training_args.lr_warmup_steps,
|
| 164 |
+
num_training_steps=training_args.max_train_steps,
|
| 165 |
+
num_cycles=training_args.lr_num_cycles,
|
| 166 |
+
power=training_args.lr_power,
|
| 167 |
+
min_lr_ratio=training_args.min_lr_ratio,
|
| 168 |
+
last_epoch=self.init_steps - 1,
|
| 169 |
+
)
|
| 170 |
+
if self.transformer_2 is not None:
|
| 171 |
+
# Ensure transformer_2 has trainable parameters before creating optimizer
|
| 172 |
+
params_to_optimize_2 = self.transformer_2.parameters()
|
| 173 |
+
params_to_optimize_2 = list(filter(lambda p: p.requires_grad, params_to_optimize_2))
|
| 174 |
+
self.optimizer_2 = torch.optim.AdamW(
|
| 175 |
+
params_to_optimize_2,
|
| 176 |
+
lr=training_args.learning_rate,
|
| 177 |
+
betas=(0.9, 0.999),
|
| 178 |
+
weight_decay=training_args.weight_decay,
|
| 179 |
+
eps=1e-8,
|
| 180 |
+
)
|
| 181 |
+
self.lr_scheduler_2 = get_scheduler(
|
| 182 |
+
training_args.lr_scheduler,
|
| 183 |
+
optimizer=self.optimizer_2,
|
| 184 |
+
num_warmup_steps=training_args.lr_warmup_steps,
|
| 185 |
+
num_training_steps=training_args.max_train_steps,
|
| 186 |
+
num_cycles=training_args.lr_num_cycles,
|
| 187 |
+
power=training_args.lr_power,
|
| 188 |
+
min_lr_ratio=training_args.min_lr_ratio,
|
| 189 |
+
last_epoch=self.init_steps - 1,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
|
| 193 |
+
training_args.data_path,
|
| 194 |
+
training_args.train_batch_size,
|
| 195 |
+
parquet_schema=self.train_dataset_schema,
|
| 196 |
+
num_data_workers=training_args.dataloader_num_workers,
|
| 197 |
+
cfg_rate=training_args.training_cfg_rate,
|
| 198 |
+
drop_last=True,
|
| 199 |
+
text_padding_length=training_args.pipeline_config.text_encoder_configs[0].arch_config.
|
| 200 |
+
text_len, # type: ignore[attr-defined]
|
| 201 |
+
seed=self.seed)
|
| 202 |
+
|
| 203 |
+
self.noise_scheduler = noise_scheduler
|
| 204 |
+
if self.training_args.boundary_ratio is not None:
|
| 205 |
+
self.boundary_timestep = self.training_args.boundary_ratio * self.noise_scheduler.num_train_timesteps
|
| 206 |
+
else:
|
| 207 |
+
self.boundary_timestep = None
|
| 208 |
+
|
| 209 |
+
logger.info("train_dataloader length: %s", len(self.train_dataloader))
|
| 210 |
+
logger.info("train_sp_batch_size: %s", training_args.train_sp_batch_size)
|
| 211 |
+
logger.info("gradient_accumulation_steps: %s", training_args.gradient_accumulation_steps)
|
| 212 |
+
logger.info("sp_size: %s", training_args.sp_size)
|
| 213 |
+
|
| 214 |
+
self.num_update_steps_per_epoch = math.ceil(
|
| 215 |
+
len(self.train_dataloader) / training_args.gradient_accumulation_steps * training_args.sp_size /
|
| 216 |
+
training_args.train_sp_batch_size)
|
| 217 |
+
self.num_train_epochs = math.ceil(training_args.max_train_steps / self.num_update_steps_per_epoch)
|
| 218 |
+
|
| 219 |
+
# TODO(will): is there a cleaner way to track epochs?
|
| 220 |
+
self.current_epoch = 0
|
| 221 |
+
|
| 222 |
+
trackers = list(training_args.trackers)
|
| 223 |
+
if not trackers and training_args.tracker_project_name:
|
| 224 |
+
trackers.append(Trackers.WANDB.value)
|
| 225 |
+
if self.global_rank != 0:
|
| 226 |
+
trackers = []
|
| 227 |
+
|
| 228 |
+
tracker_log_dir = training_args.output_dir or os.getcwd()
|
| 229 |
+
if trackers:
|
| 230 |
+
tracker_log_dir = os.path.join(tracker_log_dir, "tracker")
|
| 231 |
+
|
| 232 |
+
tracker_config = asdict(training_args) if trackers else None
|
| 233 |
+
tracker_run_name = training_args.wandb_run_name or None
|
| 234 |
+
project = training_args.tracker_project_name or "fastvideo"
|
| 235 |
+
self.tracker = initialize_trackers(
|
| 236 |
+
trackers,
|
| 237 |
+
experiment_name=project,
|
| 238 |
+
config=tracker_config,
|
| 239 |
+
log_dir=tracker_log_dir,
|
| 240 |
+
run_name=tracker_run_name,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
@abstractmethod
|
| 244 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 245 |
+
raise NotImplementedError("Training pipelines must implement this method")
|
| 246 |
+
|
| 247 |
+
def _prepare_training(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 248 |
+
self.optimizer.zero_grad()
|
| 249 |
+
if self.transformer_2 is not None:
|
| 250 |
+
self.optimizer_2.zero_grad()
|
| 251 |
+
training_batch.total_loss = 0.0
|
| 252 |
+
return training_batch
|
| 253 |
+
|
| 254 |
+
def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 255 |
+
with self.tracker.timed("timing/get_next_batch"):
|
| 256 |
+
batch = next(self.train_loader_iter, None) # type: ignore
|
| 257 |
+
if batch is None:
|
| 258 |
+
self.current_epoch += 1
|
| 259 |
+
logger.info("Starting epoch %s", self.current_epoch)
|
| 260 |
+
# Reset iterator for next epoch
|
| 261 |
+
self.train_loader_iter = iter(self.train_dataloader)
|
| 262 |
+
# Get first batch of new epoch
|
| 263 |
+
batch = next(self.train_loader_iter)
|
| 264 |
+
|
| 265 |
+
latents = batch['vae_latent']
|
| 266 |
+
latents = latents[:, :, :self.training_args.num_latent_t]
|
| 267 |
+
encoder_hidden_states = batch['text_embedding']
|
| 268 |
+
encoder_attention_mask = batch['text_attention_mask']
|
| 269 |
+
infos = batch['info_list']
|
| 270 |
+
|
| 271 |
+
training_batch.latents = latents.to(
|
| 272 |
+
get_local_torch_device(),
|
| 273 |
+
dtype=torch.bfloat16,
|
| 274 |
+
non_blocking=True,
|
| 275 |
+
)
|
| 276 |
+
training_batch.encoder_hidden_states = (encoder_hidden_states.to(
|
| 277 |
+
get_local_torch_device(),
|
| 278 |
+
dtype=torch.bfloat16,
|
| 279 |
+
non_blocking=True,
|
| 280 |
+
))
|
| 281 |
+
training_batch.encoder_attention_mask = (encoder_attention_mask.to(
|
| 282 |
+
get_local_torch_device(),
|
| 283 |
+
dtype=torch.bfloat16,
|
| 284 |
+
non_blocking=True,
|
| 285 |
+
))
|
| 286 |
+
training_batch.infos = infos
|
| 287 |
+
|
| 288 |
+
return training_batch
|
| 289 |
+
|
| 290 |
+
def _normalize_dit_input(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 291 |
+
# TODO(will): support other models
|
| 292 |
+
with self.tracker.timed("timing/normalize_input"):
|
| 293 |
+
training_batch.latents = normalize_dit_input(
|
| 294 |
+
'wan',
|
| 295 |
+
training_batch.latents,
|
| 296 |
+
self.get_module("vae"),
|
| 297 |
+
)
|
| 298 |
+
return training_batch
|
| 299 |
+
|
| 300 |
+
def _prepare_dit_inputs(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 301 |
+
assert self.training_args is not None, "training_args must be set"
|
| 302 |
+
with self.tracker.timed("timing/prepare_dit_inputs"):
|
| 303 |
+
latents = training_batch.latents
|
| 304 |
+
batch_size = latents.shape[0]
|
| 305 |
+
noise = torch.randn(latents.shape,
|
| 306 |
+
generator=self.noise_gen_cuda,
|
| 307 |
+
device=latents.device,
|
| 308 |
+
dtype=latents.dtype)
|
| 309 |
+
timesteps = self._sample_timesteps(batch_size, latents.device)
|
| 310 |
+
|
| 311 |
+
if self.training_args.sp_size > 1:
|
| 312 |
+
# Make sure that the timesteps are the same across all sp processes.
|
| 313 |
+
sp_group = get_sp_group()
|
| 314 |
+
sp_group.broadcast(timesteps, src=0)
|
| 315 |
+
sp_group.broadcast(noise, src=0)
|
| 316 |
+
sigmas = get_sigmas(
|
| 317 |
+
self.noise_scheduler,
|
| 318 |
+
latents.device,
|
| 319 |
+
timesteps,
|
| 320 |
+
n_dim=latents.ndim,
|
| 321 |
+
dtype=latents.dtype,
|
| 322 |
+
)
|
| 323 |
+
noisy_model_input = (1.0 - sigmas) * training_batch.latents + sigmas * noise
|
| 324 |
+
|
| 325 |
+
training_batch.noisy_model_input = noisy_model_input
|
| 326 |
+
training_batch.timesteps = timesteps
|
| 327 |
+
training_batch.sigmas = sigmas
|
| 328 |
+
training_batch.noise = noise
|
| 329 |
+
training_batch.raw_latent_shape = training_batch.latents.shape
|
| 330 |
+
|
| 331 |
+
return training_batch
|
| 332 |
+
|
| 333 |
+
def _sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 334 |
+
# Determine which model to train based on the boundary timestep
|
| 335 |
+
if (self.transformer_2 is not None and self.boundary_timestep is not None
|
| 336 |
+
and torch.rand(1, generator=self.noise_random_generator).item() <= self.training_args.boundary_ratio):
|
| 337 |
+
self.train_transformer_2 = True
|
| 338 |
+
else:
|
| 339 |
+
self.train_transformer_2 = False
|
| 340 |
+
|
| 341 |
+
# Broadcast the decision to all processes
|
| 342 |
+
decision = torch.tensor(1.0 if self.train_transformer_2 else 0.0, device=self.device)
|
| 343 |
+
dist.broadcast(decision, src=0)
|
| 344 |
+
self.train_transformer_2 = decision.item() == 1.0
|
| 345 |
+
|
| 346 |
+
# Sample u from the appropriate range
|
| 347 |
+
u = compute_density_for_timestep_sampling(
|
| 348 |
+
weighting_scheme=self.training_args.weighting_scheme,
|
| 349 |
+
batch_size=batch_size,
|
| 350 |
+
generator=self.noise_random_generator,
|
| 351 |
+
logit_mean=self.training_args.logit_mean,
|
| 352 |
+
logit_std=self.training_args.logit_std,
|
| 353 |
+
mode_scale=self.training_args.mode_scale,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
boundary_ratio = self.training_args.boundary_ratio
|
| 357 |
+
if self.train_transformer_2:
|
| 358 |
+
u = (1 - boundary_ratio) + u * boundary_ratio # min: 1 - boundary_ratio, max: 1
|
| 359 |
+
# elif self.transformer_2 is not None:
|
| 360 |
+
# u = u * (1 - boundary_ratio) # min: 0, max: 1 - boundary_ratio
|
| 361 |
+
# else: # patch for now to align with non-MoE timestep logic
|
| 362 |
+
# pass
|
| 363 |
+
|
| 364 |
+
indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
|
| 365 |
+
return self.noise_scheduler.timesteps[indices].to(device=device)
|
| 366 |
+
|
| 367 |
+
def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 368 |
+
latents_shape = training_batch.raw_latent_shape
|
| 369 |
+
patch_size = self.training_args.pipeline_config.dit_config.patch_size
|
| 370 |
+
current_vsa_sparsity = training_batch.current_vsa_sparsity
|
| 371 |
+
assert latents_shape is not None
|
| 372 |
+
assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
|
| 373 |
+
assert training_batch.timesteps is not None
|
| 374 |
+
if envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 375 |
+
"VIDEO_SPARSE_ATTN",
|
| 376 |
+
"SPARSE_FP4_ATTN",
|
| 377 |
+
"SPARSE_FP4_OURS_P_ATTN",
|
| 378 |
+
):
|
| 379 |
+
if not vsa_available:
|
| 380 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VIDEO_SPARSE_ATTN, "
|
| 381 |
+
"but fastvideo_kernel is not correctly installed or detected. "
|
| 382 |
+
"Please ensure fastvideo-kernel is installed.")
|
| 383 |
+
training_batch.attn_metadata = VideoSparseAttentionMetadataBuilder( # type: ignore
|
| 384 |
+
).build( # type: ignore
|
| 385 |
+
raw_latent_shape=latents_shape[2:5],
|
| 386 |
+
current_timestep=training_batch.timesteps,
|
| 387 |
+
patch_size=patch_size,
|
| 388 |
+
VSA_sparsity=current_vsa_sparsity,
|
| 389 |
+
device=get_local_torch_device())
|
| 390 |
+
elif envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
|
| 391 |
+
if not vmoba_available:
|
| 392 |
+
raise ImportError("FASTVIDEO_ATTENTION_BACKEND is set to VMOBA_ATTN, "
|
| 393 |
+
"but fastvideo_kernel (or flash_attn>=2.7.4) is not correctly installed.")
|
| 394 |
+
moba_params = self.training_args.moba_config.copy()
|
| 395 |
+
moba_params.update({
|
| 396 |
+
"current_timestep": training_batch.timesteps,
|
| 397 |
+
"raw_latent_shape": latents_shape[2:5],
|
| 398 |
+
"patch_size": self.training_args.pipeline_config.dit_config.patch_size,
|
| 399 |
+
"device": get_local_torch_device(),
|
| 400 |
+
})
|
| 401 |
+
training_batch.attn_metadata = VideoMobaAttentionMetadataBuilder().build(**moba_params)
|
| 402 |
+
else:
|
| 403 |
+
training_batch.attn_metadata = None
|
| 404 |
+
|
| 405 |
+
return training_batch
|
| 406 |
+
|
| 407 |
+
def _build_input_kwargs(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 408 |
+
training_batch.input_kwargs = {
|
| 409 |
+
"hidden_states": training_batch.noisy_model_input,
|
| 410 |
+
"encoder_hidden_states": training_batch.encoder_hidden_states,
|
| 411 |
+
"timestep": training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16),
|
| 412 |
+
"encoder_attention_mask": training_batch.encoder_attention_mask,
|
| 413 |
+
"return_dict": False,
|
| 414 |
+
}
|
| 415 |
+
return training_batch
|
| 416 |
+
|
| 417 |
+
def _transformer_forward_and_compute_loss(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 418 |
+
if vsa_available and envs.FASTVIDEO_ATTENTION_BACKEND in (
|
| 419 |
+
"VIDEO_SPARSE_ATTN",
|
| 420 |
+
"SPARSE_FP4_ATTN",
|
| 421 |
+
"SPARSE_FP4_OURS_P_ATTN",
|
| 422 |
+
) or vmoba_available and envs.FASTVIDEO_ATTENTION_BACKEND == "VMOBA_ATTN":
|
| 423 |
+
assert training_batch.attn_metadata is not None
|
| 424 |
+
else:
|
| 425 |
+
assert training_batch.attn_metadata is None
|
| 426 |
+
input_kwargs = training_batch.input_kwargs
|
| 427 |
+
|
| 428 |
+
# if 'hunyuan' in self.training_args.model_type:
|
| 429 |
+
# input_kwargs["guidance"] = torch.tensor(
|
| 430 |
+
# [1000.0],
|
| 431 |
+
# device=training_batch.noisy_model_input.device,
|
| 432 |
+
# dtype=torch.bfloat16)
|
| 433 |
+
current_model = self.transformer_2 if self.train_transformer_2 else self.transformer
|
| 434 |
+
|
| 435 |
+
with self.tracker.timed("timing/forward_backward"), set_forward_context(
|
| 436 |
+
current_timestep=training_batch.current_timestep, attn_metadata=training_batch.attn_metadata):
|
| 437 |
+
model_pred = current_model(**input_kwargs)
|
| 438 |
+
if self.training_args.precondition_outputs:
|
| 439 |
+
assert training_batch.sigmas is not None
|
| 440 |
+
model_pred = training_batch.noisy_model_input - model_pred * training_batch.sigmas
|
| 441 |
+
assert training_batch.latents is not None
|
| 442 |
+
assert training_batch.noise is not None
|
| 443 |
+
target = training_batch.latents if self.training_args.precondition_outputs else training_batch.noise - training_batch.latents
|
| 444 |
+
|
| 445 |
+
# make sure no implicit broadcasting happens
|
| 446 |
+
assert model_pred.shape == target.shape, f"model_pred.shape: {model_pred.shape}, target.shape: {target.shape}"
|
| 447 |
+
|
| 448 |
+
loss = (torch.mean(
|
| 449 |
+
(model_pred.float() - target.float())**2) / self.training_args.gradient_accumulation_steps)
|
| 450 |
+
|
| 451 |
+
loss.backward()
|
| 452 |
+
|
| 453 |
+
avg_loss = loss.detach().clone()
|
| 454 |
+
|
| 455 |
+
# Reduce across ranks without forcing a CPU sync
|
| 456 |
+
with self.tracker.timed("timing/reduce_loss"):
|
| 457 |
+
world_group = get_world_group()
|
| 458 |
+
avg_loss = world_group.all_reduce(avg_loss, op=dist.ReduceOp.AVG)
|
| 459 |
+
# Accumulate on GPU; materialize to CPU only once after
|
| 460 |
+
# all gradient-accumulation iterations (see train_one_step).
|
| 461 |
+
training_batch.total_loss += avg_loss
|
| 462 |
+
|
| 463 |
+
return training_batch
|
| 464 |
+
|
| 465 |
+
def _clip_grad_norm(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 466 |
+
max_grad_norm = self.training_args.max_grad_norm
|
| 467 |
+
|
| 468 |
+
# TODO(will): perhaps move this into transformer api so that we can do
|
| 469 |
+
# the following:
|
| 470 |
+
# grad_norm = transformer.clip_grad_norm_(max_grad_norm)
|
| 471 |
+
if max_grad_norm is not None:
|
| 472 |
+
with self.tracker.timed("timing/clip_grad_norm"):
|
| 473 |
+
# Only clip gradients for the model that is currently training
|
| 474 |
+
if self.train_transformer_2 and self.transformer_2 is not None:
|
| 475 |
+
model_parts = [self.transformer_2]
|
| 476 |
+
else:
|
| 477 |
+
model_parts = [self.transformer]
|
| 478 |
+
|
| 479 |
+
grad_norm = clip_grad_norm_while_handling_failing_dtensor_cases(
|
| 480 |
+
[p for m in model_parts for p in m.parameters()],
|
| 481 |
+
max_grad_norm,
|
| 482 |
+
foreach=None,
|
| 483 |
+
)
|
| 484 |
+
assert grad_norm is not float('nan') or grad_norm is not float('inf')
|
| 485 |
+
grad_norm = grad_norm.item() if grad_norm is not None else 0.0
|
| 486 |
+
else:
|
| 487 |
+
grad_norm = 0.0
|
| 488 |
+
training_batch.grad_norm = grad_norm
|
| 489 |
+
return training_batch
|
| 490 |
+
|
| 491 |
+
@profile_region("profiler_region_training_train_one_step")
|
| 492 |
+
def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
|
| 493 |
+
training_batch = self._prepare_training(training_batch)
|
| 494 |
+
|
| 495 |
+
for _ in range(self.training_args.gradient_accumulation_steps):
|
| 496 |
+
training_batch = self._get_next_batch(training_batch)
|
| 497 |
+
|
| 498 |
+
# Normalize DIT input
|
| 499 |
+
training_batch = self._normalize_dit_input(training_batch)
|
| 500 |
+
# Create noisy model input
|
| 501 |
+
training_batch = self._prepare_dit_inputs(training_batch)
|
| 502 |
+
assert training_batch.latents is not None
|
| 503 |
+
assert training_batch.noisy_model_input is not None
|
| 504 |
+
assert training_batch.noise is not None
|
| 505 |
+
|
| 506 |
+
# old sharding code, need to shard latents and noise but not input
|
| 507 |
+
# Shard latents across sp groups
|
| 508 |
+
training_batch.latents = training_batch.latents[:, :, :self.training_args.num_latent_t]
|
| 509 |
+
# shard noisy_model_input to match
|
| 510 |
+
training_batch.noisy_model_input = training_batch.noisy_model_input[:, :, :self.training_args.num_latent_t]
|
| 511 |
+
# shard noise to match latents
|
| 512 |
+
training_batch.noise = training_batch.noise[:, :, :self.training_args.num_latent_t]
|
| 513 |
+
|
| 514 |
+
training_batch = self._build_attention_metadata(training_batch)
|
| 515 |
+
training_batch = self._build_input_kwargs(training_batch)
|
| 516 |
+
|
| 517 |
+
training_batch = self._transformer_forward_and_compute_loss(training_batch)
|
| 518 |
+
|
| 519 |
+
training_batch = self._clip_grad_norm(training_batch)
|
| 520 |
+
|
| 521 |
+
# Only step the optimizer and scheduler for the model that is currently training
|
| 522 |
+
with self.tracker.timed("timing/optimizer_step"):
|
| 523 |
+
if self.train_transformer_2 and self.transformer_2 is not None:
|
| 524 |
+
self.optimizer_2.step()
|
| 525 |
+
self.lr_scheduler_2.step()
|
| 526 |
+
else:
|
| 527 |
+
self.optimizer.step()
|
| 528 |
+
self.lr_scheduler.step()
|
| 529 |
+
|
| 530 |
+
return training_batch
|
| 531 |
+
|
| 532 |
+
def _compute_current_sparsity(self, step: int) -> float:
|
| 533 |
+
"""Compute the VSA sparsity for a given step using the decay schedule."""
|
| 534 |
+
vsa_sparsity = self.training_args.VSA_sparsity
|
| 535 |
+
vsa_decay_rate = self.training_args.VSA_decay_rate
|
| 536 |
+
vsa_decay_interval = self.training_args.VSA_decay_interval_steps
|
| 537 |
+
vsa_init = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
|
| 538 |
+
vsa_warmup = getattr(self.training_args, 'VSA_warmup_steps', 0)
|
| 539 |
+
if step <= vsa_warmup:
|
| 540 |
+
return vsa_init
|
| 541 |
+
ramp_step = step - vsa_warmup
|
| 542 |
+
max_times = int((vsa_sparsity - vsa_init) / vsa_decay_rate) if vsa_decay_rate > 0 else 0
|
| 543 |
+
times = min(ramp_step // vsa_decay_interval, max_times)
|
| 544 |
+
return vsa_init + times * vsa_decay_rate
|
| 545 |
+
|
| 546 |
+
def _resolve_checkpoint_path(self, path: str) -> str | None:
|
| 547 |
+
"""Resolve 'latest' to the most recent checkpoint in output_dir."""
|
| 548 |
+
import glob
|
| 549 |
+
if path == "latest":
|
| 550 |
+
output_dir = self.training_args.output_dir
|
| 551 |
+
ckpt_dirs = sorted(
|
| 552 |
+
glob.glob(os.path.join(output_dir, "checkpoint-*")),
|
| 553 |
+
key=lambda d: int(d.split("-")[-1]) if d.split("-")[-1].isdigit() else 0,
|
| 554 |
+
)
|
| 555 |
+
if ckpt_dirs:
|
| 556 |
+
latest = ckpt_dirs[-1]
|
| 557 |
+
logger.info("Auto-resolved 'latest' to %s", latest)
|
| 558 |
+
return latest
|
| 559 |
+
logger.info("No checkpoints found in %s, starting from scratch", output_dir)
|
| 560 |
+
return None
|
| 561 |
+
return path
|
| 562 |
+
|
| 563 |
+
def _resume_from_checkpoint(self) -> None:
|
| 564 |
+
ckpt_path = self._resolve_checkpoint_path(self.training_args.resume_from_checkpoint)
|
| 565 |
+
if ckpt_path is None:
|
| 566 |
+
logger.info("No checkpoint to resume from, starting from step 0")
|
| 567 |
+
return
|
| 568 |
+
|
| 569 |
+
safetensors_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model.safetensors")
|
| 570 |
+
step = int(os.path.basename(os.path.normpath(ckpt_path)).split('-')[-1])
|
| 571 |
+
|
| 572 |
+
resumed_step = load_checkpoint(self.transformer, self.global_rank, ckpt_path,
|
| 573 |
+
self.optimizer, self.train_dataloader,
|
| 574 |
+
self.lr_scheduler, self.noise_random_generator)
|
| 575 |
+
if resumed_step > 0 or step == 0:
|
| 576 |
+
self.init_steps = resumed_step
|
| 577 |
+
logger.info("Successfully resumed full training state from step %s", resumed_step)
|
| 578 |
+
return
|
| 579 |
+
|
| 580 |
+
if os.path.exists(safetensors_path):
|
| 581 |
+
self.init_steps = step
|
| 582 |
+
logger.warning("Distributed checkpoint resume failed; falling back to safetensors weights at step %s",
|
| 583 |
+
step)
|
| 584 |
+
return
|
| 585 |
+
|
| 586 |
+
logger.warning("No usable checkpoint state found at %s; starting from step 0", ckpt_path)
|
| 587 |
+
self.init_steps = 0
|
| 588 |
+
|
| 589 |
+
@profile_region("profiler_region_training_train")
|
| 590 |
+
def train(self) -> None:
|
| 591 |
+
assert self.seed is not None, "seed must be set"
|
| 592 |
+
assert self.training_args is not None, "training_args must be set"
|
| 593 |
+
set_random_seed(self.seed + self.global_rank)
|
| 594 |
+
logger.info('rank: %s: start training', self.global_rank, local_main_process_only=False)
|
| 595 |
+
if not self.post_init_called:
|
| 596 |
+
self.post_init()
|
| 597 |
+
num_trainable_params = count_trainable(self.transformer)
|
| 598 |
+
logger.info("Starting training with %s B trainable parameters", round(num_trainable_params / 1e9, 3))
|
| 599 |
+
|
| 600 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 601 |
+
num_trainable_params = count_trainable(self.transformer_2)
|
| 602 |
+
logger.info("Transformer 2: Starting training with %s B trainable parameters",
|
| 603 |
+
round(num_trainable_params / 1e9, 3))
|
| 604 |
+
|
| 605 |
+
# Set random seeds for deterministic training
|
| 606 |
+
self.noise_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
|
| 607 |
+
self.noise_gen_cuda = torch.Generator(device=current_platform.device_name).manual_seed(self.seed +
|
| 608 |
+
self.global_rank)
|
| 609 |
+
self.validation_random_generator = torch.Generator(device="cpu").manual_seed(self.seed + self.global_rank)
|
| 610 |
+
logger.info("Initialized random seeds with seed: %s", self.seed + self.global_rank)
|
| 611 |
+
self.noise_scheduler = FlowMatchEulerDiscreteScheduler()
|
| 612 |
+
|
| 613 |
+
if self.training_args.resume_from_checkpoint:
|
| 614 |
+
self._resume_from_checkpoint()
|
| 615 |
+
|
| 616 |
+
self.train_loader_iter = iter(self.train_dataloader)
|
| 617 |
+
|
| 618 |
+
step_times: deque[float] = deque(maxlen=100)
|
| 619 |
+
|
| 620 |
+
self._log_training_info()
|
| 621 |
+
|
| 622 |
+
# Validation at init uses the sparsity corresponding to init_steps
|
| 623 |
+
saved_sparsity = self.training_args.VSA_sparsity
|
| 624 |
+
self.training_args.VSA_sparsity = self._compute_current_sparsity(self.init_steps)
|
| 625 |
+
self._log_validation(self.transformer, self.training_args, self.init_steps)
|
| 626 |
+
self.training_args.VSA_sparsity = saved_sparsity
|
| 627 |
+
|
| 628 |
+
# Train!
|
| 629 |
+
progress_bar = tqdm(
|
| 630 |
+
range(0, self.training_args.max_train_steps),
|
| 631 |
+
initial=self.init_steps,
|
| 632 |
+
desc="Steps",
|
| 633 |
+
# Only show the progress bar once on each machine.
|
| 634 |
+
disable=self.local_rank > 0,
|
| 635 |
+
)
|
| 636 |
+
for step in range(self.init_steps + 1, self.training_args.max_train_steps + 1):
|
| 637 |
+
start_time = time.perf_counter()
|
| 638 |
+
if vsa_available:
|
| 639 |
+
vsa_sparsity = self.training_args.VSA_sparsity
|
| 640 |
+
vsa_decay_rate = self.training_args.VSA_decay_rate
|
| 641 |
+
vsa_decay_interval_steps = self.training_args.VSA_decay_interval_steps
|
| 642 |
+
vsa_init_sparsity = getattr(self.training_args, 'VSA_init_sparsity', 0.0)
|
| 643 |
+
vsa_warmup_steps = getattr(self.training_args, 'VSA_warmup_steps', 0)
|
| 644 |
+
if step <= vsa_warmup_steps:
|
| 645 |
+
current_vsa_sparsity = vsa_init_sparsity
|
| 646 |
+
else:
|
| 647 |
+
ramp_step = step - vsa_warmup_steps
|
| 648 |
+
max_decay_times = int((vsa_sparsity - vsa_init_sparsity) / vsa_decay_rate)
|
| 649 |
+
current_decay_times = min(ramp_step // vsa_decay_interval_steps, max_decay_times)
|
| 650 |
+
current_vsa_sparsity = vsa_init_sparsity + current_decay_times * vsa_decay_rate
|
| 651 |
+
elif vmoba_available:
|
| 652 |
+
#TODO: add vmoba sparsity scheduling here
|
| 653 |
+
current_vsa_sparsity = 0.0
|
| 654 |
+
else:
|
| 655 |
+
current_vsa_sparsity = 0.0
|
| 656 |
+
|
| 657 |
+
training_batch = TrainingBatch()
|
| 658 |
+
training_batch.current_timestep = step
|
| 659 |
+
training_batch.current_vsa_sparsity = current_vsa_sparsity
|
| 660 |
+
training_batch = self.train_one_step(training_batch)
|
| 661 |
+
|
| 662 |
+
loss = float(training_batch.total_loss)
|
| 663 |
+
grad_norm = training_batch.grad_norm
|
| 664 |
+
|
| 665 |
+
step_time = time.perf_counter() - start_time
|
| 666 |
+
step_times.append(step_time)
|
| 667 |
+
avg_step_time = sum(step_times) / len(step_times)
|
| 668 |
+
|
| 669 |
+
progress_bar.set_postfix({
|
| 670 |
+
"loss": f"{loss:.4f}",
|
| 671 |
+
"step_time": f"{step_time:.2f}s",
|
| 672 |
+
"grad_norm": grad_norm,
|
| 673 |
+
})
|
| 674 |
+
progress_bar.update(1)
|
| 675 |
+
if self.global_rank == 0:
|
| 676 |
+
metrics = {
|
| 677 |
+
"train_loss": loss,
|
| 678 |
+
"learning_rate": self.lr_scheduler.get_last_lr()[0],
|
| 679 |
+
"step_time": step_time,
|
| 680 |
+
"avg_step_time": avg_step_time,
|
| 681 |
+
"grad_norm": grad_norm,
|
| 682 |
+
"vsa_sparsity": current_vsa_sparsity,
|
| 683 |
+
}
|
| 684 |
+
try:
|
| 685 |
+
assert training_batch.raw_latent_shape is not None
|
| 686 |
+
metrics["batch_size"] = int(training_batch.raw_latent_shape[0])
|
| 687 |
+
|
| 688 |
+
patch_size = self.training_args.pipeline_config.dit_config.patch_size
|
| 689 |
+
assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}"
|
| 690 |
+
patch_t, patch_h, patch_w = patch_size
|
| 691 |
+
seq_len = (training_batch.raw_latent_shape[2] // patch_t) * (
|
| 692 |
+
training_batch.raw_latent_shape[3] // patch_h) * (training_batch.raw_latent_shape[4] // patch_w)
|
| 693 |
+
if training_batch.encoder_hidden_states is not None:
|
| 694 |
+
context_len = int(training_batch.encoder_hidden_states.shape[1])
|
| 695 |
+
else:
|
| 696 |
+
context_len = 0
|
| 697 |
+
|
| 698 |
+
metrics["dit_seq_len"] = int(seq_len)
|
| 699 |
+
metrics["context_len"] = context_len
|
| 700 |
+
|
| 701 |
+
arch_config = self.training_args.pipeline_config.dit_config.arch_config
|
| 702 |
+
|
| 703 |
+
metrics["hidden_dim"] = arch_config.hidden_size
|
| 704 |
+
metrics["num_layers"] = arch_config.num_layers
|
| 705 |
+
metrics["ffn_dim"] = arch_config.ffn_dim
|
| 706 |
+
except Exception:
|
| 707 |
+
pass
|
| 708 |
+
|
| 709 |
+
self.tracker.log(metrics, step)
|
| 710 |
+
if step % self.training_args.training_state_checkpointing_steps == 0:
|
| 711 |
+
with self.profiler_controller.region("profiler_region_training_save_checkpoint"):
|
| 712 |
+
save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, step,
|
| 713 |
+
self.optimizer, self.train_dataloader, self.lr_scheduler,
|
| 714 |
+
self.noise_random_generator,
|
| 715 |
+
self.training_args.checkpoints_total_limit)
|
| 716 |
+
self.transformer.train()
|
| 717 |
+
self.sp_group.barrier()
|
| 718 |
+
|
| 719 |
+
if self.training_args.log_visualization and step % self.training_args.visualization_steps == 0:
|
| 720 |
+
self.visualize_intermediate_latents(training_batch, self.training_args, step)
|
| 721 |
+
|
| 722 |
+
if self.training_args.log_validation and step % self.training_args.validation_steps == 0:
|
| 723 |
+
with self.profiler_controller.region("profiler_region_training_validation"):
|
| 724 |
+
saved_sparsity = self.training_args.VSA_sparsity
|
| 725 |
+
self.training_args.VSA_sparsity = current_vsa_sparsity
|
| 726 |
+
self._log_validation(self.transformer, self.training_args, step)
|
| 727 |
+
self.training_args.VSA_sparsity = saved_sparsity
|
| 728 |
+
gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
|
| 729 |
+
trainable_params = round(count_trainable(self.transformer) / 1e9, 3)
|
| 730 |
+
logger.info("GPU memory usage after validation: %s MB, trainable params: %sB", gpu_memory_usage,
|
| 731 |
+
trainable_params)
|
| 732 |
+
|
| 733 |
+
self.tracker.finish()
|
| 734 |
+
save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
|
| 735 |
+
self.training_args.max_train_steps, self.optimizer, self.train_dataloader, self.lr_scheduler,
|
| 736 |
+
self.noise_random_generator, self.training_args.checkpoints_total_limit)
|
| 737 |
+
|
| 738 |
+
if envs.FASTVIDEO_TORCH_PROFILER_DIR:
|
| 739 |
+
logger.info("Stopping profiler...")
|
| 740 |
+
self.profiler_controller.stop()
|
| 741 |
+
logger.info("Profiler stopped.")
|
| 742 |
+
|
| 743 |
+
if get_sp_group():
|
| 744 |
+
cleanup_dist_env_and_memory()
|
| 745 |
+
|
| 746 |
+
def _log_training_info(self) -> None:
|
| 747 |
+
assert self.training_args is not None, "training_args must be set"
|
| 748 |
+
total_batch_size = (self.world_size * self.training_args.gradient_accumulation_steps /
|
| 749 |
+
self.training_args.sp_size * self.training_args.train_sp_batch_size)
|
| 750 |
+
logger.info("***** Running training *****")
|
| 751 |
+
logger.info(" Num examples = %s", len(self.train_dataset))
|
| 752 |
+
logger.info(" Dataloader size = %s", len(self.train_dataloader))
|
| 753 |
+
logger.info(" Num Epochs = %s", self.num_train_epochs)
|
| 754 |
+
logger.info(" Resume training from step %s", self.init_steps) # type: ignore
|
| 755 |
+
logger.info(" Instantaneous batch size per device = %s", self.training_args.train_batch_size)
|
| 756 |
+
logger.info(" Total train batch size (w. data & sequence parallel, accumulation) = %s", total_batch_size)
|
| 757 |
+
logger.info(" Gradient Accumulation steps = %s", self.training_args.gradient_accumulation_steps)
|
| 758 |
+
logger.info(" Total optimization steps = %s", self.training_args.max_train_steps)
|
| 759 |
+
logger.info(" Total training parameters per FSDP shard = %s B",
|
| 760 |
+
round(count_trainable(self.transformer) / 1e9, 3))
|
| 761 |
+
# print dtype
|
| 762 |
+
logger.info(" Master weight dtype: %s", self.transformer.parameters().__next__().dtype)
|
| 763 |
+
|
| 764 |
+
gpu_memory_usage = current_platform.get_torch_device().memory_allocated() / 1024**2
|
| 765 |
+
logger.info("GPU memory usage before train_one_step: %s MB", gpu_memory_usage)
|
| 766 |
+
logger.info("VSA validation sparsity: %s", self.training_args.VSA_sparsity)
|
| 767 |
+
|
| 768 |
+
def _prepare_validation_batch(self, sampling_param: SamplingParam, training_args: TrainingArgs,
|
| 769 |
+
validation_batch: dict[str, Any], num_inference_steps: int) -> ForwardBatch:
|
| 770 |
+
sampling_param.prompt = validation_batch['prompt']
|
| 771 |
+
sampling_param.height = training_args.num_height
|
| 772 |
+
sampling_param.width = training_args.num_width
|
| 773 |
+
sampling_param.num_inference_steps = num_inference_steps
|
| 774 |
+
sampling_param.data_type = "video"
|
| 775 |
+
if training_args.validation_guidance_scale:
|
| 776 |
+
sampling_param.guidance_scale = float(training_args.validation_guidance_scale)
|
| 777 |
+
assert self.seed is not None
|
| 778 |
+
sampling_param.seed = self.seed
|
| 779 |
+
|
| 780 |
+
latents_size = [(sampling_param.num_frames - 1) // 4 + 1, sampling_param.height // 8, sampling_param.width // 8]
|
| 781 |
+
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
|
| 782 |
+
temporal_compression_factor = training_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio
|
| 783 |
+
num_frames = (training_args.num_latent_t - 1) * temporal_compression_factor + 1
|
| 784 |
+
sampling_param.num_frames = num_frames
|
| 785 |
+
batch = ForwardBatch(
|
| 786 |
+
**shallow_asdict(sampling_param),
|
| 787 |
+
latents=None,
|
| 788 |
+
generator=self.validation_random_generator,
|
| 789 |
+
n_tokens=n_tokens,
|
| 790 |
+
eta=0.0,
|
| 791 |
+
VSA_sparsity=training_args.VSA_sparsity,
|
| 792 |
+
)
|
| 793 |
+
|
| 794 |
+
return batch
|
| 795 |
+
|
| 796 |
+
@torch.no_grad()
|
| 797 |
+
def _log_validation(self, transformer, training_args, global_step) -> None:
|
| 798 |
+
"""
|
| 799 |
+
Generate a validation video and log it to the configured tracker to check the quality during training.
|
| 800 |
+
"""
|
| 801 |
+
training_args.inference_mode = True
|
| 802 |
+
training_args.dit_cpu_offload = False
|
| 803 |
+
if not training_args.log_validation:
|
| 804 |
+
return
|
| 805 |
+
if self.validation_pipeline is None:
|
| 806 |
+
raise ValueError("Validation pipeline is not set")
|
| 807 |
+
|
| 808 |
+
logger.info("Starting validation")
|
| 809 |
+
|
| 810 |
+
# Create sampling parameters if not provided
|
| 811 |
+
sampling_param = SamplingParam.from_pretrained(training_args.model_path)
|
| 812 |
+
|
| 813 |
+
# Prepare validation prompts
|
| 814 |
+
logger.info('rank: %s: fastvideo_args.validation_dataset_file: %s',
|
| 815 |
+
self.global_rank,
|
| 816 |
+
training_args.validation_dataset_file,
|
| 817 |
+
local_main_process_only=False)
|
| 818 |
+
validation_dataset = ValidationDataset(training_args.validation_dataset_file)
|
| 819 |
+
validation_dataloader = DataLoader(validation_dataset, batch_size=None, num_workers=0)
|
| 820 |
+
|
| 821 |
+
self.transformer.eval()
|
| 822 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 823 |
+
self.transformer_2.eval()
|
| 824 |
+
|
| 825 |
+
validation_steps = training_args.validation_sampling_steps.split(",")
|
| 826 |
+
validation_steps = [int(step) for step in validation_steps]
|
| 827 |
+
validation_steps = [step for step in validation_steps if step > 0]
|
| 828 |
+
# Log validation results for this step
|
| 829 |
+
world_group = get_world_group()
|
| 830 |
+
num_sp_groups = world_group.world_size // self.sp_group.world_size
|
| 831 |
+
one_prompt_per_rank = os.environ.get(
|
| 832 |
+
"FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK",
|
| 833 |
+
"",
|
| 834 |
+
).lower() in {"1", "true", "yes", "on"}
|
| 835 |
+
|
| 836 |
+
# Process each validation prompt for each validation step
|
| 837 |
+
for num_inference_steps in validation_steps:
|
| 838 |
+
logger.info("rank: %s: num_inference_steps: %s",
|
| 839 |
+
self.global_rank,
|
| 840 |
+
num_inference_steps,
|
| 841 |
+
local_main_process_only=False)
|
| 842 |
+
step_videos: list[np.ndarray] = []
|
| 843 |
+
step_captions: list[str] = []
|
| 844 |
+
|
| 845 |
+
step_audio: list[np.ndarray | None] = []
|
| 846 |
+
step_sample_rates: list[int | None] = []
|
| 847 |
+
|
| 848 |
+
for prompt_idx, validation_batch in enumerate(validation_dataloader):
|
| 849 |
+
if one_prompt_per_rank and prompt_idx > 0:
|
| 850 |
+
continue
|
| 851 |
+
|
| 852 |
+
batch = self._prepare_validation_batch(sampling_param, training_args, validation_batch,
|
| 853 |
+
num_inference_steps)
|
| 854 |
+
logger.info("rank: %s: rank_in_sp_group: %s, batch.prompt: %s",
|
| 855 |
+
self.global_rank,
|
| 856 |
+
self.rank_in_sp_group,
|
| 857 |
+
batch.prompt,
|
| 858 |
+
local_main_process_only=False)
|
| 859 |
+
|
| 860 |
+
assert batch.prompt is not None and isinstance(batch.prompt, str)
|
| 861 |
+
step_captions.append(batch.prompt)
|
| 862 |
+
|
| 863 |
+
# Run validation inference
|
| 864 |
+
output_batch = self.validation_pipeline.forward(batch, training_args)
|
| 865 |
+
samples = output_batch.output.cpu()
|
| 866 |
+
|
| 867 |
+
# Capture audio if available
|
| 868 |
+
audio = output_batch.extra.get("audio")
|
| 869 |
+
sample_rate = output_batch.extra.get("audio_sample_rate")
|
| 870 |
+
|
| 871 |
+
if audio is not None and torch.is_tensor(audio):
|
| 872 |
+
audio = audio.detach().cpu().float().numpy()
|
| 873 |
+
|
| 874 |
+
step_audio.append(audio)
|
| 875 |
+
step_sample_rates.append(sample_rate)
|
| 876 |
+
|
| 877 |
+
if self.rank_in_sp_group != 0:
|
| 878 |
+
continue
|
| 879 |
+
|
| 880 |
+
# Process outputs
|
| 881 |
+
video = rearrange(samples, "b c t h w -> t b c h w")
|
| 882 |
+
frames = []
|
| 883 |
+
for x in video:
|
| 884 |
+
x = torchvision.utils.make_grid(x, nrow=6)
|
| 885 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 886 |
+
frames.append((x * 255).numpy().astype(np.uint8))
|
| 887 |
+
step_videos.append(frames)
|
| 888 |
+
|
| 889 |
+
# Only sp_group leaders (rank_in_sp_group == 0) need to send their
|
| 890 |
+
# results to global rank 0
|
| 891 |
+
if self.rank_in_sp_group == 0 and self.global_rank == 0:
|
| 892 |
+
# Global rank 0 collects results from all sp_group leaders
|
| 893 |
+
all_videos = step_videos # Start with own results
|
| 894 |
+
all_captions = step_captions
|
| 895 |
+
all_audios = step_audio
|
| 896 |
+
all_sample_rates = step_sample_rates
|
| 897 |
+
|
| 898 |
+
# Receive from other sp_group leaders
|
| 899 |
+
for sp_group_idx in range(1, num_sp_groups):
|
| 900 |
+
src_rank = sp_group_idx * self.sp_world_size # Global rank of other sp_group leaders
|
| 901 |
+
recv_videos = world_group.recv_object(src=src_rank)
|
| 902 |
+
recv_captions = world_group.recv_object(src=src_rank)
|
| 903 |
+
recv_audios = world_group.recv_object(src=src_rank)
|
| 904 |
+
recv_sample_rates = world_group.recv_object(src=src_rank)
|
| 905 |
+
|
| 906 |
+
all_videos.extend(recv_videos)
|
| 907 |
+
all_captions.extend(recv_captions)
|
| 908 |
+
all_audios.extend(recv_audios)
|
| 909 |
+
all_sample_rates.extend(recv_sample_rates)
|
| 910 |
+
|
| 911 |
+
video_filenames = []
|
| 912 |
+
for i, (video, caption, audio, sample_rate) in enumerate(
|
| 913 |
+
zip(all_videos, all_captions, all_audios, all_sample_rates, strict=True)):
|
| 914 |
+
os.makedirs(training_args.output_dir, exist_ok=True)
|
| 915 |
+
filename = os.path.join(
|
| 916 |
+
training_args.output_dir,
|
| 917 |
+
f"validation_step_{global_step}_inference_steps_{num_inference_steps}_video_{i}.mp4")
|
| 918 |
+
imageio.mimsave(filename, video, fps=sampling_param.fps)
|
| 919 |
+
# Mux audio if available
|
| 920 |
+
if (audio is not None and sample_rate is not None and not self._mux_audio(
|
| 921 |
+
filename,
|
| 922 |
+
audio,
|
| 923 |
+
sample_rate,
|
| 924 |
+
)):
|
| 925 |
+
logger.warning("Audio mux failed for validation video %s; saved video without audio.", filename)
|
| 926 |
+
video_filenames.append(filename)
|
| 927 |
+
|
| 928 |
+
artifacts = []
|
| 929 |
+
for filename, caption in zip(video_filenames, all_captions, strict=True):
|
| 930 |
+
video_artifact = self.tracker.video(filename, caption=caption)
|
| 931 |
+
if video_artifact is not None:
|
| 932 |
+
artifacts.append(video_artifact)
|
| 933 |
+
if artifacts:
|
| 934 |
+
logs = {f"validation_videos_{num_inference_steps}_steps": artifacts}
|
| 935 |
+
self.tracker.log_artifacts(logs, global_step)
|
| 936 |
+
elif self.rank_in_sp_group == 0:
|
| 937 |
+
# Other sp_group leaders send their results to global rank 0
|
| 938 |
+
world_group.send_object(step_videos, dst=0)
|
| 939 |
+
world_group.send_object(step_captions, dst=0)
|
| 940 |
+
world_group.send_object(step_audio, dst=0)
|
| 941 |
+
world_group.send_object(step_sample_rates, dst=0)
|
| 942 |
+
|
| 943 |
+
world_group.barrier()
|
| 944 |
+
|
| 945 |
+
# Re-enable gradients for training
|
| 946 |
+
training_args.inference_mode = False
|
| 947 |
+
self.transformer.train()
|
| 948 |
+
if getattr(self, "transformer_2", None) is not None:
|
| 949 |
+
self.transformer_2.train()
|
| 950 |
+
|
| 951 |
+
@staticmethod
|
| 952 |
+
def _mux_audio(
|
| 953 |
+
video_path: str,
|
| 954 |
+
audio: torch.Tensor | np.ndarray,
|
| 955 |
+
sample_rate: int,
|
| 956 |
+
) -> bool:
|
| 957 |
+
"""Mux audio into video using PyAV."""
|
| 958 |
+
try:
|
| 959 |
+
import av
|
| 960 |
+
except ImportError:
|
| 961 |
+
logger.warning("PyAV not installed; cannot mux audio. "
|
| 962 |
+
"Install with: pip install av")
|
| 963 |
+
return False
|
| 964 |
+
|
| 965 |
+
if torch.is_tensor(audio):
|
| 966 |
+
audio_np = audio.detach().cpu().float().numpy()
|
| 967 |
+
else:
|
| 968 |
+
audio_np = np.asarray(audio, dtype=np.float32)
|
| 969 |
+
|
| 970 |
+
if audio_np.ndim == 1:
|
| 971 |
+
audio_np = audio_np[:, None]
|
| 972 |
+
elif audio_np.ndim == 2:
|
| 973 |
+
if audio_np.shape[0] <= 8 and audio_np.shape[1] > audio_np.shape[0]:
|
| 974 |
+
audio_np = audio_np.T
|
| 975 |
+
else:
|
| 976 |
+
logger.warning("Unexpected audio shape %s; skipping mux.", audio_np.shape)
|
| 977 |
+
return False
|
| 978 |
+
|
| 979 |
+
audio_np = np.clip(audio_np, -1.0, 1.0)
|
| 980 |
+
audio_int16 = (audio_np * 32767.0).astype(np.int16)
|
| 981 |
+
num_channels = audio_int16.shape[1]
|
| 982 |
+
layout = "stereo" if num_channels == 2 else "mono"
|
| 983 |
+
|
| 984 |
+
try:
|
| 985 |
+
import wave
|
| 986 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 987 |
+
out_path = os.path.join(tmpdir, "muxed.mp4")
|
| 988 |
+
wav_path = os.path.join(tmpdir, "audio.wav")
|
| 989 |
+
|
| 990 |
+
# Write audio to WAV file
|
| 991 |
+
with wave.open(wav_path, "wb") as wav_file:
|
| 992 |
+
wav_file.setnchannels(num_channels)
|
| 993 |
+
wav_file.setsampwidth(2)
|
| 994 |
+
wav_file.setframerate(sample_rate)
|
| 995 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 996 |
+
|
| 997 |
+
# Open input video and audio
|
| 998 |
+
input_video = av.open(video_path)
|
| 999 |
+
input_audio = av.open(wav_path)
|
| 1000 |
+
|
| 1001 |
+
# Create output with both streams
|
| 1002 |
+
output = av.open(out_path, mode="w")
|
| 1003 |
+
|
| 1004 |
+
# Add video stream (copy codec from input)
|
| 1005 |
+
in_video_stream = input_video.streams.video[0]
|
| 1006 |
+
out_video_stream = output.add_stream(
|
| 1007 |
+
codec_name=in_video_stream.codec_context.name,
|
| 1008 |
+
rate=in_video_stream.average_rate,
|
| 1009 |
+
)
|
| 1010 |
+
out_video_stream.width = in_video_stream.width
|
| 1011 |
+
out_video_stream.height = in_video_stream.height
|
| 1012 |
+
out_video_stream.pix_fmt = in_video_stream.pix_fmt
|
| 1013 |
+
|
| 1014 |
+
# Add audio stream (AAC)
|
| 1015 |
+
out_audio_stream = output.add_stream("aac", rate=sample_rate)
|
| 1016 |
+
out_audio_stream.layout = layout
|
| 1017 |
+
|
| 1018 |
+
# Remux video (decode and re-encode to be safe)
|
| 1019 |
+
for frame in input_video.decode(video=0):
|
| 1020 |
+
for packet in out_video_stream.encode(frame):
|
| 1021 |
+
output.mux(packet)
|
| 1022 |
+
for packet in out_video_stream.encode():
|
| 1023 |
+
output.mux(packet)
|
| 1024 |
+
|
| 1025 |
+
# Encode audio
|
| 1026 |
+
for frame in input_audio.decode(audio=0):
|
| 1027 |
+
frame.pts = None # Let encoder assign PTS
|
| 1028 |
+
for packet in out_audio_stream.encode(frame):
|
| 1029 |
+
output.mux(packet)
|
| 1030 |
+
for packet in out_audio_stream.encode():
|
| 1031 |
+
output.mux(packet)
|
| 1032 |
+
|
| 1033 |
+
input_video.close()
|
| 1034 |
+
input_audio.close()
|
| 1035 |
+
output.close()
|
| 1036 |
+
shutil.move(out_path, video_path)
|
| 1037 |
+
return True
|
| 1038 |
+
except Exception as e:
|
| 1039 |
+
logger.warning("Audio mux failed: %s", e)
|
| 1040 |
+
return False
|
| 1041 |
+
|
| 1042 |
+
def visualize_intermediate_latents(self, training_batch: TrainingBatch, training_args: TrainingArgs, step: int):
|
| 1043 |
+
"""Add visualization data to tracker logging and save frames to disk."""
|
| 1044 |
+
raise NotImplementedError("Visualize intermediate latents is not implemented for training pipeline")
|
backend_snapshot/fastvideo/training/wan_training_pipeline.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 2 |
+
import sys
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
|
| 5 |
+
from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs
|
| 6 |
+
from fastvideo.logger import init_logger
|
| 7 |
+
from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (FlowUniPCMultistepScheduler)
|
| 8 |
+
from fastvideo.pipelines.basic.wan.wan_pipeline import WanPipeline
|
| 9 |
+
from fastvideo.training.training_pipeline import TrainingPipeline
|
| 10 |
+
from fastvideo.utils import is_vsa_available
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
vsa_available = is_vsa_available()
|
| 14 |
+
except Exception:
|
| 15 |
+
vsa_available = False
|
| 16 |
+
|
| 17 |
+
logger = init_logger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class WanTrainingPipeline(TrainingPipeline):
|
| 21 |
+
"""
|
| 22 |
+
A training pipeline for Wan.
|
| 23 |
+
"""
|
| 24 |
+
_required_config_modules = ["scheduler", "transformer", "vae"]
|
| 25 |
+
|
| 26 |
+
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
|
| 27 |
+
self.modules["scheduler"] = FlowUniPCMultistepScheduler(shift=fastvideo_args.pipeline_config.flow_shift)
|
| 28 |
+
|
| 29 |
+
def create_training_stages(self, training_args: TrainingArgs):
|
| 30 |
+
"""
|
| 31 |
+
May be used in future refactors.
|
| 32 |
+
"""
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def initialize_validation_pipeline(self, training_args: TrainingArgs):
|
| 36 |
+
logger.info("Initializing validation pipeline...")
|
| 37 |
+
args_copy = deepcopy(training_args)
|
| 38 |
+
|
| 39 |
+
args_copy.inference_mode = True
|
| 40 |
+
validation_pipeline = WanPipeline.from_pretrained(
|
| 41 |
+
training_args.model_path,
|
| 42 |
+
args=args_copy, # type: ignore
|
| 43 |
+
inference_mode=True,
|
| 44 |
+
loaded_modules={
|
| 45 |
+
"transformer": self.get_module("transformer"),
|
| 46 |
+
},
|
| 47 |
+
tp_size=training_args.tp_size,
|
| 48 |
+
sp_size=training_args.sp_size,
|
| 49 |
+
num_gpus=training_args.num_gpus,
|
| 50 |
+
pin_cpu_memory=training_args.pin_cpu_memory,
|
| 51 |
+
dit_cpu_offload=True)
|
| 52 |
+
|
| 53 |
+
self.validation_pipeline = validation_pipeline
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main(args) -> None:
|
| 57 |
+
logger.info("Starting training pipeline...")
|
| 58 |
+
|
| 59 |
+
pipeline = WanTrainingPipeline.from_pretrained(args.pretrained_model_name_or_path, args=args)
|
| 60 |
+
args = pipeline.training_args
|
| 61 |
+
pipeline.train()
|
| 62 |
+
logger.info("Training pipeline done")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
argv = sys.argv
|
| 67 |
+
from fastvideo.fastvideo_args import TrainingArgs
|
| 68 |
+
from fastvideo.utils import FlexibleArgumentParser
|
| 69 |
+
parser = FlexibleArgumentParser()
|
| 70 |
+
parser = TrainingArgs.add_cli_args(parser)
|
| 71 |
+
parser = FastVideoArgs.add_cli_args(parser)
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
args.dit_cpu_offload = False
|
| 74 |
+
main(args)
|
backend_snapshot/manifest.sha256
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
45ff4b677a84fad92bd2ff596bf432cb1b9386c5923b6c0824f896074e7cfbc6 ./README.md
|
| 2 |
+
9d1d8dc58aab529270fe31eb1735d6a1382c0c6d36fccca122a8dbffa1b714fd ./fastvideo-kernel/python/fastvideo_kernel/block_sparse_attn_ours_p.py
|
| 3 |
+
211c7f0445fbe9488250f01fa83457c6620e83bd6f3877db791fd155de93c08b ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton_ours_p.py
|
| 4 |
+
3f3a407a88612ea17ad65e1b6b9cf6b7b02df56956d8301c4b13bffa92095016 ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py
|
| 5 |
+
56f17c602dede53c7c3677058f81274681530f1b83c086d9d1d44c6b51feefbb ./fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py
|
| 6 |
+
2b821b0e2e7bdb3581be6312ebbece42380a6ee28a7a982f0cf2dc71fab849c8 ./fastvideo/attention/backends/sparse_fp4_ours_p_attn.py
|
| 7 |
+
a97adcc52d7558c49f418c09395fd1665e988ad290d2276b95f21dfca0f8eb7d ./fastvideo/attention/backends/video_sparse_attn.py
|
| 8 |
+
79ef6f38ec0f5bfe16b2b98327ad2ccd15f3c863dd87fd03affc5dbdaa0a8224 ./fastvideo/configs/models/dits/base.py
|
| 9 |
+
ddcab6f4fd33c9813840571b6bf83bbbcea164b564166951ed4301297db6cef0 ./fastvideo/forward_context.py
|
| 10 |
+
6cfd128e782b7787a27ddd28a5e2d50cb4b0e2e9425d51d9780f14c91e8206f0 ./fastvideo/pipelines/stages/denoising.py
|
| 11 |
+
489388dbdd9e5e3ad24db3012bd9b108794509a9729891d7dd315a102abba828 ./fastvideo/platforms/cuda.py
|
| 12 |
+
c046b1914041b59254bcdfe577aed20d6f007a72632ea1fe1ae92fa678eca760 ./fastvideo/platforms/interface.py
|
| 13 |
+
2456d39ca28019e12bb7ab007774e86348f0582a017bf0e6c91e2a01d654a1a0 ./fastvideo/train/models/wan/wan.py
|
| 14 |
+
bc46e84b732567de6c0325223405daecd1226c623e303be33c7be9b5b7fdec08 ./fastvideo/training/training_pipeline.py
|
| 15 |
+
1d3898fa37e21029df6c37e05dc34ed7805a211c2f87de6642db890e5a8c6f2e ./fastvideo/training/wan_training_pipeline.py
|
| 16 |
+
5c982b64653fae83ebfdeb43fda8f29b3e2cb581fb4daee38cd3cf56aa9d73f5 ./scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
|
| 17 |
+
5c1d5ce9ecc8b90e59ddfc2ddb3e2dae500bcd3acb90429c901444b1630f05fb ./scripts/training/run_sparse_fp4_train_v4_common.sh
|
backend_snapshot/scripts/training/run_sparse_fp4_train_v4_1n_sparse09_hpo_on_ours_p_init2050_interactive.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=sfp4-s09-oursp-i2050
|
| 3 |
+
#SBATCH --account=nvr_elm_llm
|
| 4 |
+
#SBATCH --partition=interactive
|
| 5 |
+
#SBATCH --nodes=1
|
| 6 |
+
#SBATCH --gres=gpu:8
|
| 7 |
+
#SBATCH --ntasks-per-node=1
|
| 8 |
+
#SBATCH --cpus-per-task=128
|
| 9 |
+
#SBATCH --mem=1440G
|
| 10 |
+
#SBATCH --time=02:00:00
|
| 11 |
+
#SBATCH --output=slurm_logs/sfp4_sparse09_ours_p_init2050_1n_interactive_%j.out
|
| 12 |
+
#SBATCH --error=slurm_logs/sfp4_sparse09_ours_p_init2050_1n_interactive_%j.err
|
| 13 |
+
|
| 14 |
+
export RUN_NAME="sfp4_v4_sparse09_hpo_on_ours_p_init2050_1n_interactive"
|
| 15 |
+
export WANDB_RUN_ID="sfp4v4-sparse09-hpo-on-ours-p-init2050-1n-interactive"
|
| 16 |
+
export FASTVIDEO_ATTENTION_BACKEND=SPARSE_FP4_OURS_P_ATTN
|
| 17 |
+
export FASTVIDEO_SPARSE_FP4_USE_HIGH_PREC_O=1
|
| 18 |
+
export CHECKPOINT_LIMIT=5
|
| 19 |
+
export SAVE_STEPS=50
|
| 20 |
+
export EVAL_STEPS=50
|
| 21 |
+
export VALIDATION_SAMPLING_STEPS=50
|
| 22 |
+
export USE_SRUN=0
|
| 23 |
+
|
| 24 |
+
export VSA_SPARSITY=0.9
|
| 25 |
+
export VSA_INIT_SPARSITY=0.9
|
| 26 |
+
export VSA_WARMUP_STEPS=0
|
| 27 |
+
export VSA_DECAY_RATE=0.03
|
| 28 |
+
export VSA_DECAY_INTERVAL_STEPS=50
|
| 29 |
+
|
| 30 |
+
export INIT_WEIGHTS_FROM_SAFETENSORS="checkpoints/init/sfp4_v4_sparse06_hpo_on_ours_p_1n_interactive_v2_ckpt2050/transformer/diffusion_pytorch_model.safetensors"
|
| 31 |
+
|
| 32 |
+
exec bash scripts/training/run_sparse_fp4_train_v4_common.sh
|
backend_snapshot/scripts/training/run_sparse_fp4_train_v4_common.sh
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
set -euo pipefail
|
| 4 |
+
set -x
|
| 5 |
+
|
| 6 |
+
cd /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/users/yitongl/code/FastVideo
|
| 7 |
+
source .venv/bin/activate
|
| 8 |
+
|
| 9 |
+
: "${RUN_NAME:?RUN_NAME must be set by the Slurm wrapper}"
|
| 10 |
+
: "${WANDB_RUN_ID:?WANDB_RUN_ID must be set by the Slurm wrapper}"
|
| 11 |
+
|
| 12 |
+
export PYTHONPATH=fastvideo-kernel/python:fastvideo-kernel:${PYTHONPATH:-}
|
| 13 |
+
export FASTVIDEO_ATTENTION_BACKEND="${FASTVIDEO_ATTENTION_BACKEND:-SPARSE_FP4_ATTN}"
|
| 14 |
+
export FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK="${FASTVIDEO_VALIDATION_ONE_PROMPT_PER_RANK:-1}"
|
| 15 |
+
export WANDB_MODE=online
|
| 16 |
+
export WANDB_BASE_URL="https://api.wandb.ai"
|
| 17 |
+
export WANDB_RESUME=allow
|
| 18 |
+
export WANDB_NAME="${RUN_NAME}"
|
| 19 |
+
export TOKENIZERS_PARALLELISM=false
|
| 20 |
+
export NCCL_P2P_DISABLE=1
|
| 21 |
+
export TORCH_NCCL_ENABLE_MONITORING=0
|
| 22 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 23 |
+
export TRITON_CACHE_DIR="/tmp/triton_cache_${SLURM_JOB_ID:-manual}"
|
| 24 |
+
|
| 25 |
+
if [[ -n "${SLURM_JOB_NODELIST:-}" ]]; then
|
| 26 |
+
MASTER_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1)
|
| 27 |
+
else
|
| 28 |
+
MASTER_ADDR=127.0.0.1
|
| 29 |
+
fi
|
| 30 |
+
MASTER_PORT=$((20000 + (${SLURM_JOB_ID:-0} % 20000)))
|
| 31 |
+
export MASTER_ADDR MASTER_PORT
|
| 32 |
+
|
| 33 |
+
NUM_GPUS_PER_NODE=8
|
| 34 |
+
NNODES=${SLURM_NNODES:-1}
|
| 35 |
+
TOTAL_GPUS=$((NNODES * NUM_GPUS_PER_NODE))
|
| 36 |
+
OUTPUT_DIR="checkpoints/${RUN_NAME}"
|
| 37 |
+
CHECKPOINT_LIMIT="${CHECKPOINT_LIMIT:-5}"
|
| 38 |
+
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
| 39 |
+
DATA_DIR="data/Wan-Syn_77x448x832_600k"
|
| 40 |
+
VALIDATION_DATASET_FILE="examples/training/finetune/Wan2.1-VSA/Wan-Syn-Data/validation_64.json"
|
| 41 |
+
SAVE_STEPS="${SAVE_STEPS:-50}"
|
| 42 |
+
EVAL_STEPS="${EVAL_STEPS:-50}"
|
| 43 |
+
VALIDATION_SAMPLING_STEPS="${VALIDATION_SAMPLING_STEPS:-50}"
|
| 44 |
+
MAX_TRAIN_STEPS="${MAX_TRAIN_STEPS:-100000}"
|
| 45 |
+
VSA_SPARSITY="${VSA_SPARSITY:-0.9}"
|
| 46 |
+
VSA_INIT_SPARSITY="${VSA_INIT_SPARSITY:-0.6}"
|
| 47 |
+
VSA_WARMUP_STEPS="${VSA_WARMUP_STEPS:-0}"
|
| 48 |
+
VSA_DECAY_RATE="${VSA_DECAY_RATE:-0.03}"
|
| 49 |
+
VSA_DECAY_INTERVAL_STEPS="${VSA_DECAY_INTERVAL_STEPS:-50}"
|
| 50 |
+
INIT_WEIGHTS_FROM_SAFETENSORS="${INIT_WEIGHTS_FROM_SAFETENSORS:-}"
|
| 51 |
+
|
| 52 |
+
mkdir -p slurm_logs "${OUTPUT_DIR}"
|
| 53 |
+
|
| 54 |
+
find_latest_checkpoint() {
|
| 55 |
+
if [[ ! -d "${OUTPUT_DIR}" ]]; then
|
| 56 |
+
return 1
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
mapfile -t checkpoint_steps < <(find "${OUTPUT_DIR}" -maxdepth 1 -type d -name 'checkpoint-*' -printf '%f\n' \
|
| 60 |
+
| sed 's/checkpoint-//' \
|
| 61 |
+
| sort -nr)
|
| 62 |
+
|
| 63 |
+
local step
|
| 64 |
+
for step in "${checkpoint_steps[@]}"; do
|
| 65 |
+
if [[ -f "${OUTPUT_DIR}/checkpoint-${step}/transformer/diffusion_pytorch_model.safetensors" ]]; then
|
| 66 |
+
echo "${OUTPUT_DIR}/checkpoint-${step}"
|
| 67 |
+
return 0
|
| 68 |
+
fi
|
| 69 |
+
done
|
| 70 |
+
|
| 71 |
+
return 1
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
prune_checkpoints() {
|
| 75 |
+
if [[ ! -d "${OUTPUT_DIR}" ]]; then
|
| 76 |
+
return 0
|
| 77 |
+
fi
|
| 78 |
+
|
| 79 |
+
mapfile -t checkpoint_steps < <(find "${OUTPUT_DIR}" -maxdepth 1 -type d -name 'checkpoint-*' -printf '%f\n' \
|
| 80 |
+
| sed 's/checkpoint-//' \
|
| 81 |
+
| sort -n)
|
| 82 |
+
|
| 83 |
+
local count=${#checkpoint_steps[@]}
|
| 84 |
+
if (( count <= CHECKPOINT_LIMIT )); then
|
| 85 |
+
return 0
|
| 86 |
+
fi
|
| 87 |
+
|
| 88 |
+
local remove_count=$((count - CHECKPOINT_LIMIT))
|
| 89 |
+
local step
|
| 90 |
+
for step in "${checkpoint_steps[@]:0:remove_count}"; do
|
| 91 |
+
rm -rf "${OUTPUT_DIR}/checkpoint-${step}"
|
| 92 |
+
done
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
RESUME_ARGS=()
|
| 96 |
+
refresh_resume_args() {
|
| 97 |
+
RESUME_ARGS=()
|
| 98 |
+
|
| 99 |
+
local latest_ckpt
|
| 100 |
+
if latest_ckpt=$(find_latest_checkpoint); then
|
| 101 |
+
RESUME_ARGS=(
|
| 102 |
+
--resume_from_checkpoint latest
|
| 103 |
+
--init_weights_from_safetensors "${latest_ckpt}/transformer/diffusion_pytorch_model.safetensors"
|
| 104 |
+
)
|
| 105 |
+
echo "=== Resuming from ${latest_ckpt} ==="
|
| 106 |
+
fi
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
COMMON_ARGS=(
|
| 110 |
+
--tracker_project_name "wan_t2v_sparse_fp4"
|
| 111 |
+
--wandb_run_name "${RUN_NAME}"
|
| 112 |
+
--output_dir "${OUTPUT_DIR}"
|
| 113 |
+
--train_batch_size 1
|
| 114 |
+
--train_sp_batch_size 1
|
| 115 |
+
--gradient_accumulation_steps 1
|
| 116 |
+
--num_latent_t 20
|
| 117 |
+
--num_height 448
|
| 118 |
+
--num_width 832
|
| 119 |
+
--num_frames 77
|
| 120 |
+
--enable_gradient_checkpointing_type "full"
|
| 121 |
+
--num_gpus "${TOTAL_GPUS}"
|
| 122 |
+
--sp_size 1
|
| 123 |
+
--tp_size 1
|
| 124 |
+
--hsdp_replicate_dim "${TOTAL_GPUS}"
|
| 125 |
+
--hsdp_shard_dim 1
|
| 126 |
+
--model_path "${MODEL_PATH}"
|
| 127 |
+
--pretrained_model_name_or_path "${MODEL_PATH}"
|
| 128 |
+
--data_path "${DATA_DIR}"
|
| 129 |
+
--dataloader_num_workers 4
|
| 130 |
+
--log_validation
|
| 131 |
+
--validation_dataset_file "${VALIDATION_DATASET_FILE}"
|
| 132 |
+
--validation_steps "${EVAL_STEPS}"
|
| 133 |
+
--validation_sampling_steps "${VALIDATION_SAMPLING_STEPS}"
|
| 134 |
+
--validation_guidance_scale "5.0"
|
| 135 |
+
--learning_rate 1e-6
|
| 136 |
+
--mixed_precision "bf16"
|
| 137 |
+
--weight_only_checkpointing_steps "${SAVE_STEPS}"
|
| 138 |
+
--training_state_checkpointing_steps "${SAVE_STEPS}"
|
| 139 |
+
--weight_decay 0.01
|
| 140 |
+
--max_grad_norm 1.0
|
| 141 |
+
--inference_mode False
|
| 142 |
+
--checkpoints_total_limit "${CHECKPOINT_LIMIT}"
|
| 143 |
+
--training_cfg_rate 0.1
|
| 144 |
+
--dit_precision "fp32"
|
| 145 |
+
--ema_start_step 0
|
| 146 |
+
--flow_shift 1
|
| 147 |
+
--seed 1000
|
| 148 |
+
--VSA-sparsity "${VSA_SPARSITY}"
|
| 149 |
+
--VSA-init-sparsity "${VSA_INIT_SPARSITY}"
|
| 150 |
+
--VSA-warmup-steps "${VSA_WARMUP_STEPS}"
|
| 151 |
+
--VSA-decay-rate "${VSA_DECAY_RATE}"
|
| 152 |
+
--VSA-decay-interval-steps "${VSA_DECAY_INTERVAL_STEPS}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if [[ -n "${INIT_WEIGHTS_FROM_SAFETENSORS}" ]]; then
|
| 156 |
+
COMMON_ARGS+=(
|
| 157 |
+
--init_weights_from_safetensors "${INIT_WEIGHTS_FROM_SAFETENSORS}"
|
| 158 |
+
)
|
| 159 |
+
fi
|
| 160 |
+
|
| 161 |
+
run_training() {
|
| 162 |
+
local max_steps=$1
|
| 163 |
+
|
| 164 |
+
local torchrun_cmd=(
|
| 165 |
+
torchrun
|
| 166 |
+
--nnodes="${NNODES}" \
|
| 167 |
+
--nproc_per_node="${NUM_GPUS_PER_NODE}" \
|
| 168 |
+
--rdzv_backend=c10d \
|
| 169 |
+
--rdzv_endpoint="${MASTER_ADDR}:${MASTER_PORT}" \
|
| 170 |
+
fastvideo/training/wan_training_pipeline.py \
|
| 171 |
+
--max_train_steps "${max_steps}" \
|
| 172 |
+
"${COMMON_ARGS[@]}" \
|
| 173 |
+
"${RESUME_ARGS[@]}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
if [[ "${USE_SRUN:-1}" == "1" ]]; then
|
| 177 |
+
srun "${torchrun_cmd[@]}"
|
| 178 |
+
else
|
| 179 |
+
"${torchrun_cmd[@]}"
|
| 180 |
+
fi
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
echo "=== ${RUN_NAME} array=${SLURM_ARRAY_TASK_ID:-0} nodes=${NNODES} gpus=${TOTAL_GPUS} ==="
|
| 184 |
+
echo "=== save_steps=${SAVE_STEPS} eval_steps=${EVAL_STEPS} checkpoint_limit=${CHECKPOINT_LIMIT} ==="
|
| 185 |
+
echo "=== master=${MASTER_ADDR}:${MASTER_PORT} validation_one_prompt_per_rank=1 ==="
|
| 186 |
+
|
| 187 |
+
if [[ ! -f "${OUTPUT_DIR}/checkpoint-0/transformer/diffusion_pytorch_model.safetensors" ]]; then
|
| 188 |
+
if ! find_latest_checkpoint >/dev/null; then
|
| 189 |
+
echo "=== Creating step-0 validation and checkpoint ==="
|
| 190 |
+
run_training 0
|
| 191 |
+
prune_checkpoints
|
| 192 |
+
fi
|
| 193 |
+
fi
|
| 194 |
+
|
| 195 |
+
refresh_resume_args
|
| 196 |
+
run_training "${MAX_TRAIN_STEPS}"
|
| 197 |
+
prune_checkpoints
|
| 198 |
+
|
| 199 |
+
echo "=== Done arr=${SLURM_ARRAY_TASK_ID:-0} ==="
|