""" Fused Attention =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) Credits: OpenAI kernel team Extra Credits: - Original flash attention paper (https://arxiv.org/abs/2205.14135) - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) """ import math import torch import triton import triton.language as tl _BLOCK_N=64 _BLOCK_M=64 @triton.jit def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, qk_scale, N_CTX, sliding_window_offset, sliding_window_size, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr ): # range of values handled by this stage if SLIDING_WINDOW and not COMPLEMENT_SLIDING_WINDOW: if COMPLEMENT_SLIDING_WINDOW: lo = 0 hi = (((start_m + 1) * BLOCK_M + sliding_window_offset - sliding_window_size + BLOCK_N - 1) // BLOCK_N) * BLOCK_N else: lo = ((start_m * BLOCK_M + sliding_window_offset - sliding_window_size + 1) // BLOCK_N) * BLOCK_N hi = ((((start_m + 1) * BLOCK_M - 1) + sliding_window_offset + BLOCK_N) // BLOCK_N) * BLOCK_N if lo < 0: lo = 0 if hi > N_CTX: hi = N_CTX # lo = 0 # hi = N_CTX lo = tl.multiple_of(lo, BLOCK_N) K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) else: lo, hi = 0, N_CTX # loop over k, v and update accumulator for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- if IS_EVEN_N: k = tl.load(K_block_ptr) else: k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero") qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk = qk * qk_scale if SLIDING_WINDOW: dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \ + start_m * BLOCK_M - start_n + sliding_window_offset if COMPLEMENT_SLIDING_WINDOW: mask = (dist >= sliding_window_size) else: mask = (dist >= 0) & (dist < sliding_window_size) qk = tl.where(mask, qk, float("-inf")) if not IS_EVEN_N: qk = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], qk, float("-inf")) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) if SLIDING_WINDOW: p = tl.where(mask, p, 0) if not IS_EVEN_N: p = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], p, 0) l_ij = tl.sum(p, 1) # -- update m_i and l_i tmp = m_i - m_ij alpha_mask = (tmp != tmp) # check nan alpha = tl.math.exp2(tmp) alpha = tl.where(alpha_mask, 1., alpha) l_i = l_i * alpha + l_ij # -- update output accumulator -- acc = acc * alpha[:, None] # update acc if IS_EVEN_N: v = tl.load(V_block_ptr) else: v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option="zero") acc += tl.dot(p.to(v.dtype), v) # update m_i and l_i m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i @triton.heuristics( { "IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0, "IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0, } ) @triton.jit def _attn_fwd(Q, K, V, sm_scale, M, Out, L,# stride_qz, stride_qh, stride_qm, stride_qk, # stride_kz, stride_kh, stride_kn, stride_kk, # stride_vz, stride_vh, stride_vk, stride_vn, # stride_oz, stride_oh, stride_om, stride_on, # Z, H, H_KV, # N_CTX, # ROUND_CTX, NKV_CTX, sliding_window_offset, sliding_window_size, IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, BLOCK_M: tl.constexpr, # BLOCK_DMODEL: tl.constexpr, # BLOCK_N: tl.constexpr, # END: tl.constexpr, INIT: tl.constexpr, SLIDING_WINDOW: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr ): start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H off_hkv = off_h // (H//H_KV) q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(NKV_CTX, BLOCK_DMODEL), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(BLOCK_DMODEL, NKV_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(ROUND_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # initialize pointer to m and l m_ptrs = M + off_hz * ROUND_CTX + offs_m l_ptrs = L + off_hz * ROUND_CTX + offs_m if INIT: m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) else: # don't have to check boundary for q len m_i = tl.load(m_ptrs).to(tl.float32) l_i = tl.load(l_ptrs).to(tl.float32) acc = tl.load(O_block_ptr).to(tl.float32) qk_scale = sm_scale qk_scale *= 1.4426950408889634 # 1/log(2) # load q: it will stay in SRAM throughout if IS_EVEN_M: q = tl.load(Q_block_ptr) else: q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option="zero") acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # start_m, qk_scale, NKV_CTX, # sliding_window_offset, sliding_window_size, BLOCK_M, BLOCK_DMODEL, BLOCK_N, SLIDING_WINDOW, IS_EVEN_M, IS_EVEN_N, COMPLEMENT_SLIDING_WINDOW) # epilogue if (END): m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] else: tl.store(l_ptrs, l_i) tl.store(m_ptrs, m_i) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.heuristics( { "IS_EVEN_M": lambda args: args["N_CTX"] % args["BLOCK_M"] == 0, "IS_EVEN_N": lambda args: args["NKV_CTX"] % args["BLOCK_N"] == 0, } ) @triton.jit def _score_kernel( Q, K, M, sm_scale, Out, stride_qz, stride_qh, stride_qm, stride_qk, # stride_kz, stride_kh, stride_kn, stride_kk, # stride_oz, stride_oh, stride_on, Z, H, H_KV, # N_CTX, # ROUND_CTX, NKV_CTX, sliding_window_offset, sliding_window_size, SLIDING_WINDOW: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr, IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, BLOCK_M: tl.constexpr, # BLOCK_DMODEL: tl.constexpr, # BLOCK_N: tl.constexpr, # ): start_n = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H off_hkv = off_h // (H//H_KV) q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh m_ptrs = M + off_hz * ROUND_CTX + tl.arange(0, BLOCK_M) o = tl.zeros([BLOCK_M], dtype=tl.float32) Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(BLOCK_DMODEL, NKV_CTX), strides=(stride_kk, stride_kn), offsets=(0, start_n * BLOCK_N), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1), ) if IS_EVEN_N: k = tl.load(K_block_ptr) else: k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option="zero") lo = 0 hi = ROUND_CTX qk_scale = sm_scale qk_scale *= 1.4426950408889634 # 1/log(2) for start_m in range(lo, hi, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) if IS_EVEN_M: q = tl.load(Q_block_ptr) else: q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option="zero") m = tl.load(m_ptrs) # calc qk qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk = qk * qk_scale if SLIDING_WINDOW: # dist = tl.arange(start_m, start_m + BLOCK_M)[:, None] \ # - tl.arange(start_n * BLOCK_N, (start_n + 1) + BLOCK_N)[None, :] + sliding_window_offset dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \ + start_m - start_n * BLOCK_N + sliding_window_offset if COMPLEMENT_SLIDING_WINDOW: mask = (dist >= sliding_window_size) else: mask = (dist >= 0) & (dist < sliding_window_size) qk = qk - m[:, None] p = tl.math.exp2(qk) # (BLOCK_M, BLOCK_N) if SLIDING_WINDOW: p = tl.where(mask, p, 0) if not IS_EVEN_N: p = tl.where( ((tl.arange(0, BLOCK_M) + start_m) < N_CTX)[:, None], p, 0 ) o += tl.sum(p, axis=0) Q_block_ptr = tl.advance(Q_block_ptr, offsets=(BLOCK_M, 0)) m_ptrs = m_ptrs + BLOCK_M o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh o_range = tl.arange(0, BLOCK_N) + start_n * BLOCK_N # orange o_ptrs = Out + o_offset + o_range tl.store(o_ptrs, o.to(Out.type.element_ty), mask = o_range < NKV_CTX) def get_score(q, k, m, sliding_window, complement_sliding_window): assert q.dim() == 4 assert k.dim() == 4 assert m.dim() == 3 assert q.shape[:2] == m.shape[:2] N_CTX = q.size(-2) NKV_CTX = k.size(-2) ROUND_CTX = m.size(-1) ret = torch.zeros( (q.size(0), q.size(1), k.size(2)), dtype=k.dtype, device=k.device ) if sliding_window is not None: sliding_window_offset, sliding_window_size = sliding_window else: sliding_window_offset, sliding_window_size = None, None grid = lambda META: ( triton.cdiv(k.shape[2], META["BLOCK_N"]), q.shape[0] * q.shape[1] ) sm_scale = 1 / math.sqrt(q.size(-1)) global _BLOCK_N global _BLOCK_M try: _score_kernel[grid]( q, k, m, sm_scale, ret, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), ret.stride(0), ret.stride(1), ret.stride(2), q.size(0), q.size(1), k.size(1), N_CTX, ROUND_CTX, NKV_CTX, sliding_window_offset, sliding_window_size, SLIDING_WINDOW=(sliding_window is not None), COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, BLOCK_DMODEL=q.size(-1) ) except triton.OutOfResources as E: from warnings import warn _BLOCK_N = _BLOCK_N // 2 _BLOCK_M = _BLOCK_M // 2 warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.") _score_kernel[grid]( q, k, m, sm_scale, ret, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), ret.stride(0), ret.stride(1), ret.stride(2), q.size(0), q.size(1), k.size(1), N_CTX, ROUND_CTX, NKV_CTX, sliding_window_offset, sliding_window_size, SLIDING_WINDOW=(sliding_window is not None), COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, BLOCK_DMODEL=q.size(-1) ) return ret def _forward( q, k, v, sm_scale, o = None, m = None, l = None, end = False, sliding_window=None, init=False, complement_sliding_window=False ): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} q_round_len = math.ceil(q.shape[2] / 64) * 64 if sliding_window is not None: sliding_window_offset, sliding_window_size = sliding_window else: sliding_window_offset, sliding_window_size = None, None grid = lambda META: ( triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], ) global _BLOCK_N global _BLOCK_M try: _attn_fwd[grid]( q, k, v, sm_scale, m, o, l, # q.stride(0), q.stride(1), q.stride(2), q.stride(3), # k.stride(0), k.stride(1), k.stride(2), k.stride(3), # v.stride(0), v.stride(1), v.stride(2), v.stride(3), # o.stride(0), o.stride(1), o.stride(2), o.stride(3), # q.shape[0], q.shape[1], k.shape[1], # q.shape[2], # q_round_len, k.shape[2], sliding_window_offset, sliding_window_size, BLOCK_DMODEL=Lk, # END=end, INIT=init, BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, SLIDING_WINDOW=(sliding_window is not None), COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, num_warps=4, num_stages=4 ) except triton.OutOfResources as E: _BLOCK_N = _BLOCK_N // 2 _BLOCK_M = _BLOCK_M // 2 from warnings import warn warn(f"Triton Attention Output Resources. {E}\nUse smaller block size {_BLOCK_N}.") _attn_fwd[grid]( q, k, v, sm_scale, m, o, l, # q.stride(0), q.stride(1), q.stride(2), q.stride(3), # k.stride(0), k.stride(1), k.stride(2), k.stride(3), # v.stride(0), v.stride(1), v.stride(2), v.stride(3), # o.stride(0), o.stride(1), o.stride(2), o.stride(3), # q.shape[0], q.shape[1], k.shape[1], # q.shape[2], # q_round_len, k.shape[2], sliding_window_offset, sliding_window_size, BLOCK_DMODEL=Lk, # END=end, INIT=init, BLOCK_M=_BLOCK_M, BLOCK_N=_BLOCK_N, SLIDING_WINDOW=(sliding_window is not None), COMPLEMENT_SLIDING_WINDOW=complement_sliding_window, num_warps=4, num_stages=4 ) if end: o = o[:, :, :q.shape[2], :].contiguous().to(q.dtype) return o, m, l class MultiStageDotProductionAttention: def __init__( self, q_shape, dtype, device, ): self.q_shape = q_shape self.dtype = dtype self.device = device self.end = False self.ret = torch.zeros( q_shape, dtype=dtype, device=device ) self.score_list = [] def append( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sliding_window=None, complement_sliding_window: bool = False, end=False, get_score=False, *args, **kwargs ): raise NotImplementedError def get_result(self): return self.ret, self.score_list class TritonMultiStageDotProductionAttention(MultiStageDotProductionAttention): def __init__(self, q_shape, dtype, device): self.q_shape = q_shape self.dtype = dtype self.device = device q_round_len = math.ceil(q_shape[2] / 64) * 64 o_shape = (q_shape[0], q_shape[1], q_round_len, q_shape[3]) m_shape = (q_shape[0], q_shape[1], q_round_len) l_shape = (q_shape[0], q_shape[1], q_round_len) self.o = torch.empty(o_shape, device=device, dtype=torch.float32) self.m = torch.empty(m_shape, device=device, dtype=torch.float32) self.l = torch.empty(l_shape, device=device, dtype=torch.float32) self.q_list = [] self.k_list = [] self.sliding_window_list = [] self.complement_sliding_window_list = [] self.score_list = [] self.end = False self.init = False def finalize(self): self.end = True for q, k, sliding_window, comp in zip(self.q_list, self.k_list, self.sliding_window_list, self.complement_sliding_window_list): if q is not None: score = get_score(q, k, self.m, sliding_window, comp) self.score_list.append(score) else: self.score_list.append(None) self.ret = self.o def append(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, end=False, get_score=False, sliding_window = None, complement_sliding_window: bool = False): assert q.shape == self.q_shape if isinstance(sliding_window, int): sliding_window = ( k.shape[2] - q.shape[2], sliding_window ) q = q.contiguous() k = k.contiguous() v = v.contiguous() sm_scale = 1 / math.sqrt(q.shape[-1]) o, m, l = _forward( q, k, v, sm_scale, self.o, self.m, self.l, sliding_window=sliding_window, end=end, init=not self.init, complement_sliding_window=complement_sliding_window ) self.init = True self.o = o self.m = m self.l = l if get_score: self.q_list.append(q) self.k_list.append(k) self.sliding_window_list.append(sliding_window) self.complement_sliding_window_list.append(complement_sliding_window) else: self.q_list.append(None) self.k_list.append(None) self.sliding_window_list.append(None) self.complement_sliding_window_list.append(None) if end: assert not self.end self.finalize() def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def streaming_forward( q, k, v, n_init, n_local, ): # q,k,v should be tensors already equipped with RoPE # k,v should already repeated to align with q.shape assert q.dim() == 4 # (bsz, num_heads, seqlen, head_dim) assert q.shape == k.shape == v.shape head_dim = q.shape[-1] if head_dim not in [16, 32, 64, 128, 256, 512]: target_dim = 2 ** math.ceil(math.log2(head_dim)) - head_dim q = torch.nn.functional.pad(q, [0, target_dim, 0, 0, 0, 0, 0, 0]) k = torch.nn.functional.pad(k, [0, target_dim, 0, 0, 0, 0, 0, 0]) v = torch.nn.functional.pad(v, [0, target_dim, 0, 0, 0, 0, 0, 0]) q_len = q.size(2) k_len = k.size(2) attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device) if k_len > n_local: init_k = k[:, :, :n_init, :].contiguous() init_v = v[:, :, :n_init, :].contiguous() attn.append(q, k, v, sliding_window=n_local) attn.append( q, init_k, init_v, end=True, sliding_window=(k_len - q_len, n_local), complement_sliding_window=True ) else: attn.append(q, k, v, sliding_window=n_local, end=True) score, _ = attn.get_result() return score[...,:head_dim] def streaming_forward2( q, k, v, n_init, n_local, ): q_len = q.size(2) k_len = k.size(2) attn = TritonMultiStageDotProductionAttention(q.shape, q.dtype, q.device) if k_len > n_local: init_k = k[:, :, :n_init, :].contiguous() init_v = v[:, :, :n_init, :].contiguous() else: init_k = torch.empty( (k.size(0), k.size(1), 0, k.size(3)), dtype=k.dtype, device=k.device ) init_v = torch.empty( (v.size(0), v.size(1), 0, v.size(3)), dtype=v.dtype, device=v.device ) attn.append(q, k, v, sliding_window=n_local) attn.append( q, init_k, init_v, end=True, sliding_window=(k_len - q_len, n_local), complement_sliding_window=True ) score, _ = attn.get_result() return score def stream_llm_forward(n_local, n_init, *args, **kwargs): Attn = TritonMultiStageDotProductionAttention def forward(self, query : torch.Tensor, key_value : torch.Tensor, position_bias : torch.Tensor, use_cache: bool, past_key_value, project_q, project_k, project_v, attention_out, dim_head, num_heads, num_heads_kv ): batch_size = query.size(0) len_q = query.size(1) len_k = key_value.size(1) h_q = project_q(query) # (batch, len_q, num_heads * dim_head) h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head) h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head) h_q = h_q.view(batch_size, len_q, num_heads, dim_head).permute(0, 2, 1, 3) # (batch, num_heads, len_q, dim_head) h_k = h_k.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head) h_v = h_v.view(batch_size, len_k, num_heads_kv, dim_head).permute(0, 2, 1, 3) # (batch, num_heads_kv, len_k, dim_head) h_q = h_q.contiguous() # (batch * num_heads, len_q, dim_head) h_k = h_k.contiguous() # (batch * num_heads, len_k, dim_head) h_v = h_v.contiguous() # (batch * num_heads, len_k, dim_head) if past_key_value is not None: h_k = torch.cat([past_key_value[0], h_k], dim=-2) h_v = torch.cat([past_key_value[1], h_v], dim=-2) len_k += past_key_value[2] if use_cache: if len_k <= n_local + n_init: h_k_cache = h_k h_v_cache = h_v else: h_k_cache = torch.cat([h_k[:,:, :n_init, :], h_k[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2) h_v_cache = torch.cat([h_v[:,:, :n_init, :], h_v[:, :, max(0, h_k.size(-2) - n_local):, :]], dim=2) current_key_value = (h_k_cache, h_v_cache, len_k) else: current_key_value = None h_q_ = h_q h_k_ = h_k h_v_ = h_v if len_q + n_local < h_k_.size(-2): h_k_ = h_k_[:, :, h_k_.size(-2) - len_q - n_local:, :].contiguous().clone() h_v_ = h_v_[:, :, h_v_.size(-2) - len_q - n_local:, :].contiguous().clone() local_h_q, local_h_k = position_bias(h_q_, h_k_) local_h_v = h_v_ if len_k > n_local: init_h_q = position_bias.apply_rotary_pos_emb_one_angle( h_q, n_local + n_init ) init_h_k = position_bias.apply_rotary_pos_emb( h_k[:, :, :n_init, :].contiguous(), n_init, n_init, position_bias._cos_cached, position_bias._sin_cached ) init_h_v = h_v[:, :, :n_init, :].contiguous() else: init_h_q = h_q init_h_k = torch.empty( (batch_size, num_heads_kv, 0, dim_head), device=h_k.device, dtype=h_k.dtype ) init_h_v = torch.empty( (batch_size, num_heads_kv, 0, dim_head), device=h_v.device, dtype=h_v.dtype ) attn = Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device) attn.append(local_h_q, local_h_k, local_h_v, sliding_window=n_local) attn.append( init_h_q, init_h_k, init_h_v, end=True, sliding_window=(len_k - len_q, n_local), complement_sliding_window=True ) score, _ = attn.get_result() score = score.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3).contiguous() # (batch, len_q, num_heads, dim_head) score = score.reshape(batch_size, len_q, num_heads * dim_head) # (batch, len_q, num_heads * dim_head) score = attention_out(score) if use_cache: return score, current_key_value else: return score return forward