from typing import Callable, TypeVar, Any from typing_extensions import ParamSpec from functools import lru_cache, _CacheInfo def conditional_cache(maxsize: int, condition: Callable): def decorator(func): @lru_cache_ext(maxsize=maxsize) def cached_func(*args, **kwargs): return func(*args, **kwargs) def wrapper(*args, **kwargs): if condition(*args, **kwargs): return cached_func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper return decorator def hash_list(l: list) -> int: __hash = 0 for i, e in enumerate(l): __hash = hash((__hash, i, hash_item(e))) return __hash def hash_dict(d: dict) -> int: __hash = 0 for k, v in d.items(): __hash = hash((__hash, k, hash_item(v))) return __hash def hash_item(e) -> int: if hasattr(e, "__hash__") and callable(e.__hash__): try: return hash(e) except TypeError: pass if isinstance(e, (list, set, tuple)): return hash_list(list(e)) elif isinstance(e, (dict)): return hash_dict(e) else: raise TypeError(f"unhashable type: {e.__class__}") PT = ParamSpec("PT") RT = TypeVar("RT") def lru_cache_ext( *opts, hashfunc: Callable[..., int] = hash_item, **kwopts ) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]: def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]: class _lru_cache_ext_wrapper: args: tuple kwargs: dict[str, Any] def cache_info(self) -> _CacheInfo: ... def cache_clear(self) -> None: ... @classmethod @lru_cache(*opts, **kwopts) def cached_func(cls, args_hash: int) -> RT: return func(*cls.args, **cls.kwargs) @classmethod def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT: __hash = hashfunc( ( id(func), *[hashfunc(a) for a in args], *[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()], ) ) cls.args = args cls.kwargs = kwargs cls.cache_info = cls.cached_func.cache_info cls.cache_clear = cls.cached_func.cache_clear return cls.cached_func(__hash) return _lru_cache_ext_wrapper() return decorator