tinyllama-lckv-w10-ft-250b / cache_utils.py
whynlp's picture
Upload LCKVLlamaForCausalLM
61c49cc verified
raw
history blame
23.4 kB
from typing import Any, Dict, List, Optional, Tuple
import torch
from transformers.cache_utils import Cache, DynamicCache, SinkCache
from .utils import LayerTypeParser
class IndexedCache(Cache):
"""
Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache
assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure.
"""
build_position_ids_based_on_cache = False
def __init__(self) -> None:
super().__init__()
self.key_cache: Dict[int, torch.Tensor] = {}
self.value_cache: Dict[int, torch.Tensor] = {}
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._update = True # to prevent the cache from updating when inference with iterations
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx in self.key_cache:
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in sorted(self.key_cache.keys()):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers that compute KVs in the model.
"""
return len(self.key_cache)
@property
def min_layer(self) -> int:
return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
def is_min_layer(self, layer_idx: int) -> bool:
return self.min_layer is None or self.min_layer == layer_idx
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if self.is_min_layer(layer_idx):
self._seen_tokens += key_states.shape[-2]
# Retrieve the cache
if layer_idx not in self.key_cache:
new_key_states = key_states
new_value_states = value_states
else:
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
# Update the cache
if self._update:
self.key_cache[layer_idx] = new_key_states
self.value_cache[layer_idx] = new_value_states
return new_key_states, new_value_states
def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if layer_idx is None:
layer_idx = self.min_layer
# TODO: deprecate this function in favor of `cache_position`
is_empty_layer = (
(len(self.key_cache) == 0) # no cache in any layer
or (layer_idx not in self.key_cache) # skipped `layer_idx` and hasn't run a layer with cache after it
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
return layer_seq_length
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length."""
return None
@classmethod
def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache":
"""Converts a dynamic cache into an equivalent `IndexedCache`."""
cache = cls(*args, **kwargs)
cache._seen_tokens = dynamic_cache._seen_tokens
for layer_idx in range(len(dynamic_cache.key_cache)):
key_states, value_states = dynamic_cache[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class IndexedSinkCache(Cache):
"""
This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by
layer index, similar to the `IndexedCache` class.
"""
build_position_ids_based_on_cache = True
def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None:
super().__init__()
self.key_cache: Dict[int, torch.Tensor] = {}
self.value_cache: Dict[int, torch.Tensor] = {}
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._update = True # to prevent the cache from updating when inference with iterations
@staticmethod
def _rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_key_rotary_pos_emb(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
return rotated_key_states
def _get_rerotation_cos_sin(
self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if offset not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + offset :]
shifted_cos = cos[self.num_sink_tokens : -offset]
original_sin = sin[self.num_sink_tokens + offset :]
shifted_sin = sin[self.num_sink_tokens : -offset]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[offset] = (
rerotation_cos.to(dtype).unsqueeze(0),
rerotation_sin.to(dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[offset]
@property
def min_layer(self) -> int:
return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
def is_min_layer(self, layer_idx: int) -> bool:
return self.min_layer is None or self.min_layer == layer_idx
def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if layer_idx is None:
layer_idx = self.min_layer
if layer_idx not in self.key_cache:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted.
Return:
A tuple containing the updated key and value states.
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
using_rope = cos is not None and sin is not None
# Update the number of seen tokens
if self.is_min_layer(layer_idx):
self._seen_tokens += key_states.shape[-2]
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and self.is_min_layer(layer_idx):
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]:
self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0)
# [bsz, num_heads, seq_len, head_dim]
if layer_idx not in self.key_cache:
# Empty cache
new_key_states = key_states
new_value_states = value_states
else:
# Growing cache
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if self._update:
self.key_cache[layer_idx] = new_key_states
self.value_cache[layer_idx] = new_value_states
# If the cache is full, we need to shift the cache
if (seq_length := self.get_seq_length(layer_idx)) > self.window_length:
# Shifting cache
keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
seq_length - self.window_length,
key_states.dtype,
self._cos_cache[:seq_length],
self._sin_cache[:seq_length],
)
if partial_rotation_size is not None:
keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size],
keys_to_keep[..., partial_rotation_size:],
)
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
if partial_rotation_size is not None:
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
# Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2)
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2)
return new_key_states, new_value_states
@classmethod
def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache":
"""Converts a dynamic cache into an equivalent `IndexedCache`."""
cache = cls(*args, **kwargs)
cache.window_length = sink_cache.window_length
cache.num_sink_tokens = sink_cache.num_sink_tokens
cache._seen_tokens = sink_cache._seen_tokens
cache._cos_cache = sink_cache._cos_cache
cache._sin_cache = sink_cache._sin_cache
cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache
for layer_idx in range(len(sink_cache.key_cache)):
cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx]
cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx]
return cache
class IndexedSlidingWindowCache(IndexedCache):
"""
Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer
a subclass of `StaticCache` as it is dynamic.
"""
build_position_ids_based_on_cache = False
def __init__(self, sliding_window: int = None) -> None:
super().__init__()
self.sliding_window = sliding_window
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
# Update the number of seen tokens
if self.is_min_layer(layer_idx):
self._seen_tokens += key_states.shape[-2]
# [bsz, num_heads, seq_len, head_dim]
if layer_idx not in self.key_cache:
# Empty cache
new_key_states = key_states
new_value_states = value_states
else:
# Growing cache
new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if self._update:
self.key_cache[layer_idx] = new_key_states
self.value_cache[layer_idx] = new_value_states
# If the cache is full, we need to shift the cache
if self.get_seq_length(layer_idx) > self.sliding_window:
self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :]
self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :]
return new_key_states, new_value_states
def get_max_length(self) -> Optional[int]:
return self.sliding_window
@classmethod
def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache":
"""This is to override the `from_cache` method in the `IndexedCache` class."""
cache = cls(*args, **kwargs)
cache._seen_tokens = sliding_window_cache._seen_tokens
cache.sliding_window = sliding_window_cache.sliding_window
for layer_idx in range(len(sliding_window_cache.key_cache)):
cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx]
cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx]
return cache
class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache):
"""
Hybrid Cache class to be used for models that alternate between a local sliding window attention and global
attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for
sliding window attention and ["IndexedCache"] for global attention.
"""
build_position_ids_based_on_cache = False
def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None:
super().__init__(sliding_window=sliding_window)
self.parser = parser
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
if self.parser[layer_idx].use_sliding_window:
return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
else:
return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
def get_max_length(self) -> Optional[int]:
return IndexedCache.get_max_length(self)
@classmethod
def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache":
"""This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class."""
cache = cls(*args, **kwargs)
cache._seen_tokens = hybrid_cache._seen_tokens
cache.sliding_window = hybrid_cache.sliding_window
cache.parser = hybrid_cache.parser
for layer_idx in range(len(hybrid_cache.key_cache)):
cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx]
cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx]
return cache
class LayerCache(torch.nn.Module):
"""
A cache for storing the key-value pairs for layers.
"""
def __init__(self) -> None:
"""
The placeholder is used to expand the key-value pairs if the layer attends to the top layers.
Size: (batch_size, num_key_value_heads, 1, head_dim)
"""
super().__init__()
self.key_layer_cache: Dict[int, torch.Tensor] = {}
self.value_layer_cache: Dict[int, torch.Tensor] = {}
self.layer_type = None
self.placeholder = None
def setup(self, placeholder: torch.Tensor):
"""setup the cache, calling this function is necessary if there is a layer that attends to the top layers"""
self.placeholder = placeholder
def initialize(self, parser: LayerTypeParser, sequence_length: int):
"""initialize the cache"""
layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top}
if layers_to_init:
b, h, _, d = self.placeholder.size()
init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d))
for layer_idx in layers_to_init:
self.layer_append(layer_idx, init_kvs, init_kvs)
def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = self.key_layer_cache.get(layer_idx, None)
value_states = self.value_layer_cache.get(layer_idx, None)
if zerofill:
if key_states is None:
key_states = self.placeholder
value_states = self.placeholder
else:
key_states = torch.cat([self.placeholder, key_states], dim=2)
value_states = torch.cat([self.placeholder, value_states], dim=2)
return key_states, value_states
def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
self.key_layer_cache[layer_idx] = key
self.value_layer_cache[layer_idx] = value
def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
if layer_idx not in self.key_layer_cache:
self.key_layer_cache[layer_idx] = key
self.value_layer_cache[layer_idx] = value
else:
self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2)
self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2)
class LayerIndexedCache(LayerCache, IndexedCache):
"""
A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache.
"""
def __init__(self) -> None:
LayerCache.__init__(self)
IndexedCache.__init__(self)
class LayerIndexedSinkCache(LayerCache, IndexedSinkCache):
"""
A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache.
"""
def __init__(self) -> None:
LayerCache.__init__(self)
IndexedSinkCache.__init__(self)
class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache):
"""
A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache.
"""
def __init__(self) -> None:
LayerCache.__init__(self)
IndexedSlidingWindowCache.__init__(self)
class LayerIndexedHybridCache(LayerCache, IndexedHybridCache):
"""
A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache.
"""
def __init__(self) -> None:
LayerCache.__init__(self)
IndexedHybridCache.__init__(self)
class AutoLayerCache(torch.nn.Module):
"""
AutoLayerCache is a module that automatically creates a cache from an existing cache.
"""
CACHE_MAPPING = {
DynamicCache: LayerIndexedCache,
SinkCache: LayerIndexedSinkCache,
IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache,
IndexedHybridCache: LayerIndexedHybridCache,
}
def __init__(self, *args, **kwargs):
raise RuntimeError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_cache(cache)` method."
)
@classmethod
def from_cache(cls, cache: Cache, *args, **kwargs):
"""
Create a new cache from an existing cache. The new cache will have the same type as the original cache.
"""
cache_type = type(cache)
if cache_type not in cls.CACHE_MAPPING:
raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.")
cache_class = cls.CACHE_MAPPING[cache_type]
if hasattr(cache_class, "from_cache"):
return cache_class.from_cache(cache, *args, **kwargs)
else:
# we init an empty cache and copy the attributes
new_cache = cache_class(*args, **kwargs)
new_cache.__dict__.update(cache.__dict__)
return new_cache