| |
| |
|
|
| import contextlib |
| import contextvars |
| import hashlib |
| import json |
| import os |
| import threading |
| import time |
| from collections import OrderedDict |
| from dataclasses import dataclass |
| from typing import Iterable, Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from kimodo.sanitize import sanitize_texts |
|
|
| _ACTIVE_SESSION = contextvars.ContextVar("kimodo_demo_active_session", default=None) |
|
|
|
|
| @dataclass |
| class CacheStats: |
| hits: int = 0 |
| misses: int = 0 |
| disk_hits: int = 0 |
|
|
|
|
| class EmbeddingCache: |
| """Disk-backed text embedding cache with a small in-memory LRU.""" |
|
|
| def __init__( |
| self, |
| *, |
| model_name: str, |
| encoder_id: str, |
| base_dir: Optional[str] = None, |
| max_mem_entries: int = 128, |
| ) -> None: |
| cache_root = base_dir or os.environ.get( |
| "kimodo_EMBED_CACHE_DIR", |
| os.path.join("~", ".cache", "kimodo_demo", "embeddings"), |
| ) |
| self.base_dir = os.path.expanduser(cache_root) |
| self.model_name = model_name |
| self.encoder_id = encoder_id |
| self.max_mem_entries = max_mem_entries |
| self.stats = CacheStats() |
|
|
| self._lock = threading.Lock() |
| self._mem_cache: OrderedDict[str, np.ndarray] = OrderedDict() |
| self._index = {} |
| self._index_loaded = False |
|
|
| def _model_dir(self) -> str: |
| return os.path.join(self.base_dir, self.model_name) |
|
|
| def _index_path(self) -> str: |
| return os.path.join(self._model_dir(), "index.json") |
|
|
| def _prewarm_marker_path(self, key: str) -> str: |
| return os.path.join(self._model_dir(), f"prewarm_{key}.json") |
|
|
| def has_prewarm_marker(self, key: str) -> bool: |
| return os.path.exists(self._prewarm_marker_path(key)) |
|
|
| def write_prewarm_marker(self, key: str, *, prompt_count: int) -> None: |
| os.makedirs(self._model_dir(), exist_ok=True) |
| payload = {"prompt_count": prompt_count, "updated_at": time.time()} |
| tmp_path = f"{self._prewarm_marker_path(key)}.tmp" |
| with open(tmp_path, "w", encoding="utf-8") as f: |
| json.dump(payload, f) |
| os.replace(tmp_path, self._prewarm_marker_path(key)) |
|
|
| def _load_index(self) -> None: |
| if self._index_loaded: |
| return |
| index_path = self._index_path() |
| if os.path.exists(index_path): |
| try: |
| with open(index_path, "r", encoding="utf-8") as f: |
| self._index = json.load(f) |
| except json.JSONDecodeError: |
| self._index = {} |
| self._index_loaded = True |
|
|
| def _save_index(self) -> None: |
| os.makedirs(self._model_dir(), exist_ok=True) |
| tmp_path = f"{self._index_path()}.tmp" |
| with open(tmp_path, "w", encoding="utf-8") as f: |
| json.dump(self._index, f) |
| os.replace(tmp_path, self._index_path()) |
|
|
| def _make_key(self, text: str) -> str: |
| key_src = f"{self.model_name}|{self.encoder_id}|{text}" |
| return hashlib.sha256(key_src.encode("utf-8")).hexdigest() |
|
|
| def _entry_path(self, key: str) -> str: |
| return os.path.join(self._model_dir(), f"{key}.npy") |
|
|
| def _mem_get(self, key: str) -> Optional[np.ndarray]: |
| if key in self._mem_cache: |
| self._mem_cache.move_to_end(key) |
| return self._mem_cache[key] |
| return None |
|
|
| def _mem_put(self, key: str, value: np.ndarray) -> None: |
| self._mem_cache[key] = value |
| self._mem_cache.move_to_end(key) |
| while len(self._mem_cache) > self.max_mem_entries: |
| self._mem_cache.popitem(last=False) |
|
|
| def _disk_load(self, key: str) -> Optional[np.ndarray]: |
| path = self._entry_path(key) |
| if not os.path.exists(path): |
| return None |
| try: |
| return np.load(path) |
| except Exception: |
| return None |
|
|
| def _disk_save(self, key: str, value: np.ndarray) -> None: |
| os.makedirs(self._model_dir(), exist_ok=True) |
| np.save(self._entry_path(key), value) |
| self._index[key] = { |
| "length": int(value.shape[0]), |
| "dtype": str(value.dtype), |
| "updated_at": time.time(), |
| } |
|
|
| def _maybe_use_session_cache(self, texts: list[str]): |
| session = _ACTIVE_SESSION.get() |
| if session is None: |
| return None |
| if session.last_prompt_texts == texts and session.last_prompt_embeddings is not None: |
| return session.last_prompt_embeddings, session.last_prompt_lengths |
| return None |
|
|
| def _update_session_cache(self, texts: list[str], tensor: torch.Tensor, lengths: list[int]) -> None: |
| session = _ACTIVE_SESSION.get() |
| if session is None: |
| return |
| session.last_prompt_texts = texts |
| session.last_prompt_embeddings = tensor |
| session.last_prompt_lengths = lengths |
|
|
| def get_or_encode(self, texts: Iterable[str], encoder): |
| if isinstance(texts, str): |
| texts = [texts] |
| texts = sanitize_texts(list(texts)) |
| if len(texts) == 0: |
| empty = torch.empty() |
| return empty, [] |
|
|
| session_cache = self._maybe_use_session_cache(texts) |
| if session_cache is not None: |
| return session_cache |
|
|
| arrays: list[Optional[np.ndarray]] = [None] * len(texts) |
| lengths: list[int] = [0] * len(texts) |
| misses: list[tuple[int, str, str]] = [] |
|
|
| with self._lock: |
| self._load_index() |
| for idx, text in enumerate(texts): |
| key = self._make_key(text) |
| cached = self._mem_get(key) |
| if cached is not None: |
| arrays[idx] = cached |
| lengths[idx] = cached.shape[0] |
| self.stats.hits += 1 |
| continue |
|
|
| cached = self._disk_load(key) |
| if cached is not None: |
| arrays[idx] = cached |
| lengths[idx] = cached.shape[0] |
| self._mem_put(key, cached) |
| self.stats.disk_hits += 1 |
| continue |
|
|
| misses.append((idx, text, key)) |
| self.stats.misses += 1 |
|
|
| if misses: |
| miss_texts = [text for _, text, _ in misses] |
| miss_tensor, miss_lengths = encoder(miss_texts) |
| miss_tensor = miss_tensor.detach().cpu() |
| miss_tensor_np = miss_tensor.numpy() |
|
|
| with self._lock: |
| self._load_index() |
| for miss_idx, length in enumerate(miss_lengths): |
| idx, _text, key = misses[miss_idx] |
| arr = miss_tensor_np[miss_idx, :length].copy() |
| arrays[idx] = arr |
| lengths[idx] = int(length) |
| self._mem_put(key, arr) |
| self._disk_save(key, arr) |
| self._save_index() |
|
|
| max_len = max(lengths) if lengths else 0 |
| feat_dim = arrays[0].shape[-1] if arrays[0] is not None else 0 |
| dtype = arrays[0].dtype if arrays[0] is not None else np.float32 |
| padded = np.zeros((len(texts), max_len, feat_dim), dtype=dtype) |
| for idx, arr in enumerate(arrays): |
| if arr is None: |
| continue |
| padded[idx, : arr.shape[0]] = arr |
|
|
| result = torch.from_numpy(padded) |
| self._update_session_cache(texts, result, lengths) |
| return result, lengths |
|
|
|
|
| class CachedTextEncoder: |
| """Wrapper around a text encoder to add disk-backed caching.""" |
|
|
| def __init__(self, encoder, *, model_name: str, base_dir: Optional[str] = None): |
| self.encoder = encoder |
| self.model_name = model_name |
| encoder_id = f"{type(encoder).__name__}" |
| self.cache = EmbeddingCache(model_name=model_name, encoder_id=encoder_id, base_dir=base_dir) |
|
|
| def __call__(self, texts): |
| return self.cache.get_or_encode(texts, self.encoder) |
|
|
| def prewarm(self, texts) -> None: |
| if isinstance(texts, str): |
| texts = [texts] |
| texts = sanitize_texts(list(texts)) |
| prewarm_key = hashlib.sha256("|".join(texts).encode("utf-8")).hexdigest() |
| if self.cache.has_prewarm_marker(prewarm_key): |
| return |
| self.cache.get_or_encode(texts, self.encoder) |
| self.cache.write_prewarm_marker(prewarm_key, prompt_count=len(texts)) |
|
|
| def to(self, device=None, dtype=None): |
| if hasattr(self.encoder, "to"): |
| self.encoder.to(device=device, dtype=dtype) |
| return self |
|
|
| @contextlib.contextmanager |
| def session_context(self, session): |
| token = _ACTIVE_SESSION.set(session) |
| try: |
| yield |
| finally: |
| _ACTIVE_SESSION.reset(token) |
|
|
| def __getattr__(self, name): |
| return getattr(self.encoder, name) |
|
|