Spaces:
Running
Running
""" | |
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called. | |
""" | |
import hashlib | |
import json | |
from typing import TYPE_CHECKING, Any, List, Optional, TypedDict, Union | |
from litellm.caching.caching import DualCache | |
from litellm.caching.in_memory_cache import InMemoryCache | |
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam | |
if TYPE_CHECKING: | |
from opentelemetry.trace import Span as _Span | |
from litellm.router import Router | |
litellm_router = Router | |
Span = Union[_Span, Any] | |
else: | |
Span = Any | |
litellm_router = Any | |
class PromptCachingCacheValue(TypedDict): | |
model_id: str | |
class PromptCachingCache: | |
def __init__(self, cache: DualCache): | |
self.cache = cache | |
self.in_memory_cache = InMemoryCache() | |
def serialize_object(obj: Any) -> Any: | |
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string.""" | |
if hasattr(obj, "dict"): | |
# If the object is a Pydantic model, use its `dict()` method | |
return obj.dict() | |
elif isinstance(obj, dict): | |
# If the object is a dictionary, serialize it with sorted keys | |
return json.dumps( | |
obj, sort_keys=True, separators=(",", ":") | |
) # Standardize serialization | |
elif isinstance(obj, list): | |
# Serialize lists by ensuring each element is handled properly | |
return [PromptCachingCache.serialize_object(item) for item in obj] | |
elif isinstance(obj, (int, float, bool)): | |
return obj # Keep primitive types as-is | |
return str(obj) | |
def get_prompt_caching_cache_key( | |
messages: Optional[List[AllMessageValues]], | |
tools: Optional[List[ChatCompletionToolParam]], | |
) -> Optional[str]: | |
if messages is None and tools is None: | |
return None | |
# Use serialize_object for consistent and stable serialization | |
data_to_hash = {} | |
if messages is not None: | |
serialized_messages = PromptCachingCache.serialize_object(messages) | |
data_to_hash["messages"] = serialized_messages | |
if tools is not None: | |
serialized_tools = PromptCachingCache.serialize_object(tools) | |
data_to_hash["tools"] = serialized_tools | |
# Combine serialized data into a single string | |
data_to_hash_str = json.dumps( | |
data_to_hash, | |
sort_keys=True, | |
separators=(",", ":"), | |
) | |
# Create a hash of the serialized data for a stable cache key | |
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest() | |
return f"deployment:{hashed_data}:prompt_caching" | |
def add_model_id( | |
self, | |
model_id: str, | |
messages: Optional[List[AllMessageValues]], | |
tools: Optional[List[ChatCompletionToolParam]], | |
) -> None: | |
if messages is None and tools is None: | |
return None | |
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) | |
self.cache.set_cache( | |
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300 | |
) | |
return None | |
async def async_add_model_id( | |
self, | |
model_id: str, | |
messages: Optional[List[AllMessageValues]], | |
tools: Optional[List[ChatCompletionToolParam]], | |
) -> None: | |
if messages is None and tools is None: | |
return None | |
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) | |
await self.cache.async_set_cache( | |
cache_key, | |
PromptCachingCacheValue(model_id=model_id), | |
ttl=300, # store for 5 minutes | |
) | |
return None | |
async def async_get_model_id( | |
self, | |
messages: Optional[List[AllMessageValues]], | |
tools: Optional[List[ChatCompletionToolParam]], | |
) -> Optional[PromptCachingCacheValue]: | |
""" | |
if messages is not none | |
- check full messages | |
- check messages[:-1] | |
- check messages[:-2] | |
- check messages[:-3] | |
use self.cache.async_batch_get_cache(keys=potential_cache_keys]) | |
""" | |
if messages is None and tools is None: | |
return None | |
# Generate potential cache keys by slicing messages | |
potential_cache_keys = [] | |
if messages is not None: | |
full_cache_key = PromptCachingCache.get_prompt_caching_cache_key( | |
messages, tools | |
) | |
potential_cache_keys.append(full_cache_key) | |
# Check progressively shorter message slices | |
for i in range(1, min(4, len(messages))): | |
partial_messages = messages[:-i] | |
partial_cache_key = PromptCachingCache.get_prompt_caching_cache_key( | |
partial_messages, tools | |
) | |
potential_cache_keys.append(partial_cache_key) | |
# Perform batch cache lookup | |
cache_results = await self.cache.async_batch_get_cache( | |
keys=potential_cache_keys | |
) | |
if cache_results is None: | |
return None | |
# Return the first non-None cache result | |
for result in cache_results: | |
if result is not None: | |
return result | |
return None | |
def get_model_id( | |
self, | |
messages: Optional[List[AllMessageValues]], | |
tools: Optional[List[ChatCompletionToolParam]], | |
) -> Optional[PromptCachingCacheValue]: | |
if messages is None and tools is None: | |
return None | |
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools) | |
return self.cache.get_cache(cache_key) | |