|
import contextlib |
|
import functools |
|
import inspect |
|
from enum import Enum |
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
import torch.autograd |
|
from diffusers.utils.import_utils import OptionalDependencyNotAvailable |
|
from torch.nn.functional import scaled_dot_product_attention as native_sdpa |
|
|
|
from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER |
|
from finetrainers.logging import get_logger |
|
from finetrainers.utils.import_utils import ( |
|
is_flash_attn_available, |
|
is_flash_attn_version, |
|
is_sageattention_available, |
|
is_sageattention_version, |
|
is_torch_version, |
|
is_xformers_available, |
|
is_xformers_version, |
|
) |
|
|
|
|
|
if is_flash_attn_available(): |
|
if is_flash_attn_version("<", "2.6.3"): |
|
raise OptionalDependencyNotAvailable( |
|
"The `flash-attn` library version is too old. Please update it to at least 2.6.3." |
|
) |
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward |
|
else: |
|
flash_attn_func = None |
|
flash_attn_varlen_func = None |
|
_flash_attn_forward = None |
|
_flash_attn_backward = None |
|
|
|
|
|
if is_sageattention_available(): |
|
if is_sageattention_version("<", "2.1.1"): |
|
raise OptionalDependencyNotAvailable( |
|
"The `sageattention` library version is too old. Please update it to at least 2.1.1." |
|
) |
|
|
|
from sageattention import ( |
|
sageattn, |
|
sageattn_qk_int8_pv_fp8_cuda, |
|
sageattn_qk_int8_pv_fp8_cuda_sm90, |
|
sageattn_qk_int8_pv_fp16_cuda, |
|
sageattn_qk_int8_pv_fp16_triton, |
|
sageattn_varlen, |
|
) |
|
else: |
|
sageattn = None |
|
sageattn_qk_int8_pv_fp16_cuda = None |
|
sageattn_qk_int8_pv_fp16_triton = None |
|
sageattn_qk_int8_pv_fp8_cuda = None |
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
|
sageattn_varlen = None |
|
|
|
|
|
if is_torch_version(">=", "2.5.0"): |
|
import torch.nn.attention.flex_attention as flex_attention |
|
|
|
|
|
if is_torch_version(">=", "2.6.0"): |
|
from torch.distributed.tensor.experimental._attention import ( |
|
_AttentionOp, |
|
_cp_options, |
|
_templated_ring_attention, |
|
_templated_ring_attention_backward, |
|
set_rotate_method, |
|
) |
|
else: |
|
_cp_options = None |
|
_templated_ring_attention = None |
|
set_rotate_method = None |
|
|
|
class _AttentionOp: |
|
def __init__(self, *args, **kwargs): |
|
raise OptionalDependencyNotAvailable( |
|
"The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." |
|
) |
|
|
|
|
|
if is_xformers_available(): |
|
if is_xformers_version("<", "0.0.29"): |
|
raise OptionalDependencyNotAvailable( |
|
"The `xformers` library version is too old. Please update it to at least 0.0.29." |
|
) |
|
|
|
import xformers.ops as xops |
|
else: |
|
xops = None |
|
|
|
|
|
logger = get_logger() |
|
|
|
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] |
|
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] |
|
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
|
|
|
|
|
|
|
|
|
|
|
def _finetrainers_scaled_dot_product_efficient_attention_forward( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_bias: Optional[torch.Tensor] = None, |
|
compute_log_sumexp: bool = False, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
seqlen_q = query.shape[-2] |
|
out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_bias, |
|
compute_log_sumexp=compute_log_sumexp, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
|
|
|
|
|
|
if compute_log_sumexp: |
|
assert lse.ndim == 3 |
|
lse = lse[:, :, :seqlen_q] |
|
|
|
return out, lse, philox_seed, philox_offset |
|
|
|
|
|
|
|
def _finetrainers_scaled_dot_product_efficient_attention_backward( |
|
grad_out_: torch.Tensor, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_bias: torch.Tensor, |
|
out: torch.Tensor, |
|
logsumexp: torch.Tensor, |
|
philox_seed: torch.Tensor, |
|
philox_offset: torch.Tensor, |
|
dropout_p: float, |
|
grad_input_mask: List[bool], |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
assert len(grad_input_mask) == 4 |
|
|
|
kAlignLSE = 32 |
|
|
|
logsumexp = torch.nn.functional.pad( |
|
logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") |
|
) |
|
|
|
grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( |
|
grad_out_=grad_out_, |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_bias, |
|
out=out, |
|
logsumexp=logsumexp, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
dropout_p=dropout_p, |
|
grad_input_mask=grad_input_mask, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, grad_attn_bias |
|
|
|
|
|
|
|
def _finetrainers_flash_attn_forward( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
scale: Optional[float] = None, |
|
is_causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
return_softmax: bool = False, |
|
): |
|
query, key, value = ( |
|
x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) |
|
) |
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( |
|
query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax |
|
) |
|
out = out.permute(0, 2, 1, 3).contiguous() |
|
return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state |
|
|
|
|
|
|
|
def _finetrainers_flash_attn_backward( |
|
grad_out: torch.Tensor, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
out: torch.Tensor, |
|
logsumexp: torch.Tensor, |
|
dropout_p: float, |
|
scale: Optional[float] = None, |
|
is_causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
deterministic: bool = False, |
|
rng_state: Optional[torch.Tensor] = None, |
|
_permute_outputs: bool = True, |
|
): |
|
dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) |
|
grad_out = grad_out.permute(0, 2, 1, 3).contiguous() |
|
|
|
dq, dk, dv, softmax_d = _flash_attn_backward( |
|
grad_out, |
|
query, |
|
key, |
|
value, |
|
out, |
|
logsumexp, |
|
dq, |
|
dk, |
|
dv, |
|
dropout_p, |
|
scale, |
|
is_causal, |
|
window_size, |
|
softcap, |
|
alibi_slopes, |
|
deterministic, |
|
rng_state, |
|
) |
|
|
|
|
|
dq = dq[..., : grad_out.shape[-1]] |
|
dk = dk[..., : grad_out.shape[-1]] |
|
dv = dv[..., : grad_out.shape[-1]] |
|
|
|
if _permute_outputs: |
|
dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) |
|
return dq, dk, dv |
|
|
|
|
|
|
|
|
|
|
|
class AttentionProvider(str, Enum): |
|
|
|
|
|
|
|
FLASH = "flash" |
|
FLASH_VARLEN = "flash_varlen" |
|
|
|
|
|
FLEX = "flex" |
|
NATIVE = "native" |
|
_NATIVE_CUDNN = "_native_cudnn" |
|
_NATIVE_EFFICIENT = "_native_efficient" |
|
_NATIVE_FLASH = "_native_flash" |
|
_NATIVE_MATH = "_native_math" |
|
|
|
|
|
SAGE = "sage" |
|
SAGE_VARLEN = "sage_varlen" |
|
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" |
|
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" |
|
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" |
|
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" |
|
|
|
|
|
|
|
|
|
|
|
XFORMERS = "xformers" |
|
|
|
|
|
class _AttentionProviderRegistry: |
|
_providers = {} |
|
_constraints = {} |
|
_supports_cp = {} |
|
_supported_arg_names = {} |
|
|
|
_active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) |
|
_checks_enabled = FINETRAINERS_ATTN_CHECKS |
|
|
|
|
|
_mesh: torch.distributed.device_mesh.DeviceMesh = None |
|
_convert_to_fp32: bool = None |
|
_rotate_method: Literal["allgather", "alltoall"] = None |
|
|
|
@classmethod |
|
def register( |
|
cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False |
|
): |
|
logger.debug(f"Registering attention provider: {provider}") |
|
|
|
def decorator(func): |
|
cls._providers[provider] = func |
|
cls._constraints[provider] = constraints or [] |
|
cls._supports_cp[provider] = supports_cp |
|
cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) |
|
return func |
|
|
|
return decorator |
|
|
|
@classmethod |
|
def get_active_provider(cls): |
|
return cls._active_provider, cls._providers[cls._active_provider] |
|
|
|
@classmethod |
|
def list_providers(cls): |
|
return list(cls._providers.keys()) |
|
|
|
@classmethod |
|
def supports_context_parallel(cls, provider: AttentionProvider): |
|
if provider not in cls._providers: |
|
raise ValueError(f"Provider {provider} is not registered.") |
|
return cls._supports_cp.get(provider, False) |
|
|
|
@classmethod |
|
def context_parallel_enabled(cls): |
|
return cls._mesh is not None |
|
|
|
@classmethod |
|
def _set_context_parallel( |
|
cls, |
|
mesh: torch.distributed.device_mesh.DeviceMesh = None, |
|
convert_to_fp32: bool = None, |
|
rotate_method: str = None, |
|
*, |
|
reset: bool = False, |
|
): |
|
if reset: |
|
mesh = convert_to_fp32 = rotate_method = None |
|
cls._mesh = mesh |
|
cls._convert_to_fp32 = convert_to_fp32 |
|
cls._rotate_method = rotate_method |
|
|
|
@classmethod |
|
def _raise_cp_error_if_mesh_not_set(cls): |
|
if cls._mesh is None: |
|
raise ValueError( |
|
"`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." |
|
) |
|
|
|
|
|
@contextlib.contextmanager |
|
def attention_provider( |
|
provider: AttentionProvider = AttentionProvider.NATIVE, |
|
*, |
|
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, |
|
convert_to_fp32: bool = True, |
|
rotate_method: str = "allgather", |
|
): |
|
"""Context manager to set the active attention provider and possibly enable context parallelism.""" |
|
|
|
if provider not in _AttentionProviderRegistry._providers: |
|
raise ValueError(f"Provider {provider} is not registered.") |
|
if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): |
|
raise ValueError(f"Provider {provider} does not support context parallelism.") |
|
|
|
old_provider = _AttentionProviderRegistry._active_provider |
|
_AttentionProviderRegistry._active_provider = provider |
|
|
|
_AttentionProviderRegistry._mesh = mesh |
|
_AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 |
|
_AttentionProviderRegistry._rotate_method = rotate_method |
|
if mesh is not None: |
|
_convert_to_f32 = _cp_options.convert_to_f32 |
|
_enable_load_balance = _cp_options.enable_load_balance |
|
_rotate_method = _cp_options.rotate_method |
|
|
|
try: |
|
yield |
|
finally: |
|
_AttentionProviderRegistry._active_provider = old_provider |
|
|
|
_AttentionProviderRegistry._mesh = None |
|
_AttentionProviderRegistry._convert_to_fp32 = None |
|
_AttentionProviderRegistry._rotate_method = None |
|
if mesh is not None: |
|
_cp_options.convert_to_f32 = _convert_to_f32 |
|
_cp_options.enable_load_balance = _enable_load_balance |
|
_cp_options.rotate_method = _rotate_method |
|
|
|
|
|
def attention_dispatch( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
enable_gqa: bool = False, |
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> torch.Tensor: |
|
attention_kwargs = attention_kwargs or {} |
|
provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() |
|
kwargs = { |
|
"query": query, |
|
"key": key, |
|
"value": value, |
|
"attn_mask": attn_mask, |
|
"dropout_p": dropout_p, |
|
"is_causal": is_causal, |
|
"scale": scale, |
|
"enable_gqa": enable_gqa, |
|
**attention_kwargs, |
|
} |
|
|
|
if _AttentionProviderRegistry._checks_enabled: |
|
removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) |
|
if removed_kwargs: |
|
log_freq = 512 |
|
msg = ( |
|
f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " |
|
f"message will be logged every {log_freq} calls." |
|
) |
|
logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) |
|
for check in _AttentionProviderRegistry._constraints.get(provider_name): |
|
check(**kwargs) |
|
|
|
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} |
|
|
|
if _AttentionProviderRegistry.context_parallel_enabled(): |
|
_set_context_parallel_options(**kwargs) |
|
|
|
return provider_fn(**kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_context_parallel_options(is_causal: bool, **kwargs): |
|
_cp_options.enable_load_balance = is_causal |
|
_cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 |
|
set_rotate_method(_AttentionProviderRegistry._rotate_method) |
|
|
|
|
|
def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: |
|
if attn_mask is not None: |
|
raise ValueError("Attention mask must be None for this provider.") |
|
|
|
|
|
def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: |
|
if attn_mask is not None and is_causal: |
|
raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") |
|
|
|
|
|
def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
|
if query.device != key.device or query.device != value.device: |
|
raise ValueError("Query, key, and value must be on the same device.") |
|
if query.dtype != key.dtype or query.dtype != value.dtype: |
|
raise ValueError("Query, key, and value must have the same dtype.") |
|
|
|
|
|
def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
|
_check_device(query, key, value) |
|
if query.device.type != "cuda": |
|
raise ValueError("Query, key, and value must be on a CUDA device.") |
|
|
|
|
|
def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: |
|
def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
|
_check_device_cuda(query, key, value) |
|
if torch.cuda.get_device_capability(query.device) < (major, minor): |
|
raise ValueError( |
|
f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." |
|
) |
|
|
|
return check_device_cuda |
|
|
|
|
|
def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
|
if query.dtype != key.dtype: |
|
raise ValueError("Query and key must have the same dtype.") |
|
if query.dtype != value.dtype: |
|
raise ValueError("Query and value must have the same dtype.") |
|
|
|
|
|
def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
|
_check_qkv_dtype_match(query, key, value) |
|
if query.dtype not in (torch.bfloat16, torch.float16): |
|
raise ValueError("Query, key, and value must be either bfloat16 or float16.") |
|
|
|
|
|
def _check_shape( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
) -> None: |
|
if query.shape[-1] != key.shape[-1]: |
|
raise ValueError("Query and key must have the same last dimension.") |
|
if query.shape[-2] != value.shape[-2]: |
|
raise ValueError("Query and value must have the same second to last dimension.") |
|
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: |
|
raise ValueError("Attention mask must match the key's second to last dimension.") |
|
|
|
|
|
def _prepare_for_flash_attn_or_sage_varlen( |
|
batch_size: int, |
|
seq_len_q: int, |
|
seq_len_kv: int, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
device: Optional[torch.device] = None, |
|
) -> None: |
|
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) |
|
if attn_mask is None: |
|
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) |
|
else: |
|
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) |
|
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
|
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
|
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) |
|
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) |
|
max_seqlen_q = seqlens_q.max().item() |
|
max_seqlen_k = seqlens_k.max().item() |
|
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
|
|
|
|
|
def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: |
|
""" |
|
Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in |
|
FlashAttention/Sage varlen. |
|
|
|
Supports 1D to 4D shapes and common broadcasting patterns. |
|
""" |
|
if attn_mask.dtype != torch.bool: |
|
raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") |
|
|
|
if attn_mask.ndim == 1: |
|
|
|
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) |
|
|
|
elif attn_mask.ndim == 2: |
|
|
|
if attn_mask.size(0) not in [1, batch_size]: |
|
raise ValueError( |
|
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." |
|
) |
|
attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
|
elif attn_mask.ndim == 3: |
|
|
|
if attn_mask.size(0) not in [1, batch_size]: |
|
raise ValueError( |
|
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." |
|
) |
|
attn_mask = attn_mask.any(dim=1) |
|
attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
|
elif attn_mask.ndim == 4: |
|
|
|
if attn_mask.size(0) not in [1, batch_size]: |
|
raise ValueError( |
|
f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." |
|
) |
|
attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) |
|
attn_mask = attn_mask.any(dim=(1, 2)) |
|
|
|
else: |
|
raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") |
|
|
|
if attn_mask.shape != (batch_size, seq_len_k): |
|
raise ValueError( |
|
f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" |
|
) |
|
|
|
return attn_mask |
|
|
|
|
|
def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
|
return q_idx >= kv_idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
class _flash_attn_flash_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
softmax_scale: Optional[float] = None, |
|
causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
deterministic: bool = False, |
|
return_softmax: bool = False, |
|
): |
|
if softmax_scale is None: |
|
softmax_scale = q.shape[-1] ** (-0.5) |
|
|
|
ctx.dropout_p = dropout_p |
|
ctx.softmax_scale = softmax_scale |
|
ctx.causal = causal |
|
ctx.window_size = window_size |
|
ctx.softcap = softcap |
|
ctx.alibi_slopes = alibi_slopes |
|
ctx.deterministic = deterministic |
|
|
|
out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( |
|
query=q, |
|
key=k, |
|
value=v, |
|
dropout_p=dropout_p, |
|
scale=softmax_scale, |
|
is_causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
return_softmax=return_softmax, |
|
) |
|
|
|
ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) |
|
|
|
return (out, lse) if return_softmax else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
q, k, v, out, lse, rng_state = ctx.saved_tensors |
|
|
|
grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( |
|
grad_out=grad_out, |
|
query=q, |
|
key=k, |
|
value=v, |
|
out=out, |
|
logsumexp=lse, |
|
dropout_p=ctx.dropout_p, |
|
scale=ctx.softmax_scale, |
|
is_causal=ctx.causal, |
|
window_size=ctx.window_size, |
|
softcap=ctx.softcap, |
|
alibi_slopes=ctx.alibi_slopes, |
|
deterministic=ctx.deterministic, |
|
rng_state=rng_state, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None |
|
|
|
|
|
|
|
class _native_ring_flash_attn_flash_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
softmax_scale: Optional[float] = None, |
|
causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
deterministic: bool = False, |
|
return_softmax: bool = False, |
|
): |
|
if softmax_scale is None: |
|
softmax_scale = q.shape[-1] ** (-0.5) |
|
|
|
|
|
dropout_p = dropout_p if dropout_p > 0 else 1e-30 |
|
|
|
ctx.dropout_p = dropout_p |
|
ctx.softmax_scale = softmax_scale |
|
ctx.causal = causal |
|
ctx.window_size = window_size |
|
ctx.softcap = softcap |
|
ctx.alibi_slopes = alibi_slopes |
|
ctx.deterministic = deterministic |
|
|
|
out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=_finetrainers_flash_attn_forward, |
|
query=q, |
|
key=k, |
|
value=v, |
|
dropout_p=dropout_p, |
|
scale=softmax_scale, |
|
is_causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
return_softmax=True, |
|
) |
|
|
|
ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) |
|
|
|
return (out, lse) if return_softmax else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
q, k, v, out, lse, rng_state = ctx.saved_tensors |
|
lse = lse.permute(0, 2, 1).contiguous() |
|
|
|
grad_query, grad_key, grad_value = _templated_ring_attention_backward( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
|
|
|
|
|
|
|
|
seq_dim=1, |
|
op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), |
|
grad_out=grad_out, |
|
grad_out_name="grad_out", |
|
query=q, |
|
key=k, |
|
value=v, |
|
out=out, |
|
logsumexp=lse, |
|
dropout_p=ctx.dropout_p, |
|
scale=ctx.softmax_scale, |
|
is_causal=ctx.causal, |
|
window_size=ctx.window_size, |
|
softcap=ctx.softcap, |
|
alibi_slopes=ctx.alibi_slopes, |
|
deterministic=ctx.deterministic, |
|
rng_state=rng_state, |
|
) |
|
grad_query, grad_key, grad_value = ( |
|
x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.FLASH, |
|
constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
supports_cp=True, |
|
) |
|
def flash_attn_flash_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
scale: Optional[float] = None, |
|
is_causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
deterministic: bool = False, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
dispatch_fn = ( |
|
_native_ring_flash_attn_flash_attention |
|
if _AttentionProviderRegistry.context_parallel_enabled() |
|
else _flash_attn_flash_attention |
|
) |
|
return dispatch_fn.apply( |
|
query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.FLASH_VARLEN, |
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
supports_cp=False, |
|
) |
|
def _flash_varlen_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
cu_seqlens_q: Optional[torch.Tensor] = None, |
|
cu_seqlens_k: Optional[torch.Tensor] = None, |
|
max_seqlen_q: Optional[int] = None, |
|
max_seqlen_k: Optional[int] = None, |
|
dropout_p: float = 0.0, |
|
scale: Optional[float] = None, |
|
is_causal: bool = False, |
|
window_size: Tuple[int, int] = (-1, -1), |
|
softcap: float = 0.0, |
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
deterministic: bool = False, |
|
return_attn_probs: bool = False, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
batch_size, _, seq_len_q, _ = query.shape |
|
_, _, seq_len_kv, _ = key.shape |
|
|
|
if attn_mask is not None: |
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
|
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): |
|
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
|
_prepare_for_flash_attn_or_sage_varlen( |
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
|
) |
|
) |
|
else: |
|
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) |
|
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) |
|
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) |
|
|
|
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
|
key_valid, value_valid = [], [] |
|
for b in range(batch_size): |
|
valid_len = seqlens_k[b] |
|
key_valid.append(key[b, :valid_len]) |
|
value_valid.append(value[b, :valid_len]) |
|
|
|
query_packed = query.flatten(0, 1) |
|
key_packed = torch.cat(key_valid, dim=0) |
|
value_packed = torch.cat(value_valid, dim=0) |
|
|
|
if _AttentionProviderRegistry.context_parallel_enabled(): |
|
return_attn_probs = True |
|
|
|
out = flash_attn_varlen_func( |
|
q=query_packed, |
|
k=key_packed, |
|
v=value_packed, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
dropout_p=dropout_p, |
|
softmax_scale=scale, |
|
causal=is_causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=return_attn_probs, |
|
) |
|
|
|
rest = None |
|
if return_attn_probs: |
|
out, *rest = out |
|
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) |
|
if return_attn_probs: |
|
return out, *rest[:1] |
|
return out |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.FLEX, |
|
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
|
supports_cp=False, |
|
) |
|
def _native_flex_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
enable_gqa: bool = False, |
|
return_lse: bool = False, |
|
kernel_options: Optional[Dict[str, Any]] = None, |
|
) -> torch.Tensor: |
|
|
|
score_mod = None |
|
block_mask = None |
|
batch_size, num_heads, seq_len_q, _ = query.shape |
|
_, _, seq_len_kv, _ = key.shape |
|
|
|
if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): |
|
block_mask = attn_mask |
|
elif is_causal: |
|
block_mask = flex_attention.create_block_mask( |
|
_flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device |
|
) |
|
elif torch.is_tensor(attn_mask): |
|
if attn_mask.ndim == 2: |
|
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) |
|
|
|
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) |
|
|
|
if attn_mask.dtype == torch.bool: |
|
|
|
def mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
|
return attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
|
|
|
block_mask = flex_attention.create_block_mask( |
|
mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device |
|
) |
|
else: |
|
|
|
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): |
|
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
|
else: |
|
raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") |
|
|
|
return flex_attention.flex_attention( |
|
query=query, |
|
key=key, |
|
value=value, |
|
score_mod=score_mod, |
|
block_mask=block_mask, |
|
scale=scale, |
|
enable_gqa=enable_gqa, |
|
return_lse=return_lse, |
|
kernel_options=None, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.NATIVE, |
|
constraints=[_check_device, _check_shape], |
|
supports_cp=False, |
|
) |
|
def _native_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
enable_gqa: bool = False, |
|
) -> torch.Tensor: |
|
return native_sdpa( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_mask=attn_mask, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
enable_gqa=enable_gqa, |
|
) |
|
|
|
|
|
class _native_cudnn_attention(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
ctx.attn_mask = attn_mask |
|
|
|
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
|
torch.ops.aten._scaled_dot_product_cudnn_attention( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_mask, |
|
compute_log_sumexp=True, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
return_debug_mask=False, |
|
scale=scale, |
|
) |
|
) |
|
|
|
ctx.max_q = max_q |
|
ctx.max_k = max_k |
|
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( |
|
grad_out=grad_out, |
|
query=query, |
|
key=key, |
|
value=value, |
|
out=out, |
|
logsumexp=lse, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
attn_bias=ctx.attn_mask, |
|
cum_seq_q=cum_seq_q, |
|
cum_seq_k=cum_seq_k, |
|
max_q=ctx.max_q, |
|
max_k=ctx.max_k, |
|
dropout_p=ctx.dropout_p, |
|
is_causal=ctx.is_causal, |
|
scale=ctx.scale, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
|
class _native_ring_native_cudnn_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
ctx.attn_mask = attn_mask |
|
|
|
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
|
_templated_ring_attention( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=torch.ops.aten._scaled_dot_product_cudnn_attention, |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_mask, |
|
compute_log_sumexp=True, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
return_debug_mask=False, |
|
scale=scale, |
|
) |
|
) |
|
|
|
ctx.max_q = max_q |
|
ctx.max_k = max_k |
|
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
grad_query, grad_key, grad_value = _templated_ring_attention_backward( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, |
|
grad_out=grad_out, |
|
grad_out_name="grad_out", |
|
query=query, |
|
key=key, |
|
value=value, |
|
out=out, |
|
logsumexp=lse, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
attn_bias=ctx.attn_mask, |
|
cum_seq_q=cum_seq_q, |
|
cum_seq_k=cum_seq_k, |
|
max_q=ctx.max_q, |
|
max_k=ctx.max_k, |
|
dropout_p=ctx.dropout_p, |
|
is_causal=ctx.is_causal, |
|
scale=ctx.scale, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._NATIVE_CUDNN, |
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
supports_cp=True, |
|
) |
|
def native_cudnn_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
dispatch_fn = ( |
|
_native_ring_native_cudnn_attention |
|
if _AttentionProviderRegistry.context_parallel_enabled() |
|
else _native_cudnn_attention |
|
) |
|
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) |
|
|
|
|
|
class _native_efficient_attention(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
ctx.attn_mask = attn_mask |
|
|
|
|
|
out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_mask, |
|
compute_log_sumexp=True, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
|
|
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
|
|
grad_query, grad_key, grad_value, grad_attn_bias = ( |
|
_finetrainers_scaled_dot_product_efficient_attention_backward( |
|
grad_out_=grad_out, |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=ctx.attn_mask, |
|
out=out, |
|
logsumexp=lse, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
dropout_p=ctx.dropout_p, |
|
grad_input_mask=[True, True, True, False], |
|
is_causal=ctx.is_causal, |
|
scale=ctx.scale, |
|
) |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
|
class _native_ring_native_efficient_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
ctx.attn_mask = attn_mask |
|
|
|
|
|
out, lse, philox_seed, philox_offset = _templated_ring_attention( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=_finetrainers_scaled_dot_product_efficient_attention_forward, |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=attn_mask, |
|
compute_log_sumexp=True, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
|
|
ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
|
|
grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=_finetrainers_scaled_dot_product_efficient_attention_backward, |
|
grad_out=grad_out, |
|
grad_out_name="grad_out_", |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_bias=ctx.attn_mask, |
|
out=out, |
|
logsumexp=lse, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
dropout_p=ctx.dropout_p, |
|
grad_input_mask=[True, True, True, False], |
|
is_causal=ctx.is_causal, |
|
scale=ctx.scale, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._NATIVE_EFFICIENT, |
|
constraints=[_check_device, _check_shape], |
|
supports_cp=True, |
|
) |
|
def native_efficient_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
) -> torch.Tensor: |
|
dispatch_fn = ( |
|
_native_ring_native_efficient_attention |
|
if _AttentionProviderRegistry.context_parallel_enabled() |
|
else _native_efficient_attention |
|
) |
|
return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) |
|
|
|
|
|
class _native_flash_attention(torch.autograd.Function): |
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
|
|
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
|
torch.ops.aten._scaled_dot_product_flash_attention( |
|
query=query, |
|
key=key, |
|
value=value, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
return_debug_mask=False, |
|
scale=scale, |
|
) |
|
) |
|
|
|
ctx.max_q = max_q |
|
ctx.max_k = max_k |
|
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( |
|
grad_out=grad_out, |
|
query=query, |
|
key=key, |
|
value=value, |
|
out=out, |
|
logsumexp=lse, |
|
cum_seq_q=cum_seq_q, |
|
cum_seq_k=cum_seq_k, |
|
max_q=ctx.max_q, |
|
max_k=ctx.max_k, |
|
dropout_p=ctx.dropout_p, |
|
is_causal=ctx.is_causal, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
scale=ctx.scale, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None |
|
|
|
|
|
class _native_ring_native_flash_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
ctx.dropout_p = dropout_p |
|
ctx.is_causal = is_causal |
|
ctx.scale = scale |
|
|
|
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
|
_templated_ring_attention( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=torch.ops.aten._scaled_dot_product_flash_attention, |
|
query=query, |
|
key=key, |
|
value=value, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
) |
|
) |
|
|
|
ctx.max_q = max_q |
|
ctx.max_k = max_k |
|
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
|
return (out, lse) if return_lse else out |
|
|
|
@staticmethod |
|
def backward( |
|
ctx: torch.autograd.function.FunctionCtx, |
|
grad_out: torch.Tensor, |
|
*args: torch.Tensor, |
|
): |
|
_AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
|
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
|
grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( |
|
mesh=_AttentionProviderRegistry._mesh, |
|
seq_dim=2, |
|
op=torch.ops.aten._scaled_dot_product_flash_attention_backward, |
|
grad_out=grad_out, |
|
grad_out_name="grad_out", |
|
query=query, |
|
key=key, |
|
value=value, |
|
out=out, |
|
logsumexp=lse, |
|
dropout_p=ctx.dropout_p, |
|
is_causal=ctx.is_causal, |
|
scale=ctx.scale, |
|
cum_seq_q=cum_seq_q, |
|
cum_seq_k=cum_seq_k, |
|
max_q=ctx.max_q, |
|
max_k=ctx.max_k, |
|
philox_seed=philox_seed, |
|
philox_offset=philox_offset, |
|
) |
|
|
|
return grad_query, grad_key, grad_value, None, None, None, None |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._NATIVE_FLASH, |
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
supports_cp=True, |
|
) |
|
def native_flash_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
dispatch_fn = ( |
|
_native_ring_native_flash_attention |
|
if _AttentionProviderRegistry.context_parallel_enabled() |
|
else _native_flash_attention |
|
) |
|
return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._NATIVE_MATH, |
|
constraints=[_check_device, _check_shape], |
|
supports_cp=False, |
|
) |
|
def native_math_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
enable_gqa: bool = False, |
|
) -> torch.Tensor: |
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
|
return native_sdpa( |
|
query=query, |
|
key=key, |
|
value=value, |
|
attn_mask=attn_mask, |
|
dropout_p=dropout_p, |
|
is_causal=is_causal, |
|
scale=scale, |
|
enable_gqa=enable_gqa, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.SAGE, |
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
supports_cp=False, |
|
) |
|
def _sage_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
if _AttentionProviderRegistry.context_parallel_enabled(): |
|
return_lse = True |
|
|
|
kwargs = { |
|
"q": query, |
|
"k": key, |
|
"v": value, |
|
"tensor_layout": "HND", |
|
"is_causal": is_causal, |
|
"sm_scale": scale, |
|
"return_lse": return_lse, |
|
} |
|
out = sageattn(**kwargs) |
|
|
|
rest = None |
|
if return_lse: |
|
out, *rest = out |
|
if return_lse: |
|
return out, *rest[:1] |
|
return out |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.SAGE_VARLEN, |
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
|
) |
|
def _sage_varlen_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
cu_seqlens_q: Optional[torch.Tensor] = None, |
|
cu_seqlens_k: Optional[torch.Tensor] = None, |
|
max_seqlen_q: Optional[int] = None, |
|
max_seqlen_k: Optional[int] = None, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
smooth_k: bool = True, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
enable_gqa: bool = False, |
|
) -> torch.Tensor: |
|
batch_size, _, seq_len_q, _ = query.shape |
|
_, _, seq_len_kv, _ = key.shape |
|
|
|
if attn_mask is not None: |
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
|
if enable_gqa: |
|
|
|
pass |
|
|
|
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): |
|
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
|
_prepare_for_flash_attn_or_sage_varlen( |
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
|
) |
|
) |
|
else: |
|
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) |
|
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) |
|
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) |
|
|
|
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
|
key_valid, value_valid = [], [] |
|
for b in range(batch_size): |
|
valid_len = seqlens_k[b] |
|
key_valid.append(key[b, :valid_len]) |
|
value_valid.append(value[b, :valid_len]) |
|
|
|
query_packed = query.flatten(0, 1) |
|
key_packed = torch.cat(key_valid, dim=0) |
|
value_packed = torch.cat(value_valid, dim=0) |
|
|
|
out = sageattn_varlen( |
|
q=query_packed, |
|
k=key_packed, |
|
v=value_packed, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_k=max_seqlen_k, |
|
is_causal=is_causal, |
|
sm_scale=scale, |
|
smooth_k=smooth_k, |
|
) |
|
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) |
|
|
|
return out |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, |
|
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
|
supports_cp=False, |
|
) |
|
def _sage_qk_int8_pv_fp8_cuda_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
|
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
|
smooth_k: bool = True, |
|
smooth_v: bool = False, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
return sageattn_qk_int8_pv_fp8_cuda( |
|
q=query, |
|
k=key, |
|
v=value, |
|
tensor_layout="HND", |
|
is_causal=is_causal, |
|
qk_quant_gran=qk_quant_gran, |
|
sm_scale=scale, |
|
pv_accum_dtype=pv_accum_dtype, |
|
smooth_k=smooth_k, |
|
smooth_v=smooth_v, |
|
return_lse=return_lse, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
|
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
|
supports_cp=False, |
|
) |
|
def _sage_qk_int8_pv_fp8_cuda_sm90_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
|
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
|
smooth_k: bool = True, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
return sageattn_qk_int8_pv_fp8_cuda_sm90( |
|
q=query, |
|
k=key, |
|
v=value, |
|
tensor_layout="HND", |
|
is_causal=is_causal, |
|
qk_quant_gran=qk_quant_gran, |
|
sm_scale=scale, |
|
pv_accum_dtype=pv_accum_dtype, |
|
smooth_k=smooth_k, |
|
return_lse=return_lse, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, |
|
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
|
supports_cp=False, |
|
) |
|
def _sage_qk_int8_pv_fp16_cuda_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
|
pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
|
smooth_k: bool = True, |
|
smooth_v: bool = False, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
return sageattn_qk_int8_pv_fp16_cuda( |
|
q=query, |
|
k=key, |
|
v=value, |
|
tensor_layout="HND", |
|
is_causal=is_causal, |
|
qk_quant_gran=qk_quant_gran, |
|
sm_scale=scale, |
|
pv_accum_dtype=pv_accum_dtype, |
|
smooth_k=smooth_k, |
|
smooth_v=smooth_v, |
|
return_lse=return_lse, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, |
|
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
|
supports_cp=False, |
|
) |
|
def _sage_qk_int8_pv_fp16_triton_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", |
|
smooth_k: bool = True, |
|
return_lse: bool = False, |
|
) -> torch.Tensor: |
|
return sageattn_qk_int8_pv_fp16_triton( |
|
q=query, |
|
k=key, |
|
v=value, |
|
tensor_layout="HND", |
|
quantization_backend=quantization_backend, |
|
is_causal=is_causal, |
|
sm_scale=scale, |
|
smooth_k=smooth_k, |
|
return_lse=return_lse, |
|
) |
|
|
|
|
|
@_AttentionProviderRegistry.register( |
|
AttentionProvider.XFORMERS, |
|
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
|
) |
|
def _xformers_attention( |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
dropout_p: float = 0.0, |
|
is_causal: bool = False, |
|
scale: Optional[float] = None, |
|
enable_gqa: bool = False, |
|
) -> torch.Tensor: |
|
batch_size, num_heads_q, seq_len_q, _ = query.shape |
|
_, num_heads_kv, seq_len_kv, _ = key.shape |
|
|
|
|
|
if is_causal: |
|
attn_mask = xops.LowerTriangularMask() |
|
elif attn_mask is not None: |
|
if attn_mask.ndim == 2: |
|
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) |
|
elif attn_mask.ndim != 4: |
|
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") |
|
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) |
|
|
|
|
|
|
|
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
|
if enable_gqa: |
|
if num_heads_q % num_heads_kv != 0: |
|
raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") |
|
num_heads_per_group = num_heads_q // num_heads_kv |
|
query = query.unflatten(2, (num_heads_kv, -1)) |
|
key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
|
value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
|
|
|
out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) |
|
if enable_gqa: |
|
out = out.flatten(2, 3) |
|
|
|
out = out.permute(0, 2, 1, 3) |
|
return out |
|
|