File size: 4,876 Bytes
c8e7ce2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import functools
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple
FILENAME_T = str
TENSOR_NAME_T = str
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]
class SafetensorsParsingError(Exception):
"""Raised when failing to parse a safetensors file metadata.
This can be the case if the file is not a safetensors file or does not respect the specification.
"""
class NotASafetensorsRepoError(Exception):
"""Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a
`model.safetensors.index.json` file.
"""
@dataclass
class TensorInfo:
"""Information about a tensor.
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
dtype (`str`):
The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
shape (`List[int]`):
The shape of the tensor.
data_offsets (`Tuple[int, int]`):
The offsets of the data in the file as a tuple `[BEGIN, END]`.
parameter_count (`int`):
The number of parameters in the tensor.
"""
dtype: DTYPE_T
shape: List[int]
data_offsets: Tuple[int, int]
parameter_count: int = field(init=False)
def __post_init__(self) -> None:
# Taken from https://stackoverflow.com/a/13840436
try:
self.parameter_count = functools.reduce(operator.mul, self.shape)
except TypeError:
self.parameter_count = 1 # scalar value has no shape
@dataclass
class SafetensorsFileMetadata:
"""Metadata for a Safetensors file hosted on the Hub.
This class is returned by [`parse_safetensors_file_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`):
The metadata contained in the file.
tensors (`Dict[str, TensorInfo]`):
A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
[`TensorInfo`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Dict[str, str]
tensors: Dict[TENSOR_NAME_T, TensorInfo]
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for tensor in self.tensors.values():
parameter_count[tensor.dtype] += tensor.parameter_count
self.parameter_count = dict(parameter_count)
@dataclass
class SafetensorsRepoMetadata:
"""Metadata for a Safetensors repo.
A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
model) or a 'model.safetensors.index.json' index file (sharded model) at its root.
This class is returned by [`get_safetensors_metadata`].
For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.
Attributes:
metadata (`Dict`, *optional*):
The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
models.
sharded (`bool`):
Whether the repo contains a sharded model or not.
weight_map (`Dict[str, str]`):
A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
files_metadata (`Dict[str, SafetensorsFileMetadata]`):
A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
a [`SafetensorsFileMetadata`] object.
parameter_count (`Dict[str, int]`):
A map of the number of parameters per data type. Keys are data types and values are the number of parameters
of that data type.
"""
metadata: Optional[Dict]
sharded: bool
weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename
files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata
parameter_count: Dict[DTYPE_T, int] = field(init=False)
def __post_init__(self) -> None:
parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
for file_metadata in self.files_metadata.values():
for dtype, nb_parameters_ in file_metadata.parameter_count.items():
parameter_count[dtype] += nb_parameters_
self.parameter_count = dict(parameter_count)
|