Spaces:
Build error
Build error
import pickle | |
import warnings | |
import h5py | |
import numpy as np | |
class Cache: | |
def __init__(self, file: str, mode: str = 'a', overwrite=False): | |
self.db_file = h5py.File(file, mode=mode) | |
self.overwrite = overwrite | |
def _key(key): | |
if isinstance(key, str): | |
return key | |
elif isinstance(key, list): | |
ret = [] | |
for k in key: | |
ret.append(Cache._key(k)) | |
return ' '.join(ret) | |
else: | |
return str(key) | |
def _value(value: np.ndarray): | |
if isinstance(value, h5py.Dataset): | |
value: np.ndarray = value[()] | |
if value.dtype.name.startswith('bytes'): | |
value = pickle.loads(value) | |
return value | |
def __getitem__(self, key): | |
key = self._key(key) | |
if key not in self: | |
raise KeyError | |
return self._value(self.db_file[key]) | |
def __setitem__(self, key, value) -> None: | |
key = self._key(key) | |
if key in self: | |
del self.db_file[key] | |
if not isinstance(value, np.ndarray): | |
value = np.array(pickle.dumps(value)) | |
self.db_file[key] = value | |
def __delitem__(self, key) -> None: | |
key = self._key(key) | |
if key in self: | |
del self.db_file[key] | |
def __len__(self) -> int: | |
return len(self.db_file) | |
def close(self) -> None: | |
self.db_file.close() | |
def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
self.close() | |
def __contains__(self, item): | |
item = self._key(item) | |
return item in self.db_file | |
def __enter__(self): | |
return self | |
def __call__(self, function): | |
""" | |
The object of the class could also be used as a decorator. Provide an additional | |
argument `cache_id' when calling the function, and the results will be cached. | |
""" | |
def wrapper(*args, **kwargs): | |
if 'cache_id' in kwargs: | |
cache_id = kwargs['cache_id'] | |
del kwargs['cache_id'] | |
if cache_id in self and not self.overwrite: | |
return self[cache_id] | |
rst = function(*args, **kwargs) | |
self[cache_id] = rst | |
return rst | |
else: | |
warnings.warn("`cache_id' argument not found. Cache is disabled.") | |
return function(*args, **kwargs) | |
return wrapper | |