|
|
|
""" |
|
Original Author: Eric Lin (xihlin) (https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/main/triton_flash_blocksparse_attn.py) |
|
""" |
|
""" |
|
Modified by Yizhao Gao |
|
Use binary block mask for simplicity. Need to be updated to varlen version for batched inference. |
|
""" |
|
|
|
|
|
from typing import TypeVar |
|
from functools import lru_cache |
|
import math |
|
import torch |
|
import numpy as np |
|
|
|
import triton |
|
import triton.language as tl |
|
import torch.nn.functional as F |
|
import os |
|
|
|
import dataclasses |
|
|
|
|
|
|
|
def is_hip(): |
|
return triton.runtime.driver.active.get_current_target().backend == "hip" |
|
|
|
|
|
def get_sparse_attn_mask_from_topk(x, topk, use_dense_for_last_block=False): |
|
bsz, num_head, downsample_len, _ = x.shape |
|
|
|
sparse_index = torch.topk(x, topk, dim=-1).indices |
|
dense_mask = torch.full([bsz, num_head, downsample_len, downsample_len], False, dtype=torch.bool, device=x.device) |
|
dense_mask.scatter_(-1, sparse_index, True) |
|
if use_dense_for_last_block: |
|
dense_mask[:, :,-2:,:] = True |
|
dense_mask.tril_() |
|
return dense_mask |
|
|
|
|
|
def get_sparse_attn_mask_from_threshold(x, threshold, use_dense_for_last_block=False): |
|
dense_mask = x > threshold |
|
if use_dense_for_last_block: |
|
dense_mask[:, :,-2:,:] = True |
|
dense_mask.tril_() |
|
return dense_mask |
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, |
|
k_block_col_idx, |
|
block_mask_ptr, |
|
k_ptrs, v_ptrs, |
|
offs_m, offs_n, |
|
stride_kt, stride_vt, stride_bmask_n, |
|
sm_scale, |
|
seqlen_k, |
|
past_len, |
|
LAST_K_BLOCK: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
|
|
mask_val = tl.load(block_mask_ptr + k_block_col_idx * stride_bmask_n) |
|
if mask_val == True: |
|
start_n = k_block_col_idx * BLOCK_N |
|
|
|
|
|
if LAST_K_BLOCK: |
|
k = tl.load(k_ptrs + start_n * stride_kt, |
|
mask=offs_n[None, :] + start_n < seqlen_k) |
|
|
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kt) |
|
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, k) |
|
|
|
qk *= sm_scale |
|
|
|
|
|
if LAST_K_BLOCK : |
|
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) |
|
|
|
|
|
m_ij = tl.maximum(m_i, tl.max(qk, 1)) |
|
qk -= m_ij[:, None] |
|
p = tl.exp(qk) |
|
l_ij = tl.sum(p, 1) |
|
alpha = tl.exp(m_i - m_ij) |
|
l_i = l_i * alpha + l_ij |
|
acc = acc * alpha[:, None] |
|
|
|
|
|
if LAST_K_BLOCK: |
|
v = tl.load(v_ptrs + start_n * stride_vt, |
|
mask=offs_n[:, None] + start_n < seqlen_k) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vt) |
|
|
|
p = p.to(v.type.element_ty) |
|
|
|
acc += tl.dot(p, v) |
|
|
|
m_i = m_ij |
|
return acc, l_i, m_i |
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def _fwd_kernel( |
|
Q, K, V, sm_scale, |
|
block_mask_ptr, |
|
Out, |
|
stride_qz, stride_qh, stride_qm, stride_qd, |
|
stride_kz, stride_kh, stride_kn, stride_kd, |
|
stride_vz, stride_vh, stride_vn, stride_vd, |
|
stride_bmz, stride_bmh, stride_bmm, stride_bmn, |
|
stride_oz, stride_oh, stride_om, stride_od, |
|
H, N_CTX, |
|
PAST_LEN, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
BLOCK_DMODEL: tl.constexpr, |
|
): |
|
Q_LEN = N_CTX - PAST_LEN |
|
start_m = tl.program_id(0) |
|
off_hz = tl.program_id(1) |
|
off_h = off_hz % H |
|
off_z = off_hz // H |
|
Q += off_z * stride_qz + off_h * stride_qh |
|
K += off_z * stride_kz + off_h * stride_kh |
|
V += off_z * stride_vz + off_h * stride_vh |
|
block_mask_ptr += off_z * stride_bmz + off_h * stride_bmh |
|
|
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_DMODEL) |
|
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd |
|
|
|
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd |
|
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd |
|
|
|
q_ptrs = Q + off_q |
|
k_ptrs = K + off_k |
|
v_ptrs = V + off_v |
|
mask_ptrs = block_mask_ptr + start_m * stride_bmm |
|
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') |
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
|
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) |
|
|
|
k_block_start = 0 |
|
k_block_end = tl.cdiv((start_m + 1) * BLOCK_M, BLOCK_N) |
|
|
|
|
|
for col_idx in range(k_block_start, k_block_end-1): |
|
acc, l_i, m_i = _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, |
|
col_idx, |
|
mask_ptrs, |
|
k_ptrs, v_ptrs, |
|
offs_m, offs_n, |
|
stride_kn, stride_vn, stride_bmn, |
|
sm_scale, |
|
N_CTX, |
|
PAST_LEN, |
|
False, |
|
BLOCK_M, |
|
BLOCK_N, |
|
) |
|
|
|
|
|
acc, l_i, m_i = _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, |
|
k_block_end-1, |
|
mask_ptrs, |
|
k_ptrs, v_ptrs, |
|
offs_m, offs_n, |
|
stride_kn, stride_vn, stride_bmn, |
|
sm_scale, |
|
N_CTX, |
|
PAST_LEN, |
|
True, |
|
BLOCK_M, |
|
BLOCK_N, |
|
) |
|
|
|
m_i += tl.math.log(l_i) |
|
l_recip = 1 / l_i[:, None] |
|
acc = acc * l_recip |
|
acc = acc.to(Out.dtype.element_ty) |
|
|
|
|
|
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od |
|
out_ptrs = Out + off_o |
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < N_CTX) |
|
|
|
def _forward( |
|
ctx, |
|
q, |
|
k, |
|
v, |
|
block_sparse_mask, |
|
sm_scale, |
|
BLOCK_M=64, |
|
BLOCK_N=64, |
|
num_warps=None, |
|
num_stages=1, |
|
out=None |
|
): |
|
|
|
|
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1] |
|
assert k.shape[2] == v.shape[2] |
|
o = out if out is not None else torch.empty_like(q).contiguous() |
|
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) |
|
|
|
assert q.shape[-1] in [64, 128] |
|
BLOCK_DMODEL = q.shape[-1] |
|
|
|
if is_hip(): |
|
num_warps, num_stages = 8, 1 |
|
else: |
|
num_warps, num_stages = 4, 2 |
|
|
|
N_CTX = k.shape[2] |
|
PAST_LEN = N_CTX - q.shape[2] |
|
|
|
|
|
H = q.shape[1] |
|
|
|
_fwd_kernel[grid]( |
|
q, k, v, sm_scale, |
|
block_sparse_mask, |
|
o, |
|
*q.stride(), |
|
*k.stride(), |
|
*v.stride(), |
|
*block_sparse_mask.stride(), |
|
*o.stride(), |
|
H, N_CTX, |
|
PAST_LEN, |
|
BLOCK_M, |
|
BLOCK_N, |
|
BLOCK_DMODEL, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
|
|
return o |
|
|
|
|
|
|
|
|
|
class _sparse_attention(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): |
|
|
|
return _forward(ctx, q, k, v, block_sparse_dense, sm_scale) |
|
|
|
@staticmethod |
|
def backward(ctx, do): |
|
|
|
raise NotImplementedError("It does not support gradient propagation yet") |
|
return None, None, None, None, None |
|
|
|
def sparse_attention_factory(BLOCK_M=64, BLOCK_N=64, **kwargs): |
|
class _sparse_attention_config(_sparse_attention): |
|
@staticmethod |
|
def forward(ctx, q, k, v, block_sparse_dense, sm_scale): |
|
|
|
return _forward(ctx, q, k, v, block_sparse_dense, sm_scale, BLOCK_M, BLOCK_N, |
|
**kwargs |
|
) |
|
return _sparse_attention_config.apply |
|
|
|
block_sparse_triton_fn = _sparse_attention.apply |
|
|
|
|
|
|
|
def test_topk_sparse_attention(): |
|
|
|
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 2, 4, 256, 64 |
|
TOPK = 2 |
|
BLOCK = 64 |
|
torch.manual_seed(0) |
|
|
|
|
|
q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
sm_scale = 1.0 / (D_HEAD ** 0.5) |
|
|
|
|
|
downsample_factor = BLOCK |
|
downsample_len = math.ceil(SEQ_LEN / downsample_factor) |
|
x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], device='cuda', dtype=torch.bfloat16) |
|
x_ds[:,:,:,0] = 100 |
|
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) |
|
|
|
|
|
triton_output = block_sparse_triton_fn( |
|
q, k, v, |
|
block_mask, |
|
sm_scale |
|
) |
|
|
|
|
|
|
|
full_mask = torch.kron(block_mask.float(), |
|
torch.ones(BLOCK, BLOCK, device='cuda')) |
|
full_mask = full_mask[..., :SEQ_LEN, :SEQ_LEN].bool() |
|
full_mask = full_mask & torch.tril(torch.ones_like(full_mask)) |
|
|
|
|
|
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale |
|
attn = attn.masked_fill(~full_mask, float('-inf')) |
|
attn = F.softmax(attn, dim=-1) |
|
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) |
|
|
|
|
|
|
|
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ |
|
"Triton output doesn't match reference" |
|
print("Pass topk sparse attention test with qlen == klen") |
|
|
|
|
|
|
|
def test_topk_sparse_attention_qlt_kl(): |
|
BATCH, N_HEADS = 2, 4 |
|
Q_LEN, K_LEN, D_HEAD = 128, 256, 64 |
|
TOPK = 1 |
|
BLOCK = 64 |
|
torch.manual_seed(0) |
|
|
|
|
|
q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.bfloat16) |
|
sm_scale = 1.0 / (D_HEAD ** 0.5) |
|
|
|
downsample_factor = BLOCK |
|
downsample_len = math.ceil(K_LEN / downsample_factor) |
|
x_ds = torch.randn(BATCH, N_HEADS, downsample_len, downsample_len, |
|
device='cuda', dtype=torch.bfloat16) |
|
|
|
x_ds[:, :, :, 0] = 100 |
|
block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) |
|
|
|
|
|
triton_output = block_sparse_triton_fn(q, k, v, block_mask, sm_scale) |
|
|
|
past_len = K_LEN - Q_LEN |
|
|
|
attn = torch.einsum('bhsd,bhtd->bhst', q, k) * sm_scale |
|
|
|
full_mask_full = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device='cuda')).bool() |
|
full_mask_full = full_mask_full[..., :K_LEN, :K_LEN] |
|
|
|
effective_mask = full_mask_full[..., past_len:K_LEN, :] |
|
|
|
|
|
i_global = torch.arange(past_len, K_LEN, device=k.device).unsqueeze(1) |
|
j_global = torch.arange(K_LEN, device=k.device).unsqueeze(0) |
|
causal_mask = (j_global <= i_global) |
|
|
|
final_mask = effective_mask & causal_mask |
|
|
|
attn = attn.masked_fill(~final_mask, float('-inf')) |
|
attn = F.softmax(attn, dim=-1) |
|
ref_output = torch.einsum('bhst,bhtd->bhsd', attn, v) |
|
|
|
|
|
assert torch.allclose(triton_output, ref_output, atol=1e-2, rtol=1e-2), \ |
|
"Triton output doesn't match reference when qlen < klen" |
|
|
|
print("Pass topk sparse attention test with qlen < klen") |
|
|
|
|
|
if __name__ == "__main__": |
|
test_topk_sparse_attention() |
|
test_topk_sparse_attention_qlt_kl() |