''' Utilities for dealing with simple state dicts as npz files instead of pth files. ''' import torch from collections.abc import MutableMapping, Mapping def load_from_numpy_dict(model, numpy_dict, prefix='', examples=None): ''' Loads a model from numpy_dict using load_state_dict. Converts numpy types to torch types using the current state_dict of the model to determine types and devices for the tensors. Supports loading a subdict by prepending the given prefix to all keys. ''' if prefix: if not prefix.endswith('.'): prefix = prefix + '.' numpy_dict = PrefixSubDict(numpy_dict, prefix) if examples is None: exampels = model.state_dict() torch_state_dict = TorchTypeMatchingDict(numpy_dict, examples) model.load_state_dict(torch_state_dict) def save_to_numpy_dict(model, numpy_dict, prefix=''): ''' Saves a model by copying tensors to numpy_dict. Converts torch types to numpy types using `t.detach().cpu().numpy()`. Supports saving a subdict by prepending the given prefix to all keys. ''' if prefix: if not prefix.endswith('.'): prefix = prefix + '.' for k, v in model.numpy_dict().items(): if isinstance(v, torch.Tensor): v = v.detach().cpu().numpy() numpy_dict[prefix + k] = v class TorchTypeMatchingDict(Mapping): ''' Provides a view of a dict of numpy values as torch tensors, where the types are converted to match the types and devices in the given dict of examples. ''' def __init__(self, data, examples): self.data = data self.examples = examples self.cached_data = {} def __getitem__(self, key): if key in self.cached_data: return self.cached_data[key] val = self.data[key] if key not in self.examples: return val example = self.examples.get(key, None) example_type = type(example) if example is not None and type(val) != example_type: if isinstance(example, torch.Tensor): val = torch.from_numpy(val) else: val = example_type(val) if isinstance(example, torch.Tensor): val = val.to(dtype=example.dtype, device=example.device) self.cached_data[key] = val return val def __iter__(self): return self.data.keys() def __len__(self): return len(self.data) class PrefixSubDict(MutableMapping): ''' Provides a view of the subset of a dict where string keys begin with the given prefix. The prefix is stripped from all keys of the view. ''' def __init__(self, data, prefix=''): self.data = data self.prefix = prefix self._cached_keys = None def __getitem__(self, key): return self.data[self.prefix + key] def __setitem__(self, key, value): pkey = self.prefix + key if self._cached_keys is not None and pkey not in self.data: self._cached_keys = None self.data[pkey] = value def __delitem__(self, key): pkey = self.prefix + key if self._cached_keys is not None and pkey in self.data: self._cached_keys = None del self.data[pkey] def __cached_keys(self): if self._cached_keys is None: plen = len(self.prefix) self._cached_keys = list(k[plen:] for k in self.data if k.startswith(self.prefix)) return self._cached_keys def __iter__(self): return iter(self.__cached_keys()) def __len__(self): return len(self.__cached_keys())