import base64 import pickle from dataclasses import dataclass from typing import Dict, Optional, Tuple import safetensors.torch import torch from .aliases import PathOrStr __all__ = [ "state_dict_to_safetensors_file", "safetensors_file_to_state_dict", ] @dataclass(eq=True, frozen=True) class STKey: keys: Tuple value_is_pickled: bool def encode_key(key: STKey) -> str: b = pickle.dumps((key.keys, key.value_is_pickled)) b = base64.urlsafe_b64encode(b) return str(b, "ASCII") def decode_key(key: str) -> STKey: b = base64.urlsafe_b64decode(key) keys, value_is_pickled = pickle.loads(b) return STKey(keys, value_is_pickled) def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]: result = {} for key, value in d.items(): if isinstance(value, torch.Tensor): result[STKey((key,), False)] = value elif isinstance(value, dict): value = flatten_dict(value) for inner_key, inner_value in value.items(): result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value else: pickled = bytearray(pickle.dumps(value)) pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8) result[STKey((key,), True)] = pickled_tensor return result def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict: result: Dict = {} for key, value in d.items(): if key.value_is_pickled: value = pickle.loads(value.numpy().data) target_dict = result for k in key.keys[:-1]: new_target_dict = target_dict.get(k) if new_target_dict is None: new_target_dict = {} target_dict[k] = new_target_dict target_dict = new_target_dict target_dict[key.keys[-1]] = value return result def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr): state_dict = flatten_dict(state_dict) state_dict = {encode_key(k): v for k, v in state_dict.items()} safetensors.torch.save_file(state_dict, filename) def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict: if map_location is None: map_location = "cpu" state_dict = safetensors.torch.load_file(filename, device=map_location) state_dict = {decode_key(k): v for k, v in state_dict.items()} return unflatten_dict(state_dict)