|
import math |
|
from typing import Optional, Tuple, TypeVar |
|
import torch.nn as nn |
|
import torch |
|
import triton |
|
|
|
from functools import lru_cache |
|
|
|
|
|
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd |
|
|
|
|
|
Layout = Tuple[torch.LongTensor, torch.LongTensor] |
|
|
|
|
|
def create_sparse_attn_mask( |
|
n_heads: int, |
|
max_seq_len: int, |
|
max_seq_len_k: int, |
|
dtype: torch.dtype, |
|
device: torch.device, |
|
BLOCK: int, |
|
local_blocks: int, |
|
vert_stride: int, |
|
homo_head: bool, |
|
return_dense: bool |
|
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]: |
|
layout, block_sparse_pattern, _ = _get_sparse_attn_mask( |
|
n_heads=n_heads, |
|
q_len=max_seq_len, |
|
N_CTX=max_seq_len_k, |
|
dtype=dtype, |
|
device=device, |
|
BLOCK=BLOCK, |
|
local_blocks=local_blocks, |
|
vert_stride=vert_stride, |
|
homo_head=homo_head, |
|
return_dense=return_dense |
|
) |
|
return layout, block_sparse_pattern |
|
|
|
|
|
class BlockSparseAttentionLayer(nn.Module): |
|
def __init__( |
|
self, |
|
n_heads: int, |
|
max_seq_len: int, |
|
sparse_block_size: int, |
|
local_blocks: int, |
|
vert_stride: int, |
|
kernel_block_size: Optional[int] = None, |
|
homo_head: bool = False, |
|
active_head_range: Optional[Tuple[int]] = None |
|
) -> None: |
|
super().__init__() |
|
|
|
self.n_heads = n_heads |
|
self.max_seq_len = max_seq_len |
|
self.sparse_block_size = sparse_block_size |
|
self.kernel_block_size = kernel_block_size or sparse_block_size |
|
self.local_blocks = local_blocks |
|
self.vert_stride = vert_stride |
|
self.homo_head = homo_head |
|
self.active_head_range = active_head_range |
|
|
|
|
|
self._sparse_block_mask = None |
|
self._sparse_layout = None |
|
self._dtype = None |
|
self._device = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None: |
|
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end] |
|
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end] |
|
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end] |
|
|
|
def _initialize_internals( |
|
self, |
|
dtype: torch.dtype, |
|
device: torch.device |
|
) -> None: |
|
self._dtype, self._device = dtype, device |
|
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask( |
|
n_heads=self.n_heads, |
|
max_seq_len=self.max_seq_len, |
|
max_seq_len_k=self.max_seq_len, |
|
dtype=dtype, |
|
device=device, |
|
BLOCK=self.sparse_block_size, |
|
local_blocks=self.local_blocks, |
|
vert_stride=self.vert_stride, |
|
homo_head=self.homo_head, |
|
return_dense=False, |
|
) |
|
if (not self.homo_head) and (self.active_head_range is not None): |
|
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads." |
|
h_start, h_end = self.active_head_range |
|
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end) |
|
|
|
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}." |
|
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given" |
|
if self.sparse_block_size // self.kernel_block_size > 1: |
|
_mul = self.sparse_block_size // self.kernel_block_size |
|
|
|
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul)) |
|
num_sparse_blocks = self._sparse_block_mask.size(-1) |
|
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] |
|
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask) |
|
|
|
|
|
def forward( |
|
self, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
sm_scale: float, |
|
*, |
|
|
|
left_paddings: Optional[torch.LongTensor] = None, |
|
seqlens: Optional[torch.LongTensor] = None, |
|
|
|
cu_seqlens_k: Optional[torch.LongTensor] = None, |
|
cu_seqlens_q: Optional[torch.LongTensor] = None, |
|
) -> torch.Tensor: |
|
|
|
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None: |
|
blocksparse_op = get_local_strided_sparse_attention_op( |
|
n_heads=self.n_heads, |
|
max_seq_len=self.max_seq_len, |
|
sparse_block_size=self.sparse_block_size, |
|
kernel_block_size=self.kernel_block_size, |
|
local_blocks=self.local_blocks, |
|
vert_stride=self.vert_stride, |
|
homo_head=self.homo_head, |
|
device=q.device, |
|
inference=not self.training |
|
) |
|
return blocksparse_op(q, k, v, sm_scale) |
|
|
|
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context" |
|
|
|
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device): |
|
self._initialize_internals(dtype=q.dtype, device=q.device) |
|
|
|
if k.dim() == 3: |
|
assert cu_seqlens_k is not None |
|
return blocksparse_flash_attn_varlen_fwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
cu_seqlens_k=cu_seqlens_k, |
|
cu_seqlens_q=cu_seqlens_q, |
|
sm_scale=sm_scale, |
|
sparse_layout=self._sparse_layout, |
|
block_size=self.kernel_block_size, |
|
max_seqlen=self.max_seq_len, |
|
) |
|
if k.dim() == 4: |
|
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference." |
|
return blocksparse_flash_attn_padded_fwd( |
|
q=q, |
|
k=k, |
|
v=v, |
|
sm_scale=sm_scale, |
|
sparse_layout=self._sparse_layout, |
|
left_paddings=left_paddings, |
|
seqlens=seqlens, |
|
block_size=self.kernel_block_size, |
|
max_seqlen=self.max_seq_len, |
|
) |
|
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.') |
|
|