| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import importlib |
| | import inspect |
| | import os |
| | from collections import OrderedDict |
| | from pathlib import Path |
| | from typing import List, Optional, Union |
| |
|
| | import safetensors |
| | import torch |
| | from huggingface_hub.utils import EntryNotFoundError |
| |
|
| | from ..utils import ( |
| | SAFE_WEIGHTS_INDEX_NAME, |
| | SAFETENSORS_FILE_EXTENSION, |
| | WEIGHTS_INDEX_NAME, |
| | _add_variant, |
| | _get_model_file, |
| | is_accelerate_available, |
| | is_torch_version, |
| | logging, |
| | ) |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | _CLASS_REMAPPING_DICT = { |
| | "Transformer2DModel": { |
| | "ada_norm_zero": "DiTTransformer2DModel", |
| | "ada_norm_single": "PixArtTransformer2DModel", |
| | } |
| | } |
| |
|
| |
|
| | if is_accelerate_available(): |
| | from accelerate import infer_auto_device_map |
| | from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device |
| |
|
| |
|
| | |
| | def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): |
| | if isinstance(device_map, str): |
| | no_split_modules = model._get_no_split_modules(device_map) |
| | device_map_kwargs = {"no_split_module_classes": no_split_modules} |
| |
|
| | if device_map != "sequential": |
| | max_memory = get_balanced_memory( |
| | model, |
| | dtype=torch_dtype, |
| | low_zero=(device_map == "balanced_low_0"), |
| | max_memory=max_memory, |
| | **device_map_kwargs, |
| | ) |
| | else: |
| | max_memory = get_max_memory(max_memory) |
| |
|
| | device_map_kwargs["max_memory"] = max_memory |
| | device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) |
| |
|
| | return device_map |
| |
|
| |
|
| | def _fetch_remapped_cls_from_config(config, old_class): |
| | previous_class_name = old_class.__name__ |
| | remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None) |
| |
|
| | |
| | |
| | if remapped_class_name: |
| | |
| | diffusers_library = importlib.import_module(__name__.split(".")[0]) |
| | remapped_class = getattr(diffusers_library, remapped_class_name) |
| | logger.info( |
| | f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type." |
| | f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this" |
| | " DOESN'T affect the final results." |
| | ) |
| | return remapped_class |
| | else: |
| | return old_class |
| |
|
| |
|
| | def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): |
| | """ |
| | Reads a checkpoint file, returning properly formatted errors if they arise. |
| | """ |
| | try: |
| | file_extension = os.path.basename(checkpoint_file).split(".")[-1] |
| | if file_extension == SAFETENSORS_FILE_EXTENSION: |
| | return safetensors.torch.load_file(checkpoint_file, device="cpu") |
| | else: |
| | weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} |
| | return torch.load( |
| | checkpoint_file, |
| | map_location="cpu", |
| | **weights_only_kwarg, |
| | ) |
| | except Exception as e: |
| | try: |
| | with open(checkpoint_file) as f: |
| | if f.read().startswith("version"): |
| | raise OSError( |
| | "You seem to have cloned a repository without having git-lfs installed. Please install " |
| | "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " |
| | "you cloned." |
| | ) |
| | else: |
| | raise ValueError( |
| | f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " |
| | "model. Make sure you have saved the model properly." |
| | ) from e |
| | except (UnicodeDecodeError, ValueError): |
| | raise OSError( |
| | f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " |
| | ) |
| |
|
| |
|
| | def load_model_dict_into_meta( |
| | model, |
| | state_dict: OrderedDict, |
| | device: Optional[Union[str, torch.device]] = None, |
| | dtype: Optional[Union[str, torch.dtype]] = None, |
| | model_name_or_path: Optional[str] = None, |
| | ) -> List[str]: |
| | device = device or torch.device("cpu") |
| | dtype = dtype or torch.float32 |
| |
|
| | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) |
| |
|
| | unexpected_keys = [] |
| | empty_state_dict = model.state_dict() |
| | for param_name, param in state_dict.items(): |
| | if param_name not in empty_state_dict: |
| | unexpected_keys.append(param_name) |
| | continue |
| |
|
| | if empty_state_dict[param_name].shape != param.shape: |
| | model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" |
| | raise ValueError( |
| | f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." |
| | ) |
| |
|
| | if accepts_dtype: |
| | set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) |
| | else: |
| | set_module_tensor_to_device(model, param_name, device, value=param) |
| | return unexpected_keys |
| |
|
| |
|
| | def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: |
| | |
| | |
| | state_dict = state_dict.copy() |
| | error_msgs = [] |
| |
|
| | |
| | |
| | def load(module: torch.nn.Module, prefix: str = ""): |
| | args = (state_dict, prefix, {}, True, [], [], error_msgs) |
| | module._load_from_state_dict(*args) |
| |
|
| | for name, child in module._modules.items(): |
| | if child is not None: |
| | load(child, prefix + name + ".") |
| |
|
| | load(model_to_load) |
| |
|
| | return error_msgs |
| |
|
| |
|
| | def _fetch_index_file( |
| | is_local, |
| | pretrained_model_name_or_path, |
| | subfolder, |
| | use_safetensors, |
| | cache_dir, |
| | variant, |
| | force_download, |
| | resume_download, |
| | proxies, |
| | local_files_only, |
| | token, |
| | revision, |
| | user_agent, |
| | commit_hash, |
| | ): |
| | if is_local: |
| | index_file = Path( |
| | pretrained_model_name_or_path, |
| | subfolder or "", |
| | _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), |
| | ) |
| | else: |
| | index_file_in_repo = Path( |
| | subfolder or "", |
| | _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant), |
| | ).as_posix() |
| | try: |
| | index_file = _get_model_file( |
| | pretrained_model_name_or_path, |
| | weights_name=index_file_in_repo, |
| | cache_dir=cache_dir, |
| | force_download=force_download, |
| | resume_download=resume_download, |
| | proxies=proxies, |
| | local_files_only=local_files_only, |
| | token=token, |
| | revision=revision, |
| | subfolder=subfolder, |
| | user_agent=user_agent, |
| | commit_hash=commit_hash, |
| | ) |
| | index_file = Path(index_file) |
| | except (EntryNotFoundError, EnvironmentError): |
| | index_file = None |
| |
|
| | return index_file |
| |
|