Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
6.69 kB
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
ProxyLogging = Any
else:
ProxyLogging = Any
class PodLockManager:
"""
Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time.
"""
def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.redis_cache = redis_cache
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(
self,
cronjob_id: str,
) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
return None
try:
verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache(
lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True
return False
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring Redis lock for {cronjob_id}: {e}"
)
return False
async def release_lock(
self,
cronjob_id: str,
):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error releasing Redis lock for {cronjob_id}: {e}"
)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_acquired_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 1,
},
)
)
@staticmethod
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_released_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 0,
},
)
)