OLMo-Bitnet-1B / safetensors_util.py
emozilla's picture
update inference code
2b1c7b3
raw
history blame
No virus
2.44 kB
import base64
import pickle
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import safetensors.torch
import torch
from .aliases import PathOrStr
__all__ = [
"state_dict_to_safetensors_file",
"safetensors_file_to_state_dict",
]
@dataclass(eq=True, frozen=True)
class STKey:
keys: Tuple
value_is_pickled: bool
def encode_key(key: STKey) -> str:
b = pickle.dumps((key.keys, key.value_is_pickled))
b = base64.urlsafe_b64encode(b)
return str(b, "ASCII")
def decode_key(key: str) -> STKey:
b = base64.urlsafe_b64decode(key)
keys, value_is_pickled = pickle.loads(b)
return STKey(keys, value_is_pickled)
def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]:
result = {}
for key, value in d.items():
if isinstance(value, torch.Tensor):
result[STKey((key,), False)] = value
elif isinstance(value, dict):
value = flatten_dict(value)
for inner_key, inner_value in value.items():
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value
else:
pickled = bytearray(pickle.dumps(value))
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8)
result[STKey((key,), True)] = pickled_tensor
return result
def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict:
result: Dict = {}
for key, value in d.items():
if key.value_is_pickled:
value = pickle.loads(value.numpy().data)
target_dict = result
for k in key.keys[:-1]:
new_target_dict = target_dict.get(k)
if new_target_dict is None:
new_target_dict = {}
target_dict[k] = new_target_dict
target_dict = new_target_dict
target_dict[key.keys[-1]] = value
return result
def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr):
state_dict = flatten_dict(state_dict)
state_dict = {encode_key(k): v for k, v in state_dict.items()}
safetensors.torch.save_file(state_dict, filename)
def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict:
if map_location is None:
map_location = "cpu"
state_dict = safetensors.torch.load_file(filename, device=map_location)
state_dict = {decode_key(k): v for k, v in state_dict.items()}
return unflatten_dict(state_dict)