Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
6.17 kB
"""
Wrapper around router cache. Meant to handle model cooldown logic
"""
import time
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union
from litellm import verbose_logger
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = Union[_Span, Any]
else:
Span = Any
class CooldownCacheValue(TypedDict):
exception_received: str
status_code: str
timestamp: float
cooldown_time: float
class CooldownCache:
def __init__(self, cache: DualCache, default_cooldown_time: float):
self.cache = cache
self.default_cooldown_time = default_cooldown_time
self.in_memory_cache = InMemoryCache()
def _common_add_cooldown_logic(
self, model_id: str, original_exception, exception_status, cooldown_time: float
) -> Tuple[str, CooldownCacheValue]:
try:
current_time = time.time()
cooldown_key = f"deployment:{model_id}:cooldown"
# Store the cooldown information for the deployment separately
cooldown_data = CooldownCacheValue(
exception_received=str(original_exception),
status_code=str(exception_status),
timestamp=current_time,
cooldown_time=cooldown_time,
)
return cooldown_key, cooldown_data
except Exception as e:
verbose_logger.error(
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format(
str(e)
)
)
raise e
def add_deployment_to_cooldown(
self,
model_id: str,
original_exception: Exception,
exception_status: int,
cooldown_time: Optional[float],
):
try:
_cooldown_time = cooldown_time or self.default_cooldown_time
cooldown_key, cooldown_data = self._common_add_cooldown_logic(
model_id=model_id,
original_exception=original_exception,
exception_status=exception_status,
cooldown_time=_cooldown_time,
)
# Set the cache with a TTL equal to the cooldown time
self.cache.set_cache(
value=cooldown_data,
key=cooldown_key,
ttl=_cooldown_time,
)
except Exception as e:
verbose_logger.error(
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format(
str(e)
)
)
raise e
@staticmethod
def get_cooldown_cache_key(model_id: str) -> str:
return f"deployment:{model_id}:cooldown"
async def async_get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids
]
# Retrieve the values for the keys using mget
## more likely to be none if no models ratelimited. So just check redis every 1s
## each redis call adds ~100ms latency.
## check in memory cache first
results = await self.cache.async_batch_get_cache(
keys=keys, parent_otel_span=parent_otel_span
)
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = []
if results is None:
return active_cooldowns
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_active_cooldowns(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> List[Tuple[str, CooldownCacheValue]]:
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
active_cooldowns = []
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
active_cooldowns.append((model_id, cooldown_cache_value))
return active_cooldowns
def get_min_cooldown(
self, model_ids: List[str], parent_otel_span: Optional[Span]
) -> float:
"""Return min cooldown time required for a group of model id's."""
# Generate the keys for the deployments
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids]
# Retrieve the values for the keys using mget
results = (
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span)
or []
)
min_cooldown_time: Optional[float] = None
# Process the results
for model_id, result in zip(model_ids, results):
if result and isinstance(result, dict):
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore
if min_cooldown_time is None:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time:
min_cooldown_time = cooldown_cache_value["cooldown_time"]
return min_cooldown_time or self.default_cooldown_time
# Usage example:
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time)
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status)
# active_cooldowns = cooldown_cache.get_active_cooldowns()