|
""" |
|
Subquadratic attention combining sliding window and linear attentions |
|
- Using "standard" sliding windows |
|
- Didactically computes outputs with n^2 attention weights for now |
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference |
|
|
|
For each layer: |
|
- We first compute (softmax) attention over sliding windows |
|
- We then compute standard linear attention to "fill in" the earlier parts |
|
- We combine to model the entire sequence |
|
""" |
|
from typing import List, Tuple, Optional, Callable |
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.cache_utils import Cache |
|
|
|
from .linear_attention import ( |
|
LolcatsLinearAttention, LinearAttentionState, |
|
softmax_attention |
|
) |
|
|
|
|
|
|
|
|
|
def get_masks(window_size: int, q_len: int, k_len: int, |
|
device: torch.device) -> tuple[torch.Tensor]: |
|
""" |
|
Return masks for softmax and linear attention terms |
|
-> 1 is include, 0 is ignore |
|
""" |
|
kwargs = {'device': device, 'dtype': int} |
|
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len) |
|
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size) |
|
window_mask = causal_mask - linear_mask |
|
|
|
|
|
return window_mask[None, None, ...], linear_mask[None, None, ...] |
|
|
|
|
|
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, |
|
f_q: torch.Tensor, f_k: torch.Tensor, |
|
v: torch.Tensor, |
|
window_factor: torch.Tensor, |
|
linear_factor: torch.Tensor, |
|
window_size: int, |
|
kv_state: torch.Tensor = None, |
|
k_state: torch.Tensor = None, |
|
eps: float = 1e-12, |
|
mask_value: float=-1e8): |
|
""" |
|
Hybrid attention combining sliding window and linear attentions |
|
""" |
|
|
|
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) |
|
|
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) |
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) |
|
|
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
|
|
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) |
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) |
|
sum_ln = a_ln.sum(dim=-1, keepdim=True) |
|
|
|
|
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) |
|
if kv_state is not None: |
|
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) |
|
sum_ln += linear_factor * torch.einsum( |
|
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] |
|
y = (y / (sum_sm + sum_ln)).to(q.dtype) |
|
return y, a |
|
|
|
|
|
|
|
|
|
|
|
class LolcatsSlidingWindowAttention(LolcatsLinearAttention): |
|
""" |
|
Lolcats attention combining sliding window and linear attention |
|
""" |
|
def __init__(self, |
|
window_size: int = 64, |
|
decode_window_size: int = None, |
|
affine_attention_factors: bool = False, |
|
init_window_factor: float = 0, |
|
train_window_factor: bool = True, |
|
state_grad_enabled: bool = False, |
|
**kwargs): |
|
self.window_size = window_size |
|
self.decode_window_size = ( |
|
decode_window_size if decode_window_size is not None else window_size |
|
) |
|
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} |
|
super().__init__(**kwargs) |
|
self.attention_type = kwargs['attention_type'] |
|
|
|
self.quadratic_attention = hybrid_attention_quadratic |
|
self.attention_type = kwargs['attention_type'] |
|
|
|
self.affine_attention_factors = affine_attention_factors |
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype |
|
if train_window_factor: |
|
self.window_factors = nn.Parameter( |
|
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) |
|
else: |
|
self.register_buffer( |
|
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) |
|
) |
|
|
|
self.base_inference = False |
|
self.state_grad_enabled = state_grad_enabled |
|
|
|
def forward(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Forward pass with the option to compute attention weights multiple ways |
|
if self.train_attention is True |
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models |
|
""" |
|
b, l, _ = hidden_states.size() |
|
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, |
|
position_ids, past_key_value) |
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) |
|
|
|
if self.train_attention: |
|
|
|
with torch.no_grad(): |
|
_y_true, a_true = softmax_attention(q, k, v)[:2] |
|
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) |
|
y_true = self.o_proj(y_true) |
|
|
|
|
|
|
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size) |
|
attn_weights = ((a_pred, a_true), (y_pred, _y_true)) |
|
else: |
|
attn_weights = None |
|
|
|
if past_key_value is None: |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size) |
|
attn_weights = a_pred |
|
else: |
|
past_key_value.window_size = self.decode_window_size |
|
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: |
|
assert use_cache is True |
|
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx, |
|
self.feature_map_k, |
|
dtype=q.dtype) |
|
k_cache, v_cache, f_kv_state, f_k_state = _kv |
|
|
|
|
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
|
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) |
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
|
|
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) |
|
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) |
|
sum_ln = linear_factors * torch.einsum( |
|
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] |
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
else: |
|
try: |
|
kv_state = past_key_value.kv_states[self.layer_idx] |
|
k_state = past_key_value.k_states[self.layer_idx] |
|
except IndexError: |
|
kv_state, k_state = None, None |
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, |
|
window_factors, linear_factors, |
|
window_size=self.window_size, |
|
kv_state=kv_state, |
|
k_state=k_state) |
|
|
|
|
|
|
|
|
|
past_key_value.update(k, v, self.layer_idx, |
|
fmap_key_states=f_k, |
|
accumulate_in_fp32=True) |
|
|
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) |
|
y_true = self.o_proj(y_true) |
|
return y_true, attn_weights, past_key_value |
|
|
|
|
|
class LinearAttentionSlidingWindowCache(LinearAttentionState): |
|
""" |
|
Class for `past_key_values` |
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state" |
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36) |
|
""" |
|
def __init__(self, window_size: int = 64) -> None: |
|
super().__init__() |
|
self._seen_tokens = 0 |
|
self._seen_tokens_by_layer: List[int] = [] |
|
self.kv_states: List[torch.Tensor] = [] |
|
self.k_states: List[torch.Tensor] = [] |
|
|
|
|
|
self.decode_kv_states: List[torch.Tensor] = [] |
|
self.decode_k_states: List[torch.Tensor] = [] |
|
self.k_cache: List[torch.Tensor] = [] |
|
self.v_cache: List[torch.Tensor] = [] |
|
self.window_size = window_size |
|
|
|
def update(self, key_states: torch.Tensor, value_states: torch.Tensor, |
|
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, |
|
accumulate_in_fp32: bool = False, |
|
fmap_key_states: torch.Tensor = None, |
|
grad_enabled: bool = False, |
|
**kwargs: any, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Update KV, K states; and KV cache during training |
|
- For decoding, use `self.decode_kv_states` to keep track of KV states |
|
up to sliding window terms |
|
- For (chunked) training, use `self.kv_states` to keep track of KV states |
|
up to end of sequence |
|
- Likewise for `self.decode_k_states` and `self.k_states` |
|
""" |
|
with torch.set_grad_enabled(grad_enabled): |
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
dtype = key_states.dtype |
|
if accumulate_in_fp32: |
|
|
|
fmap_key_states = fmap_key_states.float() |
|
value_states = value_states.float() |
|
|
|
|
|
decode_kv_state = torch.einsum( |
|
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size] |
|
) |
|
|
|
kv_state = decode_kv_state + torch.einsum( |
|
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:] |
|
) |
|
|
|
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) |
|
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) |
|
|
|
|
|
if len(self.k_states) <= layer_idx: |
|
self.kv_states.append(kv_state.to(dtype)) |
|
self.k_states.append(k_state.to(dtype)) |
|
|
|
self.decode_kv_states.append(decode_kv_state.to(dtype)) |
|
self.decode_k_states.append(decode_k_state.to(dtype)) |
|
|
|
self.k_cache.append(key_states[:, :, -self.window_size:, :]) |
|
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) |
|
|
|
else: |
|
|
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) |
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) |
|
self.kv_states[layer_idx] = kv_state |
|
self.k_states[layer_idx] = k_state |
|
|
|
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) |
|
+ decode_kv_state).to(dtype) |
|
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) |
|
+ decode_k_state).to(dtype) |
|
self.decode_kv_states[layer_idx] = decode_kv_state |
|
self.decode_k_states[layer_idx] = decode_k_state |
|
|
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] |
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] |
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] |
|
|
|
return self.kv_states[layer_idx], self.k_states[layer_idx] |
|
|
|
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, |
|
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): |
|
""" |
|
Update the decoding KV and K states, and KV cache, during decodeing |
|
""" |
|
with torch.no_grad(): |
|
k_cache = self.k_cache[layer_idx] |
|
v_cache = self.v_cache[layer_idx] |
|
|
|
if k_cache.shape[-2] < self.window_size: |
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) |
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
k_state = feature_map_k(k_cache[:, :, :1, :]) |
|
v_state = v_cache[:, :, :1, :] |
|
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) |
|
self.decode_kv_states[layer_idx] += kv_state |
|
self.decode_k_states[layer_idx] += k_state |
|
|
|
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) |
|
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += keys.shape[-2] |
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] |
|
return (self.k_cache[layer_idx], self.v_cache[layer_idx], |
|
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) |
|
|