|
import math |
|
import torch |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
|
|
def self_extend_flash_forward_triton( |
|
model_self, |
|
query_position, |
|
group_size_2, |
|
neighbor_query_states, |
|
neighbor_key_states, |
|
group_query_states, |
|
group_key_states, |
|
value_states, |
|
attention_mask, |
|
bsz, |
|
q_len, |
|
kv_seq_len, |
|
attn_dropout, |
|
): |
|
|
|
o = _self_extend_flash_forward_triton(q=neighbor_query_states, |
|
k=neighbor_key_states, |
|
q1=group_query_states, |
|
k1=group_key_states, |
|
v=value_states, |
|
causal=(q_len == kv_seq_len), |
|
sm_scale=1. / math.sqrt(neighbor_query_states.shape[-1]), |
|
window=group_size_2) |
|
o = o.transpose(1, 2).contiguous() |
|
|
|
return o |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _self_extend_flash_forward_triton(q, k, q1, k1, v, causal, sm_scale, window): |
|
|
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] |
|
assert Lq == Lk and Lk == Lv |
|
assert Lk in {16, 32, 64, 128} |
|
|
|
device = torch.cuda.device_of(q) |
|
with torch.cuda.device(device): |
|
o = torch.empty_like(q) |
|
BLOCK_M = 128 |
|
BLOCK_N = 32 |
|
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1]) |
|
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
|
_fwd_kernel[grid]( |
|
q, |
|
k, |
|
q1, |
|
k1, |
|
v, |
|
sm_scale, |
|
L, |
|
o, |
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3), |
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3), |
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3), |
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3), |
|
q.shape[0], |
|
q.shape[1], |
|
q.shape[2], |
|
k.shape[2], |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_N=BLOCK_N, |
|
BLOCK_DMODEL=Lk, |
|
IS_CAUSAL=causal, |
|
WINDOW=window, |
|
num_warps=8, |
|
num_stages=2) |
|
|
|
return o |
|
|
|
|
|
|
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["Q_CTX"] % args["BLOCK_M"] == 0, |
|
"EVEN_N": lambda args: args["KV_CTX"] % args["BLOCK_N"] == 0, |
|
} |
|
) |
|
@triton.jit |
|
def _fwd_kernel( |
|
Q, |
|
K, |
|
Q1, |
|
K1, |
|
V, |
|
sm_scale, |
|
L, |
|
Out, |
|
stride_qz, stride_qh, stride_qm, stride_qk, |
|
stride_kz, stride_kh, stride_kn, stride_kk, |
|
stride_vz, stride_vh, stride_vn, stride_vk, |
|
stride_oz, stride_oh, stride_om, stride_on, |
|
Z, |
|
H, |
|
Q_CTX, |
|
KV_CTX, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_DMODEL: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
IS_CAUSAL: tl.constexpr, |
|
WINDOW: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr |
|
|
|
): |
|
start_m = tl.program_id(0) |
|
off_hz = tl.program_id(1) |
|
|
|
q_offset = off_hz * stride_qh |
|
vk_offset = off_hz * stride_kh |
|
|
|
|
|
Q_block_ptr = tl.make_block_ptr( |
|
base=Q + q_offset, |
|
shape=(Q_CTX, BLOCK_DMODEL), |
|
strides=(stride_qm, stride_qk), |
|
offsets=(start_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
K_block_ptr = tl.make_block_ptr( |
|
base=K + vk_offset, |
|
shape=(KV_CTX, BLOCK_DMODEL), |
|
strides=(stride_kn, stride_kk), |
|
offsets=(0, 0), |
|
block_shape=(BLOCK_N, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
Q1_block_ptr = tl.make_block_ptr( |
|
base=Q1 + q_offset, |
|
shape=(Q_CTX, BLOCK_DMODEL), |
|
strides=(stride_qm, stride_qk), |
|
offsets=(start_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
K1_block_ptr = tl.make_block_ptr( |
|
base=K1 + vk_offset, |
|
shape=(KV_CTX, BLOCK_DMODEL), |
|
strides=(stride_kn, stride_kk), |
|
offsets=(0, 0), |
|
block_shape=(BLOCK_N, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
V_block_ptr = tl.make_block_ptr( |
|
base=V + vk_offset, |
|
shape=(KV_CTX, BLOCK_DMODEL), |
|
strides=(stride_vn, stride_vk), |
|
offsets=(0, 0), |
|
block_shape=(BLOCK_N, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
|
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_n = tl.arange(0, BLOCK_N) |
|
|
|
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") |
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) |
|
|
|
|
|
|
|
|
|
qk_scale = sm_scale * 1.4426950408889634 |
|
|
|
|
|
if EVEN_M: |
|
q = tl.load(Q_block_ptr) |
|
q1 = tl.load(Q1_block_ptr) |
|
else: |
|
q = tl.load(Q_block_ptr, boundary_check=(1,0)) |
|
q1 = tl.load(Q1_block_ptr, boundary_check=(1,0)) |
|
|
|
q = (q * qk_scale).to(tl.bfloat16) |
|
q1 = (q1 * qk_scale).to(tl.bfloat16) |
|
|
|
|
|
|
|
|
|
offs_k = tl.arange(0, BLOCK_DMODEL) |
|
I = tl.where(offs_k[:, None] == offs_k, |
|
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=tl.bfloat16), |
|
tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=tl.bfloat16)) |
|
q = tl.dot(q, I).to(tl.bfloat16) |
|
q1 = tl.dot(q1, I).to(tl.bfloat16) |
|
|
|
|
|
|
|
lo = 0 |
|
if IS_CAUSAL: |
|
hi = tl.minimum(KV_CTX, (start_m + 1) * BLOCK_M) |
|
else: |
|
hi = KV_CTX |
|
|
|
for start_n in range(lo, hi, BLOCK_N): |
|
|
|
if EVEN_N: |
|
k = tl.load(K_block_ptr) |
|
k1 = tl.load(K1_block_ptr) |
|
v = tl.load(V_block_ptr) |
|
else: |
|
k = tl.load(K_block_ptr, boundary_check=(1,0)) |
|
k1 = tl.load(K1_block_ptr, boundary_check=(1,0)) |
|
v = tl.load(V_block_ptr, boundary_check=(1,0)) |
|
|
|
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
|
|
|
|
mask = ( KV_CTX - Q_CTX + offs_m[:, None]) >= (start_n + offs_n[None, :] + WINDOW) |
|
qk += tl.where(mask, tl.dot(q1, tl.trans(k1)), tl.dot(q, tl.trans(k))) |
|
|
|
|
|
|
|
|
|
|
|
if IS_CAUSAL: |
|
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) |
|
qk = tl.where(mask, qk, float("-inf")) |
|
|
|
|
|
|
|
m_i_new = tl.maximum(m_i, tl.max(qk, 1)) |
|
alpha = tl.math.exp2(m_i - m_i_new) |
|
p = tl.math.exp2(qk - m_i_new[:, None]) |
|
|
|
|
|
acc_scale = l_i * 0 + alpha |
|
acc *= acc_scale[:, None] |
|
acc += tl.dot(p.to(tl.bfloat16), v) |
|
|
|
|
|
l_i = l_i * alpha + tl.sum(p, 1) |
|
m_i = m_i_new |
|
|
|
|
|
K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) |
|
K1_block_ptr = tl.advance(K1_block_ptr, (BLOCK_N, 0)) |
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) |
|
|
|
|
|
|
|
acc = acc * (1.0 / l_i[:, None]) |
|
l_ptrs = L + off_hz * Q_CTX + offs_m |
|
|
|
mask_m = offs_m < Q_CTX |
|
l_i = m_i + tl.math.log2(l_i) |
|
if EVEN_M: |
|
tl.store(l_ptrs, l_i) |
|
else: |
|
tl.store(l_ptrs, l_i, mask=mask_m) |
|
|
|
|
|
O_block_ptr = tl.make_block_ptr( |
|
base=Out + q_offset, |
|
shape=(Q_CTX, BLOCK_DMODEL), |
|
strides=(stride_om, stride_on), |
|
offsets=(start_m * BLOCK_M, 0), |
|
block_shape=(BLOCK_M, BLOCK_DMODEL), |
|
order=(1, 0) |
|
) |
|
if EVEN_M: |
|
tl.store(O_block_ptr, acc.to(tl.bfloat16)) |
|
else: |
|
tl.store(O_block_ptr, acc.to(tl.bfloat16), boundary_check=(1,0)) |
|
|