|
""" |
|
Author: Eric Lin (xihlin) |
|
""" |
|
""" |
|
... note(bapatra):: |
|
This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module |
|
imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal. |
|
In the future, would be really good to revisit this and refactor into a more readable file structure. |
|
|
|
""" |
|
from typing import TypeVar |
|
from functools import lru_cache |
|
import math |
|
import pytest |
|
import torch |
|
import numpy as np |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
import os |
|
|
|
import dataclasses |
|
|
|
Phi3SmallConfig = TypeVar('Phi3SmallConfig') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
class BlockSparseParams(object): |
|
block_size: int |
|
kernel_block_size: int |
|
num_local_blocks: int |
|
vert_stride: int |
|
homo_head_pattern: bool = False |
|
|
|
@classmethod |
|
def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams": |
|
return cls( |
|
block_size=config.blocksparse_block_size, |
|
kernel_block_size=config.blocksparse_triton_kernel_block_size, |
|
num_local_blocks=config.blocksparse_num_local_blocks, |
|
vert_stride=config.blocksparse_vert_stride, |
|
homo_head_pattern=config.blocksparse_homo_head_pattern, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dense_to_crow_col(x): |
|
''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. |
|
param: |
|
TODO: |
|
1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it? |
|
NOTE: col_indices padded -1 |
|
''' |
|
pad = -1 |
|
dim = x.dim() |
|
assert x.dim() in (2, 3) |
|
if x.dim() == 2: |
|
x = x[None] |
|
x = [xi.to_sparse_csr() for xi in x] |
|
crows = torch.vstack([xi.crow_indices() for xi in x]) |
|
cols = [xi.col_indices() for xi in x] |
|
max_cols = max(len(xi) for xi in cols) |
|
cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols] |
|
cols = torch.vstack(cols) |
|
if dim == 2: |
|
crows = crows[0] |
|
cols = cols[0] |
|
return crows, cols |
|
|
|
|
|
def crow_col_to_dense(crows, cols, dtype=torch.float16): |
|
dim = crows.dim() |
|
if dim == 1: |
|
crows = crows[None] |
|
cols = cols[None] |
|
device = crows.device |
|
crows, cols = crows.cpu(), cols.cpu() |
|
shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) |
|
x = torch.zeros(shape, dtype=dtype) |
|
for i in range(shape[0]): |
|
for j in range(shape[1]): |
|
x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1 |
|
if dim == 1: |
|
x = x[0] |
|
return x.to(device) |
|
|
|
|
|
def dense_to_ccol_row(x): |
|
'''Similar, but to CSC format |
|
''' |
|
x = x.transpose(-2, -1) |
|
return dense_to_crow_col(x) |
|
|
|
|
|
def ccol_row_to_dense(ccol, rows, dtype=torch.float16): |
|
return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() |
|
|
|
|
|
def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False): |
|
''' |
|
:return: a tuple of 3: |
|
- tuple of crow_indices, col_indices representation of CSR format. |
|
- block dense mask |
|
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None |
|
''' |
|
with torch.no_grad(): |
|
N_BLOCK = triton.cdiv(N_CTX, BLOCK) |
|
q_pos = torch.arange(N_BLOCK)[:, None] |
|
k_pos = torch.arange(N_BLOCK)[None] |
|
mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0 |
|
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) |
|
N_BLOCK_Q = triton.cdiv(q_len, BLOCK) |
|
block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr() |
|
if return_dense: |
|
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) |
|
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] |
|
mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask |
|
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense |
|
else: |
|
return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None |
|
|
|
|
|
def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False): |
|
''' |
|
:return: a tuple of 3: |
|
- tuple of crow_indices, col_indices representation of CSR format. |
|
- block dense mask |
|
- all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None |
|
''' |
|
if homo_head: |
|
with torch.no_grad(): |
|
(crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense) |
|
crow = crow[None].expand(n_heads, crow.shape[0]) |
|
col = col[None].expand(n_heads, col.shape[0]) |
|
if return_dense: |
|
mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape) |
|
return (crow, col), block_mask_dense, mask_dense |
|
|
|
with torch.no_grad(): |
|
N_BLOCK = triton.cdiv(N_CTX, BLOCK) |
|
q_pos = torch.arange(N_BLOCK)[None, :, None] |
|
k_pos = torch.arange(N_BLOCK)[None, None] |
|
head_sliding_step = max(1, int(vert_stride / n_heads)) |
|
mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)] |
|
mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) |
|
block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype) |
|
N_BLOCK_Q = triton.cdiv(q_len, BLOCK) |
|
block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:] |
|
if return_dense: |
|
mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK))) |
|
causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:] |
|
mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None] |
|
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense |
|
else: |
|
return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None |
|
|
|
|
|
def get_sparse_attn_mask(q, N_CTX, *args, **kwargs): |
|
return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def _fwd_kernel( |
|
Q, K, V, sm_scale, |
|
layout_crow_ptr, |
|
layout_col_ptr, |
|
layout_crow_stride_h, layout_crow_stride_m, |
|
layout_col_stride_h, layout_col_stride_m, |
|
TMP, L, M, |
|
Out, |
|
stride_qz, stride_qh, stride_qm, stride_qd, |
|
stride_kz, stride_kh, stride_kn, stride_kd, |
|
stride_vz, stride_vh, stride_vn, stride_vd, |
|
stride_oz, stride_oh, stride_om, stride_od, |
|
Z, H, N_CTX, |
|
PAST_LEN, |
|
Q_ROUNDED_LEN, |
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
EVEN_M_BLOCK: tl.constexpr, |
|
EVEN_N_BLOCK: tl.constexpr, |
|
INFERENCE: tl.constexpr, |
|
NUM_DBLOCKS: tl.constexpr, |
|
): |
|
Q_LEN = N_CTX - PAST_LEN |
|
start_m = tl.program_id(0) |
|
off_hz = tl.program_id(1) |
|
off_h = off_hz % H |
|
off_z = off_hz // H |
|
Q += off_z * stride_qz + off_h * stride_qh |
|
K += off_z * stride_kz + off_h * stride_kh |
|
V += off_z * stride_vz + off_h * stride_vh |
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_DMODEL) |
|
off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd |
|
|
|
off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd |
|
off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd |
|
|
|
q_ptrs = Q + off_q |
|
k_ptrs = K + off_k |
|
v_ptrs = V + off_v |
|
|
|
t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m |
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') |
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
if NUM_DBLOCKS >= 2: |
|
acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
|
|
|
if EVEN_M_BLOCK: |
|
q = tl.load(q_ptrs) |
|
if NUM_DBLOCKS >= 2: |
|
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) |
|
else: |
|
q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN) |
|
if NUM_DBLOCKS >= 2: |
|
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN) |
|
|
|
layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m |
|
start_l = tl.load(layout_ptr).to(tl.int32) |
|
end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32) |
|
|
|
|
|
for col_idx_idx in range(start_l, end_l): |
|
col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32) |
|
start_n = col_idx * BLOCK_N |
|
|
|
if EVEN_N_BLOCK: |
|
k = tl.load(k_ptrs + start_n * stride_kn) |
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, k) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
if EVEN_N_BLOCK: |
|
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd) |
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX) |
|
qk += tl.dot(q2, k) |
|
|
|
qk *= sm_scale |
|
qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf')) |
|
|
|
m_ij = tl.max(qk, 1) |
|
p = tl.exp(qk - m_ij[:, None]) |
|
l_ij = tl.sum(p, 1) |
|
|
|
m_i_new = tl.maximum(m_i, m_ij) |
|
alpha = tl.exp(m_i - m_i_new) |
|
beta = tl.exp(m_ij - m_i_new) |
|
l_i_new = alpha * l_i + beta * l_ij |
|
|
|
|
|
p_scale = beta / l_i_new |
|
p = p * p_scale[:, None] |
|
|
|
acc_scale = l_i / l_i_new * alpha |
|
|
|
|
|
acc = acc * acc_scale[:, None] |
|
if NUM_DBLOCKS >= 2: |
|
acc2 = acc2 * acc_scale[:, None] |
|
p = p.to(Q.dtype.element_ty) |
|
|
|
if EVEN_N_BLOCK: |
|
v = tl.load(v_ptrs + start_n * stride_vn) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX) |
|
acc += tl.dot(p, v) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
if EVEN_N_BLOCK: |
|
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX) |
|
acc2 += tl.dot(p, v) |
|
|
|
|
|
l_i = l_i_new |
|
m_i = m_i_new |
|
|
|
|
|
|
|
|
|
|
|
if not INFERENCE: |
|
l_ptrs = L + off_hz * N_CTX + offs_m |
|
m_ptrs = M + off_hz * N_CTX + offs_m |
|
if EVEN_M_BLOCK: |
|
tl.store(l_ptrs, l_i) |
|
tl.store(m_ptrs, m_i) |
|
else: |
|
tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN) |
|
tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN) |
|
|
|
|
|
off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od |
|
out_ptrs = Out + off_o |
|
tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN) |
|
if NUM_DBLOCKS >= 2: |
|
tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN) |
|
|
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, |
|
} |
|
) |
|
@triton.jit |
|
def _bwd_preprocess( |
|
Out, DO, L, |
|
NewDO, Delta, |
|
N_CTX, |
|
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, |
|
EVEN_M_BLOCK: tl.constexpr, |
|
): |
|
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) |
|
off_d = tl.arange(0, D_HEAD) |
|
|
|
if EVEN_M_BLOCK: |
|
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) |
|
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32) |
|
else: |
|
o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) |
|
do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32) |
|
denom = tl.load(L + off_m).to(tl.float32) |
|
|
|
do = do / denom[:, None] |
|
delta = tl.sum(o * do, axis=1) |
|
|
|
if EVEN_M_BLOCK: |
|
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do) |
|
else: |
|
tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX) |
|
tl.store(Delta + off_m, delta) |
|
|
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0, |
|
'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0, |
|
} |
|
) |
|
@triton.jit |
|
def _bwd_kernel( |
|
Q, K, V, sm_scale, |
|
layout_ccol_ptr, |
|
layout_row_ptr, |
|
layout_ccol_stride_h, layout_ccol_stride_m, |
|
layout_row_stride_h, layout_row_stride_m, |
|
Out, DO, |
|
DQ, DK, DV, |
|
L, M, |
|
D, |
|
stride_qz, stride_qh, stride_qm, stride_qd, |
|
stride_kz, stride_kh, stride_kn, stride_kd, |
|
stride_vz, stride_vh, stride_vn, stride_vd, |
|
stride_oz, stride_oh, stride_om, stride_od, |
|
|
|
Z, H, N_CTX, |
|
num_block, |
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
EVEN_M_BLOCK: tl.constexpr, |
|
EVEN_N_BLOCK: tl.constexpr, |
|
NUM_DBLOCKS: tl.constexpr, |
|
): |
|
start_n = tl.program_id(0) |
|
off_hz = tl.program_id(1) |
|
off_z = off_hz // H |
|
off_h = off_hz % H |
|
|
|
Q += off_z * stride_qz + off_h * stride_qh |
|
K += off_z * stride_kz + off_h * stride_kh |
|
V += off_z * stride_vz + off_h * stride_vh |
|
DO += off_z * stride_oz + off_h * stride_oh |
|
DQ += off_z * stride_oz + off_h * stride_oh |
|
DK += off_z * stride_oz + off_h * stride_oh |
|
DV += off_z * stride_oz + off_h * stride_oh |
|
|
|
|
|
|
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
offs_m = tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_DMODEL) |
|
|
|
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd) |
|
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd) |
|
|
|
|
|
D_ptrs = D + off_hz * N_CTX |
|
m_ptrs = M + off_hz * N_CTX |
|
|
|
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
|
|
if EVEN_N_BLOCK: |
|
k = tl.load(k_ptrs) |
|
v = tl.load(v_ptrs) |
|
else: |
|
k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX) |
|
v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) |
|
if EVEN_N_BLOCK: |
|
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd) |
|
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd) |
|
else: |
|
k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX) |
|
v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX) |
|
|
|
|
|
|
|
layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m |
|
start_l = tl.load(layout_ptr).to(tl.int32) |
|
end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32) |
|
|
|
for row_idx_idx in range(start_l, end_l): |
|
row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32) |
|
start_m = row_idx * BLOCK_M |
|
|
|
|
|
offs_m_curr = start_m + offs_m |
|
q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd) |
|
do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) |
|
dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od) |
|
|
|
|
|
if EVEN_M_BLOCK: |
|
q = tl.load(q_ptrs) |
|
else: |
|
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX) |
|
|
|
|
|
qk = tl.dot(q, tl.trans(k)) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
if EVEN_M_BLOCK: |
|
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd) |
|
else: |
|
q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX) |
|
qk += tl.dot(q2, tl.trans(k2)) |
|
|
|
qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf')) |
|
|
|
if EVEN_M_BLOCK: |
|
m = tl.load(m_ptrs + offs_m_curr) |
|
else: |
|
m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) |
|
p = tl.exp(qk * sm_scale - m[:, None]) |
|
|
|
|
|
if EVEN_M_BLOCK: |
|
do = tl.load(do_ptrs) |
|
else: |
|
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
if EVEN_M_BLOCK: |
|
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od) |
|
else: |
|
do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX) |
|
|
|
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2) |
|
|
|
|
|
if EVEN_M_BLOCK: |
|
Di = tl.load(D_ptrs + offs_m_curr) |
|
else: |
|
Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX) |
|
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] |
|
dp += tl.dot(do, tl.trans(v)) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
dp += tl.dot(do2, tl.trans(v2)) |
|
|
|
|
|
ds = p * dp * sm_scale |
|
|
|
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) |
|
if NUM_DBLOCKS >= 2: |
|
dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2) |
|
|
|
|
|
dq = tl.dot(ds.to(Q.dtype.element_ty), k) |
|
if EVEN_M_BLOCK: |
|
tl.atomic_add(dq_ptrs, dq) |
|
else: |
|
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2) |
|
dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od |
|
if EVEN_M_BLOCK: |
|
tl.atomic_add(dq_ptrs2, dq2) |
|
else: |
|
tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX) |
|
|
|
|
|
dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) |
|
dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od) |
|
if EVEN_N_BLOCK: |
|
tl.store(dv_ptrs, dv) |
|
tl.store(dk_ptrs, dk) |
|
else: |
|
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX) |
|
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX) |
|
|
|
if NUM_DBLOCKS >= 2: |
|
dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od |
|
dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od |
|
if EVEN_N_BLOCK: |
|
tl.store(dv_ptrs2, dv2) |
|
tl.store(dk_ptrs2, dk2) |
|
else: |
|
tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX) |
|
tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX) |
|
|
|
|
|
|
|
def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None): |
|
''' |
|
:param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v. |
|
:param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor. |
|
Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all.. |
|
''' |
|
assert q.shape[-1] == k.shape[-1] == v.shape[-1] |
|
assert k.shape[2] == v.shape[2] |
|
o = out if out is not None else torch.empty_like(q).contiguous() |
|
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) |
|
|
|
q_rounded_len = grid[0] * BLOCK_M |
|
tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
|
|
|
if inference is None: |
|
inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad) |
|
|
|
if inference: |
|
L, m = tmp, tmp |
|
else: |
|
L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
|
m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32) |
|
|
|
if layout_col_indices.dim() == 1: |
|
layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1) |
|
layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1) |
|
|
|
assert q.shape[-1] in [64, 128] |
|
BLOCK_DMODEL = 64 |
|
|
|
if num_warps is None: |
|
MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL) |
|
num_warps = max(1, 2 ** int(math.log2(MIN_D / 16))) |
|
|
|
else: |
|
assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.cuda.device(q.device.index): |
|
_fwd_kernel[grid]( |
|
q, k, v, sm_scale, |
|
layout_crow_indices, |
|
layout_col_indices, |
|
layout_crow_indices.stride(0), layout_crow_indices.stride(1), |
|
layout_col_indices.stride(0), layout_col_indices.stride(1), |
|
tmp, L, m, |
|
o, |
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
q.shape[0], q.shape[1], k.shape[2], |
|
k.shape[2] - q.shape[2], |
|
q_rounded_len, |
|
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, |
|
BLOCK_DMODEL=BLOCK_DMODEL, |
|
EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0, |
|
EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 , |
|
INFERENCE=inference, |
|
NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
if inference: |
|
L, m = None, None |
|
|
|
ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices) |
|
ctx.BLOCK_M = BLOCK_M |
|
ctx.BLOCK_N = BLOCK_N |
|
ctx.BLOCK_DMODEL = BLOCK_DMODEL |
|
|
|
ctx.grid = grid |
|
ctx.sm_scale = sm_scale |
|
ctx.num_warps = num_warps |
|
ctx.num_stages = num_stages |
|
return o |
|
|
|
|
|
def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None): |
|
|
|
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors |
|
|
|
|
|
|
|
|
|
|
|
if not do.is_contiguous(): |
|
do = do.contiguous() |
|
|
|
|
|
|
|
|
|
if not o.is_contiguous(): |
|
|
|
raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.') |
|
|
|
|
|
if layout_ccol_indices.dim() == 1: |
|
layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1) |
|
layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1) |
|
|
|
|
|
dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32) |
|
dk = dk if dk is not None else torch.empty_like(k) |
|
dv =dv if dv is not None else torch.empty_like(v) |
|
do_scaled = torch.empty_like(do) |
|
delta = torch.empty_like(l) |
|
|
|
assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride() |
|
|
|
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( |
|
o, do, l, |
|
do_scaled, delta, |
|
k.shape[2], |
|
BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1], |
|
) |
|
|
|
grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1]) |
|
|
|
_bwd_kernel[grid]( |
|
q, k, v, ctx.sm_scale, |
|
layout_ccol_indices, |
|
layout_row_indices, |
|
layout_ccol_indices.stride(0), layout_ccol_indices.stride(1), |
|
layout_row_indices.stride(0), layout_row_indices.stride(1), |
|
o, do_scaled, |
|
dq, dk, dv, |
|
l, m, |
|
delta, |
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
q.shape[0], q.shape[1], q.shape[2], |
|
ctx.grid[0], |
|
BLOCK_M=ctx.BLOCK_M, |
|
BLOCK_N=ctx.BLOCK_N, |
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, |
|
NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL, |
|
num_warps=ctx.num_warps, |
|
num_stages=1, |
|
) |
|
return dq, dk, dv, None, None, None |
|
|
|
|
|
class _sparse_attention(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
|
BLOCK = 128 |
|
|
|
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK) |
|
|
|
@staticmethod |
|
def backward(ctx, do): |
|
|
|
q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors |
|
|
|
|
|
layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices)) |
|
return _backward(ctx, do, layout_ccol_indices, layout_row_indices) |
|
|
|
|
|
|
|
|
|
class _sparse_attention_inference(_sparse_attention): |
|
|
|
@staticmethod |
|
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
|
BLOCK = 128 |
|
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK) |
|
|
|
|
|
|
|
def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs): |
|
class _sparse_attention_config(_sparse_attention): |
|
@staticmethod |
|
def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale): |
|
|
|
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, |
|
**kwargs |
|
) |
|
return _sparse_attention_config.apply |
|
|
|
|
|
@lru_cache(maxsize=8) |
|
def get_local_strided_sparse_attention_op( |
|
n_heads: int, |
|
max_seq_len:int, |
|
sparse_block_size: int=128, |
|
local_blocks: int=4, |
|
vert_stride: int=4, |
|
homo_head: bool=False, |
|
dtype=torch.bfloat16, |
|
device='cuda', |
|
active_head_range=None, |
|
verbose=True, |
|
**kwargs): |
|
''' |
|
:param n_heads: total number of attention heads (regardless of tensor/model parallel) |
|
:param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences. |
|
:param sparse_block_size: sparse block size. Default to 128 |
|
:param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens. |
|
:param vert_stride: Default to 4. Meaning |
|
:param homo_head: if all head shared the same pattern. |
|
:param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads. |
|
Mainly for tensor/model parallelization where heads are splitted to different GPUs. |
|
''' |
|
|
|
if verbose: |
|
print((f'> new block_sparse_attn op constructed with config: ' |
|
f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, ' |
|
f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}')) |
|
|
|
_, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device, |
|
BLOCK=sparse_block_size, local_blocks=local_blocks, |
|
vert_stride=vert_stride, homo_head=homo_head, |
|
return_dense=False) |
|
if (not homo_head) and (active_head_range is not None): |
|
assert isinstance(active_head_range, tuple) |
|
assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.' |
|
h_start, h_end = active_head_range |
|
block_sparse_pattern = block_sparse_pattern[h_start:h_end] |
|
|
|
return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs) |
|
|
|
|
|
def get_sparse_attn_op( |
|
sparse_pattern: torch.tensor, |
|
sparse_block_size: int=128, |
|
kernel_block_size=128, |
|
qkv_format='q,k,v', |
|
**kwargs): |
|
''' |
|
Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime, |
|
which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.) |
|
|
|
:param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`. |
|
This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention |
|
:param sparse_block_size: sparse block size. Default to 128 |
|
:param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size` |
|
:param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported. |
|
|
|
:param kwargs: keyward arguments passed to `_forward` |
|
''' |
|
|
|
|
|
assert qkv_format == 'q,k,v' |
|
|
|
if kernel_block_size is None: |
|
kernel_block_size = sparse_block_size |
|
else: |
|
assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}." |
|
assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given" |
|
|
|
|
|
|
|
|
|
if sparse_block_size // kernel_block_size > 1: |
|
_mul = sparse_block_size // kernel_block_size |
|
|
|
sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul)) |
|
num_sparse_blocks = sparse_pattern.size(-1) |
|
block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] |
|
sparse_pattern *= block_causal_mask.type_as(sparse_pattern) |
|
|
|
|
|
|
|
BLOCK_N = kernel_block_size |
|
NUM_BLOCK = sparse_pattern.size(-1) |
|
MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK |
|
|
|
grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern) |
|
|
|
grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_cache_size = 1 if kwargs.get('inference', False) else 8 |
|
|
|
@lru_cache(maxsize=max_cache_size) |
|
def get_backward_layout_by_block_len(block_len): |
|
assert block_len <= NUM_BLOCK |
|
if block_len == NUM_BLOCK: |
|
return (grand_layout_ccol_indices, grand_layout_row_indices) |
|
return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _q_k_v_sparse_attention(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, q, k, v, sm_scale): |
|
|
|
|
|
MIN_BLOCK_SIZE = 16 |
|
assert BLOCK_N >= MIN_BLOCK_SIZE |
|
BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N |
|
|
|
|
|
K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size) |
|
|
|
Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N) |
|
|
|
|
|
layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1] |
|
layout_col_indices = grand_layout_col_indices |
|
|
|
|
|
return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, |
|
**kwargs |
|
) |
|
@staticmethod |
|
def backward(ctx, do): |
|
q, k = ctx.saved_tensors[:2] |
|
assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.' |
|
|
|
block_len = triton.cdiv(do.shape[2], kernel_block_size) |
|
backward_layout = get_backward_layout_by_block_len(block_len) |
|
return _backward(ctx, do, *backward_layout)[:4] |
|
|
|
|
|
def _q_k_v_sparse_attention_fn(*args): |
|
return _q_k_v_sparse_attention.apply(*args) |
|
|
|
_q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern |
|
_q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices |
|
_q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices |
|
_q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices |
|
_q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices |
|
|
|
return _q_k_v_sparse_attention_fn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def blocksparse_flash_attn_padded_fwd( |
|
q, k, v, |
|
sm_scale, |
|
sparse_layout, |
|
*, |
|
left_paddings = None, |
|
seqlens = None, |
|
block_size = 64, |
|
max_seqlen = None |
|
): |
|
''' |
|
q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size) |
|
left_paddings: (batch, ), number of left paddings for each sample. |
|
seqlens: can be used to specify right padding. No need to specify if left_paddings is used. |
|
''' |
|
batches, q_len, n_heads, head_size = q.shape |
|
_, k_len, n_kv_heads, _ = k.shape |
|
|
|
|
|
assert q.dim() == k.dim() == v.dim() == 4 |
|
assert q.size(2) % k.size(2) == 0 |
|
assert q.size(0) == k.size(0) and q.size(3) == k.size(3) |
|
assert k.shape == v.shape |
|
assert q_len == 1 or q_len == k_len, \ |
|
f'q length can only 1 for decoding for same as k length for prefilling.' |
|
|
|
q_k_ratio = q.size(2) // k.size(2) |
|
|
|
if max_seqlen: |
|
assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.' |
|
|
|
|
|
out = q.new_zeros(q.shape) |
|
|
|
layout_crow_indices, layout_col_indices = sparse_layout |
|
block_d = triton.next_power_of_2(head_size) |
|
|
|
if left_paddings is not None: |
|
assert left_paddings.shape == (batches,) |
|
k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous() |
|
else: |
|
k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device) |
|
|
|
if seqlens is not None: |
|
k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts) |
|
assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.' |
|
else: |
|
k_batch_ends = torch.zeros_like(k_batch_starts) + k_len |
|
|
|
if q_len == 1: |
|
q_batch_starts = torch.zeros_like(k_batch_starts) |
|
q_batch_ends = q_batch_starts + 1 |
|
else: |
|
q_batch_starts = k_batch_starts |
|
q_batch_ends = k_batch_ends |
|
|
|
|
|
q_lens = (q_batch_ends - q_batch_starts).cpu() |
|
n_blocks = (q_lens + block_size - 1) // block_size |
|
|
|
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], |
|
dtype=q_batch_starts.dtype, |
|
device=q_batch_starts.device) |
|
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], |
|
dtype=q_batch_starts.dtype, |
|
device=q_batch_starts.device) |
|
|
|
grid = (len(q_start_sids), n_heads) |
|
|
|
with torch.cuda.device(q.device.index): |
|
_fwd_kernel_batch_inference[grid]( |
|
q, k, v, out, |
|
sm_scale, |
|
q_batch_starts, |
|
q_batch_ends, |
|
k_batch_starts, |
|
k_batch_ends, |
|
q_batch_ids, |
|
q_start_sids, |
|
|
|
*q.stride(), |
|
*k.stride(), |
|
*v.stride(), |
|
*out.stride(), |
|
|
|
layout_crow_indices, |
|
layout_col_indices, |
|
*layout_crow_indices.stride(), |
|
*layout_col_indices.stride(), |
|
|
|
q_k_ratio, |
|
HAS_BATCH_DIM = True, |
|
D_HEAD = head_size, |
|
BLOCK_M = block_size, |
|
BLOCK_N = block_size, |
|
BLOCK_D = block_d, |
|
BLOCK_M_LOADING = 16 if q_len == 1 else block_size, |
|
EVEN_D = block_d == head_size, |
|
num_warps = 1 if q_len == 1 else 4, |
|
num_stages = 1 |
|
) |
|
|
|
|
|
return out |
|
|
|
|
|
def blocksparse_flash_attn_varlen_fwd( |
|
q, k, v, |
|
cu_seqlens_k, |
|
cu_seqlens_q, |
|
sm_scale, |
|
sparse_layout, |
|
*, |
|
block_size=64, |
|
max_seqlen = None |
|
): |
|
|
|
_, n_heads, head_size = q.shape |
|
batch_size = cu_seqlens_k.size(0) - 1 |
|
|
|
|
|
|
|
assert q.dim() == k.dim() == v.dim() == 3 |
|
assert q.size(1) % k.size(1) == 0 |
|
assert q.size(2) == k.size(2) |
|
assert k.shape == v.shape |
|
assert cu_seqlens_k.dim() == 1 |
|
|
|
q_k_ratio = q.size(1) // k.size(1) |
|
|
|
if cu_seqlens_q is None: |
|
if q.size(0) == batch_size: |
|
cu_seqlens_q = torch.arange(0, batch_size + 1, |
|
dtype=cu_seqlens_k.dtype, |
|
device=cu_seqlens_k.device) |
|
elif q.size(0) == k.size(0): |
|
cu_seqlens_q = cu_seqlens_k |
|
else: |
|
raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.') |
|
else: |
|
assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) |
|
|
|
|
|
q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() |
|
k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() |
|
|
|
assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \ |
|
'length of q should either be 1 (decoding) or same as k (prefilling).' |
|
|
|
if max_seqlen: |
|
assert k_lens.max() <= max_seqlen |
|
|
|
n_blocks = (q_lens + block_size - 1) // block_size |
|
|
|
q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)], |
|
dtype=cu_seqlens_q.dtype, |
|
device=cu_seqlens_q.device) |
|
q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)], |
|
dtype=cu_seqlens_q.dtype, |
|
device=cu_seqlens_q.device) |
|
|
|
|
|
out = q.new_empty(q.shape) |
|
cu_seqlens_q = cu_seqlens_q.contiguous() |
|
cu_seqlens_k = cu_seqlens_k.contiguous() |
|
|
|
layout_crow_indices, layout_col_indices = sparse_layout |
|
block_d = triton.next_power_of_2(head_size) |
|
|
|
decoding_only = (q_lens == 1).all() |
|
|
|
grid = (len(q_start_sids), n_heads) |
|
|
|
with torch.cuda.device(q.device.index): |
|
_fwd_kernel_batch_inference[grid]( |
|
q, k, v, out, |
|
sm_scale, |
|
cu_seqlens_q[:-1], |
|
cu_seqlens_q[1:], |
|
cu_seqlens_k[:-1], |
|
cu_seqlens_k[1:], |
|
q_batch_ids, |
|
q_start_sids, |
|
|
|
0, *q.stride(), |
|
0, *k.stride(), |
|
0, *v.stride(), |
|
0, *out.stride(), |
|
|
|
layout_crow_indices, |
|
layout_col_indices, |
|
*layout_crow_indices.stride(), |
|
*layout_col_indices.stride(), |
|
|
|
q_k_ratio, |
|
HAS_BATCH_DIM = False, |
|
D_HEAD = head_size, |
|
BLOCK_M = block_size, |
|
BLOCK_N = block_size, |
|
BLOCK_D = block_d, |
|
BLOCK_M_LOADING = 16 if decoding_only else block_size, |
|
EVEN_D = block_d == head_size, |
|
num_warps = 1 if decoding_only else 4, |
|
num_stages = 3 |
|
) |
|
|
|
return out |
|
|
|
|
|
@triton.jit |
|
def _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, Q, |
|
k_block_col_idx, |
|
layout_col_ptr, |
|
layout_col_stride_h, layout_col_stride_m, |
|
k_ptrs, |
|
v_ptrs, |
|
off_h, offs_m, offs_n, offs_d, |
|
stride_kt, stride_vt, |
|
sm_scale, |
|
k_seqlen, |
|
past_len, |
|
LAST_K_BLOCK: tl.constexpr, |
|
BLOCK_M_LOADING: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
D_HEAD: tl.constexpr, |
|
EVEN_D: tl.constexpr, |
|
M_LT_N: tl.constexpr |
|
): |
|
k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32) |
|
start_n = k_block_id * BLOCK_N |
|
|
|
if LAST_K_BLOCK: |
|
if EVEN_D: |
|
k = tl.load(k_ptrs + start_n * stride_kt, |
|
mask=offs_n[None, :] + start_n < k_seqlen) |
|
else: |
|
|
|
k = tl.load(k_ptrs + start_n * stride_kt, |
|
mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD)) |
|
else: |
|
if EVEN_D: |
|
k = tl.load(k_ptrs + start_n * stride_kt) |
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kt, |
|
mask=offs_d[:, None] < D_HEAD) |
|
|
|
|
|
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, k) |
|
|
|
qk *= sm_scale |
|
|
|
|
|
if LAST_K_BLOCK | M_LT_N: |
|
qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf')) |
|
|
|
|
|
m_ij = tl.max(qk, 1) |
|
p = tl.exp(qk - m_ij[:, None]) |
|
|
|
l_ij = tl.sum(p, 1) |
|
|
|
m_i_new = tl.maximum(m_i, m_ij) |
|
alpha = tl.exp(m_i - m_i_new) |
|
beta = tl.exp(m_ij - m_i_new) |
|
l_i_new = alpha * l_i + beta * l_ij |
|
|
|
|
|
p_scale = beta / l_i_new |
|
p = p * p_scale[:, None] |
|
|
|
acc_scale = l_i / l_i_new * alpha |
|
acc = acc * acc_scale[:, None] |
|
|
|
p = p.to(Q.dtype.element_ty) |
|
|
|
if LAST_K_BLOCK: |
|
if EVEN_D: |
|
v = tl.load(v_ptrs + start_n * stride_vt, |
|
mask=offs_n[:, None] + start_n < k_seqlen) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vt, |
|
mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD)) |
|
else: |
|
if EVEN_D: |
|
v = tl.load(v_ptrs + start_n * stride_vt) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vt, |
|
mask=offs_d[None, :] < D_HEAD) |
|
|
|
acc += tl.dot(p, v) |
|
|
|
l_i = l_i_new |
|
m_i = m_i_new |
|
return acc, l_i, m_i |
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'], |
|
} |
|
) |
|
@triton.jit |
|
def _fwd_kernel_batch_inference( |
|
Q, K, V, Out, |
|
|
|
sm_scale, |
|
q_batch_starts, |
|
q_batch_ends, |
|
k_batch_starts, |
|
k_batch_ends, |
|
q_batch_ids, |
|
q_start_sids, |
|
|
|
stride_qb, stride_qt, stride_qh, stride_qd, |
|
stride_kb, stride_kt, stride_kh, stride_kd, |
|
stride_vb, stride_vt, stride_vh, stride_vd, |
|
stride_ob, stride_ot, stride_oh, stride_od, |
|
|
|
layout_crow_ptr, |
|
layout_col_ptr, |
|
layout_crow_stride_h, layout_crow_stride_m, |
|
layout_col_stride_h, layout_col_stride_m, |
|
|
|
q_k_ratio, |
|
|
|
HAS_BATCH_DIM: tl.constexpr, |
|
D_HEAD: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
BLOCK_D: tl.constexpr, |
|
BLOCK_M_LOADING: tl.constexpr, |
|
EVEN_D: tl.constexpr, |
|
M_LT_N: tl.constexpr |
|
): |
|
''' |
|
NOTATION: |
|
pid: position id |
|
sid: storage id |
|
sbid: storage block id |
|
pbid: position block id |
|
offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) |
|
|
|
q and blocks in KV needs to be contiguous |
|
|
|
Arguments: |
|
kv_seq_lens: for compute past_len |
|
kv_storage_offsets: similar to block_tables in vllm, except it is dynamic. |
|
TODO: fix this |
|
|
|
TODO: |
|
Optimize grouped-attn |
|
|
|
CUDA graph support issue |
|
1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...) |
|
since we mix prompt and decoing phase here, it can be more complex. |
|
need to set up diff cuda-graph for diff (off_zm, off_z) |
|
|
|
# indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding |
|
therefore, cu_seqlens_q, kv_seq_lens |
|
|
|
''' |
|
off_zm = tl.program_id(0) |
|
off_h = tl.program_id(1) |
|
|
|
off_h_for_kv = off_h // q_k_ratio |
|
off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) |
|
q_start_sid = tl.load(q_start_sids + off_zm) |
|
start_m = q_start_sid // BLOCK_M |
|
|
|
if HAS_BATCH_DIM: |
|
Q += off_z * stride_qb |
|
K += off_z * stride_kb |
|
V += off_z * stride_vb |
|
Out += off_z * stride_ob |
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_D) |
|
|
|
q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) |
|
q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start |
|
|
|
k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) |
|
k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start |
|
|
|
past_len = k_seqlen - q_seqlen |
|
|
|
Q += q_cu_start * stride_qt + off_h * stride_qh |
|
K += k_cu_start * stride_kt + off_h_for_kv * stride_kh |
|
V += k_cu_start * stride_vt + off_h_for_kv * stride_vh |
|
Out += q_cu_start * stride_ot + off_h * stride_oh |
|
|
|
q_pbid = (past_len + q_start_sid) // BLOCK_M |
|
|
|
if EVEN_D: |
|
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, |
|
mask=offs_m[:, None] < q_seqlen) |
|
else: |
|
q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, |
|
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), |
|
other=0) |
|
|
|
sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m |
|
|
|
|
|
k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) |
|
k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) |
|
|
|
m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf') |
|
l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) |
|
acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) |
|
|
|
k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd |
|
v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd |
|
|
|
for k_block_col_idx in range(k_block_start, k_block_end - 1): |
|
acc, l_i, m_i = _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, Q, |
|
k_block_col_idx, |
|
layout_col_ptr, |
|
layout_col_stride_h, layout_col_stride_m, |
|
k_ptrs, |
|
v_ptrs, |
|
off_h, offs_m, offs_n, offs_d, |
|
stride_kt, stride_vt, |
|
sm_scale, |
|
k_seqlen, |
|
past_len, |
|
False, |
|
BLOCK_M_LOADING, |
|
BLOCK_N, |
|
D_HEAD, |
|
EVEN_D, |
|
M_LT_N |
|
) |
|
|
|
acc, l_i, m_i = _fwd_kernel_inner( |
|
acc, l_i, m_i, |
|
q, Q, |
|
k_block_end - 1, |
|
layout_col_ptr, |
|
layout_col_stride_h, layout_col_stride_m, |
|
k_ptrs, |
|
v_ptrs, |
|
off_h, offs_m, offs_n, offs_d, |
|
stride_kt, stride_vt, |
|
sm_scale, |
|
k_seqlen, |
|
past_len, |
|
True, |
|
BLOCK_M_LOADING, |
|
BLOCK_N, |
|
D_HEAD, |
|
EVEN_D, |
|
M_LT_N |
|
) |
|
|
|
|
|
if EVEN_D: |
|
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, |
|
mask=offs_m[:, None] < q_seqlen) |
|
else: |
|
tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc, |
|
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None): |
|
''' |
|
q, k, v: shape=(batch, n_heads, seq, dim) |
|
''' |
|
|
|
if sm_scale is None: |
|
sm_scale = math.sqrt(float(q.size(-1))) |
|
|
|
if block_attn_mask is not None: |
|
assert attn_mask is None |
|
outs = [] |
|
for s in range(0, q.size(2), block_size): |
|
e = min(s + block_size, q.size(2)) |
|
q_block = q[:, :, s:e] |
|
attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale |
|
mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)] |
|
mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device)) |
|
mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0) |
|
attn = attn.masked_fill((1 - mask).bool(), float('-inf')) |
|
attn = attn.softmax(-1) |
|
out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e]) |
|
outs.append(out) |
|
torch_output = torch.cat(outs, dim=2) |
|
else: |
|
attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale |
|
|
|
if attn_mask is not None: |
|
attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf')) |
|
|
|
|
|
attn = attn.softmax(-1) |
|
if do is not None: |
|
dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do) |
|
print(f'> torch_attn computed dv: {dv=}') |
|
torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v) |
|
return torch_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)]) |
|
def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True, |
|
sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None): |
|
Q_LEN = Q_LEN or N_CTX |
|
torch.manual_seed(20) |
|
q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) |
|
|
|
if sm_scale is None: |
|
sm_scale = 1. / math.sqrt(D_HEAD) |
|
|
|
|
|
|
|
sm_scale = 0.0078125 |
|
if backward: |
|
q.requires_grad_(), k.requires_grad_(), v.requires_grad_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dout = torch.randn_like(q).contiguous() |
|
|
|
|
|
|
|
|
|
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size, |
|
local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True) |
|
|
|
if sparse_attention_fn is None: |
|
sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX, |
|
sparse_block_size=sparse_block_size, |
|
local_blocks=local_blocks, |
|
vert_stride=vert_stride, |
|
homo_head=homo_head, |
|
device=q.device, |
|
dtype=q.dtype, |
|
kernel_block_size=kernel_block_size) |
|
|
|
ref_out = torch_attention(q, k, v, mask_dense, sm_scale) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if backward: |
|
ref_out.backward(dout) |
|
ref_dv, v.grad = v.grad.clone(), None |
|
ref_dk, k.grad = k.grad.clone(), None |
|
ref_dq, q.grad = q.grad.clone(), None |
|
|
|
tri_out = sparse_attention_fn(q, k, v, sm_scale) |
|
|
|
decimal = 1 if dtype == torch.bfloat16 else 2 |
|
assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}' |
|
|
|
if backward: |
|
tri_out.backward(dout) |
|
tri_dv, v.grad = v.grad.clone(), None |
|
tri_dk, k.grad = k.grad.clone(), None |
|
tri_dq, q.grad = q.grad.clone(), None |
|
|
|
if backward: |
|
assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2) |
|
assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0) |
|
assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0) |
|
|
|
print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}') |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip() |
|
|
|
support_backward = True |
|
|
|
|
|
|
|
|
|
HAS_DENSE_TRITON_FLASH = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func |
|
HAS_FLASH = True |
|
except BaseException: |
|
HAS_FLASH = False |
|
print('> cannot import flash_attn') |
|
|
|
|
|
|
|
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 |
|
|
|
|
|
BLOCK_SIZE = 64 |
|
LOCAl_BLOCKS = 8 |
|
VERT_STRIDE = 1 |
|
HOMO_HEAD = False |
|
sparse_type = 'home' if HOMO_HEAD else 'hetero' |
|
dtype = torch.bfloat16 |
|
|
|
|
|
modes = ['fwd', 'bwd'] if support_backward else ['fwd'] |
|
|
|
configs = [triton.testing.Benchmark( |
|
x_names=['SEQ_LEN'], |
|
x_vals=[2**i for i in range(8, 16)], |
|
line_arg='provider', |
|
line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'], |
|
line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'], |
|
styles=[('red', '-'), ('blue', '-'), ('green', '-')], |
|
ylabel='ms', |
|
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}', |
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode} |
|
) for mode in modes] |
|
|
|
|
|
@triton.testing.perf_report(configs) |
|
def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None): |
|
assert mode in ['fwd', 'bwd'] |
|
warmup = 25 |
|
rep = 100 |
|
N_CTX = SEQ_LEN |
|
if provider == 'triton': |
|
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
sm_scale = 1.3 |
|
fn = lambda: triton_attention(q, k, v, sm_scale) |
|
if mode == 'bwd': |
|
o = fn() |
|
do = torch.randn_like(o) |
|
fn = lambda: o.backward(do, retain_graph=True) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
if provider == 'triton_sparse': |
|
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True) |
|
sm_scale = 1.3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sparse_attention_fn is None: |
|
|
|
sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN, |
|
local_blocks=LOCAl_BLOCKS, |
|
vert_stride=VERT_STRIDE, |
|
homo_head=HOMO_HEAD, |
|
sparse_block_size=BLOCK_SIZE, |
|
kernel_block_size=BLOCK_SIZE, |
|
device=q.device) |
|
|
|
|
|
|
|
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
|
if mode == 'bwd': |
|
o = fn() |
|
do = torch.randn_like(o) |
|
fn = lambda: o.backward(do, retain_graph=True) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
if provider == 'flash': |
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) |
|
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) |
|
cu_seqlens[1:] = lengths.cumsum(0) |
|
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) |
|
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) |
|
if mode == 'bwd': |
|
o = fn() |
|
do = torch.randn_like(o) |
|
fn = lambda: o.backward(do, retain_graph=True) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 |
|
|
|
BLOCK_SIZE = 64 |
|
LOCAl_BLOCKS = 8 |
|
VERT_STRIDE = 16 |
|
HOMO_HEAD = False |
|
sparse_type = 'home' if HOMO_HEAD else 'hetero' |
|
dtype = torch.bfloat16 |
|
MAX_N_CTX = 8192 |
|
|
|
configs = [triton.testing.Benchmark( |
|
x_names=['PAST_LEN'], |
|
x_vals=[2**i - 1 for i in range(8, 14)], |
|
line_arg='provider', |
|
line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'], |
|
line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'], |
|
styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')], |
|
ylabel='ms', |
|
plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}', |
|
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode} |
|
) for mode in ['fwd']] |
|
@triton.testing.perf_report(configs) |
|
def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'): |
|
assert mode in ['fwd'] |
|
warmup = 25 |
|
rep = 100 |
|
N_CTX = PAST_LEN + Q_LEN |
|
if provider == 'torch': |
|
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
sm_scale = 1.3 |
|
mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE, |
|
local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True) |
|
|
|
fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
if provider == 'triton_sparse': |
|
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
sm_scale = 1.3 |
|
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, |
|
local_blocks=LOCAl_BLOCKS, |
|
vert_stride=VERT_STRIDE, |
|
homo_head=HOMO_HEAD, |
|
sparse_block_size=BLOCK_SIZE, |
|
kernel_block_size=BLOCK_SIZE, |
|
device=q.device, |
|
inference=True) |
|
|
|
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
if provider == 'triton_dense': |
|
q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
sm_scale = 1.3 |
|
sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX, |
|
local_blocks=1, |
|
vert_stride=1, |
|
homo_head=True, |
|
sparse_block_size=BLOCK_SIZE, |
|
kernel_block_size=BLOCK_SIZE, |
|
device=q.device, |
|
inference=True) |
|
|
|
fn = lambda: sparse_attention_fn(q, k, v, sm_scale) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
if provider == 'flash': |
|
assert Q_LEN == 1 |
|
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) |
|
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) |
|
cu_seqlens[1:] = lengths.cumsum(0) |
|
cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32) |
|
|
|
|
|
q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False) |
|
|
|
fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False) |
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) |
|
return ms |
|
|
|
|
|
test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward) |
|
|
|
|
|
bench_flash_attention_inference.run(save_path='.', print_data=True) |
|
exit() |
|
|
|
test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64, |
|
dtype=torch.bfloat16, homo_head=False, backward=support_backward) |
|
|
|
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128, |
|
kernel_block_size=64, local_blocks=8, vert_stride=8) |
|
test_op(3, 2, 2047, 128, homo_head=False, backward=False) |
|
|
|
|
|
test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64) |
|
|
|
|
|
|
|
|
|
test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False, |
|
backward=support_backward, local_blocks=1, vert_stride=1) |
|
|
|
|
|
test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward) |
|
|
|
|
|
test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward) |
|
test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward) |
|
|
|
|
|
test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False) |
|
test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bench_flash_attention.run(save_path='.', print_data=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|