reach-vb's picture
reach-vb HF staff
87245978eac49d491b540e2a86047c183ef44b5025e4ace6bf1f58653aed56a8
c8e7ce2
raw
history blame
No virus
4.88 kB
import functools
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple
FILENAME_T = str
TENSOR_NAME_T = str
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
class SafetensorsParsingError(Exception):
"""Raised when failing to parse a safetensors file metadata.
This can be the case if the file is not a safetensors file or does not respect the specification.
"""
class NotASafetensorsRepoError(Exception):
"""Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a
`model.safetensors.index.json` file.
"""
@dataclass
class TensorInfo:
"""Information about a tensor.
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
dtype (`str`):
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
shape (`List[int]`):
The shape of the tensor.
data_offsets (`Tuple[int, int]`):
The offsets of the data in the file as a tuple `[BEGIN, END]`.
parameter_count (`int`):
The number of parameters in the tensor.
"""
dtype: DTYPE_T
shape: List[int]
data_offsets: Tuple[int, int]
parameter_count: int = field(init=False)
def __post_init__(self) -> None:
# Taken from https://stackoverflow.com/a/13840436
try:
self.parameter_count = functools.reduce(operator.mul, self.shape)
except TypeError:
self.parameter_count = 1 # scalar value has no shape
@dataclass
class SafetensorsFileMetadata:
"""Metadata for a Safetensors file hosted on the Hub.
This class is returned by [`parse_safetensors_file_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`):
The metadata contained in the file.
tensors (`Dict[str, TensorInfo]`):
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
[`TensorInfo`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Dict[str, str]
tensors: Dict[TENSOR_NAME_T, TensorInfo]
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for tensor in self.tensors.values():
parameter_count[tensor.dtype] += tensor.parameter_count
self.parameter_count = dict(parameter_count)
@dataclass
class SafetensorsRepoMetadata:
"""Metadata for a Safetensors repo.
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
This class is returned by [`get_safetensors_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`, *optional*):
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
models.
sharded (`bool`):
Whether the repo contains a sharded model or not.
weight_map (`Dict[str, str]`):
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
files_metadata (`Dict[str, SafetensorsFileMetadata]`):
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
a [`SafetensorsFileMetadata`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Optional[Dict]
sharded: bool
weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for file_metadata in self.files_metadata.values():
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
parameter_count[dtype] += nb_parameters_
self.parameter_count = dict(parameter_count)