Spaces:
Sleeping
Sleeping
File size: 6,687 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
ProxyLogging = Any
else:
ProxyLogging = Any
class PodLockManager:
"""
Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time.
"""
def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.redis_cache = redis_cache
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(
self,
cronjob_id: str,
) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
return None
try:
verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache(
lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True
return False
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring Redis lock for {cronjob_id}: {e}"
)
return False
async def release_lock(
self,
cronjob_id: str,
):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s failed to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error releasing Redis lock for {cronjob_id}: {e}"
)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_acquired_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 1,
},
)
)
@staticmethod
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_released_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 0,
},
)
)
|