Spaces:
Runtime error
Runtime error
""" | |
Pure python version of Safetensors safe_open | |
From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 | |
""" | |
import json | |
import mmap | |
import os | |
import torch | |
class SafetensorsWrapper: | |
def __init__(self, metadata, tensors): | |
self._metadata = metadata | |
self._tensors = tensors | |
def metadata(self): | |
return self._metadata | |
def keys(self): | |
return self._tensors.keys() | |
def get_tensor(self, k): | |
return self._tensors[k] | |
DTYPES = { | |
"F32": torch.float32, | |
"F16": torch.float16, | |
"BF16": torch.bfloat16, | |
} | |
def create_tensor(storage, info, offset): | |
dtype = DTYPES[info["dtype"]] | |
shape = info["shape"] | |
start, stop = info["data_offsets"] | |
return ( | |
torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) | |
.view(dtype=dtype) | |
.reshape(shape) | |
) | |
def safe_open(filename, framework="pt", device="cpu"): | |
if framework != "pt": | |
raise ValueError("`framework` must be 'pt'") | |
with open(filename, mode="r", encoding="utf8") as file_obj: | |
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: | |
header = m.read(8) | |
n = int.from_bytes(header, "little") | |
metadata_bytes = m.read(n) | |
metadata = json.loads(metadata_bytes) | |
size = os.stat(filename).st_size | |
storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() | |
offset = n + 8 | |
return SafetensorsWrapper( | |
metadata=metadata.get("__metadata__", {}), | |
tensors={ | |
name: create_tensor(storage, info, offset).to(device) | |
for name, info in metadata.items() | |
if name != "__metadata__" | |
}, | |
) | |