Spaces:
Runtime error
Runtime error
| import sys | |
| from abc import ABC, abstractmethod | |
| from typing import ( | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| ) | |
| from collections import OrderedDict | |
| import diskcache | |
| import llama_cpp.llama | |
| from .llama_types import * | |
| class BaseLlamaCache(ABC): | |
| """Base cache class for a llama.cpp model.""" | |
| def __init__(self, capacity_bytes: int = (2 << 30)): | |
| self.capacity_bytes = capacity_bytes | |
| def cache_size(self) -> int: | |
| raise NotImplementedError | |
| def _find_longest_prefix_key( | |
| self, | |
| key: Tuple[int, ...], | |
| ) -> Optional[Tuple[int, ...]]: | |
| pass | |
| def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": | |
| raise NotImplementedError | |
| def __contains__(self, key: Sequence[int]) -> bool: | |
| raise NotImplementedError | |
| def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState") -> None: | |
| raise NotImplementedError | |
| class LlamaRAMCache(BaseLlamaCache): | |
| """Cache for a llama.cpp model using RAM.""" | |
| def __init__(self, capacity_bytes: int = (2 << 30)): | |
| super().__init__(capacity_bytes) | |
| self.capacity_bytes = capacity_bytes | |
| self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = OrderedDict() | |
| def cache_size(self): | |
| return sum([state.llama_state_size for state in self.cache_state.values()]) | |
| def _find_longest_prefix_key( | |
| self, | |
| key: Tuple[int, ...], | |
| ) -> Optional[Tuple[int, ...]]: | |
| min_len = 0 | |
| min_key = None | |
| keys = ( | |
| (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() | |
| ) | |
| for k, prefix_len in keys: | |
| if prefix_len > min_len: | |
| min_len = prefix_len | |
| min_key = k | |
| return min_key | |
| def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": | |
| key = tuple(key) | |
| _key = self._find_longest_prefix_key(key) | |
| if _key is None: | |
| raise KeyError("Key not found") | |
| value = self.cache_state[_key] | |
| self.cache_state.move_to_end(_key) | |
| return value | |
| def __contains__(self, key: Sequence[int]) -> bool: | |
| return self._find_longest_prefix_key(tuple(key)) is not None | |
| def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): | |
| key = tuple(key) | |
| if key in self.cache_state: | |
| del self.cache_state[key] | |
| self.cache_state[key] = value | |
| while self.cache_size > self.capacity_bytes and len(self.cache_state) > 0: | |
| self.cache_state.popitem(last=False) | |
| # Alias for backwards compatibility | |
| LlamaCache = LlamaRAMCache | |
| class LlamaDiskCache(BaseLlamaCache): | |
| """Cache for a llama.cpp model using disk.""" | |
| def __init__( | |
| self, cache_dir: str = ".cache/llama_cache", capacity_bytes: int = (2 << 30) | |
| ): | |
| super().__init__(capacity_bytes) | |
| self.cache = diskcache.Cache(cache_dir) | |
| def cache_size(self): | |
| return int(self.cache.volume()) # type: ignore | |
| def _find_longest_prefix_key( | |
| self, | |
| key: Tuple[int, ...], | |
| ) -> Optional[Tuple[int, ...]]: | |
| min_len = 0 | |
| min_key: Optional[Tuple[int, ...]] = None | |
| for k in self.cache.iterkeys(): # type: ignore | |
| prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) | |
| if prefix_len > min_len: | |
| min_len = prefix_len | |
| min_key = k # type: ignore | |
| return min_key | |
| def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": | |
| key = tuple(key) | |
| _key = self._find_longest_prefix_key(key) | |
| if _key is None: | |
| raise KeyError("Key not found") | |
| value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore | |
| # NOTE: This puts an integer as key in cache, which breaks, | |
| # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens | |
| # self.cache.push(_key, side="front") # type: ignore | |
| return value | |
| def __contains__(self, key: Sequence[int]) -> bool: | |
| return self._find_longest_prefix_key(tuple(key)) is not None | |
| def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): | |
| print("LlamaDiskCache.__setitem__: called", file=sys.stderr) | |
| key = tuple(key) | |
| if key in self.cache: | |
| print("LlamaDiskCache.__setitem__: delete", file=sys.stderr) | |
| del self.cache[key] | |
| self.cache[key] = value | |
| print("LlamaDiskCache.__setitem__: set", file=sys.stderr) | |
| while self.cache_size > self.capacity_bytes and len(self.cache) > 0: | |
| key_to_remove = next(iter(self.cache)) | |
| del self.cache[key_to_remove] | |
| print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) | |