| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from copy import deepcopy |
| | from typing import Optional, Dict, Any, Tuple |
| |
|
| | import torch |
| | from transformers.cache_utils import Cache |
| |
|
| | from .configuration_decilm import DeciLMConfig |
| | from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2, SinkCache, StaticCache, SlidingWindowCache |
| |
|
| |
|
| | class VariableCache(Cache_4_44_2, Cache): |
| | """ |
| | A Cache object that supports a different Cache implementation for every layer, |
| | including layers without any kv-cache. |
| | Implemented using a list of Cache objects, each represents a "model" with 1 layer. |
| | The default implementation for the layer caches is StaticCache. |
| | The cache of each layer is allocated to the same gpu as the layer itself. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | config: DeciLMConfig, |
| | batch_size: int = None, |
| | max_cache_len: int = None, |
| | dtype: torch.dtype = torch.float32, |
| | max_batch_size: Optional[int] = None, |
| | **kwargs, |
| | ) -> None: |
| | Cache_4_44_2.__init__(self) |
| |
|
| | self.config = deepcopy(config) |
| | self.max_batch_size = batch_size or max_batch_size |
| | self.batch_size = self.max_batch_size |
| | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
| | self.dtype = dtype |
| |
|
| | self.layer_caches: list[Cache_4_44_2 | None] = [None] * config.num_hidden_layers |
| | self.layer_devices: list[torch.device | None] = [None] * config.num_hidden_layers |
| |
|
| | def update( |
| | self, |
| | key_states: torch.Tensor, |
| | value_states: torch.Tensor, |
| | layer_idx: int, |
| | cache_kwargs: Optional[Dict[str, Any]] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.layer_caches[layer_idx] is None: |
| | self.layer_devices[layer_idx] = key_states.device |
| | self._init_layer_cache(layer_idx) |
| |
|
| | layer_cache = self.layer_caches[layer_idx] |
| | assert layer_cache is not None, f"Trying to update the cache of a cache-less layer: {layer_idx=}" |
| |
|
| | k_out, v_out = layer_cache.update(key_states=key_states, |
| | value_states=value_states, |
| | layer_idx=0, |
| | cache_kwargs=cache_kwargs) |
| | seq_len = self.get_seq_length(layer_idx) |
| | k_out = k_out[:, :, :seq_len, :] |
| | v_out = v_out[:, :, :seq_len, :] |
| | return k_out, v_out |
| |
|
| | def _init_layer_cache(self, layer_idx: int) -> None: |
| | block_config = self.config.block_configs[layer_idx] |
| | attention_config = block_config.attention |
| |
|
| | if attention_config.no_op or attention_config.replace_with_linear: |
| | return None |
| |
|
| | device = self.layer_devices[layer_idx] |
| | assert device is not None, f"Trying to init layer cache for {layer_idx=} without device" |
| |
|
| | config = deepcopy(self.config) |
| | config.num_hidden_layers = 1 |
| | config.num_key_value_heads = self.config.num_attention_heads // attention_config.n_heads_in_group |
| |
|
| | if attention_config.window_length is not None: |
| | if not attention_config.is_sink: |
| | config.sliding_window = attention_config.window_length |
| | self.layer_caches[layer_idx] = SlidingWindowCache(config=config, |
| | max_batch_size=self.max_batch_size, |
| | max_cache_len=self.max_cache_len, |
| | device=device, |
| | dtype=self.dtype) |
| | return |
| | elif not attention_config.unshifted_sink: |
| | self.layer_caches[layer_idx] = SinkCache(window_length=attention_config.window_length, |
| | num_sink_tokens=attention_config.num_sink_tokens) |
| | return |
| |
|
| | self.layer_caches[layer_idx] = StaticCache(config=config, |
| | max_batch_size=self.max_batch_size, |
| | max_cache_len=self.max_cache_len, |
| | device=device, |
| | dtype=self.dtype) |
| |
|
| | def _get_first_real_cache(self) -> Cache: |
| | for layer_cache in self.layer_caches: |
| | if layer_cache is not None: |
| | return layer_cache |
| | raise ValueError(f"No real cache found, all layer caches are None.") |
| |
|
| | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| | if layer_idx == 0 and self.layer_caches[0] is None: |
| | try: |
| | layer_cache = self._get_first_real_cache() |
| | except ValueError: |
| | return 0 |
| | else: |
| | layer_cache = self.layer_caches[layer_idx] |
| | return layer_cache.get_seq_length() |
| |
|
| | def get_max_length(self) -> Optional[int]: |
| | """Returns the maximum sequence length of the cached states.""" |
| | return self.max_cache_len |
| |
|
| | def reset(self): |
| | for layer_idx in range(len(self.layer_caches)): |
| | layer_cache = self.layer_caches[layer_idx] |
| | if hasattr(layer_cache, "reset"): |
| | layer_cache.reset() |
| | else: |
| | self._init_layer_cache(layer_idx) |
| |
|