File size: 2,467 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pickle
import warnings

import h5py
import numpy as np


class Cache:
    def __init__(self, file: str, mode: str = 'a', overwrite=False):
        self.db_file = h5py.File(file, mode=mode)
        self.overwrite = overwrite

    @staticmethod
    def _key(key):
        if isinstance(key, str):
            return key
        elif isinstance(key, list):
            ret = []
            for k in key:
                ret.append(Cache._key(k))
            return ' '.join(ret)
        else:
            return str(key)

    @staticmethod
    def _value(value: np.ndarray):
        if isinstance(value, h5py.Dataset):
            value: np.ndarray = value[()]
        if value.dtype.name.startswith('bytes'):
            value = pickle.loads(value)
        return value

    def __getitem__(self, key):
        key = self._key(key)
        if key not in self:
            raise KeyError
        return self._value(self.db_file[key])

    def __setitem__(self, key, value) -> None:
        key = self._key(key)
        if key in self:
            del self.db_file[key]
        if not isinstance(value, np.ndarray):
            value = np.array(pickle.dumps(value))
        self.db_file[key] = value

    def __delitem__(self, key) -> None:
        key = self._key(key)
        if key in self:
            del self.db_file[key]

    def __len__(self) -> int:
        return len(self.db_file)

    def close(self) -> None:
        self.db_file.close()

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close()

    def __contains__(self, item):
        item = self._key(item)
        return item in self.db_file

    def __enter__(self):
        return self

    def __call__(self, function):
        """
        The object of the class could also be used as a decorator. Provide an additional
        argument `cache_id' when calling the function, and the results will be cached.
        """

        def wrapper(*args, **kwargs):
            if 'cache_id' in kwargs:
                cache_id = kwargs['cache_id']
                del kwargs['cache_id']
                if cache_id in self and not self.overwrite:
                    return self[cache_id]
                rst = function(*args, **kwargs)
                self[cache_id] = rst
                return rst
            else:
                warnings.warn("`cache_id' argument not found. Cache is disabled.")
                return function(*args, **kwargs)

        return wrapper