Spaces:
Sleeping
Sleeping
import asyncio | |
import uuid | |
from typing import TYPE_CHECKING, Any, Optional | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching.redis_cache import RedisCache | |
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS | |
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj | |
from litellm.types.services import ServiceTypes | |
if TYPE_CHECKING: | |
ProxyLogging = Any | |
else: | |
ProxyLogging = Any | |
class PodLockManager: | |
""" | |
Manager for acquiring and releasing locks for cron jobs using Redis. | |
Ensures that only one pod can run a cron job at a time. | |
""" | |
def __init__(self, redis_cache: Optional[RedisCache] = None): | |
self.pod_id = str(uuid.uuid4()) | |
self.redis_cache = redis_cache | |
def get_redis_lock_key(cronjob_id: str) -> str: | |
return f"cronjob_lock:{cronjob_id}" | |
async def acquire_lock( | |
self, | |
cronjob_id: str, | |
) -> Optional[bool]: | |
""" | |
Attempt to acquire the lock for a specific cron job using Redis. | |
Uses the SET command with NX and EX options to ensure atomicity. | |
""" | |
if self.redis_cache is None: | |
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock") | |
return None | |
try: | |
verbose_proxy_logger.debug( | |
"Pod %s attempting to acquire Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX) | |
# and with an expiration (EX) to avoid deadlocks. | |
lock_key = PodLockManager.get_redis_lock_key(cronjob_id) | |
acquired = await self.redis_cache.async_set_cache( | |
lock_key, | |
self.pod_id, | |
nx=True, | |
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, | |
) | |
if acquired: | |
verbose_proxy_logger.info( | |
"Pod %s successfully acquired Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
return True | |
else: | |
# Check if the current pod already holds the lock | |
current_value = await self.redis_cache.async_get_cache(lock_key) | |
if current_value is not None: | |
if isinstance(current_value, bytes): | |
current_value = current_value.decode("utf-8") | |
if current_value == self.pod_id: | |
verbose_proxy_logger.info( | |
"Pod %s already holds the Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
self._emit_acquired_lock_event(cronjob_id, self.pod_id) | |
return True | |
return False | |
except Exception as e: | |
verbose_proxy_logger.error( | |
f"Error acquiring Redis lock for {cronjob_id}: {e}" | |
) | |
return False | |
async def release_lock( | |
self, | |
cronjob_id: str, | |
): | |
""" | |
Release the lock if the current pod holds it. | |
Uses get and delete commands to ensure that only the owner can release the lock. | |
""" | |
if self.redis_cache is None: | |
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock") | |
return | |
try: | |
cronjob_id = cronjob_id | |
verbose_proxy_logger.debug( | |
"Pod %s attempting to release Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
lock_key = PodLockManager.get_redis_lock_key(cronjob_id) | |
current_value = await self.redis_cache.async_get_cache(lock_key) | |
if current_value is not None: | |
if isinstance(current_value, bytes): | |
current_value = current_value.decode("utf-8") | |
if current_value == self.pod_id: | |
result = await self.redis_cache.async_delete_cache(lock_key) | |
if result == 1: | |
verbose_proxy_logger.info( | |
"Pod %s successfully released Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
self._emit_released_lock_event( | |
cronjob_id=cronjob_id, | |
pod_id=self.pod_id, | |
) | |
else: | |
verbose_proxy_logger.debug( | |
"Pod %s failed to release Redis lock for cronjob_id=%s", | |
self.pod_id, | |
cronjob_id, | |
) | |
else: | |
verbose_proxy_logger.debug( | |
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s", | |
self.pod_id, | |
cronjob_id, | |
current_value, | |
) | |
else: | |
verbose_proxy_logger.debug( | |
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found", | |
self.pod_id, | |
cronjob_id, | |
) | |
except Exception as e: | |
verbose_proxy_logger.error( | |
f"Error releasing Redis lock for {cronjob_id}: {e}" | |
) | |
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str): | |
asyncio.create_task( | |
service_logger_obj.async_service_success_hook( | |
service=ServiceTypes.POD_LOCK_MANAGER, | |
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, | |
call_type="_emit_acquired_lock_event", | |
event_metadata={ | |
"gauge_labels": f"{cronjob_id}:{pod_id}", | |
"gauge_value": 1, | |
}, | |
) | |
) | |
def _emit_released_lock_event(cronjob_id: str, pod_id: str): | |
asyncio.create_task( | |
service_logger_obj.async_service_success_hook( | |
service=ServiceTypes.POD_LOCK_MANAGER, | |
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS, | |
call_type="_emit_released_lock_event", | |
event_metadata={ | |
"gauge_labels": f"{cronjob_id}:{pod_id}", | |
"gauge_value": 0, | |
}, | |
) | |
) | |