Spaces:
Running
Running
| # utils/cache_manager.py | |
| """ | |
| Intelligent response caching system with semantic similarity detection | |
| Reduces LLM costs and improves response times | |
| """ | |
| import json | |
| import hashlib | |
| import pickle | |
| import time | |
| from typing import Any, Dict, Optional | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sentence_transformers import SentenceTransformer | |
| class ResponseCache: | |
| """ | |
| Advanced caching system with semantic similarity matching | |
| Caches LLM responses and similar queries to avoid redundant API calls | |
| """ | |
| def __init__(self, cache_file: str = "./data/cache/response_cache.db", similarity_threshold: float = 0.85): | |
| self.cache_file = cache_file | |
| self.similarity_threshold = similarity_threshold | |
| self.cache_data = {} | |
| self.embedding_model = None | |
| self._load_cache() | |
| self._initialize_embedding_model() | |
| def _initialize_embedding_model(self): | |
| """Initialize embedding model for semantic similarity""" | |
| try: | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("✅ Semantic cache embedding model loaded") | |
| except Exception as e: | |
| print(f"⚠️ Could not load embedding model: {e}. Using exact match caching.") | |
| self.embedding_model = None | |
| def _load_cache(self): | |
| """Load cache from disk""" | |
| try: | |
| with open(self.cache_file, 'rb') as f: | |
| self.cache_data = pickle.load(f) | |
| print(f"✅ Cache loaded with {len(self.cache_data)} entries") | |
| except (FileNotFoundError, EOFError): | |
| self.cache_data = {} | |
| print("🆕 Starting with empty cache") | |
| def _save_cache(self): | |
| """Save cache to disk""" | |
| try: | |
| import os | |
| os.makedirs(os.path.dirname(self.cache_file), exist_ok=True) | |
| with open(self.cache_file, 'wb') as f: | |
| pickle.dump(self.cache_data, f) | |
| except Exception as e: | |
| print(f"❌ Could not save cache: {e}") | |
| def _generate_cache_key(self, prompt: str, provider: str, temperature: float) -> str: | |
| """Generate deterministic cache key""" | |
| content = f"{prompt}_{provider}_{temperature}" | |
| return hashlib.md5(content.encode()).hexdigest() | |
| def _get_semantic_embedding(self, text: str) -> np.ndarray: | |
| """Get semantic embedding for text""" | |
| if self.embedding_model is None: | |
| return None | |
| return self.embedding_model.encode([text])[0] | |
| def _find_similar_cached_response(self, prompt: str, provider: str, temperature: float) -> Optional[Dict]: | |
| """Find semantically similar cached responses""" | |
| if self.embedding_model is None or not self.cache_data: | |
| return None | |
| prompt_embedding = self._get_semantic_embedding(prompt) | |
| if prompt_embedding is None: | |
| return None | |
| best_match = None | |
| best_similarity = 0 | |
| for cache_key, cache_entry in self.cache_data.items(): | |
| if cache_entry['provider'] != provider: | |
| continue | |
| # Check if cached embedding exists | |
| if 'embedding' not in cache_entry: | |
| # Generate embedding for existing cache entries | |
| cache_entry['embedding'] = self._get_semantic_embedding(cache_entry['prompt']) | |
| if cache_entry['embedding'] is None: | |
| continue | |
| # Calculate similarity | |
| similarity = cosine_similarity( | |
| [prompt_embedding], | |
| [cache_entry['embedding']] | |
| )[0][0] | |
| if similarity > best_similarity and similarity >= self.similarity_threshold: | |
| best_similarity = similarity | |
| best_match = cache_entry | |
| if best_match: | |
| print(f"🎯 Semantic cache hit: similarity {best_similarity:.3f}") | |
| return best_match | |
| return None | |
| def get(self, prompt: str, provider: str, temperature: float = 0.1) -> Optional[str]: | |
| """ | |
| Get cached response for prompt | |
| Returns None if no cache hit | |
| """ | |
| # First try exact match | |
| cache_key = self._generate_cache_key(prompt, provider, temperature) | |
| if cache_key in self.cache_data: | |
| entry = self.cache_data[cache_key] | |
| # Check if cache is still valid (24 hour TTL) | |
| if datetime.now() - entry['timestamp'] < timedelta(hours=24): | |
| print(f"✅ Exact cache hit for {provider}") | |
| return entry['response'] | |
| else: | |
| # Remove expired entry | |
| del self.cache_data[cache_key] | |
| # Try semantic similarity match | |
| similar_entry = self._find_similar_cached_response(prompt, provider, temperature) | |
| if similar_entry: | |
| if datetime.now() - similar_entry['timestamp'] < timedelta(hours=24): | |
| return similar_entry['response'] | |
| return None | |
| def set(self, prompt: str, response: str, provider: str, temperature: float = 0.1): | |
| """Cache a response""" | |
| cache_key = self._generate_cache_key(prompt, provider, temperature) | |
| cache_entry = { | |
| 'prompt': prompt, | |
| 'response': response, | |
| 'provider': provider, | |
| 'temperature': temperature, | |
| 'timestamp': datetime.now(), | |
| 'embedding': self._get_semantic_embedding(prompt) | |
| } | |
| self.cache_data[cache_key] = cache_entry | |
| # Limit cache size (keep most recent 1000 entries) | |
| if len(self.cache_data) > 1000: | |
| # Remove oldest entries | |
| sorted_entries = sorted(self.cache_data.items(), | |
| key=lambda x: x[1]['timestamp']) | |
| for key, _ in sorted_entries[:-1000]: | |
| del self.cache_data[key] | |
| self._save_cache() | |
| print(f"💾 Cached response from {provider}") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics""" | |
| total_entries = len(self.cache_data) | |
| if total_entries == 0: | |
| return {"total_entries": 0} | |
| now = datetime.now() | |
| recent_entries = sum(1 for entry in self.cache_data.values() | |
| if now - entry['timestamp'] < timedelta(hours=1)) | |
| providers = {} | |
| for entry in self.cache_data.values(): | |
| provider = entry['provider'] | |
| providers[provider] = providers.get(provider, 0) + 1 | |
| return { | |
| "total_entries": total_entries, | |
| "recent_entries_1h": recent_entries, | |
| "providers_distribution": providers, | |
| "cache_file": self.cache_file, | |
| "semantic_caching": self.embedding_model is not None | |
| } | |
| def clear_expired(self, max_age_hours: int = 24): | |
| """Clear expired cache entries""" | |
| now = datetime.now() | |
| expired_keys = [ | |
| key for key, entry in self.cache_data.items() | |
| if now - entry['timestamp'] > timedelta(hours=max_age_hours) | |
| ] | |
| for key in expired_keys: | |
| del self.cache_data[key] | |
| self._save_cache() | |
| print(f"🧹 Cleared {len(expired_keys)} expired cache entries") | |
| # Cached LLM Provider Wrapper | |
| class CachedLLMProvider: | |
| """Wrapper that adds caching to any LLM provider""" | |
| def __init__(self, llm_provider, cache_manager: ResponseCache): | |
| self.llm_provider = llm_provider | |
| self.cache_manager = cache_manager | |
| def generate(self, prompt: str, system_message: str = None, **kwargs) -> str: | |
| """Generate with caching""" | |
| full_prompt = prompt | |
| if system_message: | |
| full_prompt = f"{system_message}\n\n{prompt}" | |
| provider_name = self.llm_provider.get_provider_name() | |
| temperature = kwargs.get('temperature', 0.1) | |
| # Try cache first | |
| cached_response = self.cache_manager.get(full_prompt, provider_name, temperature) | |
| if cached_response: | |
| return cached_response | |
| # Generate fresh response | |
| response = self.llm_provider.generate(prompt, system_message, **kwargs) | |
| # Cache the response | |
| self.cache_manager.set(full_prompt, response, provider_name, temperature) | |
| return response | |
| def get_provider_name(self) -> str: | |
| return f"Cached-{self.llm_provider.get_provider_name()}" | |
| # Quick test | |
| def test_cache_system(): | |
| """Test the caching system""" | |
| print("🧪 Testing Cache System") | |
| print("=" * 50) | |
| cache = ResponseCache("./data/test_cache.db") | |
| # Test cache operations | |
| test_prompt = "Explain machine learning in simple terms." | |
| test_response = "Machine learning is a subset of AI that enables computers to learn from data." | |
| cache.set(test_prompt, test_response, "test-provider", 0.1) | |
| # Test retrieval | |
| retrieved = cache.get(test_prompt, "test-provider", 0.1) | |
| print(f"✅ Cache test: {retrieved == test_response}") | |
| # Test stats | |
| stats = cache.get_stats() | |
| print(f"📊 Cache stats: {stats}") | |
| if __name__ == "__main__": | |
| test_cache_system() |