File size: 4,876 Bytes
c8e7ce2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import functools
import operator
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple


FILENAME_T = str
TENSOR_NAME_T = str
DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"]


class SafetensorsParsingError(Exception):
    """Raised when failing to parse a safetensors file metadata.

    This can be the case if the file is not a safetensors file or does not respect the specification.
    """


class NotASafetensorsRepoError(Exception):
    """Raised when a repo is not a Safetensors repo i.e. doesn't have either a `model.safetensors` or a
    `model.safetensors.index.json` file.
    """


@dataclass
class TensorInfo:
    """Information about a tensor.

    For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.

    Attributes:
        dtype (`str`):
            The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL").
        shape (`List[int]`):
            The shape of the tensor.
        data_offsets (`Tuple[int, int]`):
            The offsets of the data in the file as a tuple `[BEGIN, END]`.
        parameter_count (`int`):
            The number of parameters in the tensor.
    """

    dtype: DTYPE_T
    shape: List[int]
    data_offsets: Tuple[int, int]
    parameter_count: int = field(init=False)

    def __post_init__(self) -> None:
        # Taken from https://stackoverflow.com/a/13840436
        try:
            self.parameter_count = functools.reduce(operator.mul, self.shape)
        except TypeError:
            self.parameter_count = 1  # scalar value has no shape


@dataclass
class SafetensorsFileMetadata:
    """Metadata for a Safetensors file hosted on the Hub.

    This class is returned by [`parse_safetensors_file_metadata`].

    For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.

    Attributes:
        metadata (`Dict`):
            The metadata contained in the file.
        tensors (`Dict[str, TensorInfo]`):
            A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a
            [`TensorInfo`] object.
        parameter_count (`Dict[str, int]`):
            A map of the number of parameters per data type. Keys are data types and values are the number of parameters
            of that data type.
    """

    metadata: Dict[str, str]
    tensors: Dict[TENSOR_NAME_T, TensorInfo]
    parameter_count: Dict[DTYPE_T, int] = field(init=False)

    def __post_init__(self) -> None:
        parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
        for tensor in self.tensors.values():
            parameter_count[tensor.dtype] += tensor.parameter_count
        self.parameter_count = dict(parameter_count)


@dataclass
class SafetensorsRepoMetadata:
    """Metadata for a Safetensors repo.

    A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared
    model) or a 'model.safetensors.index.json' index file (sharded model) at its root.

    This class is returned by [`get_safetensors_metadata`].

    For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format.

    Attributes:
        metadata (`Dict`, *optional*):
            The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded
            models.
        sharded (`bool`):
            Whether the repo contains a sharded model or not.
        weight_map (`Dict[str, str]`):
            A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors.
        files_metadata (`Dict[str, SafetensorsFileMetadata]`):
            A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as
            a [`SafetensorsFileMetadata`] object.
        parameter_count (`Dict[str, int]`):
            A map of the number of parameters per data type. Keys are data types and values are the number of parameters
            of that data type.
    """

    metadata: Optional[Dict]
    sharded: bool
    weight_map: Dict[TENSOR_NAME_T, FILENAME_T]  # tensor name -> filename
    files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata]  # filename -> metadata
    parameter_count: Dict[DTYPE_T, int] = field(init=False)

    def __post_init__(self) -> None:
        parameter_count: Dict[DTYPE_T, int] = defaultdict(int)
        for file_metadata in self.files_metadata.values():
            for dtype, nb_parameters_ in file_metadata.parameter_count.items():
                parameter_count[dtype] += nb_parameters_
        self.parameter_count = dict(parameter_count)