|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
from typing import Optional, Dict, Any, Tuple |
|
|
|
import torch |
|
from transformers.cache_utils import Cache |
|
|
|
from .configuration_decilm import DeciLMConfig, AttentionConfig |
|
from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, StaticCache |
|
|
|
|
|
class VariableCache(Cache_4_44_2, Cache): |
|
""" |
|
A Cache object that supports a different Cache implementation for every layer, |
|
including layers without any kv-cache. |
|
Implemented using a list of Cache objects, each represents a "model" with 1 layer. |
|
The default implementation for the layer caches is StaticCache. |
|
The cache of each layer is allocated to the same gpu as the layer itself. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
*, |
|
config: DeciLMConfig, |
|
batch_size: int = None, |
|
max_cache_len: int = None, |
|
dtype: torch.dtype = torch.float32, |
|
max_batch_size: Optional[int] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
Cache_4_44_2.__init__(self) |
|
|
|
self.config = deepcopy(config) |
|
self.max_batch_size = batch_size or max_batch_size |
|
self.batch_size = self.max_batch_size |
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
|
self.dtype = dtype |
|
|
|
self.layer_caches: list[Cache | None] = [None] * config.num_hidden_layers |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
layer_cache = self.layer_caches[layer_idx] |
|
|
|
if layer_cache is None: |
|
block_config = self.config.block_configs[layer_idx] |
|
layer_cache = self._init_layer_cache(attention_config=block_config.attention, device=key_states.device) |
|
assert layer_cache is not None, "Trying to update the cache of a cache-less layer" |
|
self.layer_caches[layer_idx] = layer_cache |
|
|
|
k_out, v_out = layer_cache.update(key_states=key_states, |
|
value_states=value_states, |
|
layer_idx=0, |
|
cache_kwargs=cache_kwargs) |
|
seq_len = self.get_seq_length(layer_idx) |
|
k_out = k_out[:, :, :seq_len, :] |
|
v_out = v_out[:, :, :seq_len, :] |
|
return k_out, v_out |
|
|
|
def _init_layer_cache(self, |
|
attention_config: AttentionConfig, |
|
device: torch.device, |
|
) -> Cache | None: |
|
if attention_config.no_op or attention_config.replace_with_linear: |
|
return None |
|
config = deepcopy(self.config) |
|
config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group |
|
return StaticCache(config, self.max_batch_size, self.max_cache_len, device, self.dtype) |
|
|
|
def _get_first_real_cache(self) -> Cache: |
|
for layer_cache in self.layer_caches: |
|
if layer_cache is not None: |
|
return layer_cache |
|
raise ValueError(f"No real cache found, all layer caches are None.") |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
if layer_idx == 0 and self.layer_caches[0] is None: |
|
try: |
|
layer_cache = self._get_first_real_cache() |
|
except ValueError: |
|
return 0 |
|
else: |
|
layer_cache = self.layer_caches[layer_idx] |
|
return layer_cache.get_seq_length() |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.max_cache_len |
|
|
|
def reset(self): |
|
for layer_cache in self.layer_caches: |
|
if hasattr(layer_cache, "reset"): |
|
layer_cache.reset() |
|
|