Spaces:
Paused
Paused
| """ | |
| Redis Semantic Cache implementation for LiteLLM | |
| The RedisSemanticCache provides semantic caching functionality using Redis as a backend. | |
| This cache stores responses based on the semantic similarity of prompts rather than | |
| exact matching, allowing for more flexible caching of LLM responses. | |
| This implementation uses RedisVL's SemanticCache to find semantically similar prompts | |
| and their cached responses. | |
| """ | |
| import ast | |
| import asyncio | |
| import json | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple, cast | |
| import litellm | |
| from litellm._logging import print_verbose | |
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
| get_str_from_messages, | |
| ) | |
| from litellm.types.utils import EmbeddingResponse | |
| from .base_cache import BaseCache | |
| class RedisSemanticCache(BaseCache): | |
| """ | |
| Redis-backed semantic cache for LLM responses. | |
| This cache uses vector similarity to find semantically similar prompts that have been | |
| previously sent to the LLM, allowing for cache hits even when prompts are not identical | |
| but carry similar meaning. | |
| """ | |
| DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index" | |
| def __init__( | |
| self, | |
| host: Optional[str] = None, | |
| port: Optional[str] = None, | |
| password: Optional[str] = None, | |
| redis_url: Optional[str] = None, | |
| similarity_threshold: Optional[float] = None, | |
| embedding_model: str = "text-embedding-ada-002", | |
| index_name: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the Redis Semantic Cache. | |
| Args: | |
| host: Redis host address | |
| port: Redis port | |
| password: Redis password | |
| redis_url: Full Redis URL (alternative to separate host/port/password) | |
| similarity_threshold: Threshold for semantic similarity (0.0 to 1.0) | |
| where 1.0 requires exact matches and 0.0 accepts any match | |
| embedding_model: Model to use for generating embeddings | |
| index_name: Name for the Redis index | |
| ttl: Default time-to-live for cache entries in seconds | |
| **kwargs: Additional arguments passed to the Redis client | |
| Raises: | |
| Exception: If similarity_threshold is not provided or required Redis | |
| connection information is missing | |
| """ | |
| from redisvl.extensions.llmcache import SemanticCache | |
| from redisvl.utils.vectorize import CustomTextVectorizer | |
| if index_name is None: | |
| index_name = self.DEFAULT_REDIS_INDEX_NAME | |
| print_verbose(f"Redis semantic-cache initializing index - {index_name}") | |
| # Validate similarity threshold | |
| if similarity_threshold is None: | |
| raise ValueError("similarity_threshold must be provided, passed None") | |
| # Store configuration | |
| self.similarity_threshold = similarity_threshold | |
| # Convert similarity threshold [0,1] to distance threshold [0,2] | |
| # For cosine distance: 0 = most similar, 2 = least similar | |
| # While similarity: 1 = most similar, 0 = least similar | |
| self.distance_threshold = 1 - similarity_threshold | |
| self.embedding_model = embedding_model | |
| # Set up Redis connection | |
| if redis_url is None: | |
| try: | |
| # Attempt to use provided parameters or fallback to environment variables | |
| host = host or os.environ["REDIS_HOST"] | |
| port = port or os.environ["REDIS_PORT"] | |
| password = password or os.environ["REDIS_PASSWORD"] | |
| except KeyError as e: | |
| # Raise a more informative exception if any of the required keys are missing | |
| missing_var = e.args[0] | |
| raise ValueError( | |
| f"Missing required Redis configuration: {missing_var}. " | |
| f"Provide {missing_var} or redis_url." | |
| ) from e | |
| redis_url = f"redis://:{password}@{host}:{port}" | |
| print_verbose(f"Redis semantic-cache redis_url: {redis_url}") | |
| # Initialize the Redis vectorizer and cache | |
| cache_vectorizer = CustomTextVectorizer(self._get_embedding) | |
| self.llmcache = SemanticCache( | |
| name=index_name, | |
| redis_url=redis_url, | |
| vectorizer=cache_vectorizer, | |
| distance_threshold=self.distance_threshold, | |
| overwrite=False, | |
| ) | |
| def _get_ttl(self, **kwargs) -> Optional[int]: | |
| """ | |
| Get the TTL (time-to-live) value for cache entries. | |
| Args: | |
| **kwargs: Keyword arguments that may contain a custom TTL | |
| Returns: | |
| Optional[int]: The TTL value in seconds, or None if no TTL should be applied | |
| """ | |
| ttl = kwargs.get("ttl") | |
| if ttl is not None: | |
| ttl = int(ttl) | |
| return ttl | |
| def _get_embedding(self, prompt: str) -> List[float]: | |
| """ | |
| Generate an embedding vector for the given prompt using the configured embedding model. | |
| Args: | |
| prompt: The text to generate an embedding for | |
| Returns: | |
| List[float]: The embedding vector | |
| """ | |
| # Create an embedding from prompt | |
| embedding_response = cast( | |
| EmbeddingResponse, | |
| litellm.embedding( | |
| model=self.embedding_model, | |
| input=prompt, | |
| cache={"no-store": True, "no-cache": True}, | |
| ), | |
| ) | |
| embedding = embedding_response["data"][0]["embedding"] | |
| return embedding | |
| def _get_cache_logic(self, cached_response: Any) -> Any: | |
| """ | |
| Process the cached response to prepare it for use. | |
| Args: | |
| cached_response: The raw cached response | |
| Returns: | |
| The processed cache response, or None if input was None | |
| """ | |
| if cached_response is None: | |
| return cached_response | |
| # Convert bytes to string if needed | |
| if isinstance(cached_response, bytes): | |
| cached_response = cached_response.decode("utf-8") | |
| # Convert string representation to Python object | |
| try: | |
| cached_response = json.loads(cached_response) | |
| except json.JSONDecodeError: | |
| try: | |
| cached_response = ast.literal_eval(cached_response) | |
| except (ValueError, SyntaxError) as e: | |
| print_verbose(f"Error parsing cached response: {str(e)}") | |
| return None | |
| return cached_response | |
| def set_cache(self, key: str, value: Any, **kwargs) -> None: | |
| """ | |
| Store a value in the semantic cache. | |
| Args: | |
| key: The cache key (not directly used in semantic caching) | |
| value: The response value to cache | |
| **kwargs: Additional arguments including 'messages' for the prompt | |
| and optional 'ttl' for time-to-live | |
| """ | |
| print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}") | |
| value_str: Optional[str] = None | |
| try: | |
| # Extract the prompt from messages | |
| messages = kwargs.get("messages", []) | |
| if not messages: | |
| print_verbose("No messages provided for semantic caching") | |
| return | |
| prompt = get_str_from_messages(messages) | |
| value_str = str(value) | |
| # Get TTL and store in Redis semantic cache | |
| ttl = self._get_ttl(**kwargs) | |
| if ttl is not None: | |
| self.llmcache.store(prompt, value_str, ttl=int(ttl)) | |
| else: | |
| self.llmcache.store(prompt, value_str) | |
| except Exception as e: | |
| print_verbose( | |
| f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}" | |
| ) | |
| def get_cache(self, key: str, **kwargs) -> Any: | |
| """ | |
| Retrieve a semantically similar cached response. | |
| Args: | |
| key: The cache key (not directly used in semantic caching) | |
| **kwargs: Additional arguments including 'messages' for the prompt | |
| Returns: | |
| The cached response if a semantically similar prompt is found, else None | |
| """ | |
| print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}") | |
| try: | |
| # Extract the prompt from messages | |
| messages = kwargs.get("messages", []) | |
| if not messages: | |
| print_verbose("No messages provided for semantic cache lookup") | |
| return None | |
| prompt = get_str_from_messages(messages) | |
| # Check the cache for semantically similar prompts | |
| results = self.llmcache.check(prompt=prompt) | |
| # Return None if no similar prompts found | |
| if not results: | |
| return None | |
| # Process the best matching result | |
| cache_hit = results[0] | |
| vector_distance = float(cache_hit["vector_distance"]) | |
| # Convert vector distance back to similarity score | |
| # For cosine distance: 0 = most similar, 2 = least similar | |
| # While similarity: 1 = most similar, 0 = least similar | |
| similarity = 1 - vector_distance | |
| cached_prompt = cache_hit["prompt"] | |
| cached_response = cache_hit["response"] | |
| print_verbose( | |
| f"Cache hit: similarity threshold: {self.similarity_threshold}, " | |
| f"actual similarity: {similarity}, " | |
| f"current prompt: {prompt}, " | |
| f"cached prompt: {cached_prompt}" | |
| ) | |
| return self._get_cache_logic(cached_response=cached_response) | |
| except Exception as e: | |
| print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}") | |
| async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: | |
| """ | |
| Asynchronously generate an embedding for the given prompt. | |
| Args: | |
| prompt: The text to generate an embedding for | |
| **kwargs: Additional arguments that may contain metadata | |
| Returns: | |
| List[float]: The embedding vector | |
| """ | |
| from litellm.proxy.proxy_server import llm_model_list, llm_router | |
| # Route the embedding request through the proxy if appropriate | |
| router_model_names = ( | |
| [m["model_name"] for m in llm_model_list] | |
| if llm_model_list is not None | |
| else [] | |
| ) | |
| try: | |
| if llm_router is not None and self.embedding_model in router_model_names: | |
| # Use the router for embedding generation | |
| user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") | |
| embedding_response = await llm_router.aembedding( | |
| model=self.embedding_model, | |
| input=prompt, | |
| cache={"no-store": True, "no-cache": True}, | |
| metadata={ | |
| "user_api_key": user_api_key, | |
| "semantic-cache-embedding": True, | |
| "trace_id": kwargs.get("metadata", {}).get("trace_id", None), | |
| }, | |
| ) | |
| else: | |
| # Generate embedding directly | |
| embedding_response = await litellm.aembedding( | |
| model=self.embedding_model, | |
| input=prompt, | |
| cache={"no-store": True, "no-cache": True}, | |
| ) | |
| # Extract and return the embedding vector | |
| return embedding_response["data"][0]["embedding"] | |
| except Exception as e: | |
| print_verbose(f"Error generating async embedding: {str(e)}") | |
| raise ValueError(f"Failed to generate embedding: {str(e)}") from e | |
| async def async_set_cache(self, key: str, value: Any, **kwargs) -> None: | |
| """ | |
| Asynchronously store a value in the semantic cache. | |
| Args: | |
| key: The cache key (not directly used in semantic caching) | |
| value: The response value to cache | |
| **kwargs: Additional arguments including 'messages' for the prompt | |
| and optional 'ttl' for time-to-live | |
| """ | |
| print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}") | |
| try: | |
| # Extract the prompt from messages | |
| messages = kwargs.get("messages", []) | |
| if not messages: | |
| print_verbose("No messages provided for semantic caching") | |
| return | |
| prompt = get_str_from_messages(messages) | |
| value_str = str(value) | |
| # Generate embedding for the value (response) to cache | |
| prompt_embedding = await self._get_async_embedding(prompt, **kwargs) | |
| # Get TTL and store in Redis semantic cache | |
| ttl = self._get_ttl(**kwargs) | |
| if ttl is not None: | |
| await self.llmcache.astore( | |
| prompt, | |
| value_str, | |
| vector=prompt_embedding, # Pass through custom embedding | |
| ttl=ttl, | |
| ) | |
| else: | |
| await self.llmcache.astore( | |
| prompt, | |
| value_str, | |
| vector=prompt_embedding, # Pass through custom embedding | |
| ) | |
| except Exception as e: | |
| print_verbose(f"Error in async_set_cache: {str(e)}") | |
| async def async_get_cache(self, key: str, **kwargs) -> Any: | |
| """ | |
| Asynchronously retrieve a semantically similar cached response. | |
| Args: | |
| key: The cache key (not directly used in semantic caching) | |
| **kwargs: Additional arguments including 'messages' for the prompt | |
| Returns: | |
| The cached response if a semantically similar prompt is found, else None | |
| """ | |
| print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}") | |
| try: | |
| # Extract the prompt from messages | |
| messages = kwargs.get("messages", []) | |
| if not messages: | |
| print_verbose("No messages provided for semantic cache lookup") | |
| kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 | |
| return None | |
| prompt = get_str_from_messages(messages) | |
| # Generate embedding for the prompt | |
| prompt_embedding = await self._get_async_embedding(prompt, **kwargs) | |
| # Check the cache for semantically similar prompts | |
| results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding) | |
| # handle results / cache hit | |
| if not results: | |
| kwargs.setdefault("metadata", {})[ | |
| "semantic-similarity" | |
| ] = 0.0 # TODO why here but not above?? | |
| return None | |
| cache_hit = results[0] | |
| vector_distance = float(cache_hit["vector_distance"]) | |
| # Convert vector distance back to similarity | |
| # For cosine distance: 0 = most similar, 2 = least similar | |
| # While similarity: 1 = most similar, 0 = least similar | |
| similarity = 1 - vector_distance | |
| cached_prompt = cache_hit["prompt"] | |
| cached_response = cache_hit["response"] | |
| # update kwargs["metadata"] with similarity, don't rewrite the original metadata | |
| kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity | |
| print_verbose( | |
| f"Cache hit: similarity threshold: {self.similarity_threshold}, " | |
| f"actual similarity: {similarity}, " | |
| f"current prompt: {prompt}, " | |
| f"cached prompt: {cached_prompt}" | |
| ) | |
| return self._get_cache_logic(cached_response=cached_response) | |
| except Exception as e: | |
| print_verbose(f"Error in async_get_cache: {str(e)}") | |
| kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 | |
| async def _index_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about the Redis index. | |
| Returns: | |
| Dict[str, Any]: Information about the Redis index | |
| """ | |
| aindex = await self.llmcache._get_async_index() | |
| return await aindex.info() | |
| async def async_set_cache_pipeline( | |
| self, cache_list: List[Tuple[str, Any]], **kwargs | |
| ) -> None: | |
| """ | |
| Asynchronously store multiple values in the semantic cache. | |
| Args: | |
| cache_list: List of (key, value) tuples to cache | |
| **kwargs: Additional arguments | |
| """ | |
| try: | |
| tasks = [] | |
| for val in cache_list: | |
| tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) | |
| await asyncio.gather(*tasks) | |
| except Exception as e: | |
| print_verbose(f"Error in async_set_cache_pipeline: {str(e)}") | |