import threading import time from threading import Lock from typing import Any, Optional from inference.core.cache.base import BaseCache from inference.core.env import MEMORY_CACHE_EXPIRE_INTERVAL class MemoryCache(BaseCache): """ MemoryCache is an in-memory cache that implements the BaseCache interface. Attributes: cache (dict): A dictionary to store the cache values. expires (dict): A dictionary to store the expiration times of the cache values. zexpires (dict): A dictionary to store the expiration times of the sorted set values. _expire_thread (threading.Thread): A thread that runs the _expire method. """ def __init__(self) -> None: """ Initializes a new instance of the MemoryCache class. """ self.cache = dict() self.expires = dict() self.zexpires = dict() self._expire_thread = threading.Thread(target=self._expire) self._expire_thread.daemon = True self._expire_thread.start() def _expire(self): """ Removes the expired keys from the cache and zexpires dictionaries. This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration. """ while True: now = time.time() keys_to_delete = [] for k, v in self.expires.copy().items(): if v < now: keys_to_delete.append(k) for k in keys_to_delete: del self.cache[k] del self.expires[k] keys_to_delete = [] for k, v in self.zexpires.copy().items(): if v < now: keys_to_delete.append(k) for k in keys_to_delete: del self.cache[k[0]][k[1]] del self.zexpires[k] while time.time() - now < MEMORY_CACHE_EXPIRE_INTERVAL: time.sleep(0.1) def get(self, key: str): """ Gets the value associated with the given key. Args: key (str): The key to retrieve the value. Returns: str: The value associated with the key, or None if the key does not exist or is expired. """ if key in self.expires: if self.expires[key] < time.time(): del self.cache[key] del self.expires[key] return None return self.cache.get(key) def set(self, key: str, value: str, expire: float = None): """ Sets a value for a given key with an optional expire time. Args: key (str): The key to store the value. value (str): The value to store. expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. """ self.cache[key] = value if expire: self.expires[key] = expire + time.time() def zadd(self, key: str, value: Any, score: float, expire: float = None): """ Adds a member with the specified score to the sorted set stored at key. Args: key (str): The key of the sorted set. value (str): The value to add to the sorted set. score (float): The score associated with the value. expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. """ if not key in self.cache: self.cache[key] = dict() self.cache[key][score] = value if expire: self.zexpires[(key, score)] = expire + time.time() def zrangebyscore( self, key: str, min: Optional[float] = -1, max: Optional[float] = float("inf"), withscores: bool = False, ): """ Retrieves a range of members from a sorted set. Args: key (str): The key of the sorted set. start (int, optional): The starting score of the range. Defaults to -1. stop (int, optional): The ending score of the range. Defaults to float("inf"). withscores (bool, optional): Whether to return the scores along with the values. Defaults to False. Returns: list: A list of values (or value-score pairs if withscores is True) in the specified score range. """ if not key in self.cache: return [] keys = sorted([k for k in self.cache[key].keys() if min <= k <= max]) if withscores: return [(self.cache[key][k], k) for k in keys] else: return [self.cache[key][k] for k in keys] def zremrangebyscore( self, key: str, min: Optional[float] = -1, max: Optional[float] = float("inf"), ): """ Removes all members in a sorted set within the given scores. Args: key (str): The key of the sorted set. start (int, optional): The minimum score of the range. Defaults to -1. stop (int, optional): The maximum score of the range. Defaults to float("inf"). Returns: int: The number of members removed from the sorted set. """ res = self.zrangebyscore(key, min=min, max=max, withscores=True) keys_to_delete = [k[1] for k in res] for k in keys_to_delete: del self.cache[key][k] return len(keys_to_delete) def acquire_lock(self, key: str, expire=None) -> Any: lock: Optional[Lock] = self.get(key) if lock is None: lock = Lock() self.set(key, lock, expire=expire) if expire is None: expire = -1 acquired = lock.acquire(timeout=expire) if not acquired: raise TimeoutError() # refresh the lock self.set(key, lock, expire=expire) return lock def set_numpy(self, key: str, value: Any, expire: float = None): return self.set(key, value, expire=expire) def get_numpy(self, key: str): return self.get(key)