Spaces:
Build error
Build error
File size: 2,467 Bytes
6680682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import pickle
import warnings
import h5py
import numpy as np
class Cache:
def __init__(self, file: str, mode: str = 'a', overwrite=False):
self.db_file = h5py.File(file, mode=mode)
self.overwrite = overwrite
@staticmethod
def _key(key):
if isinstance(key, str):
return key
elif isinstance(key, list):
ret = []
for k in key:
ret.append(Cache._key(k))
return ' '.join(ret)
else:
return str(key)
@staticmethod
def _value(value: np.ndarray):
if isinstance(value, h5py.Dataset):
value: np.ndarray = value[()]
if value.dtype.name.startswith('bytes'):
value = pickle.loads(value)
return value
def __getitem__(self, key):
key = self._key(key)
if key not in self:
raise KeyError
return self._value(self.db_file[key])
def __setitem__(self, key, value) -> None:
key = self._key(key)
if key in self:
del self.db_file[key]
if not isinstance(value, np.ndarray):
value = np.array(pickle.dumps(value))
self.db_file[key] = value
def __delitem__(self, key) -> None:
key = self._key(key)
if key in self:
del self.db_file[key]
def __len__(self) -> int:
return len(self.db_file)
def close(self) -> None:
self.db_file.close()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
def __contains__(self, item):
item = self._key(item)
return item in self.db_file
def __enter__(self):
return self
def __call__(self, function):
"""
The object of the class could also be used as a decorator. Provide an additional
argument `cache_id' when calling the function, and the results will be cached.
"""
def wrapper(*args, **kwargs):
if 'cache_id' in kwargs:
cache_id = kwargs['cache_id']
del kwargs['cache_id']
if cache_id in self and not self.overwrite:
return self[cache_id]
rst = function(*args, **kwargs)
self[cache_id] = rst
return rst
else:
warnings.warn("`cache_id' argument not found. Cache is disabled.")
return function(*args, **kwargs)
return wrapper
|