File size: 3,651 Bytes
8f87579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
Utilities for dealing with simple state dicts as npz files instead of pth files.
'''

import torch
from collections.abc import MutableMapping, Mapping

def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None):
    '''
    Loads a model from numpy_dict using load_state_dict.
    Converts numpy types to torch types using the current state_dict
    of the model to determine types and devices for the tensors.
    Supports loading a subdict by prepending the given prefix to all keys.
    '''
    if prefix:
        if not prefix.endswith('.'):
            prefix = prefix + '.'
        numpy_dict = PrefixSubDict(numpy_dict, prefix)
    if examples is None:
        exampels = model.state_dict()
    torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples)
    model.load_state_dict(torch_state_dict)

def save_to_numpy_dict(model, numpy_dict, prefix=''):
    '''
    Saves a model by copying tensors to numpy_dict.
    Converts torch types to numpy types using `t.detach().cpu().numpy()`.
    Supports saving a subdict by prepending the given prefix to all keys.
    '''
    if prefix:
        if not prefix.endswith('.'):
            prefix = prefix + '.'
    for k, v in model.numpy_dict().items():
        if isinstance(v, torch.Tensor):
            v = v.detach().cpu().numpy()
        numpy_dict[prefix + k] = v

class TorchTypeMatchingDict(Mapping):
    '''
    Provides a view of a dict of numpy values as torch tensors, where the
    types are converted to match the types and devices in the given
    dict of examples.
    '''
    def __init__(self, data, examples):
        self.data = data
        self.examples = examples
        self.cached_data = {}
    def __getitem__(self, key):
        if key in self.cached_data:
            return self.cached_data[key]
        val = self.data[key]
        if key not in self.examples:
            return val
        example = self.examples.get(key, None)
        example_type = type(example)
        if example is not None and type(val) != example_type:
            if isinstance(example, torch.Tensor):
                val = torch.from_numpy(val)
            else:
                val = example_type(val)
        if isinstance(example, torch.Tensor):
            val = val.to(dtype=example.dtype, device=example.device)
        self.cached_data[key] = val
        return val
    def __iter__(self):
        return self.data.keys()
    def __len__(self):
        return len(self.data)

class PrefixSubDict(MutableMapping):
    '''
    Provides a view of the subset of a dict where string keys begin with
    the given prefix.  The prefix is stripped from all keys of the view.
    '''
    def __init__(self, data, prefix=''):
        self.data = data
        self.prefix = prefix
        self._cached_keys = None
    def __getitem__(self, key):
        return self.data[self.prefix + key]
    def __setitem__(self, key, value):
        pkey = self.prefix + key
        if self._cached_keys is not None and pkey not in self.data:
            self._cached_keys = None
        self.data[pkey] = value
    def __delitem__(self, key):
        pkey = self.prefix + key
        if self._cached_keys is not None and pkey in self.data:
            self._cached_keys = None
        del self.data[pkey]
    def __cached_keys(self):
        if self._cached_keys is None:
            plen = len(self.prefix)
            self._cached_keys = list(k[plen:] for k in self.data
                    if k.startswith(self.prefix))
        return self._cached_keys
    def __iter__(self):
        return iter(self.__cached_keys())
    def __len__(self):
        return len(self.__cached_keys())