Fucius's picture
Upload 422 files
2eafbc4 verified
import threading
import time
from threading import Lock
from typing import Any, Optional
from inference.core.cache.base import BaseCache
from inference.core.env import MEMORY_CACHE_EXPIRE_INTERVAL
class MemoryCache(BaseCache):
"""
MemoryCache is an in-memory cache that implements the BaseCache interface.
Attributes:
cache (dict): A dictionary to store the cache values.
expires (dict): A dictionary to store the expiration times of the cache values.
zexpires (dict): A dictionary to store the expiration times of the sorted set values.
_expire_thread (threading.Thread): A thread that runs the _expire method.
"""
def __init__(self) -> None:
"""
Initializes a new instance of the MemoryCache class.
"""
self.cache = dict()
self.expires = dict()
self.zexpires = dict()
self._expire_thread = threading.Thread(target=self._expire)
self._expire_thread.daemon = True
self._expire_thread.start()
def _expire(self):
"""
Removes the expired keys from the cache and zexpires dictionaries.
This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration.
"""
while True:
now = time.time()
keys_to_delete = []
for k, v in self.expires.copy().items():
if v < now:
keys_to_delete.append(k)
for k in keys_to_delete:
del self.cache[k]
del self.expires[k]
keys_to_delete = []
for k, v in self.zexpires.copy().items():
if v < now:
keys_to_delete.append(k)
for k in keys_to_delete:
del self.cache[k[0]][k[1]]
del self.zexpires[k]
while time.time() - now < MEMORY_CACHE_EXPIRE_INTERVAL:
time.sleep(0.1)
def get(self, key: str):
"""
Gets the value associated with the given key.
Args:
key (str): The key to retrieve the value.
Returns:
str: The value associated with the key, or None if the key does not exist or is expired.
"""
if key in self.expires:
if self.expires[key] < time.time():
del self.cache[key]
del self.expires[key]
return None
return self.cache.get(key)
def set(self, key: str, value: str, expire: float = None):
"""
Sets a value for a given key with an optional expire time.
Args:
key (str): The key to store the value.
value (str): The value to store.
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
"""
self.cache[key] = value
if expire:
self.expires[key] = expire + time.time()
def zadd(self, key: str, value: Any, score: float, expire: float = None):
"""
Adds a member with the specified score to the sorted set stored at key.
Args:
key (str): The key of the sorted set.
value (str): The value to add to the sorted set.
score (float): The score associated with the value.
expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None.
"""
if not key in self.cache:
self.cache[key] = dict()
self.cache[key][score] = value
if expire:
self.zexpires[(key, score)] = expire + time.time()
def zrangebyscore(
self,
key: str,
min: Optional[float] = -1,
max: Optional[float] = float("inf"),
withscores: bool = False,
):
"""
Retrieves a range of members from a sorted set.
Args:
key (str): The key of the sorted set.
start (int, optional): The starting score of the range. Defaults to -1.
stop (int, optional): The ending score of the range. Defaults to float("inf").
withscores (bool, optional): Whether to return the scores along with the values. Defaults to False.
Returns:
list: A list of values (or value-score pairs if withscores is True) in the specified score range.
"""
if not key in self.cache:
return []
keys = sorted([k for k in self.cache[key].keys() if min <= k <= max])
if withscores:
return [(self.cache[key][k], k) for k in keys]
else:
return [self.cache[key][k] for k in keys]
def zremrangebyscore(
self,
key: str,
min: Optional[float] = -1,
max: Optional[float] = float("inf"),
):
"""
Removes all members in a sorted set within the given scores.
Args:
key (str): The key of the sorted set.
start (int, optional): The minimum score of the range. Defaults to -1.
stop (int, optional): The maximum score of the range. Defaults to float("inf").
Returns:
int: The number of members removed from the sorted set.
"""
res = self.zrangebyscore(key, min=min, max=max, withscores=True)
keys_to_delete = [k[1] for k in res]
for k in keys_to_delete:
del self.cache[key][k]
return len(keys_to_delete)
def acquire_lock(self, key: str, expire=None) -> Any:
lock: Optional[Lock] = self.get(key)
if lock is None:
lock = Lock()
self.set(key, lock, expire=expire)
if expire is None:
expire = -1
acquired = lock.acquire(timeout=expire)
if not acquired:
raise TimeoutError()
# refresh the lock
self.set(key, lock, expire=expire)
return lock
def set_numpy(self, key: str, value: Any, expire: float = None):
return self.set(key, value, expire=expire)
def get_numpy(self, key: str):
return self.get(key)