|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
|
|
from .configuration_utils import PretrainedConfig |
|
from .utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class Cache: |
|
""" |
|
Base, abstract class for all caches. The actual data structure is specific to each subclass. |
|
""" |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
|
cache to be created. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
raise NotImplementedError("Make sure to implement `update` in a subclass.") |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states, if there is any.""" |
|
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
@property |
|
def seen_tokens(self): |
|
logger.warning_once( |
|
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " |
|
"model input instead." |
|
) |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
|
|
class DynamicCache(Cache): |
|
""" |
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self._seen_tokens = 0 |
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.key_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
key_states_fp: torch.Tensor = None, |
|
value_states_fp: torch.Tensor = None, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[2] |
|
|
|
window_length = cache_kwargs['window_length'] if isinstance(cache_kwargs, dict) and 'window_length' in cache_kwargs else 32 |
|
|
|
if key_states_fp is None or value_states_fp is None: |
|
if len(self.key_cache) <= layer_idx: |
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
else: |
|
layer_idx *= 2 |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
self.key_cache.append(key_states) |
|
self.key_cache.append(key_states_fp[:, :, -window_length:, :]) |
|
self.value_cache.append(value_states) |
|
self.value_cache.append(value_states_fp[:, :, -window_length:, :]) |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-1) |
|
key_states_fp = torch.cat([self.key_cache[layer_idx+1], key_states_fp], dim=2) |
|
self.key_cache[layer_idx+1] = key_states_fp[:, :, -window_length:, :] |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-1) |
|
value_states_fp = torch.cat([self.value_cache[layer_idx+1], value_states_fp], dim=2) |
|
self.value_cache[layer_idx+1] = value_states_fp[:, :, -window_length:, :] |
|
|
|
return self.key_cache[layer_idx], key_states_fp, self.value_cache[layer_idx], value_states_fp |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-1] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.key_cache)): |
|
device = self.key_cache[layer_idx].device |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
device = self.value_cache[layer_idx].device |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: |
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" |
|
legacy_cache = () |
|
for layer_idx in range(len(self)): |
|
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) |
|
return legacy_cache |
|
|
|
@classmethod |
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None) -> "DynamicCache": |
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" |
|
cache = cls() |
|
if past_key_values is not None: |
|
for layer_idx in range(len(past_key_values)): |
|
key_states, value_states = past_key_values[layer_idx] |
|
cache.update(key_states, value_states, layer_idx) |
|
return cache |
|
|
|
|
|
class SinkCache(Cache): |
|
""" |
|
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
|
generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
|
tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
|
|
Parameters: |
|
window_length (`int`): |
|
The length of the context window. |
|
num_sink_tokens (`int`): |
|
The number of sink tokens. See the original paper for more information. |
|
""" |
|
|
|
def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self.window_length = window_length |
|
self.num_sink_tokens = num_sink_tokens |
|
self.cos_sin_cache = {} |
|
self._seen_tokens = 0 |
|
|
|
@staticmethod |
|
def _rotate_half(x): |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
def _apply_key_rotary_pos_emb( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> torch.Tensor: |
|
rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
|
return rotated_key_states |
|
|
|
def _get_rerotation_cos_sin( |
|
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if key_states.shape[-2] not in self.cos_sin_cache: |
|
|
|
cos = cos.to(torch.float32) |
|
sin = sin.to(torch.float32) |
|
|
|
|
|
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
|
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
|
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
|
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
|
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
|
self.cos_sin_cache[key_states.shape[-2]] = ( |
|
rerotation_cos.to(key_states.dtype).unsqueeze(0), |
|
rerotation_sin.to(key_states.dtype).unsqueeze(0), |
|
) |
|
return self.cos_sin_cache[key_states.shape[-2]] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.window_length |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
|
`cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
|
rotation as the tokens are shifted. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
|
|
sin = cache_kwargs.get("sin") |
|
cos = cache_kwargs.get("cos") |
|
partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
|
using_rope = cos is not None and sin is not None |
|
|
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
|
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
|
|
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
|
|
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
else: |
|
|
|
keys_to_keep = self.key_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
|
] |
|
|
|
|
|
if using_rope: |
|
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
|
key_states, cos[: self.window_length], sin[: self.window_length] |
|
) |
|
if partial_rotation_size is not None: |
|
keys_to_keep, keys_pass = ( |
|
keys_to_keep[..., :partial_rotation_size], |
|
keys_to_keep[..., partial_rotation_size:], |
|
) |
|
keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
|
if partial_rotation_size is not None: |
|
keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
|
|
|
sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
|
|
|
sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
|
values_to_keep = self.value_cache[layer_idx][ |
|
:, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
|
] |
|
self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.key_cache)): |
|
device = self.key_cache[layer_idx].device |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
device = self.value_cache[layer_idx].device |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
|
|
class StaticCache(Cache): |
|
""" |
|
Static Cache class to be used with `torch.compile(model)`. |
|
|
|
Parameters: |
|
config (`PretrainedConfig): |
|
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` |
|
required to initialize the static cache. |
|
max_batch_size (`int`): |
|
The maximum batch size with which the model will be used. |
|
max_cache_len (`int`): |
|
The maximum sequence length with which the model will be used. |
|
device (`torch.device`): |
|
The device on which the cache should be initialized. Should be the same as the layer. |
|
dtype (*optional*, defaults to `torch.float32`): |
|
The default `dtype` to use when initializing the layer. |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: |
|
super().__init__() |
|
self.max_batch_size = max_batch_size |
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
|
|
|
self.head_dim = ( |
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
|
) |
|
|
|
self.dtype = dtype if dtype is not None else torch.float32 |
|
self.num_key_value_heads = ( |
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
|
) |
|
|
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) |
|
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
It is VERY important to index using a tensor, otherwise you introduce a copy to the device. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. Kept for backward compatibility |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` |
|
to know how much of the cache it should overwrite. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
new_cache_positions = cache_kwargs.get("cache_position") |
|
k_out = self.key_cache |
|
v_out = self.value_cache |
|
|
|
k_out[:, :, new_cache_positions] = key_states |
|
v_out[:, :, new_cache_positions] = value_states |
|
|
|
return k_out, v_out |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" |
|
|
|
|
|
|
|
|
|
return (self.key_cache[0, 0].any(dim=-1)).sum() |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return self.max_cache_len |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
device = self.key_cache.device |
|
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) |
|
device = self.value_cache.device |
|
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) |
|
|
|
def to_legacy_cache(self): |
|
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" |
|
return None |
|
|