haodongli's picture
init
916b126
raw
history blame
1.74 kB
"""
Pure python version of Safetensors safe_open
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
"""
import json
import mmap
import os
import torch
class SafetensorsWrapper:
def __init__(self, metadata, tensors):
self._metadata = metadata
self._tensors = tensors
def metadata(self):
return self._metadata
def keys(self):
return self._tensors.keys()
def get_tensor(self, k):
return self._tensors[k]
DTYPES = {
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
}
def create_tensor(storage, info, offset):
dtype = DTYPES[info["dtype"]]
shape = info["shape"]
start, stop = info["data_offsets"]
return (
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
.view(dtype=dtype)
.reshape(shape)
)
def safe_open(filename, framework="pt", device="cpu"):
if framework != "pt":
raise ValueError("`framework` must be 'pt'")
with open(filename, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
header = m.read(8)
n = int.from_bytes(header, "little")
metadata_bytes = m.read(n)
metadata = json.loads(metadata_bytes)
size = os.stat(filename).st_size
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
offset = n + 8
return SafetensorsWrapper(
metadata=metadata.get("__metadata__", {}),
tensors={
name: create_tensor(storage, info, offset).to(device)
for name, info in metadata.items()
if name != "__metadata__"
},
)