Wan2GP / wan /modules /sage2_core.py
zxymimi23451's picture
Upload 258 files
78360e7 verified
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn.functional as F
from sageattention.triton.quant_per_block import per_block_int8 as per_block_int8_triton
from sageattention.triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton
from sageattention.triton.attn_qk_int8_per_block import forward as attn_false
from sageattention.triton.attn_qk_int8_per_block_causal import forward as attn_true
from sageattention.triton.attn_qk_int8_block_varlen import forward as attn_false_varlen
from sageattention.triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
from sageattention.triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton
try:
from sageattention import _qattn_sm80
SM80_ENABLED = True
except:
SM80_ENABLED = False
try:
from sageattention import _qattn_sm89
SM89_ENABLED = True
except:
SM89_ENABLED = False
try:
from sageattention import _qattn_sm90
SM90_ENABLED = True
except:
SM90_ENABLED = False
from sageattention.quant import per_block_int8 as per_block_int8_cuda
from sageattention.quant import per_warp_int8 as per_warp_int8_cuda
from sageattention.quant import sub_mean
from sageattention.quant import per_channel_fp8
from typing import Any, List, Literal, Optional, Tuple, Union
import warnings
import os
def is_sage2_supported():
device_count = torch.cuda.device_count()
for i in range(device_count):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
return False
return True
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn(
qkv_list,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
**kwargs: Any,
):
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
arch = get_cuda_arch_versions()[qkv_list[0].device.index]
if arch == "sm80":
return sageattn_qk_int8_pv_fp16_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
elif arch == "sm86":
return sageattn_qk_int8_pv_fp16_triton(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
elif arch == "sm89":
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(qkv_list, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32", smooth_v= True) # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")
@torch.compiler.disable
def sageattn_qk_int8_pv_fp16_triton(
qkv_list,
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
tensor_layout: str = "HND",
quantization_backend: str = "triton",
is_causal: bool =False,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with per-block INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using Triton.
The FP16 accumulator is added to a FP32 buffer immediately after each iteration.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
quantization_backend : str
The quantization backend, either "triton" or "cuda".
"cuda" backend offers better performance due to kernel fusion.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
q, k, v = qkv_list
qkv_list.clear()
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
seq_dim = 1 if tensor_layout == "NHD" else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if sm_scale is None:
sm_scale = 1.0 / (head_dim_og ** 0.5)
if quantization_backend == "triton":
q_int8, q_scale, k_int8, k_scale = per_block_int8_triton(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
elif quantization_backend == "cuda":
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
else:
raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
del q,k, km
if is_causal:
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
else:
o, lse = attn_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
o = o[..., :head_dim_og]
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
@torch.compiler.disable
def sageattn_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
is_causal: bool = False,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
**kwargs: Any,
) -> torch.Tensor:
"""
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
cu_seqlens_q : torch.Tensor
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
cu_seqlens_k : torch.Tensor
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
max_seqlen_q : int
The maximum sequence length for the query tensor in the batch.
max_seqlen_k : int
The maximum sequence length for the key and value tensors in the batch.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
Returns
-------
torch.Tensor
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), "cu_seqlens_q and cu_seqlens_k must be contiguous."
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if smooth_k:
km = k.mean(dim=0, keepdim=True) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
k = k - km
if sm_scale is None:
sm_scale = 1.0 / (head_dim_og ** 0.5)
q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale = per_block_int8_varlen_triton(q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale)
if is_causal:
o = attn_true_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
else:
o = attn_false_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
o = o[..., :head_dim_og]
return o
@torch.compiler.disable
def sageattn_qk_int8_pv_fp16_cuda(
qkv_list,
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
qk_quant_gran: str = "per_thread",
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with INT8 quantization for Q and K, FP16 PV with FP16/FP32 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
qk_quant_gran : str
The granularity of quantization for Q and K, either "per_warp" or "per_thread".
Default: "per_thread".
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
- "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
- "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
- "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp32".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
smooth_v : bool
Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
Default: False.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
q,k,v = qkv_list
qkv_list.clear()
dtype = q.dtype
assert SM80_ENABLED, "SM80 kernel is not available. make sure you GPUs with compute capability 8.0 or higher."
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim_og**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if qk_quant_gran == "per_warp":
q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64)
elif qk_quant_gran == "per_thread":
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=(16 if (q.size(-1) == 128 and pv_accum_dtype == "fp16+fp32") else 32), BLKK=64, WARPK=64)
q_size = q.size()
q_device = q.device
del q,k, km
o = torch.empty(q_size, dtype=dtype, device=q_device)
if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
smooth_v = False
if pv_accum_dtype == 'fp32':
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16":
if smooth_v:
smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
del v
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16+fp32":
v = v.to(torch.float16)
lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
o = o[..., :head_dim_og]
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
@torch.compiler.disable
def sageattn_qk_int8_pv_fp8_cuda(
qkv_list,
tensor_layout: str = "HND",
is_causal: bool = False,
qk_quant_gran: str = "per_thread",
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
qk_quant_gran : str
The granularity of quantization for Q and K, either "per_warp" or "per_thread".
Default: "per_thread".
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
- "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
- "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp32+fp32".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
smooth_v : bool
Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
Default: False.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
q, k, v = qkv_list
qkv_list.clear()
dtype = q.dtype
assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9."
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim_og**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if qk_quant_gran == "per_warp":
q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64)
elif qk_quant_gran == "per_thread":
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
q_size = q.size()
q_device = q.device
del q,k,km
if pv_accum_dtype == 'fp32+fp32' and smooth_v:
warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
smooth_v = False
v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v)
del v
o = torch.empty(q_size, dtype=dtype, device=q_device)
if pv_accum_dtype == "fp32":
if smooth_v:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
else:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
o = o[..., :head_dim_og]
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
@torch.compiler.disable
def sageattn_qk_int8_pv_fp8_window_cuda(
qkv_list,
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
qk_quant_gran: str = "per_thread",
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp32+fp32",
smooth_k: bool = True,
smooth_v: bool = False,
return_lse: bool = False,
window = -1,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
qk_quant_gran : str
The granularity of quantization for Q and K, either "per_warp" or "per_thread".
Default: "per_thread".
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
- "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
- "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp32+fp32".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
smooth_v : bool
Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
Default: False.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
q,k,v = qkv_list
qkv_list.clear()
dtype = q.dtype
assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9."
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
# FIXME(DefTruth): make sage attention work compatible with distributed
# env, for example, xDiT which launch by torchrun. Without this workaround,
# sage attention will run into illegal memory access error after first
# inference step in distributed env for multi gpus inference. This small
# workaround also make sage attention work compatible with torch.compile
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim_og**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if qk_quant_gran == "per_warp":
q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64)
elif qk_quant_gran == "per_thread":
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=128, WARPQ=32, BLKK=64, WARPK=64)
q_size = q.size()
q_device = q.device
del q,k
if pv_accum_dtype == 'fp32+fp32' and smooth_v:
warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
smooth_v = False
v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v)
del v
o = torch.empty(q_size, dtype=dtype, device=q_device)
if pv_accum_dtype == "fp32":
if smooth_v:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
else:
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
o = o[..., :head_dim_og]
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
@torch.compiler.disable
def sageattn_qk_int8_pv_fp8_cuda_sm90(
qkv_list,
# q: torch.Tensor,
# k: torch.Tensor,
# v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
qk_quant_gran: str = "per_thread",
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp32+fp32",
smooth_k: bool = True,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
qk_quant_gran : str
The granularity of quantization for Q and K, either "per_warp" or "per_thread".
Default: "per_thread".
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
- "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
- "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp32+fp32".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
q,k,v = qkv_list
qkv_list.clear()
dtype = q.dtype
assert SM90_ENABLED, "SM90 kernel is not available. Make sure you GPUs with compute capability 9.0."
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'."
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
torch.cuda.set_device(v.device)
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim_og**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if qk_quant_gran == "per_warp":
q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128)
elif qk_quant_gran == "per_thread":
q_int8, q_scale, k_int8, k_scale = per_thread_int8_triton(q, k, km, tensor_layout=tensor_layout, BLKQ=64, WARPQ=16, BLKK=128, WARPK=128)
q_size = q.size()
kv_len = k.size(seq_dim)
q_device = q.device
del q,k
# pad v to multiple of 128
# TODO: modify per_channel_fp8 kernel to handle this
v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0
if v_pad_len > 0:
if tensor_layout == "HND":
v = torch.cat([v, torch.zeros(v.size(0), v.size(1), v_pad_len, v.size(3), dtype=v.dtype, device=v.device)], dim=2)
else:
v = torch.cat([v, torch.zeros(v.size(0), v_pad_len, v.size(2), v.size(3), dtype=v.dtype, device=v.device)], dim=1)
v_fp8, v_scale, _ = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=False)
del v
o = torch.empty(q_size, dtype=dtype, device=q_device)
if pv_accum_dtype == "fp32":
raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.")
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse)
o = o[..., :head_dim_og]
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o