Spaces:
Running
Running
""" | |
Wrapper around router cache. Meant to handle model cooldown logic | |
""" | |
import time | |
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, TypedDict, Union | |
from litellm import verbose_logger | |
from litellm.caching.caching import DualCache | |
from litellm.caching.in_memory_cache import InMemoryCache | |
if TYPE_CHECKING: | |
from opentelemetry.trace import Span as _Span | |
Span = Union[_Span, Any] | |
else: | |
Span = Any | |
class CooldownCacheValue(TypedDict): | |
exception_received: str | |
status_code: str | |
timestamp: float | |
cooldown_time: float | |
class CooldownCache: | |
def __init__(self, cache: DualCache, default_cooldown_time: float): | |
self.cache = cache | |
self.default_cooldown_time = default_cooldown_time | |
self.in_memory_cache = InMemoryCache() | |
def _common_add_cooldown_logic( | |
self, model_id: str, original_exception, exception_status, cooldown_time: float | |
) -> Tuple[str, CooldownCacheValue]: | |
try: | |
current_time = time.time() | |
cooldown_key = f"deployment:{model_id}:cooldown" | |
# Store the cooldown information for the deployment separately | |
cooldown_data = CooldownCacheValue( | |
exception_received=str(original_exception), | |
status_code=str(exception_status), | |
timestamp=current_time, | |
cooldown_time=cooldown_time, | |
) | |
return cooldown_key, cooldown_data | |
except Exception as e: | |
verbose_logger.error( | |
"CooldownCache::_common_add_cooldown_logic - Exception occurred - {}".format( | |
str(e) | |
) | |
) | |
raise e | |
def add_deployment_to_cooldown( | |
self, | |
model_id: str, | |
original_exception: Exception, | |
exception_status: int, | |
cooldown_time: Optional[float], | |
): | |
try: | |
_cooldown_time = cooldown_time or self.default_cooldown_time | |
cooldown_key, cooldown_data = self._common_add_cooldown_logic( | |
model_id=model_id, | |
original_exception=original_exception, | |
exception_status=exception_status, | |
cooldown_time=_cooldown_time, | |
) | |
# Set the cache with a TTL equal to the cooldown time | |
self.cache.set_cache( | |
value=cooldown_data, | |
key=cooldown_key, | |
ttl=_cooldown_time, | |
) | |
except Exception as e: | |
verbose_logger.error( | |
"CooldownCache::add_deployment_to_cooldown - Exception occurred - {}".format( | |
str(e) | |
) | |
) | |
raise e | |
def get_cooldown_cache_key(model_id: str) -> str: | |
return f"deployment:{model_id}:cooldown" | |
async def async_get_active_cooldowns( | |
self, model_ids: List[str], parent_otel_span: Optional[Span] | |
) -> List[Tuple[str, CooldownCacheValue]]: | |
# Generate the keys for the deployments | |
keys = [ | |
CooldownCache.get_cooldown_cache_key(model_id) for model_id in model_ids | |
] | |
# Retrieve the values for the keys using mget | |
## more likely to be none if no models ratelimited. So just check redis every 1s | |
## each redis call adds ~100ms latency. | |
## check in memory cache first | |
results = await self.cache.async_batch_get_cache( | |
keys=keys, parent_otel_span=parent_otel_span | |
) | |
active_cooldowns: List[Tuple[str, CooldownCacheValue]] = [] | |
if results is None: | |
return active_cooldowns | |
# Process the results | |
for model_id, result in zip(model_ids, results): | |
if result and isinstance(result, dict): | |
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore | |
active_cooldowns.append((model_id, cooldown_cache_value)) | |
return active_cooldowns | |
def get_active_cooldowns( | |
self, model_ids: List[str], parent_otel_span: Optional[Span] | |
) -> List[Tuple[str, CooldownCacheValue]]: | |
# Generate the keys for the deployments | |
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] | |
# Retrieve the values for the keys using mget | |
results = ( | |
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) | |
or [] | |
) | |
active_cooldowns = [] | |
# Process the results | |
for model_id, result in zip(model_ids, results): | |
if result and isinstance(result, dict): | |
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore | |
active_cooldowns.append((model_id, cooldown_cache_value)) | |
return active_cooldowns | |
def get_min_cooldown( | |
self, model_ids: List[str], parent_otel_span: Optional[Span] | |
) -> float: | |
"""Return min cooldown time required for a group of model id's.""" | |
# Generate the keys for the deployments | |
keys = [f"deployment:{model_id}:cooldown" for model_id in model_ids] | |
# Retrieve the values for the keys using mget | |
results = ( | |
self.cache.batch_get_cache(keys=keys, parent_otel_span=parent_otel_span) | |
or [] | |
) | |
min_cooldown_time: Optional[float] = None | |
# Process the results | |
for model_id, result in zip(model_ids, results): | |
if result and isinstance(result, dict): | |
cooldown_cache_value = CooldownCacheValue(**result) # type: ignore | |
if min_cooldown_time is None: | |
min_cooldown_time = cooldown_cache_value["cooldown_time"] | |
elif cooldown_cache_value["cooldown_time"] < min_cooldown_time: | |
min_cooldown_time = cooldown_cache_value["cooldown_time"] | |
return min_cooldown_time or self.default_cooldown_time | |
# Usage example: | |
# cooldown_cache = CooldownCache(cache=your_cache_instance, cooldown_time=your_cooldown_time) | |
# cooldown_cache.add_deployment_to_cooldown(deployment, original_exception, exception_status) | |
# active_cooldowns = cooldown_cache.get_active_cooldowns() | |