Spaces:
Sleeping
Sleeping
""" | |
In-Memory Cache implementation | |
Has 4 methods: | |
- set_cache | |
- get_cache | |
- async_set_cache | |
- async_get_cache | |
""" | |
import json | |
import sys | |
import time | |
from typing import Any, List, Optional | |
from pydantic import BaseModel | |
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB | |
from .base_cache import BaseCache | |
class InMemoryCache(BaseCache): | |
def __init__( | |
self, | |
max_size_in_memory: Optional[int] = 200, | |
default_ttl: Optional[ | |
int | |
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute | |
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB | |
): | |
""" | |
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default | |
""" | |
self.max_size_in_memory = ( | |
max_size_in_memory or 200 | |
) # set an upper bound of 200 items in-memory | |
self.default_ttl = default_ttl or 600 | |
self.max_size_per_item = ( | |
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB | |
) # 1MB = 1024KB | |
# in-memory cache | |
self.cache_dict: dict = {} | |
self.ttl_dict: dict = {} | |
def check_value_size(self, value: Any): | |
""" | |
Check if value size exceeds max_size_per_item (1MB) | |
Returns True if value size is acceptable, False otherwise | |
""" | |
try: | |
# Fast path for common primitive types that are typically small | |
if ( | |
isinstance(value, (bool, int, float, str)) | |
and len(str(value)) | |
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB | |
): # Conservative estimate | |
return True | |
# Direct size check for bytes objects | |
if isinstance(value, bytes): | |
return sys.getsizeof(value) / 1024 <= self.max_size_per_item | |
# Handle special types without full conversion when possible | |
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available | |
size = value.__sizeof__() / 1024 | |
return size <= self.max_size_per_item | |
# Fallback for complex types | |
if isinstance(value, BaseModel) and hasattr( | |
value, "model_dump" | |
): # Pydantic v2 | |
value = value.model_dump() | |
elif hasattr(value, "isoformat"): # datetime objects | |
return True # datetime strings are always small | |
# Only convert to JSON if absolutely necessary | |
if not isinstance(value, (str, bytes)): | |
value = json.dumps(value, default=str) | |
return sys.getsizeof(value) / 1024 <= self.max_size_per_item | |
except Exception: | |
return False | |
def evict_cache(self): | |
""" | |
Eviction policy: | |
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict | |
This guarantees the following: | |
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes | |
- 2. When ttl is set: the item will remain in memory for at least that amount of time | |
- 3. the size of in-memory cache is bounded | |
""" | |
for key in list(self.ttl_dict.keys()): | |
if time.time() > self.ttl_dict[key]: | |
self.cache_dict.pop(key, None) | |
self.ttl_dict.pop(key, None) | |
# de-reference the removed item | |
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/ | |
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used. | |
# This can occur when an object is referenced by another object, but the reference is never removed. | |
def set_cache(self, key, value, **kwargs): | |
if len(self.cache_dict) >= self.max_size_in_memory: | |
# only evict when cache is full | |
self.evict_cache() | |
if not self.check_value_size(value): | |
return | |
self.cache_dict[key] = value | |
if "ttl" in kwargs and kwargs["ttl"] is not None: | |
self.ttl_dict[key] = time.time() + kwargs["ttl"] | |
else: | |
self.ttl_dict[key] = time.time() + self.default_ttl | |
async def async_set_cache(self, key, value, **kwargs): | |
self.set_cache(key=key, value=value, **kwargs) | |
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs): | |
for cache_key, cache_value in cache_list: | |
if ttl is not None: | |
self.set_cache(key=cache_key, value=cache_value, ttl=ttl) | |
else: | |
self.set_cache(key=cache_key, value=cache_value) | |
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]): | |
""" | |
Add value to set | |
""" | |
# get the value | |
init_value = self.get_cache(key=key) or set() | |
for val in value: | |
init_value.add(val) | |
self.set_cache(key, init_value, ttl=ttl) | |
return value | |
def get_cache(self, key, **kwargs): | |
if key in self.cache_dict: | |
if key in self.ttl_dict: | |
if time.time() > self.ttl_dict[key]: | |
self.cache_dict.pop(key, None) | |
return None | |
original_cached_response = self.cache_dict[key] | |
try: | |
cached_response = json.loads(original_cached_response) | |
except Exception: | |
cached_response = original_cached_response | |
return cached_response | |
return None | |
def batch_get_cache(self, keys: list, **kwargs): | |
return_val = [] | |
for k in keys: | |
val = self.get_cache(key=k, **kwargs) | |
return_val.append(val) | |
return return_val | |
def increment_cache(self, key, value: int, **kwargs) -> int: | |
# get the value | |
init_value = self.get_cache(key=key) or 0 | |
value = init_value + value | |
self.set_cache(key, value, **kwargs) | |
return value | |
async def async_get_cache(self, key, **kwargs): | |
return self.get_cache(key=key, **kwargs) | |
async def async_batch_get_cache(self, keys: list, **kwargs): | |
return_val = [] | |
for k in keys: | |
val = self.get_cache(key=k, **kwargs) | |
return_val.append(val) | |
return return_val | |
async def async_increment(self, key, value: float, **kwargs) -> float: | |
# get the value | |
init_value = await self.async_get_cache(key=key) or 0 | |
value = init_value + value | |
await self.async_set_cache(key, value, **kwargs) | |
return value | |
def flush_cache(self): | |
self.cache_dict.clear() | |
self.ttl_dict.clear() | |
async def disconnect(self): | |
pass | |
def delete_cache(self, key): | |
self.cache_dict.pop(key, None) | |
self.ttl_dict.pop(key, None) | |
async def async_get_ttl(self, key: str) -> Optional[int]: | |
""" | |
Get the remaining TTL of a key in in-memory cache | |
""" | |
return self.ttl_dict.get(key, None) | |