Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import time | |
import warnings | |
from importlib.metadata import version | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import transformers | |
from transformers.cache_utils import Cache, DynamicCache | |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
# https://github.com/huggingface/transformers/blob/v4.37-release/src/transformers/models/llama/modeling_llama.py | |
def llama_flash_attn2_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
# [SnapKV] register kv_cluster | |
init_snapkv(self) | |
# LlamaFlashAttention2 attention does not support output_attentions | |
if "padding_mask" in kwargs: | |
warnings.warn( | |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | |
) | |
# overwrite attention_mask with padding_mask | |
attention_mask = kwargs.pop("padding_mask") | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
# Flash attention requires the input to have the shape | |
# batch_size x seq_length x head_dim x hidden_dim | |
# therefore we just need to keep the original shape | |
query_states = query_states.view( | |
bsz, q_len, self.num_heads, self.head_dim | |
).transpose(1, 2) | |
key_states = key_states.view( | |
bsz, q_len, self.num_key_value_heads, self.head_dim | |
).transpose(1, 2) | |
value_states = value_states.view( | |
bsz, q_len, self.num_key_value_heads, self.head_dim | |
).transpose(1, 2) | |
kv_seq_len = key_states.shape[-2] | |
# if past_key_value is not None: | |
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | |
if past_key_value is not None: | |
if self.layer_idx is None: | |
raise ValueError( | |
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " | |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | |
"with a layer index." | |
) | |
if hasattr(self, "kv_seq_len"): # [SnapKV] add kv_seq_len | |
if self.kv_seq_len != 0: | |
kv_seq_len += self.kv_seq_len | |
else: | |
kv_seq_len += past_key_value.get_usable_length( | |
kv_seq_len, self.layer_idx | |
) | |
else: | |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb( | |
query_states, key_states, cos, sin, position_ids | |
) | |
# [SnapKV] move to ahead | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
if past_key_value is not None: | |
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models | |
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
# print('kv_seq_len:', kv_seq_len) | |
# print('key_states.shape:', key_states.shape) | |
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster | |
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len | |
key_states_compress, value_states_compress = self.kv_cluster.update_kv( | |
key_states, | |
query_states, | |
value_states, | |
attention_mask, | |
self.num_key_value_groups, | |
) | |
past_key_value.update( | |
key_states_compress, value_states_compress, self.layer_idx, cache_kwargs | |
) | |
else: | |
self.kv_seq_len += q_len | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs | |
) | |
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache | |
# to be able to avoid many of these transpose/reshape/view. | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | |
# therefore the input hidden states gets silently casted in float32. Hence, we need | |
# cast them back in the correct dtype just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms | |
# in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, "_pre_quantization_dtype"): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
logger.warning_once( | |
f"The input hidden states seems to be silently casted in float32, this might be related to" | |
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | |
f" {target_dtype}." | |
) | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
attn_output = self._flash_attention_forward( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
q_len, | |
dropout=dropout_rate, | |
) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
def prepare_inputs_for_generation_llama( | |
self, | |
input_ids, | |
past_key_values=None, | |
attention_mask=None, | |
inputs_embeds=None, | |
**kwargs, | |
): | |
if past_key_values is None: # [SnapKV] | |
for layer in self.model.layers: | |
layer.self_attn.kv_seq_len = 0 | |
if past_key_values is not None: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
max_cache_length = past_key_values.get_max_length() | |
else: | |
# cache_length = past_length = past_key_values[0][0].shape[2] | |
# max_cache_length = None | |
cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len | |
max_cache_length = None | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | |
if ( | |
max_cache_length is not None | |
and attention_mask is not None | |
and cache_length + input_ids.shape[1] > max_cache_length | |
): | |
attention_mask = attention_mask[:, -max_cache_length:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
llama_flash_attn2_forward_4_37 = llama_flash_attn2_forward | |
prepare_inputs_for_generation_llama_4_37 = prepare_inputs_for_generation_llama | |
def rope_forward(self, x, seq_len): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) | |
inv_freq_expanded = ( | |
self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
) | |
position_ids_expanded = position_ids[:, None, :].float() | |
# Force float32 since bfloat16 loses precision on long contexts | |
# See https://github.com/huggingface/transformers/pull/29285 | |
device_type = x.device.type | |
device_type = ( | |
device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" | |
) | |
with torch.autocast(device_type=device_type, enabled=False): | |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( | |
1, 2 | |
) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
cos = emb.cos() | |
sin = emb.sin() | |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
################## | |
# perform qk calculation and get indices | |
# this version will not update in inference mode | |
# Copied from transformers.models.llama.modeling_llama.repeat_kv | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, None, :, :].expand( | |
batch, num_key_value_heads, n_rep, slen, head_dim | |
) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
class SnapKVCluster: | |
def __init__( | |
self, | |
window_size=64, | |
max_capacity_prompt=256 + 64, | |
kernel_size=5, | |
pooling="avgpool", | |
): | |
self.window_size = window_size | |
self.max_capacity_prompt = max_capacity_prompt | |
assert self.max_capacity_prompt - self.window_size > 0 | |
self.kernel_size = kernel_size | |
self.pooling = pooling | |
def reset( | |
self, | |
window_size=64, | |
max_capacity_prompt=256 + 64, | |
kernel_size=5, | |
pooling="avgpool", | |
): | |
self.window_size = window_size | |
self.max_capacity_prompt = max_capacity_prompt | |
assert self.max_capacity_prompt - self.window_size > 0 | |
self.kernel_size = kernel_size | |
self.pooling = pooling | |
def update_kv( | |
self, | |
key_states, | |
query_states, | |
value_states, | |
attention_mask, | |
num_key_value_groups, | |
): | |
# check if prefix phase | |
assert key_states.shape[-2] == query_states.shape[-2] | |
bsz, num_heads, q_len, head_dim = query_states.shape | |
if q_len < self.max_capacity_prompt: | |
return key_states, value_states | |
else: | |
attn_weights = torch.matmul( | |
query_states[..., -self.window_size :, :], key_states.transpose(2, 3) | |
) / math.sqrt(head_dim) | |
mask = torch.full( | |
(self.window_size, self.window_size), | |
torch.finfo(attn_weights.dtype).min, | |
device=attn_weights.device, | |
) | |
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) | |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
mask = mask.to(attn_weights.device) | |
attention_mask = mask[None, None, :, :] | |
attn_weights[ | |
:, :, -self.window_size :, -self.window_size : | |
] += attention_mask | |
attn_weights = nn.functional.softmax( | |
attn_weights, dim=-1, dtype=torch.float32 | |
).to(query_states.dtype) | |
attn_weights_sum = attn_weights[ | |
:, :, -self.window_size :, : -self.window_size | |
].sum(dim=-2) | |
if self.pooling == "avgpool": | |
attn_cache = F.avg_pool1d( | |
attn_weights_sum, | |
kernel_size=self.kernel_size, | |
padding=self.kernel_size // 2, | |
stride=1, | |
) | |
elif self.pooling == "maxpool": | |
attn_cache = F.max_pool1d( | |
attn_weights_sum, | |
kernel_size=self.kernel_size, | |
padding=self.kernel_size // 2, | |
stride=1, | |
) | |
else: | |
raise ValueError("Pooling method not supported") | |
indices = attn_cache.topk( | |
self.max_capacity_prompt - self.window_size, dim=-1 | |
).indices | |
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) | |
k_past_compress = key_states[:, :, : -self.window_size, :].gather( | |
dim=2, index=indices | |
) | |
v_past_compress = value_states[:, :, : -self.window_size, :].gather( | |
dim=2, index=indices | |
) | |
k_cur = key_states[:, :, -self.window_size :, :] | |
v_cur = value_states[:, :, -self.window_size :, :] | |
key_states = torch.cat([k_past_compress, k_cur], dim=2) | |
value_states = torch.cat([v_past_compress, v_cur], dim=2) | |
return key_states, value_states | |
def init_snapkv(self): | |
if not hasattr(self, "kv_cluster"): | |
if not hasattr(self.config, "window_size"): | |
self.config.window_size = 64 | |
if not hasattr(self.config, "max_capacity_prompt"): | |
self.config.max_capacity_prompt = 4096 | |
if not hasattr(self.config, "kernel_size"): | |
self.config.kernel_size = 13 | |
if not hasattr(self.config, "pooling"): | |
self.config.pooling = "avgpool" | |
self.kv_cluster = SnapKVCluster( | |
window_size=self.config.window_size, | |
max_capacity_prompt=self.config.max_capacity_prompt, | |
kernel_size=self.config.kernel_size, | |
pooling=self.config.pooling, | |
) | |
############ | |
def check_version(): | |
try: | |
transformers_version = version("transformers") | |
except Exception as e: | |
print(f"Transformers not installed: {e}") | |
return transformers_version | |
def replace_llama(): | |
transformers_version = check_version() | |
version_list = ["4.37"] | |
warning_flag = True | |
for version in version_list: | |
if version in transformers_version: | |
warning_flag = False | |
break | |
if warning_flag: | |
warnings.warn( | |
f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}." | |
) | |
transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ( | |
prepare_inputs_for_generation_llama_4_37 | |
) | |
transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = ( | |
llama_flash_attn2_forward_4_37 | |
) | |