| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import contextlib |
| import functools |
| import inspect |
| import math |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import TYPE_CHECKING, Any, Callable |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
|
|
|
|
| if torch.distributed.is_available(): |
| import torch.distributed._functional_collectives as funcol |
|
|
| from ..utils import ( |
| get_logger, |
| is_aiter_available, |
| is_aiter_version, |
| is_flash_attn_3_available, |
| is_flash_attn_available, |
| is_flash_attn_version, |
| is_kernels_available, |
| is_kernels_version, |
| is_sageattention_available, |
| is_sageattention_version, |
| is_torch_npu_available, |
| is_torch_version, |
| is_torch_xla_available, |
| is_torch_xla_version, |
| is_xformers_available, |
| is_xformers_version, |
| ) |
| from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS |
| from ..utils.torch_utils import maybe_allow_in_graph |
| from ._modeling_parallel import gather_size_by_comm |
|
|
|
|
| if TYPE_CHECKING: |
| from ._modeling_parallel import ParallelConfig |
|
|
| _REQUIRED_FLASH_VERSION = "2.6.3" |
| _REQUIRED_AITER_VERSION = "0.1.5" |
| _REQUIRED_SAGE_VERSION = "2.1.1" |
| _REQUIRED_FLEX_VERSION = "2.5.0" |
| _REQUIRED_XLA_VERSION = "2.2" |
| _REQUIRED_XFORMERS_VERSION = "0.0.29" |
|
|
| logger = get_logger(__name__) |
|
|
| _CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) |
| _CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() |
| _CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION) |
| _CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) |
| _CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) |
| _CAN_USE_NPU_ATTN = is_torch_npu_available() |
| _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) |
| _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) |
|
|
|
|
| if _CAN_USE_FLASH_ATTN: |
| try: |
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward |
| except (ImportError, OSError, RuntimeError) as e: |
| |
| |
| logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.") |
| _CAN_USE_FLASH_ATTN = False |
| flash_attn_func = None |
| flash_attn_varlen_func = None |
| _wrapped_flash_attn_backward = None |
| _wrapped_flash_attn_forward = None |
| else: |
| flash_attn_func = None |
| flash_attn_varlen_func = None |
| _wrapped_flash_attn_backward = None |
| _wrapped_flash_attn_forward = None |
|
|
|
|
| if _CAN_USE_FLASH_ATTN_3: |
| try: |
| from flash_attn_interface import flash_attn_func as flash_attn_3_func |
| from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_FLASH_ATTN_3 = False |
| flash_attn_3_func = None |
| flash_attn_3_varlen_func = None |
| else: |
| flash_attn_3_func = None |
| flash_attn_3_varlen_func = None |
|
|
| if _CAN_USE_AITER_ATTN: |
| try: |
| from aiter import flash_attn_func as aiter_flash_attn_func |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"aiter failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_AITER_ATTN = False |
| aiter_flash_attn_func = None |
| else: |
| aiter_flash_attn_func = None |
|
|
| if _CAN_USE_SAGE_ATTN: |
| try: |
| from sageattention import ( |
| sageattn, |
| sageattn_qk_int8_pv_fp8_cuda, |
| sageattn_qk_int8_pv_fp8_cuda_sm90, |
| sageattn_qk_int8_pv_fp16_cuda, |
| sageattn_qk_int8_pv_fp16_triton, |
| sageattn_varlen, |
| ) |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_SAGE_ATTN = False |
| sageattn = None |
| sageattn_qk_int8_pv_fp8_cuda = None |
| sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
| sageattn_qk_int8_pv_fp16_cuda = None |
| sageattn_qk_int8_pv_fp16_triton = None |
| sageattn_varlen = None |
| else: |
| sageattn = None |
| sageattn_qk_int8_pv_fp16_cuda = None |
| sageattn_qk_int8_pv_fp16_triton = None |
| sageattn_qk_int8_pv_fp8_cuda = None |
| sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
| sageattn_varlen = None |
|
|
|
|
| if _CAN_USE_FLEX_ATTN: |
| try: |
| |
| |
| |
| import torch.nn.attention.flex_attention as flex_attention |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_FLEX_ATTN = False |
| flex_attention = None |
| else: |
| flex_attention = None |
|
|
|
|
| if _CAN_USE_NPU_ATTN: |
| try: |
| from torch_npu import npu_fusion_attention |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_NPU_ATTN = False |
| npu_fusion_attention = None |
| else: |
| npu_fusion_attention = None |
|
|
|
|
| if _CAN_USE_XLA_ATTN: |
| try: |
| from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_XLA_ATTN = False |
| xla_flash_attention = None |
| else: |
| xla_flash_attention = None |
|
|
|
|
| if _CAN_USE_XFORMERS_ATTN: |
| try: |
| import xformers.ops as xops |
| except (ImportError, OSError, RuntimeError) as e: |
| logger.warning(f"xformers failed to import: {e}. Falling back to native attention.") |
| _CAN_USE_XFORMERS_ATTN = False |
| xops = None |
| else: |
| xops = None |
|
|
| |
| if torch.__version__ >= "2.4.0": |
| _custom_op = torch.library.custom_op |
| _register_fake = torch.library.register_fake |
| else: |
|
|
| def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): |
| def wrap(func): |
| return func |
|
|
| return wrap if fn is None else fn |
|
|
| def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): |
| def wrap(func): |
| return func |
|
|
| return wrap if fn is None else fn |
|
|
| _custom_op = custom_op_no_op |
| _register_fake = register_fake_no_op |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| class AttentionBackendName(str, Enum): |
| |
|
|
| |
| FLASH = "flash" |
| FLASH_HUB = "flash_hub" |
| FLASH_VARLEN = "flash_varlen" |
| FLASH_VARLEN_HUB = "flash_varlen_hub" |
| _FLASH_3 = "_flash_3" |
| _FLASH_VARLEN_3 = "_flash_varlen_3" |
| _FLASH_3_HUB = "_flash_3_hub" |
| _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub" |
|
|
| |
| AITER = "aiter" |
|
|
| |
| FLEX = "flex" |
| NATIVE = "native" |
| _NATIVE_CUDNN = "_native_cudnn" |
| _NATIVE_EFFICIENT = "_native_efficient" |
| _NATIVE_FLASH = "_native_flash" |
| _NATIVE_MATH = "_native_math" |
| _NATIVE_NPU = "_native_npu" |
| _NATIVE_XLA = "_native_xla" |
|
|
| |
| SAGE = "sage" |
| SAGE_HUB = "sage_hub" |
| SAGE_VARLEN = "sage_varlen" |
| _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" |
| _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" |
| _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" |
| _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" |
| |
| |
| |
|
|
| |
| XFORMERS = "xformers" |
|
|
|
|
| class _AttentionBackendRegistry: |
| _backends = {} |
| _constraints = {} |
| _supported_arg_names = {} |
| _supports_context_parallel = set() |
| _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) |
| _checks_enabled = DIFFUSERS_ATTN_CHECKS |
|
|
| @classmethod |
| def register( |
| cls, |
| backend: AttentionBackendName, |
| constraints: list[Callable] | None = None, |
| supports_context_parallel: bool = False, |
| ): |
| logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") |
|
|
| def decorator(func): |
| cls._backends[backend] = func |
| cls._constraints[backend] = constraints or [] |
| cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) |
| if supports_context_parallel: |
| cls._supports_context_parallel.add(backend.value) |
|
|
| return func |
|
|
| return decorator |
|
|
| @classmethod |
| def get_active_backend(cls): |
| return cls._active_backend, cls._backends[cls._active_backend] |
|
|
| @classmethod |
| def set_active_backend(cls, backend: str): |
| cls._active_backend = backend |
|
|
| @classmethod |
| def list_backends(cls): |
| return list(cls._backends.keys()) |
|
|
| @classmethod |
| def _is_context_parallel_available( |
| cls, |
| backend: AttentionBackendName, |
| ) -> bool: |
| supports_context_parallel = backend.value in cls._supports_context_parallel |
| return supports_context_parallel |
|
|
|
|
| @dataclass |
| class _HubKernelConfig: |
| """Configuration for downloading and using a hub-based attention kernel.""" |
|
|
| repo_id: str |
| function_attr: str |
| revision: str | None = None |
| version: int | None = None |
| kernel_fn: Callable | None = None |
| wrapped_forward_attr: str | None = None |
| wrapped_backward_attr: str | None = None |
| wrapped_forward_fn: Callable | None = None |
| wrapped_backward_fn: Callable | None = None |
|
|
|
|
| |
| _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { |
| AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( |
| repo_id="kernels-community/flash-attn3", |
| function_attr="flash_attn_func", |
| wrapped_forward_attr="flash_attn_interface._flash_attn_forward", |
| wrapped_backward_attr="flash_attn_interface._flash_attn_backward", |
| version=1, |
| ), |
| AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( |
| repo_id="kernels-community/flash-attn3", |
| function_attr="flash_attn_varlen_func", |
| version=1, |
| ), |
| AttentionBackendName.FLASH_HUB: _HubKernelConfig( |
| repo_id="kernels-community/flash-attn2", |
| function_attr="flash_attn_func", |
| wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", |
| wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", |
| version=1, |
| ), |
| AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( |
| repo_id="kernels-community/flash-attn2", |
| function_attr="flash_attn_varlen_func", |
| version=1, |
| ), |
| AttentionBackendName.SAGE_HUB: _HubKernelConfig( |
| repo_id="kernels-community/sage-attention", |
| function_attr="sageattn", |
| version=1, |
| ), |
| } |
|
|
|
|
| @contextlib.contextmanager |
| def attention_backend(backend: str | AttentionBackendName = AttentionBackendName.NATIVE): |
| """ |
| Context manager to set the active attention backend. |
| """ |
| if backend not in _AttentionBackendRegistry._backends: |
| raise ValueError(f"Backend {backend} is not registered.") |
|
|
| backend = AttentionBackendName(backend) |
| _check_attention_backend_requirements(backend) |
| _maybe_download_kernel_for_backend(backend) |
|
|
| old_backend = _AttentionBackendRegistry._active_backend |
| _AttentionBackendRegistry.set_active_backend(backend) |
|
|
| try: |
| yield |
| finally: |
| _AttentionBackendRegistry.set_active_backend(old_backend) |
|
|
|
|
| def dispatch_attention_fn( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| attention_kwargs: dict[str, Any] | None = None, |
| *, |
| backend: AttentionBackendName | None = None, |
| parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| attention_kwargs = attention_kwargs or {} |
|
|
| if backend is None: |
| |
| |
| backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() |
| else: |
| backend_name = AttentionBackendName(backend) |
| backend_fn = _AttentionBackendRegistry._backends.get(backend_name) |
|
|
| kwargs = { |
| "query": query, |
| "key": key, |
| "value": value, |
| "attn_mask": attn_mask, |
| "dropout_p": dropout_p, |
| "is_causal": is_causal, |
| "scale": scale, |
| **attention_kwargs, |
| "_parallel_config": parallel_config, |
| } |
| if is_torch_version(">=", "2.5.0"): |
| kwargs["enable_gqa"] = enable_gqa |
|
|
| if _AttentionBackendRegistry._checks_enabled: |
| removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) |
| if removed_kwargs: |
| logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") |
| for check in _AttentionBackendRegistry._constraints.get(backend_name): |
| check(**kwargs) |
|
|
| kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} |
|
|
| return backend_fn(**kwargs) |
|
|
|
|
| |
| |
|
|
|
|
| def _check_attn_mask_or_causal(attn_mask: torch.Tensor | None, is_causal: bool, **kwargs) -> None: |
| if attn_mask is not None and is_causal: |
| raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") |
|
|
|
|
| def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| if query.device != key.device or query.device != value.device: |
| raise ValueError("Query, key, and value must be on the same device.") |
| if query.dtype != key.dtype or query.dtype != value.dtype: |
| raise ValueError("Query, key, and value must have the same dtype.") |
|
|
|
|
| def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_device(query, key, value) |
| if query.device.type != "cuda": |
| raise ValueError("Query, key, and value must be on a CUDA device.") |
|
|
|
|
| def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: |
| def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_device_cuda(query, key, value) |
| if torch.cuda.get_device_capability(query.device) < (major, minor): |
| raise ValueError( |
| f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." |
| ) |
|
|
| return check_device_cuda |
|
|
|
|
| def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| if query.dtype != key.dtype: |
| raise ValueError("Query and key must have the same dtype.") |
| if query.dtype != value.dtype: |
| raise ValueError("Query and value must have the same dtype.") |
|
|
|
|
| def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_qkv_dtype_match(query, key, value) |
| if query.dtype not in (torch.bfloat16, torch.float16): |
| raise ValueError("Query, key, and value must be either bfloat16 or float16.") |
|
|
|
|
| def _check_shape( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| **kwargs, |
| ) -> None: |
| |
| |
| |
| |
| |
| |
| if query.shape[-1] != key.shape[-1]: |
| raise ValueError("Query and key must have the same head dimension.") |
| if key.shape[-3] != value.shape[-3]: |
| raise ValueError("Key and value must have the same sequence length.") |
| if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]: |
| raise ValueError("Attention mask must match the key's sequence length.") |
|
|
|
|
| |
|
|
|
|
| def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: |
| if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: |
| if not _CAN_USE_FLASH_ATTN: |
| raise RuntimeError( |
| f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." |
| ) |
|
|
| elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: |
| if not _CAN_USE_FLASH_ATTN_3: |
| raise RuntimeError( |
| f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." |
| ) |
|
|
| elif backend in [ |
| AttentionBackendName.FLASH_HUB, |
| AttentionBackendName.FLASH_VARLEN_HUB, |
| AttentionBackendName._FLASH_3_HUB, |
| AttentionBackendName._FLASH_3_VARLEN_HUB, |
| AttentionBackendName.SAGE_HUB, |
| ]: |
| if not is_kernels_available(): |
| raise RuntimeError( |
| f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
| ) |
| if not is_kernels_version(">=", "0.12"): |
| raise RuntimeError( |
| f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`." |
| ) |
|
|
| elif backend == AttentionBackendName.AITER: |
| if not _CAN_USE_AITER_ATTN: |
| raise RuntimeError( |
| f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`." |
| ) |
|
|
| elif backend in [ |
| AttentionBackendName.SAGE, |
| AttentionBackendName.SAGE_VARLEN, |
| AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, |
| AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
| AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, |
| AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, |
| ]: |
| if not _CAN_USE_SAGE_ATTN: |
| raise RuntimeError( |
| f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." |
| ) |
|
|
| elif backend == AttentionBackendName.FLEX: |
| if not _CAN_USE_FLEX_ATTN: |
| raise RuntimeError( |
| f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." |
| ) |
|
|
| elif backend == AttentionBackendName._NATIVE_NPU: |
| if not _CAN_USE_NPU_ATTN: |
| raise RuntimeError( |
| f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." |
| ) |
|
|
| elif backend == AttentionBackendName._NATIVE_XLA: |
| if not _CAN_USE_XLA_ATTN: |
| raise RuntimeError( |
| f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." |
| ) |
|
|
| elif backend == AttentionBackendName.XFORMERS: |
| if not _CAN_USE_XFORMERS_ATTN: |
| raise RuntimeError( |
| f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." |
| ) |
|
|
|
|
| @functools.lru_cache(maxsize=128) |
| def _prepare_for_flash_attn_or_sage_varlen_without_mask( |
| batch_size: int, |
| seq_len_q: int, |
| seq_len_kv: int, |
| device: torch.device | None = None, |
| ): |
| seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) |
| seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) |
| cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) |
| cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) |
| max_seqlen_q = seqlens_q.max().item() |
| max_seqlen_k = seqlens_k.max().item() |
| return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
|
|
|
|
| def _prepare_for_flash_attn_or_sage_varlen_with_mask( |
| batch_size: int, |
| seq_len_q: int, |
| attn_mask: torch.Tensor, |
| device: torch.device | None = None, |
| ): |
| seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) |
| seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) |
| cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) |
| cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) |
| max_seqlen_q = seqlens_q.max().item() |
| max_seqlen_k = seqlens_k.max().item() |
| return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
|
|
|
|
| def _prepare_for_flash_attn_or_sage_varlen( |
| batch_size: int, |
| seq_len_q: int, |
| seq_len_kv: int, |
| attn_mask: torch.Tensor | None = None, |
| device: torch.device | None = None, |
| ) -> None: |
| if attn_mask is None: |
| return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device) |
| return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device) |
|
|
|
|
| def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: |
| """ |
| Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in |
| FlashAttention/Sage varlen. |
| |
| Supports 1D to 4D shapes and common broadcasting patterns. |
| """ |
| if attn_mask.dtype != torch.bool: |
| raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") |
|
|
| if attn_mask.ndim == 1: |
| |
| attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 2: |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." |
| ) |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 3: |
| |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." |
| ) |
| attn_mask = attn_mask.any(dim=1) |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 4: |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." |
| ) |
| attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) |
| attn_mask = attn_mask.any(dim=(1, 2)) |
|
|
| else: |
| raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") |
|
|
| if attn_mask.shape != (batch_size, seq_len_k): |
| raise ValueError( |
| f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" |
| ) |
|
|
| return attn_mask |
|
|
|
|
| def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| return q_idx >= kv_idx |
|
|
|
|
| |
| def _resolve_kernel_attr(module, attr_path: str): |
| target = module |
| for attr in attr_path.split("."): |
| if not hasattr(target, attr): |
| raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.") |
| target = getattr(target, attr) |
| return target |
|
|
|
|
| def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: |
| if backend not in _HUB_KERNELS_REGISTRY: |
| return |
| config = _HUB_KERNELS_REGISTRY[backend] |
|
|
| needs_kernel = config.kernel_fn is None |
| needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None |
| needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None |
|
|
| if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): |
| return |
|
|
| try: |
| from kernels import get_kernel |
|
|
| kernel_module = get_kernel(config.repo_id, revision=config.revision, version=config.version) |
| if needs_kernel: |
| config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) |
|
|
| if needs_wrapped_forward: |
| config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) |
|
|
| if needs_wrapped_backward: |
| config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) |
|
|
| except Exception as e: |
| logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") |
| raise |
|
|
|
|
| |
| |
| |
| |
| @_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") |
| def _wrapped_flash_attn_3( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| softmax_scale: float | None = None, |
| causal: bool = False, |
| qv: torch.Tensor | None = None, |
| q_descale: torch.Tensor | None = None, |
| k_descale: torch.Tensor | None = None, |
| v_descale: torch.Tensor | None = None, |
| attention_chunk: int = 0, |
| softcap: float = 0.0, |
| num_splits: int = 1, |
| pack_gqa: bool | None = None, |
| deterministic: bool = False, |
| sm_margin: int = 0, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| window_size = (-1, -1) |
| result = flash_attn_3_func( |
| q=q, |
| k=k, |
| v=v, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| qv=qv, |
| q_descale=q_descale, |
| k_descale=k_descale, |
| v_descale=v_descale, |
| window_size=window_size, |
| attention_chunk=attention_chunk, |
| softcap=softcap, |
| num_splits=num_splits, |
| pack_gqa=pack_gqa, |
| deterministic=deterministic, |
| sm_margin=sm_margin, |
| return_attn_probs=True, |
| ) |
| out, lse, *_ = result |
| lse = lse.permute(0, 2, 1) |
| return out, lse |
|
|
|
|
| @_register_fake("_diffusers_flash_attn_3::_flash_attn_forward") |
| def _( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| softmax_scale: float | None = None, |
| causal: bool = False, |
| qv: torch.Tensor | None = None, |
| q_descale: torch.Tensor | None = None, |
| k_descale: torch.Tensor | None = None, |
| v_descale: torch.Tensor | None = None, |
| attention_chunk: int = 0, |
| softcap: float = 0.0, |
| num_splits: int = 1, |
| pack_gqa: bool | None = None, |
| deterministic: bool = False, |
| sm_margin: int = 0, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| window_size = (-1, -1) |
| |
| |
| batch_size, seq_len, num_heads, head_dim = q.shape |
| lse_shape = (batch_size, seq_len, num_heads) |
| return torch.empty_like(q), q.new_empty(lse_shape) |
|
|
|
|
| |
|
|
|
|
| def _native_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| |
| if return_lse: |
| raise ValueError("Native attention does not support return_lse=True") |
|
|
| |
| if _save_ctx: |
| ctx.save_for_backward(query, key, value) |
| ctx.attn_mask = attn_mask |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.enable_gqa = enable_gqa |
|
|
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
|
|
| return out |
|
|
|
|
| def _native_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| query, key, value = ctx.saved_tensors |
|
|
| query.requires_grad_(True) |
| key.requires_grad_(True) |
| value.requires_grad_(True) |
|
|
| query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query_t, |
| key=key_t, |
| value=value_t, |
| attn_mask=ctx.attn_mask, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| enable_gqa=ctx.enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
|
|
| grad_out_t = grad_out.permute(0, 2, 1, 3) |
| grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( |
| outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False |
| ) |
|
|
| grad_query = grad_query_t.permute(0, 2, 1, 3) |
| grad_key = grad_key_t.permute(0, 2, 1, 3) |
| grad_value = grad_value_t.permute(0, 2, 1, 3) |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| |
| |
| |
| def _cudnn_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.") |
|
|
| tensors_to_save = () |
|
|
| |
| |
| query = query.transpose(1, 2).contiguous() |
| key = key.transpose(1, 2).contiguous() |
| value = value.transpose(1, 2).contiguous() |
| tensors_to_save += (query, key, value) |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| torch.ops.aten._scaled_dot_product_cudnn_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_mask, |
| compute_log_sumexp=return_lse, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| return_debug_mask=False, |
| scale=scale, |
| ) |
| ) |
|
|
| tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
| if _save_ctx: |
| ctx.save_for_backward(*tensors_to_save) |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.attn_mask = attn_mask |
| ctx.max_q = max_q |
| ctx.max_k = max_k |
|
|
| out = out.transpose(1, 2).contiguous() |
| if lse is not None: |
| lse = lse.transpose(1, 2).contiguous() |
| return (out, lse) if return_lse else out |
|
|
|
|
| |
| |
| def _cudnn_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_out = grad_out.transpose(1, 2).contiguous() |
| key = key.transpose(1, 2).contiguous() |
| value = value.transpose(1, 2).contiguous() |
|
|
| |
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| attn_bias=ctx.attn_mask, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
| grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| |
| |
| |
| def _native_flash_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for native flash attention.") |
|
|
| tensors_to_save = () |
|
|
| query = query.transpose(1, 2).contiguous() |
| key = key.transpose(1, 2).contiguous() |
| value = value.transpose(1, 2).contiguous() |
| tensors_to_save += (query, key, value) |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| torch.ops.aten._scaled_dot_product_flash_attention( |
| query=query, |
| key=key, |
| value=value, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| return_debug_mask=False, |
| scale=scale, |
| ) |
| ) |
|
|
| tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
| if _save_ctx: |
| ctx.save_for_backward(*tensors_to_save) |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.max_q = max_q |
| ctx.max_k = max_k |
|
|
| out = out.transpose(1, 2).contiguous() |
| if lse is not None: |
| lse = lse.transpose(1, 2).contiguous() |
| return (out, lse) if return_lse else out |
|
|
|
|
| |
| |
| |
| def _native_flash_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_out = grad_out.transpose(1, 2).contiguous() |
| key = key.transpose(1, 2).contiguous() |
| value = value.transpose(1, 2).contiguous() |
|
|
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
| grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value)) |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| |
| def _flash_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not yet supported for flash-attn 2.") |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.") |
|
|
| |
| window_size = (-1, -1) |
| softcap = 0.0 |
| alibi_slopes = None |
| deterministic = False |
| grad_enabled = any(x.requires_grad for x in (query, key, value)) |
|
|
| if scale is None: |
| scale = query.shape[-1] ** (-0.5) |
|
|
| |
| if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): |
| dropout_p = dropout_p if dropout_p > 0 else 1e-30 |
|
|
| with torch.set_grad_enabled(grad_enabled): |
| out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward( |
| query, |
| key, |
| value, |
| dropout_p, |
| scale, |
| is_causal, |
| window_size[0], |
| window_size[1], |
| softcap, |
| alibi_slopes, |
| return_lse, |
| ) |
| lse = lse.permute(0, 2, 1) |
|
|
| if _save_ctx: |
| ctx.save_for_backward(query, key, value, out, lse, rng_state) |
| ctx.dropout_p = dropout_p |
| ctx.scale = scale |
| ctx.is_causal = is_causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| def _flash_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| query, key, value, out, lse, rng_state = ctx.saved_tensors |
| grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) |
|
|
| lse_d = _wrapped_flash_attn_backward( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| lse, |
| grad_query, |
| grad_key, |
| grad_value, |
| ctx.dropout_p, |
| ctx.scale, |
| ctx.is_causal, |
| ctx.window_size[0], |
| ctx.window_size[1], |
| ctx.softcap, |
| ctx.alibi_slopes, |
| ctx.deterministic, |
| rng_state, |
| ) |
|
|
| |
| grad_query = grad_query[..., : grad_out.shape[-1]] |
| grad_key = grad_key[..., : grad_out.shape[-1]] |
| grad_value = grad_value[..., : grad_out.shape[-1]] |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| def _flash_attention_hub_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.") |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.") |
|
|
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] |
| wrapped_forward_fn = config.wrapped_forward_fn |
| wrapped_backward_fn = config.wrapped_backward_fn |
| if wrapped_forward_fn is None or wrapped_backward_fn is None: |
| raise RuntimeError( |
| "Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` " |
| "for context parallel execution." |
| ) |
|
|
| if scale is None: |
| scale = query.shape[-1] ** (-0.5) |
|
|
| window_size = (-1, -1) |
| softcap = 0.0 |
| alibi_slopes = None |
| deterministic = False |
| grad_enabled = any(x.requires_grad for x in (query, key, value)) |
|
|
| if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): |
| dropout_p = dropout_p if dropout_p > 0 else 1e-30 |
|
|
| with torch.set_grad_enabled(grad_enabled): |
| out, lse, S_dmask, rng_state = wrapped_forward_fn( |
| query, |
| key, |
| value, |
| dropout_p, |
| scale, |
| is_causal, |
| window_size[0], |
| window_size[1], |
| softcap, |
| alibi_slopes, |
| return_lse, |
| ) |
| lse = lse.permute(0, 2, 1).contiguous() |
|
|
| if _save_ctx: |
| ctx.save_for_backward(query, key, value, out, lse, rng_state) |
| ctx.dropout_p = dropout_p |
| ctx.scale = scale |
| ctx.is_causal = is_causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| def _flash_attention_hub_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] |
| wrapped_backward_fn = config.wrapped_backward_fn |
| if wrapped_backward_fn is None: |
| raise RuntimeError( |
| "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." |
| ) |
|
|
| query, key, value, out, lse, rng_state = ctx.saved_tensors |
| grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) |
|
|
| _ = wrapped_backward_fn( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| lse, |
| grad_query, |
| grad_key, |
| grad_value, |
| ctx.dropout_p, |
| ctx.scale, |
| ctx.is_causal, |
| ctx.window_size[0], |
| ctx.window_size[1], |
| ctx.softcap, |
| ctx.alibi_slopes, |
| ctx.deterministic, |
| rng_state, |
| ) |
|
|
| grad_query = grad_query[..., : grad_out.shape[-1]] |
| grad_key = grad_key[..., : grad_out.shape[-1]] |
| grad_value = grad_value[..., : grad_out.shape[-1]] |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| def _flash_attention_3_hub_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| *, |
| window_size: tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| num_splits: int = 1, |
| pack_gqa: bool | None = None, |
| deterministic: bool = False, |
| sm_margin: int = 0, |
| ): |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.") |
| if dropout_p != 0.0: |
| raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.") |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.") |
|
|
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] |
| wrapped_forward_fn = config.wrapped_forward_fn |
| if wrapped_forward_fn is None: |
| raise RuntimeError( |
| "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` " |
| "for context parallel execution." |
| ) |
|
|
| if scale is None: |
| scale = query.shape[-1] ** (-0.5) |
|
|
| out, softmax_lse, *_ = wrapped_forward_fn( |
| query, |
| key, |
| value, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| scale, |
| causal=is_causal, |
| window_size_left=window_size[0], |
| window_size_right=window_size[1], |
| attention_chunk=0, |
| softcap=softcap, |
| num_splits=num_splits, |
| pack_gqa=pack_gqa, |
| sm_margin=sm_margin, |
| ) |
|
|
| lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None |
|
|
| if _save_ctx: |
| ctx.save_for_backward(query, key, value, out, softmax_lse) |
| ctx.scale = scale |
| ctx.is_causal = is_causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.deterministic = deterministic |
| ctx.sm_margin = sm_margin |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| def _flash_attention_3_hub_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] |
| wrapped_backward_fn = config.wrapped_backward_fn |
| if wrapped_backward_fn is None: |
| raise RuntimeError( |
| "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` " |
| "for context parallel execution." |
| ) |
|
|
| query, key, value, out, softmax_lse = ctx.saved_tensors |
| grad_query = torch.empty_like(query) |
| grad_key = torch.empty_like(key) |
| grad_value = torch.empty_like(value) |
|
|
| wrapped_backward_fn( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| softmax_lse, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| grad_query, |
| grad_key, |
| grad_value, |
| ctx.scale, |
| ctx.is_causal, |
| ctx.window_size[0], |
| ctx.window_size[1], |
| ctx.softcap, |
| ctx.deterministic, |
| ctx.sm_margin, |
| ) |
|
|
| grad_query = grad_query[..., : grad_out.shape[-1]] |
| grad_key = grad_key[..., : grad_out.shape[-1]] |
| grad_value = grad_value[..., : grad_out.shape[-1]] |
|
|
| return grad_query, grad_key, grad_value |
|
|
|
|
| def _sage_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not yet supported for Sage attention.") |
| if dropout_p > 0.0: |
| raise ValueError("`dropout_p` is not yet supported for Sage attention.") |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for Sage attention.") |
|
|
| out = sageattn( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
| lse = None |
| if return_lse: |
| out, lse, *_ = out |
| lse = lse.permute(0, 2, 1) |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| def _sage_attention_hub_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not yet supported for Sage attention.") |
| if dropout_p > 0.0: |
| raise ValueError("`dropout_p` is not yet supported for Sage attention.") |
| if enable_gqa: |
| raise ValueError("`enable_gqa` is not yet supported for Sage attention.") |
|
|
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn |
| out = func( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
|
|
| lse = None |
| if return_lse: |
| out, lse, *_ = out |
| lse = lse.permute(0, 2, 1).contiguous() |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| def _sage_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| ): |
| raise NotImplementedError("Backward pass is not implemented for Sage attention.") |
|
|
|
|
| def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None): |
| |
| if attn_mask is not None and torch.all(attn_mask != 0): |
| attn_mask = None |
|
|
| |
| |
| if ( |
| attn_mask is not None |
| and attn_mask.ndim == 2 |
| and attn_mask.shape[0] == query.shape[0] |
| and attn_mask.shape[1] == key.shape[1] |
| ): |
| B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] |
| attn_mask = ~attn_mask.to(torch.bool) |
| attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() |
|
|
| return attn_mask |
|
|
|
|
| def _npu_attention_forward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _save_ctx: bool = True, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if return_lse: |
| raise ValueError("NPU attention backend does not support setting `return_lse=True`.") |
|
|
| attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) |
|
|
| out = npu_fusion_attention( |
| query, |
| key, |
| value, |
| query.size(2), |
| atten_mask=attn_mask, |
| input_layout="BSND", |
| pse=None, |
| scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, |
| pre_tockens=65536, |
| next_tockens=65536, |
| keep_prob=1.0 - dropout_p, |
| sync=False, |
| inner_precise=0, |
| )[0] |
|
|
| return out |
|
|
|
|
| |
| def _npu_attention_backward_op( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| **kwargs, |
| ): |
| raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.") |
|
|
|
|
| |
|
|
|
|
| |
| |
| |
| |
| def _wait_tensor(tensor): |
| if isinstance(tensor, funcol.AsyncCollectiveTensor): |
| tensor = tensor.wait() |
| return tensor |
|
|
|
|
| def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: |
| shape = x.shape |
| |
| |
| |
| |
| x = x.flatten() |
| x = funcol.all_to_all_single(x, None, None, group) |
| x = x.reshape(shape) |
| x = _wait_tensor(x) |
| return x |
|
|
|
|
| def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: |
| """ |
| Perform dimension sharding / reassembly across processes using _all_to_all_single. |
| |
| This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or |
| head dimension flexibly by accepting scatter_idx and gather_idx. |
| |
| Args: |
| x (torch.Tensor): |
| Input tensor. Expected shapes: |
| - When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim) |
| - When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim) |
| scatter_idx (int) : |
| Dimension along which the tensor is partitioned before all-to-all. |
| gather_idx (int): |
| Dimension along which the output is reassembled after all-to-all. |
| group : |
| Distributed process group for the Ulysses group. |
| |
| Returns: |
| torch.Tensor: Tensor with globally exchanged dimensions. |
| - For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim) |
| - For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim) |
| """ |
| group_world_size = torch.distributed.get_world_size(group) |
|
|
| if scatter_idx == 2 and gather_idx == 1: |
| |
| |
| batch_size, seq_len_local, num_heads, head_dim = x.shape |
| seq_len = seq_len_local * group_world_size |
| num_heads_local = num_heads // group_world_size |
|
|
| |
| x_temp = ( |
| x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim) |
| .transpose(0, 2) |
| .contiguous() |
| ) |
|
|
| if group_world_size > 1: |
| out = _all_to_all_single(x_temp, group=group) |
| else: |
| out = x_temp |
| |
| out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous() |
| out = out.reshape(batch_size, seq_len, num_heads_local, head_dim) |
| return out |
| elif scatter_idx == 1 and gather_idx == 2: |
| |
| |
| batch_size, seq_len, num_heads_local, head_dim = x.shape |
| num_heads = num_heads_local * group_world_size |
| seq_len_local = seq_len // group_world_size |
|
|
| |
| x_temp = ( |
| x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim) |
| .permute(1, 3, 2, 0, 4) |
| .reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim) |
| ) |
|
|
| if group_world_size > 1: |
| output = _all_to_all_single(x_temp, group) |
| else: |
| output = x_temp |
| output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous() |
| output = output.reshape(batch_size, seq_len_local, num_heads, head_dim) |
| return output |
| else: |
| raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") |
|
|
|
|
| class SeqAllToAllDim(torch.autograd.Function): |
| """ |
| all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange |
| for more info. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, group, input, scatter_id=2, gather_id=1): |
| ctx.group = group |
| ctx.scatter_id = scatter_id |
| ctx.gather_id = gather_id |
| return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) |
|
|
| @staticmethod |
| def backward(ctx, grad_outputs): |
| grad_input = SeqAllToAllDim.apply( |
| ctx.group, |
| grad_outputs, |
| ctx.gather_id, |
| ctx.scatter_id, |
| ) |
| return (None, grad_input, None, None) |
|
|
|
|
| |
| def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> tuple[torch.Tensor, int]: |
| r"""Maybe pad the head dimension to be divisible by world_size. |
| x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: tuple[torch.Tensor, int], padded |
| tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD |
| """ |
| world_size = dist.get_world_size(group=group) |
| H_PAD = 0 |
| if H % world_size != 0: |
| H_PAD = world_size - (H % world_size) |
| NEW_H_LOCAL = (H + H_PAD) // world_size |
| |
| |
| assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" |
| x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() |
| return x, H_PAD |
|
|
|
|
| def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: |
| r"""Maybe unpad the head dimension. |
| x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, |
| unpadded tensor (B, S_GLOBAL, H_LOCAL, D) |
| """ |
| rank = dist.get_rank(group=group) |
| world_size = dist.get_world_size(group=group) |
| |
| if H_PAD > 0 and rank == world_size - 1: |
| x = x[:, :, :-H_PAD, :] |
| return x.contiguous() |
|
|
|
|
| def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> tuple[torch.Tensor, int]: |
| r"""Maybe pad the head dimension to be divisible by world_size. |
| x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: tuple[torch.Tensor, int], |
| padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD |
| """ |
| if H is None: |
| return x, 0 |
|
|
| rank = dist.get_rank(group=group) |
| world_size = dist.get_world_size(group=group) |
| H_PAD = 0 |
| |
| if H % world_size != 0: |
| |
| |
| H_PAD = world_size - (H % world_size) |
| NEW_H_LOCAL = (H + H_PAD) // world_size |
| assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" |
| if rank == world_size - 1: |
| x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() |
| return x, H_PAD |
|
|
|
|
| def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: |
| r"""Maybe unpad the head dimension. |
| x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, |
| unpadded tensor (B, S_LOCAL, H_GLOBAL, D) |
| """ |
| if H_PAD > 0: |
| x = x[:, :, :-H_PAD, :] |
| return x.contiguous() |
|
|
|
|
| def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: |
| |
| assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" |
| extra_kwargs = {} |
| extra_kwargs["NUM_QO_HEAD"] = query.shape[2] |
| extra_kwargs["Q_S_LOCAL"] = query.shape[1] |
| |
| return extra_kwargs |
|
|
|
|
| @maybe_allow_in_graph |
| def all_to_all_single_any_qkv_async( |
| x: torch.Tensor, group: dist.ProcessGroup, **kwargs |
| ) -> Callable[..., torch.Tensor]: |
| r""" |
| x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) |
| """ |
| world_size = dist.get_world_size(group=group) |
| B, S_LOCAL, H, D = x.shape |
| x, H_PAD = _maybe_pad_qkv_head(x, H, group) |
| H_LOCAL = (H + H_PAD) // world_size |
| |
| x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() |
|
|
| input_split_sizes = [S_LOCAL] * world_size |
| |
| |
| |
| output_split_sizes = gather_size_by_comm(S_LOCAL, group) |
| x = x.flatten(0, 1) |
| x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) |
|
|
| def wait() -> torch.Tensor: |
| nonlocal x, H_PAD |
| x = _wait_tensor(x) |
| |
| |
| x = x.permute(1, 0, 2, 3).contiguous() |
| x = _maybe_unpad_qkv_head(x, H_PAD, group) |
| return x |
|
|
| return wait |
|
|
|
|
| @maybe_allow_in_graph |
| def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: |
| r""" |
| x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) |
| """ |
| |
| |
| H = kwargs.get("NUM_QO_HEAD", None) |
| world_size = dist.get_world_size(group=group) |
|
|
| x, H_PAD = _maybe_pad_o_head(x, H, group) |
| shape = x.shape |
| (B, S_GLOBAL, H_LOCAL, D) = shape |
|
|
| |
| |
|
|
| |
| |
| |
| |
|
|
| S_LOCAL = kwargs.get("Q_S_LOCAL") |
| input_split_sizes = gather_size_by_comm(S_LOCAL, group) |
| x = x.permute(1, 0, 2, 3).contiguous() |
| output_split_sizes = [S_LOCAL] * world_size |
| x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group) |
|
|
| def wait() -> torch.Tensor: |
| nonlocal x, H_PAD |
| x = _wait_tensor(x) |
| x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) |
| x = x.permute(2, 1, 0, 3, 4).contiguous() |
| x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) |
| x = _maybe_unpad_o_head(x, H_PAD, group) |
| return x |
|
|
| return wait |
|
|
|
|
| class TemplatedRingAttention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None, |
| dropout_p: float, |
| is_causal: bool, |
| scale: float | None, |
| enable_gqa: bool, |
| return_lse: bool, |
| forward_op, |
| backward_op, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| ring_mesh = _parallel_config.context_parallel_config._ring_mesh |
| rank = _parallel_config.context_parallel_config._ring_local_rank |
| world_size = _parallel_config.context_parallel_config.ring_degree |
| next_rank = (rank + 1) % world_size |
| prev_out = prev_lse = None |
|
|
| ctx.forward_op = forward_op |
| ctx.backward_op = backward_op |
| ctx.q_shape = query.shape |
| ctx.kv_shape = key.shape |
| ctx._parallel_config = _parallel_config |
|
|
| kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() |
| kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) |
| kv_buffer = kv_buffer.chunk(world_size) |
|
|
| for i in range(world_size): |
| if i > 0: |
| kv = kv_buffer[next_rank] |
| key_numel = key.numel() |
| key = kv[:key_numel].reshape_as(key) |
| value = kv[key_numel:].reshape_as(value) |
| next_rank = (next_rank + 1) % world_size |
|
|
| out, lse = forward_op( |
| ctx, |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| True, |
| _save_ctx=i == 0, |
| _parallel_config=_parallel_config, |
| ) |
|
|
| if _parallel_config.context_parallel_config.convert_to_fp32: |
| out = out.to(torch.float32) |
| lse = lse.to(torch.float32) |
|
|
| |
| |
| if is_torch_version("<", "2.9.0"): |
| lse = lse.unsqueeze(-1) |
| if prev_out is not None: |
| out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) |
| lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) |
| prev_out = out |
| prev_lse = lse |
|
|
| out = out.to(query.dtype) |
| lse = lse.squeeze(-1) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| ): |
| ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh |
| rank = ctx._parallel_config.context_parallel_config._ring_local_rank |
| world_size = ctx._parallel_config.context_parallel_config.ring_degree |
| next_rank = (rank + 1) % world_size |
| next_ranks = list(range(1, world_size)) + [0] |
|
|
| accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype |
| grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device) |
| grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) |
| grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device) |
| next_grad_kv = None |
|
|
| query, key, value, *_ = ctx.saved_tensors |
| kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() |
| kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group()) |
| kv_buffer = kv_buffer.chunk(world_size) |
|
|
| for i in range(world_size): |
| if i > 0: |
| kv = kv_buffer[next_rank] |
| key_numel = key.numel() |
| key = kv[:key_numel].reshape_as(key) |
| value = kv[key_numel:].reshape_as(value) |
| next_rank = (next_rank + 1) % world_size |
|
|
| grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) |
|
|
| if i > 0: |
| grad_kv_buffer = _wait_tensor(next_grad_kv) |
| grad_key_numel = grad_key.numel() |
| grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key) |
| grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value) |
|
|
| grad_query += grad_query_op |
| grad_key += grad_key_op |
| grad_value += grad_value_op |
|
|
| if i < world_size - 1: |
| grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() |
| next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group()) |
|
|
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None |
|
|
|
|
| class TemplatedUlyssesAttention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None, |
| dropout_p: float, |
| is_causal: bool, |
| scale: float | None, |
| enable_gqa: bool, |
| return_lse: bool, |
| forward_op, |
| backward_op, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh |
| world_size = _parallel_config.context_parallel_config.ulysses_degree |
| group = ulysses_mesh.get_group() |
|
|
| ctx.forward_op = forward_op |
| ctx.backward_op = backward_op |
| ctx._parallel_config = _parallel_config |
|
|
| B, S_Q_LOCAL, H, D = query.shape |
| _, S_KV_LOCAL, _, _ = key.shape |
| H_LOCAL = H // world_size |
| query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() |
| key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() |
| value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() |
| query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) |
| query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) |
|
|
| out = forward_op( |
| ctx, |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| _save_ctx=True, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
|
|
| out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() |
| out = _all_to_all_single(out, group) |
| out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() |
|
|
| if return_lse: |
| lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() |
| lse = _all_to_all_single(lse, group) |
| lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() |
| else: |
| lse = None |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| ): |
| ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh |
| world_size = ctx._parallel_config.context_parallel_config.ulysses_degree |
| group = ulysses_mesh.get_group() |
|
|
| B, S_LOCAL, H, D = grad_out.shape |
| H_LOCAL = H // world_size |
|
|
| grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() |
| grad_out = _all_to_all_single(grad_out, group) |
| grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() |
|
|
| grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) |
|
|
| grad_query, grad_key, grad_value = ( |
| x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() |
| for x in (grad_query_op, grad_key_op, grad_value_op) |
| ) |
| grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value)) |
| grad_query, grad_key, grad_value = ( |
| x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None |
|
|
|
|
| class TemplatedUlyssesAnythingAttention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor, |
| dropout_p: float, |
| is_causal: bool, |
| scale: float, |
| enable_gqa: bool, |
| return_lse: bool, |
| forward_op, |
| backward_op, |
| _parallel_config: "ParallelConfig" | None = None, |
| **kwargs, |
| ): |
| ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh |
| group = ulysses_mesh.get_group() |
|
|
| ctx.forward_op = forward_op |
| ctx.backward_op = backward_op |
| ctx._parallel_config = _parallel_config |
|
|
| metadata = ulysses_anything_metadata(query) |
| query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) |
| key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) |
| value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) |
|
|
| query = query_wait() |
| key = key_wait() |
| value = value_wait() |
|
|
| out = forward_op( |
| ctx, |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| _save_ctx=False, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
|
|
| |
| out_wait = all_to_all_single_any_o_async(out, group, **metadata) |
|
|
| if return_lse: |
| |
| lse = lse.unsqueeze(-1) |
| lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) |
| out = out_wait() |
| lse = lse_wait() |
| lse = lse.squeeze(-1).contiguous() |
| else: |
| out = out_wait() |
| lse = None |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args, |
| ): |
| raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") |
|
|
|
|
| def _templated_unified_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor, |
| dropout_p: float, |
| is_causal: bool, |
| scale: float, |
| enable_gqa: bool, |
| return_lse: bool, |
| forward_op, |
| backward_op, |
| _parallel_config: "ParallelConfig" | None = None, |
| scatter_idx: int = 2, |
| gather_idx: int = 1, |
| ): |
| """ |
| Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719 |
| """ |
| ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh |
| ulysses_group = ulysses_mesh.get_group() |
|
|
| query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) |
| key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) |
| value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) |
| out = TemplatedRingAttention.apply( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op, |
| backward_op, |
| _parallel_config, |
| ) |
| if return_lse: |
| context_layer, lse, *_ = out |
| else: |
| context_layer = out |
| |
| output = SeqAllToAllDim.apply( |
| ulysses_group, |
| context_layer, |
| gather_idx, |
| scatter_idx, |
| ) |
| if return_lse: |
| |
| |
| |
| if is_torch_version("<", "2.9.0"): |
| lse = lse.unsqueeze(-1) |
| lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) |
| lse = lse.squeeze(-1) |
| return (output, lse) |
| return output |
|
|
|
|
| def _templated_context_parallel_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| *, |
| forward_op, |
| backward_op, |
| _parallel_config: "ParallelConfig" | None = None, |
| ): |
| if is_causal: |
| raise ValueError("Causal attention is not yet supported for templated attention.") |
| if enable_gqa: |
| raise ValueError("GQA is not yet supported for templated attention.") |
|
|
| |
| if ( |
| _parallel_config.context_parallel_config.ring_degree > 1 |
| and _parallel_config.context_parallel_config.ulysses_degree > 1 |
| ): |
| return _templated_unified_attention( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op, |
| backward_op, |
| _parallel_config, |
| ) |
| elif _parallel_config.context_parallel_config.ring_degree > 1: |
| return TemplatedRingAttention.apply( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op, |
| backward_op, |
| _parallel_config, |
| ) |
| elif _parallel_config.context_parallel_config.ulysses_degree > 1: |
| if _parallel_config.context_parallel_config.ulysses_anything: |
| |
| return TemplatedUlyssesAnythingAttention.apply( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op, |
| backward_op, |
| _parallel_config, |
| ) |
| else: |
| return TemplatedUlyssesAttention.apply( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op, |
| backward_op, |
| _parallel_config, |
| ) |
| else: |
| raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") |
|
|
|
|
| |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.FLASH, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _flash_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| lse = None |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for flash-attn 2.") |
|
|
| if _parallel_config is None: |
| out = flash_attn_func( |
| q=query, |
| k=key, |
| v=value, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_attn_probs=return_lse, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| dropout_p, |
| is_causal, |
| scale, |
| False, |
| return_lse, |
| forward_op=_flash_attention_forward_op, |
| backward_op=_flash_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.FLASH_HUB, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _flash_attention_hub( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| lse = None |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for flash-attn 2.") |
|
|
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn |
| if _parallel_config is None: |
| out = func( |
| q=query, |
| k=key, |
| v=value, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_attn_probs=return_lse, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| dropout_p, |
| is_causal, |
| scale, |
| False, |
| return_lse, |
| forward_op=_flash_attention_hub_forward_op, |
| backward_op=_flash_attention_hub_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.FLASH_VARLEN_HUB, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=False, |
| ) |
| def _flash_varlen_attention_hub( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| scale: float | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len_q, _, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn |
| out = func( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_attn_probs=return_lse, |
| ) |
| out = out.unflatten(0, (batch_size, -1)) |
|
|
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.FLASH_VARLEN, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _flash_varlen_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| scale: float | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len_q, _, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| out = flash_attn_varlen_func( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_attn_probs=return_lse, |
| ) |
| out = out.unflatten(0, (batch_size, -1)) |
|
|
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._FLASH_3, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _flash_attention_3( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| scale: float | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for flash-attn 3.") |
|
|
| out, lse = _wrapped_flash_attn_3( |
| q=query, |
| k=key, |
| v=value, |
| softmax_scale=scale, |
| causal=is_causal, |
| ) |
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._FLASH_3_HUB, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _flash_attention_3_hub( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| scale: float | None = None, |
| is_causal: bool = False, |
| window_size: tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| deterministic: bool = False, |
| return_attn_probs: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for flash-attn 3.") |
|
|
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn |
| if _parallel_config is None: |
| out = func( |
| q=query, |
| k=key, |
| v=value, |
| softmax_scale=scale, |
| causal=is_causal, |
| qv=None, |
| q_descale=None, |
| k_descale=None, |
| v_descale=None, |
| window_size=window_size, |
| softcap=softcap, |
| num_splits=1, |
| pack_gqa=None, |
| deterministic=deterministic, |
| sm_margin=0, |
| return_attn_probs=return_attn_probs, |
| ) |
| return (out[0], out[1]) if return_attn_probs else out |
|
|
| forward_op = functools.partial( |
| _flash_attention_3_hub_forward_op, |
| window_size=window_size, |
| softcap=softcap, |
| num_splits=1, |
| pack_gqa=None, |
| deterministic=deterministic, |
| sm_margin=0, |
| ) |
| backward_op = functools.partial( |
| _flash_attention_3_hub_backward_op, |
| window_size=window_size, |
| softcap=softcap, |
| num_splits=1, |
| pack_gqa=None, |
| deterministic=deterministic, |
| sm_margin=0, |
| ) |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| 0.0, |
| is_causal, |
| scale, |
| False, |
| return_attn_probs, |
| forward_op=forward_op, |
| backward_op=backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_attn_probs: |
| out, lse = out |
| return out, lse |
|
|
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._FLASH_3_VARLEN_HUB, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=False, |
| ) |
| def _flash_attention_3_varlen_hub( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| scale: float | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len_q, _, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn |
| out, lse, *_ = func( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| softmax_scale=scale, |
| causal=is_causal, |
| ) |
| out = out.unflatten(0, (batch_size, -1)) |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._FLASH_VARLEN_3, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _flash_varlen_attention_3( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| scale: float | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| batch_size, seq_len_q, _, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| result = flash_attn_3_varlen_func( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_attn_probs=return_lse, |
| ) |
| if isinstance(result, tuple): |
| out, lse, *_ = result |
| else: |
| out = result |
| lse = None |
| out = out.unflatten(0, (batch_size, -1)) |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.AITER, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _aiter_flash_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for aiter attention") |
|
|
| if not return_lse and torch.is_grad_enabled(): |
| |
| out, lse, *_ = aiter_flash_attn_func( |
| q=query, |
| k=key, |
| v=value, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_lse=True, |
| ) |
| else: |
| out = aiter_flash_attn_func( |
| q=query, |
| k=key, |
| v=value, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| return_lse=return_lse, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.FLEX, |
| constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
| ) |
| def _native_flex_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | "flex_attention.BlockMask" | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| |
| score_mod = None |
| block_mask = None |
| batch_size, seq_len_q, num_heads, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): |
| block_mask = attn_mask |
| elif is_causal: |
| block_mask = flex_attention.create_block_mask( |
| _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device |
| ) |
| elif torch.is_tensor(attn_mask): |
| if attn_mask.ndim == 2: |
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) |
|
|
| attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) |
|
|
| if attn_mask.dtype == torch.bool: |
| |
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| return attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
|
|
| block_mask = flex_attention.create_block_mask( |
| mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device |
| ) |
| else: |
|
|
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): |
| return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
| else: |
| raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") |
|
|
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| out = flex_attention.flex_attention( |
| query=query, |
| key=key, |
| value=value, |
| score_mod=score_mod, |
| block_mask=block_mask, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| return_lse=return_lse, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| return out |
|
|
|
|
| def _prepare_additive_attn_mask( |
| attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True |
| ) -> torch.Tensor: |
| """ |
| Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA. |
| |
| This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks. |
| |
| Args: |
| attn_mask: 2D tensor [batch_size, seq_len_k] |
| - Boolean: True means attend, False means mask out |
| - Additive: 0.0 means attend, -inf means mask out |
| target_dtype: The dtype to convert the mask to (usually query.dtype) |
| reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting |
| |
| Returns: |
| Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if |
| reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True. |
| """ |
| |
| if attn_mask.dtype == torch.bool: |
| |
| attn_mask = torch.where(attn_mask, 0.0, float("-inf")) |
| |
| attn_mask = attn_mask.to(dtype=target_dtype) |
| else: |
| |
| attn_mask = attn_mask.to(dtype=target_dtype) |
|
|
| |
| if reshape_4d: |
| batch_size, seq_len_k = attn_mask.shape |
| attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k) |
|
|
| return attn_mask |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.NATIVE, |
| constraints=[_check_device, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _native_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("Native attention backend does not support setting `return_lse=True`.") |
|
|
| |
| |
| if ( |
| attn_mask is not None |
| and attn_mask.ndim == 2 |
| and attn_mask.shape[0] == query.shape[0] |
| and attn_mask.shape[1] == key.shape[1] |
| ): |
| |
| |
| attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) |
|
|
| if _parallel_config is None: |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op=_native_attention_forward_op, |
| backward_op=_native_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
|
|
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_CUDNN, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _native_cudnn_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| lse = None |
| if _parallel_config is None and not return_lse: |
| query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op=_cudnn_attention_forward_op, |
| backward_op=_cudnn_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_EFFICIENT, |
| constraints=[_check_device, _check_shape], |
| ) |
| def _native_efficient_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.") |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_FLASH, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _native_flash_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for aiter attention") |
|
|
| lse = None |
| if _parallel_config is None and not return_lse: |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=None, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| dropout_p, |
| is_causal, |
| scale, |
| enable_gqa, |
| return_lse, |
| forward_op=_native_flash_attention_forward_op, |
| backward_op=_native_flash_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_MATH, |
| constraints=[_check_device, _check_shape], |
| ) |
| def _native_math_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("Native math attention backend does not support setting `return_lse=True`.") |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
| out = torch.nn.functional.scaled_dot_product_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_NPU, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _native_npu_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("NPU attention backend does not support setting `return_lse=True`.") |
| if _parallel_config is None: |
| attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask) |
|
|
| out = npu_fusion_attention( |
| query, |
| key, |
| value, |
| query.size(2), |
| atten_mask=attn_mask, |
| input_layout="BSND", |
| pse=None, |
| scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, |
| pre_tockens=65536, |
| next_tockens=65536, |
| keep_prob=1.0 - dropout_p, |
| sync=False, |
| inner_precise=0, |
| )[0] |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| attn_mask, |
| dropout_p, |
| None, |
| scale, |
| None, |
| return_lse, |
| forward_op=_npu_attention_forward_op, |
| backward_op=_npu_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| return out |
|
|
|
|
| |
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._NATIVE_XLA, |
| constraints=[_check_device, _check_shape], |
| ) |
| def _native_xla_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for XLA attention") |
| if return_lse: |
| raise ValueError("XLA attention backend does not support setting `return_lse=True`.") |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
| query = query / math.sqrt(query.shape[-1]) |
| out = xla_flash_attention( |
| q=query, |
| k=key, |
| v=value, |
| causal=is_causal, |
| ) |
| out = out.permute(0, 2, 1, 3) |
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.SAGE, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _sage_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| lse = None |
| if _parallel_config is None: |
| out = sageattn( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| 0.0, |
| is_causal, |
| scale, |
| False, |
| return_lse, |
| forward_op=_sage_attention_forward_op, |
| backward_op=_sage_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.SAGE_HUB, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_context_parallel=True, |
| ) |
| def _sage_attention_hub( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| lse = None |
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn |
| if _parallel_config is None: |
| out = func( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
| if return_lse: |
| out, lse, *_ = out |
| else: |
| out = _templated_context_parallel_attention( |
| query, |
| key, |
| value, |
| None, |
| 0.0, |
| is_causal, |
| scale, |
| False, |
| return_lse, |
| forward_op=_sage_attention_hub_forward_op, |
| backward_op=_sage_attention_backward_op, |
| _parallel_config=_parallel_config, |
| ) |
| if return_lse: |
| out, lse = out |
|
|
| return (out, lse) if return_lse else out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.SAGE_VARLEN, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _sage_varlen_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("Sage varlen backend does not support setting `return_lse=True`.") |
|
|
| batch_size, seq_len_q, _, _ = query.shape |
| _, seq_len_kv, _, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| out = sageattn_varlen( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| is_causal=is_causal, |
| sm_scale=scale, |
| ) |
| out = out.unflatten(0, (batch_size, -1)) |
|
|
| return out |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, |
| constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
| ) |
| def _sage_qk_int8_pv_fp8_cuda_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| return sageattn_qk_int8_pv_fp8_cuda( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
| constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
| ) |
| def _sage_qk_int8_pv_fp8_cuda_sm90_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| return sageattn_qk_int8_pv_fp8_cuda_sm90( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, |
| constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
| ) |
| def _sage_qk_int8_pv_fp16_cuda_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| return sageattn_qk_int8_pv_fp16_cuda( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, |
| constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
| ) |
| def _sage_qk_int8_pv_fp16_triton_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| is_causal: bool = False, |
| scale: float | None = None, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if attn_mask is not None: |
| raise ValueError("`attn_mask` is not supported for sage attention") |
| return sageattn_qk_int8_pv_fp16_triton( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="NHD", |
| is_causal=is_causal, |
| sm_scale=scale, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionBackendRegistry.register( |
| AttentionBackendName.XFORMERS, |
| constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
| ) |
| def _xformers_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: torch.Tensor | None = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: float | None = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| _parallel_config: "ParallelConfig" | None = None, |
| ) -> torch.Tensor: |
| if return_lse: |
| raise ValueError("xformers attention backend does not support setting `return_lse=True`.") |
|
|
| batch_size, seq_len_q, num_heads_q, _ = query.shape |
| _, seq_len_kv, num_heads_kv, _ = key.shape |
|
|
| if is_causal: |
| attn_mask = xops.LowerTriangularMask() |
| elif attn_mask is not None: |
| if attn_mask.ndim == 2: |
| |
| |
| |
| |
| original_seq_len = attn_mask.size(1) |
| aligned_seq_len = ((original_seq_len + 7) // 8) * 8 |
|
|
| |
| aligned_mask = torch.zeros( |
| (batch_size, num_heads_q, seq_len_q, aligned_seq_len), |
| dtype=query.dtype, |
| device=query.device, |
| ) |
| |
| mask_additive = _prepare_additive_attn_mask( |
| attn_mask, target_dtype=query.dtype |
| ) |
| |
| aligned_mask[:, :, :, :original_seq_len] = mask_additive |
| |
| aligned_mask[:, :, :, original_seq_len:] = float("-inf") |
|
|
| |
| attn_mask = aligned_mask[:, :, :, :seq_len_kv] |
| elif attn_mask.ndim != 4: |
| raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") |
| elif attn_mask.ndim == 4: |
| attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) |
|
|
| if enable_gqa: |
| if num_heads_q % num_heads_kv != 0: |
| raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") |
| num_heads_per_group = num_heads_q // num_heads_kv |
| query = query.unflatten(2, (num_heads_kv, -1)) |
| key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
| value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
|
|
| out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) |
|
|
| if enable_gqa: |
| out = out.flatten(2, 3) |
|
|
| return out |
|
|