|
import functools |
|
import operator |
|
from collections import defaultdict |
|
from dataclasses import dataclass, field |
|
from typing import Dict, List, Literal, Optional, Tuple |
|
|
|
|
|
FILENAME_T = str |
|
TENSOR_NAME_T = str |
|
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"] |
|
|
|
|
|
class SafetensorsParsingError(Exception): |
|
"""Raised when failing to parse a safetensors file metadata. |
|
|
|
This can be the case if the file is not a safetensors file or does not respect the specification. |
|
""" |
|
|
|
|
|
class NotASafetensorsRepoError(Exception): |
|
"""Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a |
|
`model.safetensors.index.json` file. |
|
""" |
|
|
|
|
|
@dataclass |
|
class TensorInfo: |
|
"""Information about a tensor. |
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. |
|
|
|
Attributes: |
|
dtype (`str`): |
|
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"). |
|
shape (`List[int]`): |
|
The shape of the tensor. |
|
data_offsets (`Tuple[int, int]`): |
|
The offsets of the data in the file as a tuple `[BEGIN, END]`. |
|
parameter_count (`int`): |
|
The number of parameters in the tensor. |
|
""" |
|
|
|
dtype: DTYPE_T |
|
shape: List[int] |
|
data_offsets: Tuple[int, int] |
|
parameter_count: int = field(init=False) |
|
|
|
def __post_init__(self) -> None: |
|
|
|
try: |
|
self.parameter_count = functools.reduce(operator.mul, self.shape) |
|
except TypeError: |
|
self.parameter_count = 1 |
|
|
|
|
|
@dataclass |
|
class SafetensorsFileMetadata: |
|
"""Metadata for a Safetensors file hosted on the Hub. |
|
|
|
This class is returned by [`parse_safetensors_file_metadata`]. |
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. |
|
|
|
Attributes: |
|
metadata (`Dict`): |
|
The metadata contained in the file. |
|
tensors (`Dict[str, TensorInfo]`): |
|
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a |
|
[`TensorInfo`] object. |
|
parameter_count (`Dict[str, int]`): |
|
A map of the number of parameters per data type. Keys are data types and values are the number of parameters |
|
of that data type. |
|
""" |
|
|
|
metadata: Dict[str, str] |
|
tensors: Dict[TENSOR_NAME_T, TensorInfo] |
|
parameter_count: Dict[DTYPE_T, int] = field(init=False) |
|
|
|
def __post_init__(self) -> None: |
|
parameter_count: Dict[DTYPE_T, int] = defaultdict(int) |
|
for tensor in self.tensors.values(): |
|
parameter_count[tensor.dtype] += tensor.parameter_count |
|
self.parameter_count = dict(parameter_count) |
|
|
|
|
|
@dataclass |
|
class SafetensorsRepoMetadata: |
|
"""Metadata for a Safetensors repo. |
|
|
|
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared |
|
model) or a 'model.safetensors.index.json' index file (sharded model) at its root. |
|
|
|
This class is returned by [`get_safetensors_metadata`]. |
|
|
|
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. |
|
|
|
Attributes: |
|
metadata (`Dict`, *optional*): |
|
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded |
|
models. |
|
sharded (`bool`): |
|
Whether the repo contains a sharded model or not. |
|
weight_map (`Dict[str, str]`): |
|
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors. |
|
files_metadata (`Dict[str, SafetensorsFileMetadata]`): |
|
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as |
|
a [`SafetensorsFileMetadata`] object. |
|
parameter_count (`Dict[str, int]`): |
|
A map of the number of parameters per data type. Keys are data types and values are the number of parameters |
|
of that data type. |
|
""" |
|
|
|
metadata: Optional[Dict] |
|
sharded: bool |
|
weight_map: Dict[TENSOR_NAME_T, FILENAME_T] |
|
files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] |
|
parameter_count: Dict[DTYPE_T, int] = field(init=False) |
|
|
|
def __post_init__(self) -> None: |
|
parameter_count: Dict[DTYPE_T, int] = defaultdict(int) |
|
for file_metadata in self.files_metadata.values(): |
|
for dtype, nb_parameters_ in file_metadata.parameter_count.items(): |
|
parameter_count[dtype] += nb_parameters_ |
|
self.parameter_count = dict(parameter_count) |
|
|