| | """ |
| | Explicit KV Cache management for efficient inference. |
| | This is critical for Qualcomm deployment and agent control loops. |
| | """ |
| |
|
| | import torch |
| | from typing import Optional, Tuple |
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class KVCache: |
| | """Key-Value cache for transformer inference. |
| | |
| | Layout: [num_layers, batch_size, num_heads, max_seq_len, head_dim] |
| | |
| | This explicit cache enables: |
| | - Efficient autoregressive decoding |
| | - Cache offloading for memory management |
| | - Sliding window attention (future) |
| | - Agent control loops with cache manipulation |
| | """ |
| |
|
| | key_cache: torch.Tensor |
| | value_cache: torch.Tensor |
| | seq_len: int |
| |
|
| | @classmethod |
| | def create( |
| | cls, |
| | num_layers: int, |
| | batch_size: int, |
| | num_heads: int, |
| | max_seq_len: int, |
| | head_dim: int, |
| | dtype: torch.dtype = torch.float16, |
| | device: torch.device = None, |
| | ) -> "KVCache": |
| | """Create an empty KV cache. |
| | |
| | Args: |
| | num_layers: Number of transformer layers |
| | batch_size: Batch size |
| | num_heads: Number of attention heads |
| | max_seq_len: Maximum sequence length |
| | head_dim: Dimension per attention head |
| | dtype: Data type for cache tensors |
| | device: Device to create cache on |
| | |
| | Returns: |
| | Initialized KVCache with zero tensors |
| | """ |
| | shape = (num_layers, batch_size, num_heads, max_seq_len, head_dim) |
| |
|
| | key_cache = torch.zeros(shape, dtype=dtype, device=device) |
| | value_cache = torch.zeros(shape, dtype=dtype, device=device) |
| |
|
| | return cls(key_cache=key_cache, value_cache=value_cache, seq_len=0) |
| |
|
| | def update( |
| | self, |
| | layer_idx: int, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | position: int, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Update cache for a specific layer and return full K, V. |
| | |
| | Args: |
| | layer_idx: Index of the transformer layer |
| | key: New key tensor [batch, heads, seq_len, head_dim] |
| | value: New value tensor [batch, heads, seq_len, head_dim] |
| | position: Starting position for the new tokens |
| | |
| | Returns: |
| | Tuple of (full_key, full_value) including cached values |
| | """ |
| | seq_len = key.shape[2] |
| | end_pos = position + seq_len |
| |
|
| | |
| | self.key_cache[layer_idx, :, :, position:end_pos, :] = key |
| | self.value_cache[layer_idx, :, :, position:end_pos, :] = value |
| |
|
| | |
| | self.seq_len = max(self.seq_len, end_pos) |
| |
|
| | |
| | return ( |
| | self.key_cache[layer_idx, :, :, :end_pos, :], |
| | self.value_cache[layer_idx, :, :, :end_pos, :], |
| | ) |
| |
|
| | def get( |
| | self, |
| | layer_idx: int, |
| | end_pos: Optional[int] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Get cached K, V for a specific layer. |
| | |
| | Args: |
| | layer_idx: Index of the transformer layer |
| | end_pos: End position (defaults to current seq_len) |
| | |
| | Returns: |
| | Tuple of (key, value) tensors |
| | """ |
| | if end_pos is None: |
| | end_pos = self.seq_len |
| |
|
| | return ( |
| | self.key_cache[layer_idx, :, :, :end_pos, :], |
| | self.value_cache[layer_idx, :, :, :end_pos, :], |
| | ) |
| |
|
| | def reset(self): |
| | """Reset the cache to empty state.""" |
| | self.key_cache.zero_() |
| | self.value_cache.zero_() |
| | self.seq_len = 0 |
| |
|
| | @property |
| | def memory_usage_mb(self) -> float: |
| | """Calculate memory usage in megabytes.""" |
| | total_bytes = self.key_cache.numel() * self.key_cache.element_size() |
| | total_bytes += self.value_cache.numel() * self.value_cache.element_size() |
| | return total_bytes / (1024 * 1024) |
| |
|