MInference / minference /ops /streaming_kernel.py
iofu728's picture
Feature(MInference): build demo
43a7079
raw
history blame
No virus
26.3 kB
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import math
import torch
import triton
import triton.language as tl
_BLOCK_N=64
_BLOCK_M=64
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m, qk_scale, N_CTX,
sliding_window_offset, sliding_window_size,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr,
IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr
):
# range of values handled by this stage
if SLIDING_WINDOW and not COMPLEMENT_SLIDING_WINDOW:
if COMPLEMENT_SLIDING_WINDOW:
lo = 0
hi = (((start_m + 1) * BLOCK_M + sliding_window_offset - sliding_window_size + BLOCK_N - 1) // BLOCK_N) * BLOCK_N
else:
lo = ((start_m * BLOCK_M + sliding_window_offset - sliding_window_size + 1) // BLOCK_N) * BLOCK_N
hi = ((((start_m + 1) * BLOCK_M - 1) + sliding_window_offset + BLOCK_N) // BLOCK_N) * BLOCK_N
if lo < 0:
lo = 0
if hi > N_CTX:
hi = N_CTX
# lo = 0
# hi = N_CTX
lo = tl.multiple_of(lo, BLOCK_N)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
else:
lo, hi = 0, N_CTX
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if IS_EVEN_N:
k = tl.load(K_block_ptr)
else:
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = qk * qk_scale
if SLIDING_WINDOW:
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \
+ start_m * BLOCK_M - start_n + sliding_window_offset
if COMPLEMENT_SLIDING_WINDOW:
mask = (dist >= sliding_window_size)
else:
mask = (dist >= 0) & (dist < sliding_window_size)
qk = tl.where(mask, qk, float("-inf"))
if not IS_EVEN_N:
qk = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
if SLIDING_WINDOW:
p = tl.where(mask, p, 0)
if not IS_EVEN_N:
p = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], p, 0)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
tmp = m_i - m_ij
alpha_mask = (tmp != tmp) # check nan
alpha = tl.math.exp2(tmp)
alpha = tl.where(alpha_mask, 1., alpha)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
if IS_EVEN_N:
v = tl.load(V_block_ptr)
else:
v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero")
acc += tl.dot(p.to(v.dtype), v)
# update m_i and l_i
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
@triton.heuristics(
{
"IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0,
"IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0,
}
)
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out, L,#
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, H_KV, #
N_CTX, #
ROUND_CTX,
NKV_CTX,
sliding_window_offset,
sliding_window_size,
IS_EVEN_M: tl.constexpr,
IS_EVEN_N: tl.constexpr,
BLOCK_M: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
END: tl.constexpr,
INIT: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
COMPLEMENT_SLIDING_WINDOW: tl.constexpr
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
off_hkv = off_h // (H//H_KV)
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh
v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(NKV_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, NKV_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(ROUND_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# initialize pointer to m and l
m_ptrs = M + off_hz * ROUND_CTX + offs_m
l_ptrs = L + off_hz * ROUND_CTX + offs_m
if INIT:
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
else:
# don't have to check boundary for q len
m_i = tl.load(m_ptrs).to(tl.float32)
l_i = tl.load(l_ptrs).to(tl.float32)
acc = tl.load(O_block_ptr).to(tl.float32)
qk_scale = sm_scale
qk_scale *= 1.4426950408889634 # 1/log(2)
# load q: it will stay in SRAM throughout
if IS_EVEN_M:
q = tl.load(Q_block_ptr)
else:
q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero")
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, NKV_CTX, #
sliding_window_offset, sliding_window_size,
BLOCK_M, BLOCK_DMODEL, BLOCK_N, SLIDING_WINDOW, IS_EVEN_M, IS_EVEN_N,
COMPLEMENT_SLIDING_WINDOW)
# epilogue
if (END):
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
else:
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
@triton.heuristics(
{
"IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0,
"IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0,
}
)
@triton.jit
def _score_kernel(
Q, K, M, sm_scale, Out,
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_oz, stride_oh, stride_on,
Z, H, H_KV, #
N_CTX, #
ROUND_CTX,
NKV_CTX,
sliding_window_offset,
sliding_window_size,
SLIDING_WINDOW: tl.constexpr,
COMPLEMENT_SLIDING_WINDOW: tl.constexpr,
IS_EVEN_M: tl.constexpr,
IS_EVEN_N: tl.constexpr,
BLOCK_M: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
):
start_n = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
off_hkv = off_h // (H//H_KV)
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh
m_ptrs = M + off_hz * ROUND_CTX + tl.arange(0, BLOCK_M)
o = tl.zeros([BLOCK_M], dtype=tl.float32)
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(BLOCK_DMODEL, NKV_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, start_n * BLOCK_N),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
if IS_EVEN_N:
k = tl.load(K_block_ptr)
else:
k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero")
lo = 0
hi = ROUND_CTX
qk_scale = sm_scale
qk_scale *= 1.4426950408889634 # 1/log(2)
for start_m in range(lo, hi, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
if IS_EVEN_M:
q = tl.load(Q_block_ptr)
else:
q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero")
m = tl.load(m_ptrs)
# calc qk
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = qk * qk_scale
if SLIDING_WINDOW:
# dist = tl.arange(start_m, start_m + BLOCK_M)[:, None] \
# - tl.arange(start_n * BLOCK_N, (start_n + 1) + BLOCK_N)[None, :] + sliding_window_offset
dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \
+ start_m - start_n * BLOCK_N + sliding_window_offset
if COMPLEMENT_SLIDING_WINDOW:
mask = (dist >= sliding_window_size)
else:
mask = (dist >= 0) & (dist < sliding_window_size)
qk = qk - m[:, None]
p = tl.math.exp2(qk) # (BLOCK_M, BLOCK_N)
if SLIDING_WINDOW:
p = tl.where(mask, p, 0)
if not IS_EVEN_N:
p = tl.where(
((tl.arange(0, BLOCK_M) + start_m) < N_CTX)[:, None],
p, 0
)
o += tl.sum(p, axis=0)
Q_block_ptr = tl.advance(Q_block_ptr, offsets=(BLOCK_M, 0))
m_ptrs = m_ptrs + BLOCK_M
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
o_range = tl.arange(0, BLOCK_N) + start_n * BLOCK_N # orange
o_ptrs = Out + o_offset + o_range
tl.store(o_ptrs, o.to(Out.type.element_ty), mask = o_range < NKV_CTX)
def get_score(q, k, m, sliding_window, complement_sliding_window):
assert q.dim() == 4
assert k.dim() == 4
assert m.dim() == 3
assert q.shape[:2] == m.shape[:2]
N_CTX = q.size(-2)
NKV_CTX = k.size(-2)
ROUND_CTX = m.size(-1)
ret = torch.zeros(
(q.size(0), q.size(1), k.size(2)),
dtype=k.dtype, device=k.device
)
if sliding_window is not None:
sliding_window_offset, sliding_window_size = sliding_window
else:
sliding_window_offset, sliding_window_size = None, None
grid = lambda META: (
triton.cdiv(k.shape[2], META["BLOCK_N"]),
q.shape[0] * q.shape[1]
)
sm_scale = 1 / math.sqrt(q.size(-1))
global _BLOCK_N
global _BLOCK_M
try:
_score_kernel[grid](
q, k, m, sm_scale, ret,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
ret.stride(0), ret.stride(1), ret.stride(2),
q.size(0), q.size(1), k.size(1),
N_CTX, ROUND_CTX, NKV_CTX,
sliding_window_offset,
sliding_window_size,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
BLOCK_DMODEL=q.size(-1)
)
except triton.OutOfResources as E:
from warnings import warn
_BLOCK_N = _BLOCK_N // 2
_BLOCK_M = _BLOCK_M // 2
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.")
_score_kernel[grid](
q, k, m, sm_scale, ret,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
ret.stride(0), ret.stride(1), ret.stride(2),
q.size(0), q.size(1), k.size(1),
N_CTX, ROUND_CTX, NKV_CTX,
sliding_window_offset,
sliding_window_size,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
BLOCK_DMODEL=q.size(-1)
)
return ret
def _forward(
q, k, v, sm_scale,
o = None, m = None, l = None, end = False,
sliding_window=None, init=False,
complement_sliding_window=False
):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
q_round_len = math.ceil(q.shape[2] / 64) * 64
if sliding_window is not None:
sliding_window_offset, sliding_window_size = sliding_window
else:
sliding_window_offset, sliding_window_size = None, None
grid = lambda META: (
triton.cdiv(q.shape[2], META["BLOCK_M"]),
q.shape[0] * q.shape[1],
)
global _BLOCK_N
global _BLOCK_M
try:
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)
except triton.OutOfResources as E:
_BLOCK_N = _BLOCK_N // 2
_BLOCK_M = _BLOCK_M // 2
from warnings import warn
warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.")
_attn_fwd[grid](
q, k, v, sm_scale, m, o, l, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], k.shape[1], #
q.shape[2], #
q_round_len,
k.shape[2],
sliding_window_offset,
sliding_window_size,
BLOCK_DMODEL=Lk, #
END=end,
INIT=init,
BLOCK_M=_BLOCK_M,
BLOCK_N=_BLOCK_N,
SLIDING_WINDOW=(sliding_window is not None),
COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,
num_warps=4,
num_stages=4
)
if end:
o = o[:, :, :q.shape[2], :].contiguous().to(q.dtype)
return o, m, l
class MultiStageDotProductionAttention:
def __init__(
self,
q_shape,
dtype,
device,
):
self.q_shape = q_shape
self.dtype = dtype
self.device = device
self.end = False
self.ret = torch.zeros(
q_shape, dtype=dtype, device=device
)
self.score_list = []
def append(
self,
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
sliding_window=None, complement_sliding_window: bool = False,
end=False, get_score=False,
*args, **kwargs
):
raise NotImplementedError
def get_result(self):
return self.ret, self.score_list
class TritonMultiStageDotProductionAttention(MultiStageDotProductionAttention):
def __init__(self, q_shape, dtype, device):
self.q_shape = q_shape
self.dtype = dtype
self.device = device
q_round_len = math.ceil(q_shape[2] / 64) * 64
o_shape = (q_shape[0], q_shape[1], q_round_len, q_shape[3])
m_shape = (q_shape[0], q_shape[1], q_round_len)
l_shape = (q_shape[0], q_shape[1], q_round_len)
self.o = torch.empty(o_shape, device=device, dtype=torch.float32)
self.m = torch.empty(m_shape, device=device, dtype=torch.float32)
self.l = torch.empty(l_shape, device=device, dtype=torch.float32)
self.q_list = []
self.k_list = []
self.sliding_window_list = []
self.complement_sliding_window_list = []
self.score_list = []
self.end = False
self.init = False
def finalize(self):
self.end = True
for q, k, sliding_window, comp in zip(self.q_list, self.k_list, self.sliding_window_list, self.complement_sliding_window_list):
if q is not None:
score = get_score(q, k, self.m, sliding_window, comp)
self.score_list.append(score)
else:
self.score_list.append(None)
self.ret = self.o
def append(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, end=False, get_score=False, sliding_window = None, complement_sliding_window: bool = False):
assert q.shape == self.q_shape
if isinstance(sliding_window, int):
sliding_window = (
k.shape[2] - q.shape[2], sliding_window
)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
sm_scale = 1 / math.sqrt(q.shape[-1])
o, m, l = _forward(
q, k, v, sm_scale, self.o, self.m, self.l,
sliding_window=sliding_window, end=end, init=not self.init,
complement_sliding_window=complement_sliding_window
)
self.init = True
self.o = o
self.m = m
self.l = l
if get_score:
self.q_list.append(q)
self.k_list.append(k)
self.sliding_window_list.append(sliding_window)
self.complement_sliding_window_list.append(complement_sliding_window)
else:
self.q_list.append(None)
self.k_list.append(None)
self.sliding_window_list.append(None)
self.complement_sliding_window_list.append(None)
if end:
assert not self.end
self.finalize()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def streaming_forward(
q, k, v,
n_init, n_local,
):
# q,k,v should be tensors already equipped with RoPE
# k,v should already repeated to align with q.shape
assert q.dim() == 4 # (bsz, num_heads, seqlen, head_dim)
assert q.shape == k.shape == v.shape
head_dim = q.shape[-1]
if head_dim not in [16, 32, 64, 128, 256, 512]:
target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim
q = torch.nn.functional.pad(q, [0, target_dim, 0, 0, 0, 0, 0, 0])
k = torch.nn.functional.pad(k, [0, target_dim, 0, 0, 0, 0, 0, 0])
v = torch.nn.functional.pad(v, [0, target_dim, 0, 0, 0, 0, 0, 0])
q_len = q.size(2)
k_len = k.size(2)
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device)
if k_len > n_local:
init_k = k[:, :, :n_init, :].contiguous()
init_v = v[:, :, :n_init, :].contiguous()
attn.append(q, k, v, sliding_window=n_local)
attn.append(
q, init_k, init_v, end=True,
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True
)
else:
attn.append(q, k, v, sliding_window=n_local, end=True)
score, _ = attn.get_result()
return score[...,:head_dim]
def streaming_forward2(
q, k, v,
n_init, n_local,
):
q_len = q.size(2)
k_len = k.size(2)
attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device)
if k_len > n_local:
init_k = k[:, :, :n_init, :].contiguous()
init_v = v[:, :, :n_init, :].contiguous()
else:
init_k = torch.empty(
(k.size(0), k.size(1), 0, k.size(3)),
dtype=k.dtype, device=k.device
)
init_v = torch.empty(
(v.size(0), v.size(1), 0, v.size(3)),
dtype=v.dtype, device=v.device
)
attn.append(q, k, v, sliding_window=n_local)
attn.append(
q, init_k, init_v, end=True,
sliding_window=(k_len - q_len, n_local), complement_sliding_window=True
)
score, _ = attn.get_result()
return score
def stream_llm_forward(n_local, n_init, *args, **kwargs):
Attn = TritonMultiStageDotProductionAttention
def forward(self, query : torch.Tensor,
key_value : torch.Tensor,
position_bias : torch.Tensor,
use_cache: bool,
past_key_value,
project_q, project_k, project_v, attention_out,
dim_head, num_heads, num_heads_kv
):
batch_size = query.size(0)
len_q = query.size(1)
len_k = key_value.size(1)
h_q = project_q(query) # (batch, len_q, num_heads * dim_head)
h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head)
h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head)
h_q = h_q.view(batch_size, len_q, num_heads, dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_q, dim_head)
h_k = h_k.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head)
h_v = h_v.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head)
h_q = h_q.contiguous() # (batch * num_heads, len_q, dim_head)
h_k = h_k.contiguous() # (batch * num_heads, len_k, dim_head)
h_v = h_v.contiguous() # (batch * num_heads, len_k, dim_head)
if past_key_value is not None:
h_k = torch.cat([past_key_value[0], h_k], dim=-2)
h_v = torch.cat([past_key_value[1], h_v], dim=-2)
len_k += past_key_value[2]
if use_cache:
if len_k <= n_local + n_init:
h_k_cache = h_k
h_v_cache = h_v
else:
h_k_cache = torch.cat([h_k[:,:, :n_init, :], h_k[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2)
h_v_cache = torch.cat([h_v[:,:, :n_init, :], h_v[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2)
current_key_value = (h_k_cache, h_v_cache, len_k)
else:
current_key_value = None
h_q_ = h_q
h_k_ = h_k
h_v_ = h_v
if len_q + n_local < h_k_.size(-2):
h_k_ = h_k_[:, :, h_k_.size(-2) - len_q - n_local:, :].contiguous().clone()
h_v_ = h_v_[:, :, h_v_.size(-2) - len_q - n_local:, :].contiguous().clone()
local_h_q, local_h_k = position_bias(h_q_, h_k_)
local_h_v = h_v_
if len_k > n_local:
init_h_q = position_bias.apply_rotary_pos_emb_one_angle(
h_q, n_local + n_init
)
init_h_k = position_bias.apply_rotary_pos_emb(
h_k[:, :, :n_init, :].contiguous(),
n_init, n_init, position_bias._cos_cached, position_bias._sin_cached
)
init_h_v = h_v[:, :, :n_init, :].contiguous()
else:
init_h_q = h_q
init_h_k = torch.empty(
(batch_size, num_heads_kv, 0, dim_head),
device=h_k.device,
dtype=h_k.dtype
)
init_h_v = torch.empty(
(batch_size, num_heads_kv, 0, dim_head),
device=h_v.device,
dtype=h_v.dtype
)
attn = Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
attn.append(local_h_q, local_h_k, local_h_v, sliding_window=n_local)
attn.append(
init_h_q, init_h_k, init_h_v, end=True,
sliding_window=(len_k - len_q, n_local),
complement_sliding_window=True
)
score, _ = attn.get_result()
score = score.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3).contiguous() # (batch, len_q, num_heads, dim_head)
score = score.reshape(batch_size, len_q, num_heads * dim_head) # (batch, len_q, num_heads * dim_head)
score = attention_out(score)
if use_cache:
return score, current_key_value
else:
return score
return forward