SPARKNET / src /utils /cache.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
Caching utilities for SPARKNET
Provides LRU caching for LLM responses and embeddings
Following FAANG best practices for performance optimization
"""
import hashlib
import json
from typing import Any, Optional, Dict, Callable
from functools import wraps
from datetime import datetime, timedelta
from cachetools import TTLCache, LRUCache
from loguru import logger
class LLMResponseCache:
"""
Cache for LLM responses to reduce API calls and latency.
Features:
- TTL-based expiration
- LRU eviction policy
- Content-based hashing
- Statistics tracking
Example:
cache = LLMResponseCache(maxsize=1000, ttl=3600)
# Check cache
cached = cache.get(prompt, model)
if cached:
return cached
# Store result
cache.set(prompt, model, response)
"""
def __init__(
self,
maxsize: int = 1000,
ttl: int = 3600, # 1 hour default
enabled: bool = True,
):
"""
Initialize LLM response cache.
Args:
maxsize: Maximum number of cached responses
ttl: Time-to-live in seconds
enabled: Whether caching is enabled
"""
self.maxsize = maxsize
self.ttl = ttl
self.enabled = enabled
self._cache: TTLCache = TTLCache(maxsize=maxsize, ttl=ttl)
# Statistics
self._hits = 0
self._misses = 0
logger.info(f"Initialized LLMResponseCache (maxsize={maxsize}, ttl={ttl}s)")
def _hash_key(self, prompt: str, model: str, **kwargs) -> str:
"""Generate cache key from prompt and parameters."""
key_data = {
"prompt": prompt,
"model": model,
**kwargs,
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_str.encode()).hexdigest()
def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
"""
Get cached response if available.
Args:
prompt: The prompt sent to the LLM
model: Model identifier
**kwargs: Additional parameters
Returns:
Cached response or None
"""
if not self.enabled:
return None
key = self._hash_key(prompt, model, **kwargs)
result = self._cache.get(key)
if result is not None:
self._hits += 1
logger.debug(f"Cache HIT for model={model}")
else:
self._misses += 1
return result
def set(self, prompt: str, model: str, response: str, **kwargs):
"""
Store response in cache.
Args:
prompt: The prompt sent to the LLM
model: Model identifier
response: The LLM response
**kwargs: Additional parameters
"""
if not self.enabled:
return
key = self._hash_key(prompt, model, **kwargs)
self._cache[key] = response
logger.debug(f"Cached response for model={model}")
def invalidate(self, prompt: str, model: str, **kwargs):
"""Invalidate a specific cache entry."""
key = self._hash_key(prompt, model, **kwargs)
self._cache.pop(key, None)
def clear(self):
"""Clear all cached entries."""
self._cache.clear()
logger.info("LLM response cache cleared")
@property
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0
return {
"hits": self._hits,
"misses": self._misses,
"total": total,
"hit_rate": f"{hit_rate:.1f}%",
"size": len(self._cache),
"maxsize": self.maxsize,
"enabled": self.enabled,
}
class EmbeddingCache:
"""
Cache for text embeddings to avoid recomputation.
Uses LRU policy with configurable size.
Embeddings are stored as lists of floats.
"""
def __init__(self, maxsize: int = 10000, enabled: bool = True):
"""
Initialize embedding cache.
Args:
maxsize: Maximum number of cached embeddings
enabled: Whether caching is enabled
"""
self.maxsize = maxsize
self.enabled = enabled
self._cache: LRUCache = LRUCache(maxsize=maxsize)
self._hits = 0
self._misses = 0
logger.info(f"Initialized EmbeddingCache (maxsize={maxsize})")
def _hash_key(self, text: str, model: str) -> str:
"""Generate cache key from text and model."""
key_str = f"{model}:{text}"
return hashlib.sha256(key_str.encode()).hexdigest()
def get(self, text: str, model: str) -> Optional[list]:
"""Get cached embedding if available."""
if not self.enabled:
return None
key = self._hash_key(text, model)
result = self._cache.get(key)
if result is not None:
self._hits += 1
else:
self._misses += 1
return result
def set(self, text: str, model: str, embedding: list):
"""Store embedding in cache."""
if not self.enabled:
return
key = self._hash_key(text, model)
self._cache[key] = embedding
def get_batch(self, texts: list, model: str) -> tuple:
"""
Get cached embeddings for a batch of texts.
Returns:
Tuple of (cached_results, uncached_indices)
"""
results = {}
uncached = []
for i, text in enumerate(texts):
cached = self.get(text, model)
if cached is not None:
results[i] = cached
else:
uncached.append(i)
return results, uncached
def set_batch(self, texts: list, model: str, embeddings: list):
"""Store batch of embeddings."""
for text, embedding in zip(texts, embeddings):
self.set(text, model, embedding)
@property
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.1f}%",
"size": len(self._cache),
"maxsize": self.maxsize,
}
def cached_llm_call(cache: LLMResponseCache):
"""
Decorator for caching LLM function calls.
Example:
@cached_llm_call(llm_cache)
async def generate_response(prompt: str, model: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(prompt: str, model: str, **kwargs):
# Check cache
cached = cache.get(prompt, model, **kwargs)
if cached is not None:
return cached
# Call function
result = await func(prompt, model, **kwargs)
# Cache result
cache.set(prompt, model, result, **kwargs)
return result
@wraps(func)
def sync_wrapper(prompt: str, model: str, **kwargs):
# Check cache
cached = cache.get(prompt, model, **kwargs)
if cached is not None:
return cached
# Call function
result = func(prompt, model, **kwargs)
# Cache result
cache.set(prompt, model, result, **kwargs)
return result
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator
# Global cache instances
_llm_cache: Optional[LLMResponseCache] = None
_embedding_cache: Optional[EmbeddingCache] = None
def get_llm_cache() -> LLMResponseCache:
"""Get or create the global LLM response cache."""
global _llm_cache
if _llm_cache is None:
_llm_cache = LLMResponseCache()
return _llm_cache
def get_embedding_cache() -> EmbeddingCache:
"""Get or create the global embedding cache."""
global _embedding_cache
if _embedding_cache is None:
_embedding_cache = EmbeddingCache()
return _embedding_cache