File size: 44,714 Bytes
c7e8396 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains pytorch-specific helpers."""
import importlib
import json
import os
import re
from collections import defaultdict, namedtuple
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
from packaging import version
from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
logger = logging.get_logger(__file__)
if TYPE_CHECKING:
import torch
# SAVING
def save_torch_model(
model: "torch.nn.Module",
save_directory: Union[str, Path],
*,
filename_pattern: Optional[str] = None,
force_contiguous: bool = True,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
):
"""
Saves a given torch model to disk, handling sharding and shared tensors issues.
See also [`save_torch_state_dict`] to save a state dict with more flexibility.
For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
[`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
safetensors (the default). Otherwise, the shards are saved as pickle.
Before saving the model, the `save_directory` is cleaned from any previous shard files.
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
<Tip warning={true}>
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
</Tip>
Args:
model (`torch.nn.Module`):
The model to save on disk.
save_directory (`str` or `Path`):
The directory in which the model will be saved.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
parameter.
force_contiguous (`boolean`, *optional*):
Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
model, but it could potentially change performance if the layout of the tensor was chosen specifically for
that reason. Defaults to `True`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
metadata (`Dict[str, str]`, *optional*):
Extra information to save along with the model. Some metadata will be added for each dropped tensors.
This information will not be enough to recover the entire shared structure but might help understanding
things.
safe_serialization (`bool`, *optional*):
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
shared_tensors_to_discard (`List[str]`, *optional*):
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
detected, it will drop the first name alphabetically.
Example:
```py
>>> from huggingface_hub import save_torch_model
>>> model = ... # A PyTorch model
# Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
>>> save_torch_model(model, "path/to/folder")
# Load model back
>>> from huggingface_hub import load_torch_model # TODO
>>> load_torch_model(model, "path/to/folder")
>>>
```
"""
save_torch_state_dict(
state_dict=model.state_dict(),
filename_pattern=filename_pattern,
force_contiguous=force_contiguous,
max_shard_size=max_shard_size,
metadata=metadata,
safe_serialization=safe_serialization,
save_directory=save_directory,
is_main_process=is_main_process,
shared_tensors_to_discard=shared_tensors_to_discard,
)
def save_torch_state_dict(
state_dict: Dict[str, "torch.Tensor"],
save_directory: Union[str, Path],
*,
filename_pattern: Optional[str] = None,
force_contiguous: bool = True,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
metadata: Optional[Dict[str, str]] = None,
safe_serialization: bool = True,
is_main_process: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
) -> None:
"""
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
See also [`save_torch_model`] to directly save a PyTorch model.
For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors).
The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are
saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard,
an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses
[`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as
safetensors (the default). Otherwise, the shards are saved as pickle.
Before saving the model, the `save_directory` is cleaned from any previous shard files.
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
<Tip warning={true}>
If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`):
The state dictionary to save.
save_directory (`str` or `Path`):
The directory in which the model will be saved.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization`
parameter.
force_contiguous (`boolean`, *optional*):
Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the
model, but it could potentially change performance if the layout of the tensor was chosen specifically for
that reason. Defaults to `True`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
metadata (`Dict[str, str]`, *optional*):
Extra information to save along with the model. Some metadata will be added for each dropped tensors.
This information will not be enough to recover the entire shared structure but might help understanding
things.
safe_serialization (`bool`, *optional*):
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
in a future version.
is_main_process (`bool`, *optional*):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions. Defaults to True.
shared_tensors_to_discard (`List[str]`, *optional*):
List of tensor names to drop when saving shared tensors. If not provided and shared tensors are
detected, it will drop the first name alphabetically.
Example:
```py
>>> from huggingface_hub import save_torch_state_dict
>>> model = ... # A PyTorch model
# Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors.
>>> state_dict = model_to_save.state_dict()
>>> save_torch_state_dict(state_dict, "path/to/folder")
```
"""
save_directory = str(save_directory)
if filename_pattern is None:
filename_pattern = (
constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
if safe_serialization
else constants.PYTORCH_WEIGHTS_FILE_PATTERN
)
if metadata is None:
metadata = {}
if safe_serialization:
try:
from safetensors.torch import save_file as save_file_fn
except ImportError as e:
raise ImportError(
"Please install `safetensors` to use safe serialization. "
"You can install it with `pip install safetensors`."
) from e
# Clean state dict for safetensors
state_dict = _clean_state_dict_for_safetensors(
state_dict,
metadata,
force_contiguous=force_contiguous,
shared_tensors_to_discard=shared_tensors_to_discard,
)
else:
from torch import save as save_file_fn # type: ignore[assignment]
logger.warning(
"You are using unsafe serialization. Due to security reasons, it is recommended not to load "
"pickled models from untrusted sources. If you intend to share your model, we strongly recommend "
"using safe serialization by installing `safetensors` with `pip install safetensors`."
)
# Split dict
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Only main process should clean up existing files to avoid race conditions in distributed environment
if is_main_process:
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
for filename in os.listdir(save_directory):
if existing_files_regex.match(filename):
try:
logger.debug(f"Removing existing file '{filename}' from folder.")
os.remove(os.path.join(save_directory, filename))
except Exception as e:
logger.warning(
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
)
# Save each shard
per_file_metadata = {"format": "pt"}
if not state_dict_split.is_sharded:
per_file_metadata.update(metadata)
safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {}
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs)
logger.debug(f"Shard saved to {filename}")
# Save the index (if any)
if state_dict_split.is_sharded:
index_path = filename_pattern.format(suffix="") + ".index.json"
index = {
"metadata": {**state_dict_split.metadata, **metadata},
"weight_map": state_dict_split.tensor_to_filename,
}
with open(os.path.join(save_directory, index_path), "w") as f:
json.dump(index, f, indent=2)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). "
f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. "
f"You can find where each parameters has been saved in the index located at {index_path}."
)
logger.info(f"Model weights successfully saved to {save_directory}!")
def split_torch_state_dict_into_shards(
state_dict: Dict[str, "torch.Tensor"],
*,
filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization
made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we
have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not
[6+2+2GB], [6+2GB], [6GB].
<Tip>
To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses
`split_torch_state_dict_into_shards` under the hood.
</Tip>
<Tip warning={true}>
If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a
size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`):
The state dictionary to save.
filename_pattern (`str`, *optional*):
The pattern to generate the files names in which the model will be saved. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"`.
max_shard_size (`int` or `str`, *optional*):
The maximum size of each shard, in bytes. Defaults to 5GB.
Returns:
[`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them.
Example:
```py
>>> import json
>>> import os
>>> from safetensors.torch import save_file as safe_save_file
>>> from huggingface_hub import split_torch_state_dict_into_shards
>>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
... state_dict_split = split_torch_state_dict_into_shards(state_dict)
... for filename, tensors in state_dict_split.filename_to_tensors.items():
... shard = {tensor: state_dict[tensor] for tensor in tensors}
... safe_save_file(
... shard,
... os.path.join(save_directory, filename),
... metadata={"format": "pt"},
... )
... if state_dict_split.is_sharded:
... index = {
... "metadata": state_dict_split.metadata,
... "weight_map": state_dict_split.tensor_to_filename,
... }
... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f:
... f.write(json.dumps(index, indent=2))
```
"""
return split_state_dict_into_shards_factory(
state_dict,
max_shard_size=max_shard_size,
filename_pattern=filename_pattern,
get_storage_size=get_torch_storage_size,
get_storage_id=get_torch_storage_id,
)
# LOADING
def load_torch_model(
model: "torch.nn.Module",
checkpoint_path: Union[str, os.PathLike],
*,
strict: bool = False,
safe: bool = True,
weights_only: bool = False,
map_location: Optional[Union[str, "torch.device"]] = None,
mmap: bool = False,
filename_pattern: Optional[str] = None,
) -> NamedTuple:
"""
Load a checkpoint into a model, handling both sharded and non-sharded checkpoints.
Args:
model (`torch.nn.Module`):
The model in which to load the checkpoint.
checkpoint_path (`str` or `os.PathLike`):
Path to either the checkpoint file or directory containing the checkpoint(s).
strict (`bool`, *optional*, defaults to `False`):
Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint.
safe (`bool`, *optional*, defaults to `True`):
If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function
will first attempt to load safetensors files if they are available, otherwise it will fall back to loading
pickle files. `filename_pattern` parameter takes precedence over `safe` parameter.
weights_only (`bool`, *optional*, defaults to `False`):
If True, only loads the model weights without optimizer states and other metadata.
Only supported in PyTorch >= 1.13.
map_location (`str` or `torch.device`, *optional*):
A `torch.device` object, string or a dict specifying how to remap storage locations. It
indicates the location where all tensors should be loaded.
mmap (`bool`, *optional*, defaults to `False`):
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints.
filename_pattern (`str`, *optional*):
The pattern to look for the index file. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"`.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields.
- `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint.
- `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model.
Raises:
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
If the checkpoint file or directory does not exist.
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the checkpoint path is invalid or if the checkpoint format cannot be determined.
Example:
```python
>>> from huggingface_hub import load_torch_model
>>> model = ... # A PyTorch model
>>> load_torch_model(model, "path/to/checkpoint")
```
"""
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist")
# 1. Check if checkpoint is a single file
if checkpoint_path.is_file():
state_dict = load_state_dict_from_file(
checkpoint_file=checkpoint_path,
map_location=map_location,
weights_only=weights_only,
)
return model.load_state_dict(state_dict, strict=strict)
# 2. If not, checkpoint_path is a directory
if filename_pattern is None:
filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
# Only fallback to pickle format if safetensors index is not found and safe is False.
if not index_path.is_file() and not safe:
filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN
index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json")
if index_path.is_file():
return _load_sharded_checkpoint(
model=model,
save_directory=checkpoint_path,
strict=strict,
weights_only=weights_only,
filename_pattern=filename_pattern,
)
# Look for single model file
model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin"))
if len(model_files) == 1:
state_dict = load_state_dict_from_file(
checkpoint_file=model_files[0],
map_location=map_location,
weights_only=weights_only,
mmap=mmap,
)
return model.load_state_dict(state_dict, strict=strict)
raise ValueError(
f"Directory '{checkpoint_path}' does not contain a valid checkpoint. "
"Expected either a sharded checkpoint with an index file, or a single model file."
)
def _load_sharded_checkpoint(
model: "torch.nn.Module",
save_directory: os.PathLike,
*,
strict: bool = False,
weights_only: bool = False,
filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN,
) -> NamedTuple:
"""
Loads a sharded checkpoint into a model. This is the same as
[`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model.
Args:
model (`torch.nn.Module`):
The model in which to load the checkpoint.
save_directory (`str` or `os.PathLike`):
A path to a folder containing the sharded checkpoint.
strict (`bool`, *optional*, defaults to `False`):
Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
weights_only (`bool`, *optional*, defaults to `False`):
If True, only loads the model weights without optimizer states and other metadata.
Only supported in PyTorch >= 1.13.
filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`):
The pattern to look for the index file. Pattern must be a string that
can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix`
Defaults to `"model{suffix}.safetensors"`.
Returns:
`NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields,
- `missing_keys` is a list of str containing the missing keys
- `unexpected_keys` is a list of str containing the unexpected keys
"""
# 1. Load and validate index file
# The index file contains mapping of parameter names to shard files
index_path = filename_pattern.format(suffix="") + ".index.json"
index_file = os.path.join(save_directory, index_path)
with open(index_file, "r", encoding="utf-8") as f:
index = json.load(f)
# 2. Validate keys if in strict mode
# This is done before loading any shards to fail fast
if strict:
_validate_keys_for_strict_loading(model, index["weight_map"].keys())
# 3. Load each shard using `load_state_dict`
# Get unique shard files (multiple parameters can be in same shard)
shard_files = list(set(index["weight_map"].values()))
for shard_file in shard_files:
# Load shard into memory
shard_path = os.path.join(save_directory, shard_file)
state_dict = load_state_dict_from_file(
shard_path,
map_location="cpu",
weights_only=weights_only,
)
# Update model with parameters from this shard
model.load_state_dict(state_dict, strict=strict)
# Explicitly remove the state dict from memory
del state_dict
# 4. Return compatibility info
loaded_keys = set(index["weight_map"].keys())
model_keys = set(model.state_dict().keys())
return _IncompatibleKeys(
missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys)
)
def load_state_dict_from_file(
checkpoint_file: Union[str, os.PathLike],
map_location: Optional[Union[str, "torch.device"]] = None,
weights_only: bool = False,
mmap: bool = False,
) -> Union[Dict[str, "torch.Tensor"], Any]:
"""
Loads a checkpoint file, handling both safetensors and pickle checkpoint formats.
Args:
checkpoint_file (`str` or `os.PathLike`):
Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint.
map_location (`str` or `torch.device`, *optional*):
A `torch.device` object, string or a dict specifying how to remap storage locations. It
indicates the location where all tensors should be loaded.
weights_only (`bool`, *optional*, defaults to `False`):
If True, only loads the model weights without optimizer states and other metadata.
Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when
loading safetensors files.
mmap (`bool`, *optional*, defaults to `False`):
Whether to use memory-mapped file loading. Memory mapping can improve loading performance
for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when
loading safetensors files, as the `safetensors` library uses memory mapping by default.
Returns:
`Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint.
- For safetensors files: always returns a dictionary mapping parameter names to tensors.
- For pickle files: returns any Python object that was pickled (commonly a state dict, but could be
an entire model, optimizer state, or any other Python object).
Raises:
[`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError)
If the checkpoint file does not exist.
[`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively.
[`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
If the checkpoint file format is invalid or if git-lfs files are not properly downloaded.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
If the checkpoint file path is empty or invalid.
Example:
```python
>>> from huggingface_hub import load_state_dict_from_file
# Load a PyTorch checkpoint
>>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu")
>>> model.load_state_dict(state_dict)
# Load a safetensors checkpoint
>>> state_dict = load_state_dict_from_file("path/to/model.safetensors")
>>> model.load_state_dict(state_dict)
```
"""
checkpoint_path = Path(checkpoint_file)
# Check if file exists and is a regular file (not a directory)
if not checkpoint_path.is_file():
raise FileNotFoundError(
f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and "
"the file has been properly downloaded."
)
# Load safetensors checkpoint
if checkpoint_path.suffix == ".safetensors":
try:
from safetensors import safe_open
from safetensors.torch import load_file
except ImportError as e:
raise ImportError(
"Please install `safetensors` to load safetensors checkpoint. "
"You can install it with `pip install safetensors`."
) from e
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
metadata = f.metadata()
# see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_torch_model` method."
)
device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location
# meta device is not supported with safetensors, falling back to CPU
if device == "meta":
logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.")
device = "cpu"
return load_file(checkpoint_file, device=device) # type: ignore[arg-type]
# Otherwise, load from pickle
try:
import torch
from torch import load
except ImportError as e:
raise ImportError(
"Please install `torch` to load torch tensors. You can install it with `pip install torch`."
) from e
# Add additional kwargs, mmap is only supported in torch >= 2.1.0
additional_kwargs = {}
if version.parse(torch.__version__) >= version.parse("2.1.0"):
additional_kwargs["mmap"] = mmap
# weights_only is only supported in torch >= 1.13.0
if version.parse(torch.__version__) >= version.parse("1.13.0"):
additional_kwargs["weights_only"] = weights_only
return load(
checkpoint_file,
map_location=map_location,
**additional_kwargs,
)
# HELPERS
def _validate_keys_for_strict_loading(
model: "torch.nn.Module",
loaded_keys: Iterable[str],
) -> None:
"""
Validate that model keys match loaded keys when strict loading is enabled.
Args:
model: The PyTorch model being loaded
loaded_keys: The keys present in the checkpoint
Raises:
RuntimeError: If there are missing or unexpected keys in strict mode
"""
loaded_keys_set = set(loaded_keys)
model_keys = set(model.state_dict().keys())
missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint
unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model
if missing_keys or unexpected_keys:
error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
if missing_keys:
str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)])
error_message += f"\nMissing key(s): {str_missing_keys}."
if unexpected_keys:
str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)])
error_message += f"\nUnexpected key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""Returns a unique id for plain tensor
or a (potentially nested) Tuple of unique id for the flattened Tensor
if the input is a wrapper tensor subclass Tensor
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
# device. So the following import is safe:
import torch_xla # type: ignore[import]
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
else:
unique_id = storage_ptr(tensor)
return unique_id
def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]:
"""
Return unique identifier to a tensor storage.
Multiple different tensors can share the same underlying storage. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
In the case of meta tensors, we return None since we can't tell if they share the same storage.
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
if tensor.device.type == "meta":
return None
else:
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
def get_torch_storage_size(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
try:
return tensor.untyped_storage().nbytes()
except AttributeError:
# Fallback for torch==1.10
try:
return tensor.storage().size() * _get_dtype_size(tensor.dtype)
except NotImplementedError:
# Fallback for meta storage
# On torch >=2.0 this is the tensor size
return tensor.nelement() * _get_dtype_size(tensor.dtype)
@lru_cache()
def is_torch_tpu_available(check_device=True):
"""
Checks if `torch_xla` is installed and potentially if a TPU is in the environment
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463.
"""
if importlib.util.find_spec("torch_xla") is not None:
if check_device:
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
import torch_xla.core.xla_model as xm # type: ignore[import]
_ = xm.xla_device()
return True
except RuntimeError:
return False
return True
return False
def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor) # type: ignore
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
try:
return tensor.untyped_storage().data_ptr()
except Exception:
# Fallback for torch==1.10
try:
return tensor.storage().data_ptr()
except NotImplementedError:
# Fallback for meta storage
return 0
def _clean_state_dict_for_safetensors(
state_dict: Dict[str, "torch.Tensor"],
metadata: Dict[str, str],
force_contiguous: bool = True,
shared_tensors_to_discard: Optional[List[str]] = None,
):
"""Remove shared tensors from state_dict and update metadata accordingly (for reloading).
Warning: `state_dict` and `metadata` are mutated in-place!
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155.
"""
to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard)
for kept_name, to_remove_group in to_removes.items():
for to_remove in to_remove_group:
if metadata is None:
metadata = {}
if to_remove not in metadata:
# Do not override user data
metadata[to_remove] = kept_name
del state_dict[to_remove]
if force_contiguous:
state_dict = {k: v.contiguous() for k, v in state_dict.items()}
return state_dict
def _end_ptr(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23.
"""
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype)
else:
stop = tensor.data_ptr()
return stop
def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44
"""
filtered_tensors = []
for shared in tensors:
if len(shared) < 2:
filtered_tensors.append(shared)
continue
areas = []
for name in shared:
tensor = state_dict[name]
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
areas.sort()
_, last_stop, last_name = areas[0]
filtered_tensors.append({last_name})
for start, stop, name in areas[1:]:
if start >= last_stop:
filtered_tensors.append({name})
else:
filtered_tensors[-1].add(name)
last_stop = stop
return filtered_tensors
def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69.
"""
import torch
tensors_dict = defaultdict(set)
for k, v in state_dict.items():
if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0:
# Need to add device as key because of multiple GPU.
tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k)
tensors = list(sorted(tensors_dict.values()))
tensors = _filter_shared_not_shared(tensors, state_dict)
return tensors
def _is_complete(tensor: "torch.Tensor") -> bool:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined]
return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass
return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
tensor.dtype
) == get_torch_storage_size(tensor)
def _remove_duplicate_names(
state_dict: Dict[str, "torch.Tensor"],
*,
preferred_names: Optional[List[str]] = None,
discard_names: Optional[List[str]] = None,
) -> Dict[str, List[str]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
"""
if preferred_names is None:
preferred_names = []
unique_preferred_names = set(preferred_names)
if discard_names is None:
discard_names = []
unique_discard_names = set(discard_names)
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
raise RuntimeError(
"Error while trying to find names to remove to save state dict, but found no suitable name to keep"
f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model"
" since you could be storing much more memory than needed. Please refer to"
" https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an"
" issue."
)
keep_name = sorted(list(complete_names))[0]
# Mechanism to preferentially select keys to keep
# coming from the on-disk file to allow
# loading models saved with a different choice
# of keep_name
preferred = complete_names.difference(unique_discard_names)
if preferred:
keep_name = sorted(list(preferred))[0]
if unique_preferred_names:
preferred = unique_preferred_names.intersection(complete_names)
if preferred:
keep_name = sorted(list(preferred))[0]
for name in sorted(shared):
if name != keep_name:
to_remove[keep_name].append(name)
return to_remove
@lru_cache()
def _get_dtype_size(dtype: "torch.dtype") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344
"""
import torch
# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
_SIZE = {
torch.int64: 8,
torch.float32: 4,
torch.int32: 4,
torch.bfloat16: 2,
torch.float16: 2,
torch.int16: 2,
torch.uint8: 1,
torch.int8: 1,
torch.bool: 1,
torch.float64: 8,
_float8_e4m3fn: 1,
_float8_e5m2: 1,
}
return _SIZE[dtype]
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
"""
This is used to report missing and unexpected keys in the state dict.
Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52.
"""
def __repr__(self) -> str:
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
__str__ = __repr__
|