|
from typing import List, Tuple, Optional, Any, Dict |
|
import torch |
|
from transformers.cache_utils import Cache |
|
|
|
class FgateDynamicCache(Cache): |
|
""" |
|
A cache that grows dynamically as more tokens are generated. This is the default for generative models. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache |
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
|
|
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
|
|
|
>>> # Prepare a cache class and pass it to model's forward |
|
>>> past_key_values = DynamicCache() |
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
|
>>> outputs.past_key_values # access cache filled with key/values from generation |
|
DynamicCache() |
|
``` |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
self.log_fgate_cache: List[torch.Tensor] = [] |
|
|
|
self.key_shift_cache: List[torch.Tensor] = [] |
|
self.value_shift_cache: List[torch.Tensor] = [] |
|
|
|
self._seen_tokens = 0 |
|
|
|
def update_shift_cache( |
|
self, |
|
key_shift_state: torch.Tensor, |
|
value_shift_state: torch.Tensor, |
|
layer_idx, |
|
): |
|
assert layer_idx == len(self.key_shift_cache) == len(self.value_shift_cache) |
|
self.key_shift_cache.append(key_shift_state) |
|
self.value_shift_cache.append(value_shift_state) |
|
|
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: |
|
""" |
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the |
|
sequence length. |
|
""" |
|
if layer_idx < len(self): |
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) |
|
else: |
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") |
|
|
|
def __iter__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over |
|
keys and values |
|
""" |
|
for layer_idx in range(len(self)): |
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]) |
|
|
|
def __len__(self): |
|
""" |
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds |
|
to the number of layers in the model. |
|
""" |
|
return len(self.key_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
log_fgate_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
assert log_fgate_states.ndim == 3, f"log_fgate must be (B, H, T), but get {log_fgate_states.size()}" |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if len(self.key_cache) <= layer_idx: |
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
self.log_fgate_cache.append(log_fgate_states) |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
self.log_fgate_cache[layer_idx] = torch.cat([self.log_fgate_cache[layer_idx], log_fgate_states], dim=-1) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx] |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
if len(self.key_cache) <= layer_idx: |
|
return 0 |
|
return self.key_cache[layer_idx].shape[-2] |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" |
|
return None |
|
|
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: |
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for |
|
backward compatibility.""" |
|
legacy_cache = () |
|
for layer_idx in range(len(self)): |
|
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.log_fgate_cache[layer_idx]),) |
|
return legacy_cache |
|
|
|
@classmethod |
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_layers: Optional[int] = None) -> "DynamicCache": |
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for |
|
backward compatibility.""" |
|
raise NotImplementedError |
|
assert num_layers is not None |
|
cache = cls(num_layers) |
|
if past_key_values is not None: |
|
for layer_idx in range(len(past_key_values)): |
|
key_states, value_states, log_fgate_states = past_key_values[layer_idx] |
|
cache.update(key_states, value_states, log_fgate_states, layer_idx) |
|
return cache |
|
|
|
def crop(self, max_length: int): |
|
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be |
|
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" |
|
|
|
if max_length < 0: |
|
max_length = self.get_seq_length() - abs(max_length) |
|
|
|
if self.get_seq_length() <= max_length: |
|
return |
|
|
|
self._seen_tokens = max_length |
|
for idx in range(len(self.key_cache)): |
|
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] |
|
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] |
|
self.log_fgate_cache[idx] = self.log_fgate_cache[idx][..., :max_length] |
|
|
|
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: |
|
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by |
|
`_split_model_inputs()` in `generation.utils`""" |
|
out = [] |
|
for i in range(0, full_batch_size, split_size): |
|
current_split = DynamicCache() |
|
current_split._seen_tokens = self._seen_tokens |
|
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] |
|
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] |
|
current_split.log_fgate_cache = [tensor[i : i + split_size] for tensor in self.log_fgate_cache] |
|
out.append(current_split) |
|
return out |
|
|
|
@classmethod |
|
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": |
|
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in |
|
`generation.utils`""" |
|
cache = cls() |
|
for idx in range(len(splits[0])): |
|
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) |
|
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) |
|
layer_log_fgates = torch.cat([current.log_fgate_cache[idx] for current in splits], dim=0) |
|
cache.update(layer_keys, layer_values, layer_log_fgates, idx) |
|
return cache |
|
|
|
def batch_repeat_interleave(self, repeats: int): |
|
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" |
|
for layer_idx in range(len(self)): |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) |
|
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx].repeat_interleave(repeats, dim=0) |
|
|
|
def batch_select_indices(self, indices: torch.Tensor): |
|
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" |
|
for layer_idx in range(len(self)): |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] |
|
self.log_fgate_cache[layer_idx] = self.log_fgate_cache[layer_idx][indices, ...] |
|
|