lolcats / src /model /linear_attention /linear_window_attention_sw.py
ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
Subquadratic attention combining sliding window and linear attentions
- Using "standard" sliding windows
- Didactically computes outputs with n^2 attention weights for now
- Copied + adapted from linear_window_attention_tk.py for single-file reference
For each layer:
- We first compute (softmax) attention over sliding windows
- We then compute standard linear attention to "fill in" the earlier parts
- We combine to model the entire sequence
"""
from typing import List, Tuple, Optional, Callable
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache
from .linear_attention import (
LolcatsLinearAttention, LinearAttentionState,
softmax_attention
)
# ----------------------
# Sliding window helpers
# ----------------------
def get_masks(window_size: int, q_len: int, k_len: int,
device: torch.device) -> tuple[torch.Tensor]:
"""
Return masks for softmax and linear attention terms
-> 1 is include, 0 is ignore
"""
kwargs = {'device': device, 'dtype': int}
causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len)
linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size)
window_mask = causal_mask - linear_mask
# Return softmax mask (window), linear attention mask
# -> shapes broadcast over (b, h, q_len, k_len)
return window_mask[None, None, ...], linear_mask[None, None, ...]
def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor,
f_q: torch.Tensor, f_k: torch.Tensor,
v: torch.Tensor,
window_factor: torch.Tensor,
linear_factor: torch.Tensor,
window_size: int,
kv_state: torch.Tensor = None,
k_state: torch.Tensor = None,
eps: float = 1e-12,
mask_value: float=-1e8):
"""
Hybrid attention combining sliding window and linear attentions
"""
mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device)
# 1. Sliding window (softmax attention)
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5)
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# 2. Under window (linear attention)
a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float())
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
sum_ln = a_ln.sum(dim=-1, keepdim=True)
# 3. Combine
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
# Allow outputs to also depend on prior kv_state and k_state
y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float())
if kv_state is not None: # Combine with prior kv_state and k_state
y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())
sum_ln += linear_factor * torch.einsum(
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None]
y = (y / (sum_sm + sum_ln)).to(q.dtype)
return y, a # attention weights only for the last chunk
# ---------------------
# Attention layer class
# ---------------------
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
"""
Lolcats attention combining sliding window and linear attention
"""
def __init__(self,
window_size: int = 64,
decode_window_size: int = None,
affine_attention_factors: bool = False,
init_window_factor: float = 0,
train_window_factor: bool = True,
state_grad_enabled: bool = False,
**kwargs):
self.window_size = window_size
self.decode_window_size = (
decode_window_size if decode_window_size is not None else window_size
)
self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1}
super().__init__(**kwargs)
self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_sw'
# Determine how we compute attentions
self.quadratic_attention = hybrid_attention_quadratic
self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_sw'
# Learnable factor for combining attentions
self.affine_attention_factors = affine_attention_factors
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
if train_window_factor:
self.window_factors = nn.Parameter(
init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype))
else:
self.register_buffer(
"window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
)
# Whether we use original flash attention 2 inference (use during attention transfer)
self.base_inference = False
self.state_grad_enabled = state_grad_enabled
def forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Forward pass with the option to compute attention weights multiple ways
if self.train_attention is True
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
"""
b, l, _ = hidden_states.size()
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask,
position_ids, past_key_value)
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap
if self.train_attention:
# 1. Compute "ground-truth" attention output and weights
with torch.no_grad():
_y_true, a_true = softmax_attention(q, k, v)[:2]
y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
# 2. Compute "predicted" attention outputs
# compute attn weights under sliding window
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
window_factors, linear_factors,
window_size=self.window_size)
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
else:
attn_weights = None
# attention_mask = None # For now this is always True
if past_key_value is None: # Regular training
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v,
window_factors, linear_factors,
window_size=self.window_size)
attn_weights = a_pred
else:
past_key_value.window_size = self.decode_window_size
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
assert use_cache is True
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx,
self.feature_map_k,
dtype=q.dtype)
k_cache, v_cache, f_kv_state, f_k_state = _kv
# Sliding window + linear attention decode
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
# Softmax attention terms
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5)
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
sum_sm = a_sm.sum(dim=-1, keepdim=True)
# Combine with linear attention terms
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float())
+ linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float()))
sum_ln = linear_factors * torch.einsum(
'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None]
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
else: # Stateful training
try:
kv_state = past_key_value.kv_states[self.layer_idx]
k_state = past_key_value.k_states[self.layer_idx]
except IndexError:
kv_state, k_state = None, None
window_factors = F.sigmoid(self.window_factors)
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v,
window_factors, linear_factors,
window_size=self.window_size,
kv_state=kv_state,
k_state=k_state)
# Save and update KV cache and states
# past_key_value.update(k, v.detach(), self.layer_idx,
# fmap_key_states=f_k.detach(),
# accumulate_in_fp32=True)
past_key_value.update(k, v, self.layer_idx,
fmap_key_states=f_k,
accumulate_in_fp32=True)
# Concatenate heads and apply output projection
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
y_true = self.o_proj(y_true)
return y_true, attn_weights, past_key_value
class LinearAttentionSlidingWindowCache(LinearAttentionState):
"""
Class for `past_key_values`
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
"""
def __init__(self, window_size: int = 64) -> None:
super().__init__()
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
self._seen_tokens_by_layer: List[int] = []
self.kv_states: List[torch.Tensor] = []
self.k_states: List[torch.Tensor] = []
# Account for sliding windows
self.decode_kv_states: List[torch.Tensor] = []
self.decode_k_states: List[torch.Tensor] = []
self.k_cache: List[torch.Tensor] = []
self.v_cache: List[torch.Tensor] = []
self.window_size = window_size
def update(self, key_states: torch.Tensor, value_states: torch.Tensor,
layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None,
accumulate_in_fp32: bool = False,
fmap_key_states: torch.Tensor = None, # should not be None
grad_enabled: bool = False,
**kwargs: any,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update KV, K states; and KV cache during training
- For decoding, use `self.decode_kv_states` to keep track of KV states
up to sliding window terms
- For (chunked) training, use `self.kv_states` to keep track of KV states
up to end of sequence
- Likewise for `self.decode_k_states` and `self.k_states`
"""
with torch.set_grad_enabled(grad_enabled):
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
dtype = key_states.dtype
if accumulate_in_fp32:
# key_states = key_states.float()
fmap_key_states = fmap_key_states.float()
value_states = value_states.float()
# Decoding KV state (KV terms up to last window_size)
decode_kv_state = torch.einsum(
'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size]
)
# KV state
kv_state = decode_kv_state + torch.einsum(
'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:]
)
# shape is b, h, 1, f; note the 1
decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True)
k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True))
# Update the cache
if len(self.k_states) <= layer_idx: # Initializing kv and k states
self.kv_states.append(kv_state.to(dtype))
self.k_states.append(k_state.to(dtype))
self.decode_kv_states.append(decode_kv_state.to(dtype))
self.decode_k_states.append(decode_k_state.to(dtype))
self.k_cache.append(key_states[:, :, -self.window_size:, :])
self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype))
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
else:
# Update kv and k states recurrently
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype)
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype)
self.kv_states[layer_idx] = kv_state
self.k_states[layer_idx] = k_state
decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype)
+ decode_kv_state).to(dtype)
decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype)
+ decode_k_state).to(dtype)
self.decode_kv_states[layer_idx] = decode_kv_state
self.decode_k_states[layer_idx] = decode_k_state
self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :]
self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :]
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
return self.kv_states[layer_idx], self.k_states[layer_idx]
def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor,
layer_idx: int, feature_map_k: Callable, dtype: torch.dtype):
"""
Update the decoding KV and K states, and KV cache, during decodeing
"""
with torch.no_grad():
k_cache = self.k_cache[layer_idx]
v_cache = self.v_cache[layer_idx]
if k_cache.shape[-2] < self.window_size: # build window-size cache
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
else:
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
# else:
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
k_state = feature_map_k(k_cache[:, :, :1, :])
v_state = v_cache[:, :, :1, :]
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d
self.decode_kv_states[layer_idx] += kv_state
self.decode_k_states[layer_idx] += k_state
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2)
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2)
if layer_idx == 0:
self._seen_tokens += keys.shape[-2]
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
return (self.k_cache[layer_idx], self.v_cache[layer_idx],
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx])