Phi-3-small-128k-instruct / triton_blocksparse_attention_layer.py
aaronday3's picture
Duplicate from microsoft/Phi-3-small-128k-instruct
afd5680 verified
raw
history blame contribute delete
No virus
7.2 kB
import math
from typing import Optional, Tuple, TypeVar
import torch.nn as nn
import torch
import triton
from functools import lru_cache
from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
Layout = Tuple[torch.LongTensor, torch.LongTensor]
def create_sparse_attn_mask(
n_heads: int,
max_seq_len: int,
max_seq_len_k: int,
dtype: torch.dtype,
device: torch.device,
BLOCK: int,
local_blocks: int,
vert_stride: int,
homo_head: bool,
return_dense: bool
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
n_heads=n_heads,
q_len=max_seq_len,
N_CTX=max_seq_len_k,
dtype=dtype,
device=device,
BLOCK=BLOCK,
local_blocks=local_blocks,
vert_stride=vert_stride,
homo_head=homo_head,
return_dense=return_dense
)
return layout, block_sparse_pattern
class BlockSparseAttentionLayer(nn.Module):
def __init__(
self,
n_heads: int,
max_seq_len: int,
sparse_block_size: int,
local_blocks: int,
vert_stride: int,
kernel_block_size: Optional[int] = None,
homo_head: bool = False,
active_head_range: Optional[Tuple[int]] = None
) -> None:
super().__init__()
self.n_heads = n_heads
self.max_seq_len = max_seq_len
self.sparse_block_size = sparse_block_size
self.kernel_block_size = kernel_block_size or sparse_block_size
self.local_blocks = local_blocks
self.vert_stride = vert_stride
self.homo_head = homo_head
self.active_head_range = active_head_range
# Internal Parameters used by the layer
self._sparse_block_mask = None
self._sparse_layout = None
self._dtype = None
self._device = None
# TODO(bapatra): Ideally, I'd want to keep all the code for
# forward to be handled here, and not branch for training and inference.
# However, that refactor would need a lot of testing. For now, using the
# training op as is, and will refactor again later.
def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
def _initialize_internals(
self,
dtype: torch.dtype,
device: torch.device
) -> None:
self._dtype, self._device = dtype, device
self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
max_seq_len_k=self.max_seq_len,
dtype=dtype,
device=device,
BLOCK=self.sparse_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
return_dense=False,
)
if (not self.homo_head) and (self.active_head_range is not None):
assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
h_start, h_end = self.active_head_range
self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
if self.sparse_block_size // self.kernel_block_size > 1:
_mul = self.sparse_block_size // self.kernel_block_size
# need to consider if block_m and block_n are different
self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
num_sparse_blocks = self._sparse_block_mask.size(-1)
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sm_scale: float,
*,
# Arguments Related to Block Attention Inference
left_paddings: Optional[torch.LongTensor] = None,
seqlens: Optional[torch.LongTensor] = None,
# Arguements Related to Variable Length Inference
cu_seqlens_k: Optional[torch.LongTensor] = None,
cu_seqlens_q: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
blocksparse_op = get_local_strided_sparse_attention_op(
n_heads=self.n_heads,
max_seq_len=self.max_seq_len,
sparse_block_size=self.sparse_block_size,
kernel_block_size=self.kernel_block_size,
local_blocks=self.local_blocks,
vert_stride=self.vert_stride,
homo_head=self.homo_head,
device=q.device,
inference=not self.training
)
return blocksparse_op(q, k, v, sm_scale)
assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
# First set internals if they have not been set
if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
self._initialize_internals(dtype=q.dtype, device=q.device)
if k.dim() == 3:
assert cu_seqlens_k is not None
return blocksparse_flash_attn_varlen_fwd(
q=q,
k=k,
v=v,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_q=cu_seqlens_q,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
if k.dim() == 4:
assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
return blocksparse_flash_attn_padded_fwd(
q=q,
k=k,
v=v,
sm_scale=sm_scale,
sparse_layout=self._sparse_layout,
left_paddings=left_paddings,
seqlens=seqlens,
block_size=self.kernel_block_size,
max_seqlen=self.max_seq_len,
)
raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')