| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| import gc |
| import inspect |
| import json |
| import logging |
| import os |
| import re |
| import shutil |
| import tempfile |
| import warnings |
| from collections import OrderedDict, defaultdict |
| from typing import Optional, Union |
|
|
| import torch |
| from torch import distributed as dist |
| from torch import nn |
|
|
| from ..state import AcceleratorState |
| from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME |
| from .dataclasses import AutocastKwargs, CustomDtype, DistributedType |
| from .imports import ( |
| is_hpu_available, |
| is_mlu_available, |
| is_mps_available, |
| is_musa_available, |
| is_npu_available, |
| is_peft_available, |
| is_sdaa_available, |
| is_torch_xla_available, |
| is_xpu_available, |
| ) |
| from .memory import clear_device_cache, get_xpu_available_memory |
| from .offload import load_offloaded_weight, offload_weight, save_offload_index |
| from .tqdm import is_tqdm_available, tqdm |
| from .versions import is_torch_version |
|
|
|
|
| if is_npu_available(check_device=False): |
| import torch_npu |
|
|
| if is_mlu_available(check_device=False): |
| import torch_mlu |
|
|
| if is_sdaa_available(check_device=False): |
| import torch_sdaa |
|
|
| if is_musa_available(check_device=False): |
| import torch_musa |
|
|
| from safetensors import safe_open |
| from safetensors.torch import load_file as safe_load_file |
|
|
|
|
| WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def is_peft_model(model): |
| from .other import extract_model_from_parallel |
|
|
| if is_peft_available(): |
| from peft import PeftModel |
|
|
| return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel) |
|
|
|
|
| def check_device_same(first_device, second_device): |
| """ |
| Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False` |
| for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same |
| |
| Args: |
| first_device (`torch.device`): |
| First device to check |
| second_device (`torch.device`): |
| Second device to check |
| """ |
| if first_device.type != second_device.type: |
| return False |
|
|
| if first_device.type != "cpu" and first_device.index is None: |
| |
| |
| first_device = torch.device(first_device.type, index=0) |
|
|
| if second_device.type != "cpu" and second_device.index is None: |
| |
| |
| second_device = torch.device(second_device.type, index=0) |
|
|
| return first_device == second_device |
|
|
|
|
| def convert_file_size_to_int(size: Union[int, str]): |
| """ |
| Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). |
| |
| Args: |
| size (`int` or `str`): The size to convert. Will be directly returned if an `int`. |
| |
| Example: |
| |
| ```py |
| >>> convert_file_size_to_int("1MiB") |
| 1048576 |
| ``` |
| """ |
| mem_size = -1 |
| err_msg = ( |
| f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')." |
| ) |
| try: |
| if isinstance(size, int): |
| mem_size = size |
| elif size.upper().endswith("GIB"): |
| mem_size = int(float(size[:-3]) * (2**30)) |
| elif size.upper().endswith("MIB"): |
| mem_size = int(float(size[:-3]) * (2**20)) |
| elif size.upper().endswith("KIB"): |
| mem_size = int(float(size[:-3]) * (2**10)) |
| elif size.upper().endswith("GB"): |
| int_size = int(float(size[:-2]) * (10**9)) |
| mem_size = int_size // 8 if size.endswith("b") else int_size |
| elif size.upper().endswith("MB"): |
| int_size = int(float(size[:-2]) * (10**6)) |
| mem_size = int_size // 8 if size.endswith("b") else int_size |
| elif size.upper().endswith("KB"): |
| int_size = int(float(size[:-2]) * (10**3)) |
| mem_size = int_size // 8 if size.endswith("b") else int_size |
| except ValueError: |
| raise ValueError(err_msg) |
|
|
| if mem_size < 0: |
| raise ValueError(err_msg) |
| return mem_size |
|
|
|
|
| def dtype_byte_size(dtype: torch.dtype): |
| """ |
| Returns the size (in bytes) occupied by one parameter of type `dtype`. |
| |
| Example: |
| |
| ```py |
| >>> dtype_byte_size(torch.float32) |
| 4 |
| ``` |
| """ |
| if dtype == torch.bool: |
| return 1 / 8 |
| elif dtype == CustomDtype.INT2: |
| return 1 / 4 |
| elif dtype == CustomDtype.INT4: |
| return 1 / 2 |
| elif dtype == CustomDtype.FP8: |
| return 1 |
| elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: |
| return 1 |
| bit_search = re.search(r"[^\d](\d+)$", str(dtype)) |
| if bit_search is None: |
| raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") |
| bit_size = int(bit_search.groups()[0]) |
| return bit_size // 8 |
|
|
|
|
| def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: |
| """ |
| Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For |
| example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is |
| guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with |
| non-overlapping lifetimes may have the same id. |
| """ |
| _SIZE = { |
| torch.int64: 8, |
| torch.float32: 4, |
| torch.int32: 4, |
| torch.bfloat16: 2, |
| torch.float16: 2, |
| torch.int16: 2, |
| torch.uint8: 1, |
| torch.int8: 1, |
| torch.bool: 1, |
| torch.float64: 8, |
| } |
| try: |
| storage_ptr = tensor.untyped_storage().data_ptr() |
| storage_size = tensor.untyped_storage().nbytes() |
| except Exception: |
| try: |
| |
| storage_ptr = tensor.storage().data_ptr() |
| storage_size = tensor.storage().size() * _SIZE[tensor.dtype] |
| except NotImplementedError: |
| |
| storage_ptr = 0 |
| |
| storage_size = tensor.nelement() * _SIZE[tensor.dtype] |
|
|
| return tensor.device, storage_ptr, storage_size |
|
|
|
|
| def set_module_tensor_to_device( |
| module: nn.Module, |
| tensor_name: str, |
| device: Union[int, str, torch.device], |
| value: Optional[torch.Tensor] = None, |
| dtype: Optional[Union[str, torch.dtype]] = None, |
| fp16_statistics: Optional[torch.HalfTensor] = None, |
| tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None, |
| non_blocking: bool = False, |
| clear_cache: bool = True, |
| ): |
| """ |
| A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing |
| `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). |
| |
| Args: |
| module (`torch.nn.Module`): |
| The module in which the tensor we want to move lives. |
| tensor_name (`str`): |
| The full name of the parameter/buffer. |
| device (`int`, `str` or `torch.device`): |
| The device on which to set the tensor. |
| value (`torch.Tensor`, *optional*): |
| The value of the tensor (useful when going from the meta device to any other device). |
| dtype (`torch.dtype`, *optional*): |
| If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to |
| the dtype of the existing parameter in the model. |
| fp16_statistics (`torch.HalfTensor`, *optional*): |
| The list of fp16 statistics to set on the module, used for 8 bit model serialization. |
| tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): |
| A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given |
| execution device, this parameter is useful to reuse the first available pointer of a shared weight on the |
| device for all others, instead of duplicating memory. |
| non_blocking (`bool`, *optional*, defaults to `False`): |
| If `True`, the device transfer will be asynchronous with respect to the host, if possible. |
| clear_cache (`bool`, *optional*, defaults to `True`): |
| Whether or not to clear the device cache after setting the tensor on the device. |
| """ |
| |
| if "." in tensor_name: |
| splits = tensor_name.split(".") |
| for split in splits[:-1]: |
| new_module = getattr(module, split) |
| if new_module is None: |
| raise ValueError(f"{module} has no attribute {split}.") |
| module = new_module |
| tensor_name = splits[-1] |
|
|
| if tensor_name not in module._parameters and tensor_name not in module._buffers: |
| raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") |
| is_buffer = tensor_name in module._buffers |
| old_value = getattr(module, tensor_name) |
|
|
| |
| |
| if ( |
| value is not None |
| and tied_params_map is not None |
| and value.data_ptr() in tied_params_map |
| and device in tied_params_map[value.data_ptr()] |
| ): |
| module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device] |
| return |
| elif ( |
| tied_params_map is not None |
| and old_value.data_ptr() in tied_params_map |
| and device in tied_params_map[old_value.data_ptr()] |
| ): |
| module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device] |
| return |
|
|
| if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: |
| raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") |
|
|
| param = module._parameters[tensor_name] if tensor_name in module._parameters else None |
| param_cls = type(param) |
|
|
| if value is not None: |
| |
| |
| if old_value.shape != value.shape and param_cls.__name__ != "Params4bit": |
| raise ValueError( |
| f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.' |
| ) |
|
|
| if dtype is None: |
| |
| value = value.to(old_value.dtype, non_blocking=non_blocking) |
| elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
| value = value.to(dtype, non_blocking=non_blocking) |
|
|
| device_quantization = None |
| with torch.no_grad(): |
| |
| |
| if ( |
| param is not None |
| and param.device.type not in ("cuda", "xpu") |
| and torch.device(device).type in ("cuda", "xpu") |
| and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"] |
| ): |
| device_quantization = device |
| device = "cpu" |
| |
| if isinstance(device, int): |
| if is_npu_available(): |
| device = f"npu:{device}" |
| elif is_mlu_available(): |
| device = f"mlu:{device}" |
| elif is_sdaa_available(): |
| device = f"sdaa:{device}" |
| elif is_musa_available(): |
| device = f"musa:{device}" |
| elif is_hpu_available(): |
| device = "hpu" |
| if "xpu" in str(device) and not is_xpu_available(): |
| raise ValueError(f'{device} is not available, you should use device="cpu" instead') |
| if value is None: |
| new_value = old_value.to(device, non_blocking=non_blocking) |
| if dtype is not None and device in ["meta", torch.device("meta")]: |
| if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
| new_value = new_value.to(dtype, non_blocking=non_blocking) |
|
|
| if not is_buffer: |
| module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) |
| elif isinstance(value, torch.Tensor): |
| new_value = value.to(device, non_blocking=non_blocking) |
| else: |
| new_value = torch.tensor(value, device=device) |
| if device_quantization is not None: |
| device = device_quantization |
| if is_buffer: |
| module._buffers[tensor_name] = new_value |
| elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device): |
| param_cls = type(module._parameters[tensor_name]) |
| kwargs = module._parameters[tensor_name].__dict__ |
| if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]: |
| if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32: |
| |
| new_value = new_value.to(torch.float16, non_blocking=non_blocking) |
| |
| if device == "cpu" and param_cls.__name__ == "Int8Params": |
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu") |
| new_value.CB = new_value.CB.to("cpu") |
| new_value.SCB = new_value.SCB.to("cpu") |
| else: |
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to( |
| device, non_blocking=non_blocking |
| ) |
| elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: |
| new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to( |
| device, non_blocking=non_blocking |
| ) |
| elif param_cls.__name__ in ["AffineQuantizedTensor"]: |
| new_value = new_value.to(device, non_blocking=non_blocking) |
| else: |
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to( |
| device, non_blocking=non_blocking |
| ) |
|
|
| module._parameters[tensor_name] = new_value |
| if fp16_statistics is not None: |
| module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking) |
| del fp16_statistics |
| |
| if ( |
| module.__class__.__name__ == "Linear8bitLt" |
| and getattr(module.weight, "SCB", None) is None |
| and str(module.weight.device) != "meta" |
| ): |
| |
| device_index = torch.device(device).index if torch.device(device).type == "cuda" else None |
| if not getattr(module.weight, "SCB", None) and device_index is not None: |
| if module.bias is not None and module.bias.device.type != "meta": |
| |
| module = module.cuda(device_index) |
| elif module.bias is None: |
| |
| module = module.cuda(device_index) |
| elif ( |
| module.__class__.__name__ == "Linear4bit" |
| and getattr(module.weight, "quant_state", None) is None |
| and str(module.weight.device) != "meta" |
| ): |
| |
| device_index = torch.device(device).index if torch.device(device).type == "cuda" else None |
| if not getattr(module.weight, "quant_state", None) and device_index is not None: |
| module.weight = module.weight.cuda(device_index) |
|
|
| |
| if clear_cache and device not in ("cpu", "meta"): |
| clear_device_cache() |
|
|
| |
| |
| if ( |
| tied_params_map is not None |
| and old_value.data_ptr() in tied_params_map |
| and device not in tied_params_map[old_value.data_ptr()] |
| ): |
| tied_params_map[old_value.data_ptr()][device] = new_value |
| elif ( |
| value is not None |
| and tied_params_map is not None |
| and value.data_ptr() in tied_params_map |
| and device not in tied_params_map[value.data_ptr()] |
| ): |
| tied_params_map[value.data_ptr()][device] = new_value |
|
|
|
|
| def named_module_tensors( |
| module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False |
| ): |
| """ |
| A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` |
| it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`. |
| |
| Args: |
| module (`torch.nn.Module`): |
| The module we want the tensors on. |
| include_buffer (`bool`, *optional*, defaults to `True`): |
| Whether or not to include the buffers in the result. |
| recurse (`bool`, *optional`, defaults to `False`): |
| Whether or not to go look in every submodule or just return the direct parameters and buffers. |
| remove_non_persistent (`bool`, *optional*, defaults to `False`): |
| Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers = |
| True |
| """ |
| yield from module.named_parameters(recurse=recurse) |
|
|
| if include_buffers: |
| non_persistent_buffers = set() |
| if remove_non_persistent: |
| non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse) |
| for named_buffer in module.named_buffers(recurse=recurse): |
| name, _ = named_buffer |
| if name not in non_persistent_buffers: |
| yield named_buffer |
|
|
|
|
| def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False): |
| """ |
| Gather all non persistent buffers of a given modules into a set |
| |
| Args: |
| module (`nn.Module`): |
| The module we want the non persistent buffers on. |
| recurse (`bool`, *optional*, defaults to `False`): |
| Whether or not to go look in every submodule or just return the direct non persistent buffers. |
| fqns (`bool`, *optional*, defaults to `False`): |
| Whether or not to return the fully-qualified names of the non persistent buffers. |
| """ |
|
|
| non_persistent_buffers_set = module._non_persistent_buffers_set |
| if recurse: |
| for n, m in module.named_modules(): |
| if fqns: |
| non_persistent_buffers_set |= {n + "." + b for b in m._non_persistent_buffers_set} |
| else: |
| non_persistent_buffers_set |= m._non_persistent_buffers_set |
|
|
| return non_persistent_buffers_set |
|
|
|
|
| def check_tied_parameters_in_config(model: nn.Module): |
| """ |
| Check if there is any indication in the given model that some weights should be tied. |
| |
| Args: |
| model (`torch.nn.Module`): The model to inspect |
| |
| Returns: |
| bool: True if the model needs to have tied weights |
| """ |
|
|
| |
| has_tied_word_embedding = False |
| has_tied_encoder_decoder = False |
| has_tied_module = False |
|
|
| if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]: |
| has_tied_word_embedding = False |
| model_decoder_config = None |
| if hasattr(model, "config"): |
| model_decoder_config = ( |
| model.config.get_text_config(decoder=True) |
| if hasattr(model.config, "get_text_config") |
| else model.config |
| ) |
| has_tied_word_embedding = ( |
| model_decoder_config is not None |
| and getattr(model_decoder_config, "tie_word_embeddings", False) |
| and model.get_output_embeddings() |
| ) |
|
|
| has_tied_encoder_decoder = ( |
| hasattr(model, "config") |
| and getattr(model.config, "is_encoder_decoder", False) |
| and getattr(model.config, "tie_encoder_decoder", False) |
| ) |
| has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules()) |
| return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module]) |
|
|
|
|
| def _get_param_device(param, device_map): |
| if param in device_map: |
| return device_map[param] |
| parent_param = ".".join(param.split(".")[:-1]) |
| if parent_param == param: |
| raise ValueError(f"The `device_map` does not contain the module {param}.") |
| else: |
| return _get_param_device(parent_param, device_map) |
|
|
|
|
| def check_tied_parameters_on_same_device(tied_params, device_map): |
| """ |
| Check if tied parameters are on the same device |
| |
| Args: |
| tied_params (`List[List[str]]`): |
| A list of lists of parameter names being all tied together. |
| |
| device_map (`Dict[str, Union[int, str, torch.device]]`): |
| A map that specifies where each submodule should go. |
| |
| """ |
| for tie_param in tied_params: |
| tie_param_devices = {} |
| for param in tie_param: |
| tie_param_devices[param] = _get_param_device(param, device_map) |
| if len(set(tie_param_devices.values())) > 1: |
| logger.warning( |
| f"Tied parameters are on different devices: {tie_param_devices}. " |
| "Please modify your custom device map or set `device_map='auto'`. " |
| ) |
|
|
|
|
| def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]: |
| """ |
| Find the tied parameters in a given model. |
| |
| <Tip warning={true}> |
| |
| The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore |
| them. |
| |
| </Tip> |
| |
| Args: |
| model (`torch.nn.Module`): The model to inspect. |
| |
| Returns: |
| List[List[str]]: A list of lists of parameter names being all tied together. |
| |
| Example: |
| |
| ```py |
| >>> from collections import OrderedDict |
| >>> import torch.nn as nn |
| |
| >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) |
| >>> model.linear2.weight = model.linear1.weight |
| >>> find_tied_parameters(model) |
| [['linear1.weight', 'linear2.weight']] |
| ``` |
| """ |
|
|
| |
| all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)} |
|
|
| |
| |
| no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)} |
|
|
| |
| tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys()) |
|
|
| |
| |
| tied_param_groups = {} |
| for tied_param_name in tied_param_names: |
| tied_param = all_named_parameters[tied_param_name] |
| for param_name, param in no_duplicate_named_parameters.items(): |
| |
| if param is tied_param: |
| if param_name not in tied_param_groups: |
| tied_param_groups[param_name] = [] |
| tied_param_groups[param_name].append(tied_param_name) |
|
|
| return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()] |
|
|
|
|
| def retie_parameters(model, tied_params): |
| """ |
| Reties tied parameters in a given model if the link was broken (for instance when adding hooks). |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model in which to retie parameters. |
| tied_params (`List[List[str]]`): |
| A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`. |
| """ |
| for tied_group in tied_params: |
| param_to_tie = None |
| |
| for param_name in tied_group: |
| module = model |
| splits = param_name.split(".") |
| for split in splits[:-1]: |
| module = getattr(module, split) |
| param = getattr(module, splits[-1]) |
| if param_to_tie is None and param.device != torch.device("meta"): |
| param_to_tie = param |
| break |
| if param_to_tie is not None: |
| for param_name in tied_group: |
| module = model |
| splits = param_name.split(".") |
| for split in splits[:-1]: |
| module = getattr(module, split) |
| setattr(module, splits[-1], param_to_tie) |
|
|
|
|
| def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: |
| """ |
| Just does torch.dtype(dtype) if necessary. |
| """ |
| if isinstance(dtype, str): |
| |
| dtype = dtype.replace("torch.", "") |
| dtype = getattr(torch, dtype) |
| return dtype |
|
|
|
|
| def compute_module_sizes( |
| model: nn.Module, |
| dtype: Optional[Union[str, torch.device]] = None, |
| special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, |
| buffers_only: bool = False, |
| ): |
| """ |
| Compute the size of each submodule of a given model. |
| """ |
| if dtype is not None: |
| dtype = _get_proper_dtype(dtype) |
| dtype_size = dtype_byte_size(dtype) |
| if special_dtypes is not None: |
| special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} |
| special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} |
| module_sizes = defaultdict(int) |
|
|
| module_list = [] |
|
|
| if not buffers_only: |
| module_list = named_module_tensors(model, recurse=True) |
| else: |
| module_list = model.named_buffers(recurse=True) |
|
|
| for name, tensor in module_list: |
| if special_dtypes is not None and name in special_dtypes: |
| size = tensor.numel() * special_dtypes_size[name] |
| elif dtype is None: |
| size = tensor.numel() * dtype_byte_size(tensor.dtype) |
| elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): |
| |
| |
| size = tensor.numel() * dtype_byte_size(tensor.dtype) |
| else: |
| size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) |
| name_parts = name.split(".") |
| for idx in range(len(name_parts) + 1): |
| module_sizes[".".join(name_parts[:idx])] += size |
|
|
| return module_sizes |
|
|
|
|
| def compute_module_total_buffer_size( |
| model: nn.Module, |
| dtype: Optional[Union[str, torch.device]] = None, |
| special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, |
| ): |
| """ |
| Compute the total size of buffers in each submodule of a given model. |
| """ |
| module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True) |
| return module_sizes.get("", 0) |
|
|
|
|
| def get_max_layer_size( |
| modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str] |
| ): |
| """ |
| Utility function that will scan a list of named modules and return the maximum size used by one full layer. The |
| definition of a layer being: |
| - a module with no direct children (just parameters and buffers) |
| - a module whose class name is in the list `no_split_module_classes` |
| |
| Args: |
| modules (`List[Tuple[str, torch.nn.Module]]`): |
| The list of named modules where we want to determine the maximum layer size. |
| module_sizes (`Dict[str, int]`): |
| A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). |
| no_split_module_classes (`List[str]`): |
| A list of class names for layers we don't want to be split. |
| |
| Returns: |
| `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size. |
| """ |
| max_size = 0 |
| layer_names = [] |
| modules_to_treat = modules.copy() |
| while len(modules_to_treat) > 0: |
| module_name, module = modules_to_treat.pop(0) |
| modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else [] |
| if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: |
| |
| size = module_sizes[module_name] |
| if size > max_size: |
| max_size = size |
| layer_names = [module_name] |
| elif size == max_size: |
| layer_names.append(module_name) |
| else: |
| modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat |
| return max_size, layer_names |
|
|
|
|
| def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None): |
| """ |
| Get the maximum memory available if nothing is passed, converts string to int otherwise. |
| """ |
| import psutil |
|
|
| if max_memory is None: |
| max_memory = {} |
| |
| if is_npu_available(): |
| for i in range(torch.npu.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("npu", i)) |
| max_memory[i] = torch.npu.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| elif is_mlu_available(): |
| for i in range(torch.mlu.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("mlu", i)) |
| max_memory[i] = torch.mlu.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| elif is_sdaa_available(): |
| for i in range(torch.sdaa.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("sdaa", i)) |
| max_memory[i] = torch.sdaa.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| elif is_musa_available(): |
| for i in range(torch.musa.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("musa", i)) |
| max_memory[i] = torch.musa.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| elif is_xpu_available(): |
| for i in range(torch.xpu.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("xpu", i)) |
| max_memory[i] = get_xpu_available_memory(i) |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| elif is_hpu_available(): |
| for i in range(torch.hpu.device_count()): |
| try: |
| _ = torch.tensor(0, device=torch.device("hpu", i)) |
| max_memory[i] = torch.hpu.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| else: |
| for i in range(torch.cuda.device_count()): |
| try: |
| _ = torch.tensor([0], device=i) |
| max_memory[i] = torch.cuda.mem_get_info(i)[0] |
| except Exception: |
| logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.") |
| continue |
| |
| if is_mps_available(): |
| max_memory["mps"] = psutil.virtual_memory().available |
| else: |
| max_memory["cpu"] = psutil.virtual_memory().available |
| return max_memory |
|
|
| for key in max_memory: |
| if isinstance(max_memory[key], str): |
| max_memory[key] = convert_file_size_to_int(max_memory[key]) |
|
|
| |
| |
| gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)] |
| gpu_devices.sort() |
| |
| if is_npu_available(): |
| num_devices = torch.npu.device_count() |
| elif is_mlu_available(): |
| num_devices = torch.mlu.device_count() |
| elif is_sdaa_available(): |
| num_devices = torch.sdaa.device_count() |
| elif is_musa_available(): |
| num_devices = torch.musa.device_count() |
| elif is_xpu_available(): |
| num_devices = torch.xpu.device_count() |
| elif is_hpu_available(): |
| num_devices = torch.hpu.device_count() |
| else: |
| num_devices = torch.cuda.device_count() |
| for device in gpu_devices: |
| if device >= num_devices or device < 0: |
| logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}") |
| |
| all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()] |
| |
| for k in max_memory.keys(): |
| if k not in all_devices: |
| raise ValueError( |
| f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'" |
| ) |
| max_memory = {k: max_memory[k] for k in all_devices} |
|
|
| return max_memory |
|
|
|
|
| def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""): |
| """ |
| Cleans a device_map by grouping all submodules that go on the same device together. |
| """ |
| |
| prefix = "" if module_name == "" else f"{module_name}." |
| values = [v for k, v in device_map.items() if k.startswith(prefix)] |
| if len(set(values)) == 1 and len(values) > 1: |
| for k in [k for k in device_map if k.startswith(prefix)]: |
| del device_map[k] |
| device_map[module_name] = values[0] |
|
|
| |
| children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)] |
| idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1 |
| children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules) |
| for child in children_modules: |
| clean_device_map(device_map, module_name=child) |
|
|
| return device_map |
|
|
|
|
| def load_offloaded_weights(model, index, offload_folder): |
| """ |
| Loads the weights from the offload folder into the model. |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to load the weights into. |
| index (`dict`): |
| A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the |
| model. |
| offload_folder (`str`): |
| The folder where the offloaded weights are stored. |
| """ |
| if index is None or len(index) == 0: |
| |
| return |
| for param_name, metadata in index.items(): |
| if "SCB" in param_name: |
| continue |
| fp16_statistics = None |
| if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys(): |
| weight_name = param_name.replace("weight", "SCB") |
| fp16_statistics = load_offloaded_weight( |
| os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name] |
| ) |
| tensor_file = os.path.join(offload_folder, f"{param_name}.dat") |
| weight = load_offloaded_weight(tensor_file, metadata) |
| set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics) |
|
|
|
|
| def get_module_leaves(module_sizes): |
| module_children = {} |
| for module in module_sizes: |
| if module == "" or "." not in module: |
| continue |
| parent = module.rsplit(".", 1)[0] |
| module_children[parent] = module_children.get(parent, 0) + 1 |
| leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != ""] |
| return leaves |
|
|
|
|
| def get_balanced_memory( |
| model: nn.Module, |
| max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, |
| no_split_module_classes: Optional[list[str]] = None, |
| dtype: Optional[Union[str, torch.dtype]] = None, |
| special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, |
| low_zero: bool = False, |
| ): |
| """ |
| Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU. |
| |
| <Tip> |
| |
| All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the |
| meta device (as it would if initialized within the `init_empty_weights` context manager). |
| |
| </Tip> |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to analyze. |
| max_memory (`Dict`, *optional*): |
| A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. |
| Example: `max_memory={0: "1GB"}`. |
| no_split_module_classes (`List[str]`, *optional*): |
| A list of layer class names that should never be split across device (for instance any layer that has a |
| residual connection). |
| dtype (`str` or `torch.dtype`, *optional*): |
| If provided, the weights will be converted to that type when loaded. |
| special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): |
| If provided, special dtypes to consider for some specific weights (will override dtype used as default for |
| all weights). |
| low_zero (`bool`, *optional*): |
| Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the |
| Transformers generate function). |
| """ |
| |
| user_not_set_max_memory = max_memory is None |
| max_memory = get_max_memory(max_memory) |
|
|
| if is_npu_available(): |
| expected_device_type = "npu" |
| elif is_mlu_available(): |
| expected_device_type = "mlu" |
| elif is_sdaa_available(): |
| expected_device_type = "sdaa" |
| elif is_musa_available(): |
| expected_device_type = "musa" |
| elif is_xpu_available(): |
| expected_device_type = "xpu" |
| elif is_hpu_available(): |
| expected_device_type = "hpu" |
| elif is_mps_available(): |
| expected_device_type = "mps" |
| else: |
| expected_device_type = "cuda" |
| num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0]) |
|
|
| if num_devices == 0: |
| return max_memory |
|
|
| if num_devices == 1: |
| |
| low_zero = False |
| |
| if user_not_set_max_memory: |
| for key in max_memory.keys(): |
| if isinstance(key, int): |
| max_memory[key] *= 0.9 |
| logger.info( |
| f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. " |
| "You can set `max_memory` in to a higher value to use more memory (at your own risk)." |
| ) |
| break |
|
|
| module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) |
| per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices) |
|
|
| |
| |
| |
| |
| |
| if no_split_module_classes is None: |
| no_split_module_classes = [] |
| elif not isinstance(no_split_module_classes, (list, tuple)): |
| no_split_module_classes = [no_split_module_classes] |
|
|
| |
| if len(no_split_module_classes) > 0: |
| no_split_children = {} |
| for name, size in module_sizes.items(): |
| if name == "": |
| continue |
| submodule = model |
| for submodule_name in name.split("."): |
| submodule = getattr(submodule, submodule_name) |
| class_name = submodule.__class__.__name__ |
| if class_name in no_split_module_classes and class_name not in no_split_children: |
| no_split_children[class_name] = size |
|
|
| if set(no_split_children.keys()) == set(no_split_module_classes): |
| break |
| buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0 |
| else: |
| buffer = 0 |
|
|
| |
| leaves = get_module_leaves(module_sizes) |
| leaves_set = set(leaves) |
| module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set} |
| |
| leaves = get_module_leaves(module_sizes) |
| mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1)) |
| buffer = int(1.25 * max(buffer, mean_leaves)) |
| per_gpu += buffer |
|
|
| |
| gpus_idx_list = list( |
| sorted( |
| device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0 |
| ) |
| ) |
| |
| for idx in gpus_idx_list[:-1]: |
| max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) |
|
|
| if low_zero: |
| min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)])) |
| max_memory[0] = min(min_zero, max_memory[0]) |
|
|
| return max_memory |
|
|
|
|
| def calculate_maximum_sizes(model: torch.nn.Module): |
| "Computes the total size of the model and its largest layer" |
| sizes = compute_module_sizes(model) |
| |
| no_split_modules = getattr(model, "_no_split_modules", None) |
| if no_split_modules is None: |
| no_split_modules = [] |
|
|
| modules_to_treat = ( |
| list(model.named_parameters(recurse=False)) |
| + list(model.named_children()) |
| + list(model.named_buffers(recurse=False)) |
| ) |
| largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules) |
| total_size = sizes[""] |
| return total_size, largest_layer |
|
|
|
|
| def _init_infer_auto_device_map( |
| model: nn.Module, |
| max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, |
| no_split_module_classes: Optional[list[str]] = None, |
| dtype: Optional[Union[str, torch.dtype]] = None, |
| special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None, |
| ) -> tuple[ |
| list[Union[int, str]], |
| dict[Union[int, str], Union[int, str]], |
| list[Union[int, str]], |
| list[int], |
| dict[str, int], |
| list[list[str]], |
| list[str], |
| list[tuple[str, nn.Module]], |
| ]: |
| """ |
| Initialize variables required for computing the device map for model allocation. |
| """ |
| max_memory = get_max_memory(max_memory) |
| if no_split_module_classes is None: |
| no_split_module_classes = [] |
| elif not isinstance(no_split_module_classes, (list, tuple)): |
| no_split_module_classes = [no_split_module_classes] |
|
|
| devices = list(max_memory.keys()) |
| if "disk" not in devices: |
| devices.append("disk") |
| gpus = [device for device in devices if device not in ["cpu", "disk"]] |
|
|
| |
| if "mps" in gpus: |
| main_devices = ["mps"] |
| elif len(gpus) > 0: |
| main_devices = [gpus[0], "cpu"] |
| else: |
| main_devices = ["cpu"] |
|
|
| module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) |
| tied_parameters = find_tied_parameters(model) |
| if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: |
| logger.warning( |
| "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." |
| ) |
|
|
| |
| modules_to_treat = ( |
| list(model.named_parameters(recurse=False)) |
| + list(model.named_children()) |
| + list(model.named_buffers(recurse=False)) |
| ) |
|
|
| return ( |
| devices, |
| max_memory, |
| main_devices, |
| gpus, |
| module_sizes, |
| tied_parameters, |
| no_split_module_classes, |
| modules_to_treat, |
| ) |
|
|
|
|
| def get_module_size_with_ties( |
| tied_params, |
| module_size, |
| module_sizes, |
| modules_to_treat, |
| ) -> tuple[int, list[str], list[nn.Module]]: |
| """ |
| Calculate the total size of a module, including its tied parameters. |
| |
| Args: |
| tied_params (`List[str]`): The list of tied parameters. |
| module_size (`int`): The size of the module without tied parameters. |
| module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size. |
| modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat. |
| |
| Returns: |
| `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the |
| tied modules. |
| """ |
| if len(tied_params) < 1: |
| return module_size, [], [] |
| tied_module_names = [] |
| tied_modules = [] |
|
|
| for tied_param in tied_params: |
| tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0] |
| tied_module_names.append(modules_to_treat[tied_module_index][0]) |
| tied_modules.append(modules_to_treat[tied_module_index][1]) |
|
|
| module_size_with_ties = module_size |
| for tied_param, tied_module_name in zip(tied_params, tied_module_names): |
| module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] |
|
|
| return module_size_with_ties, tied_module_names, tied_modules |
|
|
|
|
| def fallback_allocate( |
| modules: list[tuple[str, nn.Module]], |
| module_sizes: dict[str, int], |
| size_limit: Union[int, str], |
| no_split_module_classes: Optional[list[str]] = None, |
| tied_parameters: Optional[list[list[str]]] = None, |
| ) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]: |
| """ |
| Find a module that fits in the size limit using BFS and return it with its name and the remaining modules. |
| |
| Args: |
| modules (`List[Tuple[str, nn.Module]]`): |
| The list of named modules to search in. |
| module_sizes (`Dict[str, int]`): |
| A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). |
| size_limit (`Union[int, str]`): |
| The maximum size a module can have. |
| no_split_module_classes (`Optional[List[str]]`, *optional*): |
| A list of class names for layers we don't want to be split. |
| tied_parameters (`Optional[List[List[str]]`, *optional*): |
| A list of lists of parameter names being all tied together. |
| |
| Returns: |
| `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing: |
| - The name of the module that fits within the size limit. |
| - The module itself. |
| - The list of remaining modules after the found module is removed. |
| """ |
| try: |
| size_limit = convert_file_size_to_int(size_limit) |
| except ValueError: |
| return None, None, modules |
|
|
| if no_split_module_classes is None: |
| no_split_module_classes = [] |
|
|
| if tied_parameters is None: |
| tied_parameters = [] |
|
|
| modules_to_search = modules.copy() |
| module_found = False |
|
|
| while modules_to_search: |
| name, module = modules_to_search.pop(0) |
|
|
| tied_param_groups = [ |
| tied_group |
| for tied_group in tied_parameters |
| if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) |
| ] |
|
|
| tied_params = sum( |
| [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] |
| ) |
|
|
| module_size_with_ties, _, _ = get_module_size_with_ties( |
| tied_params, module_sizes[name], module_sizes, modules_to_search |
| ) |
|
|
| |
| if module_size_with_ties <= size_limit: |
| module_found = True |
| break |
|
|
| |
| modules_children = ( |
| [] |
| if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) |
| else list(module.named_children()) |
| ) |
|
|
| |
| if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: |
| continue |
|
|
| |
| modules_children = list(module.named_parameters(recurse=False)) + modules_children |
| modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search |
|
|
| if not module_found: |
| return None, None, modules |
|
|
| |
| current_names = [n for n, _ in modules] |
| dot_idx = [i for i, c in enumerate(name) if c == "."] |
|
|
| for dot_index in dot_idx: |
| parent_name = name[:dot_index] |
| if parent_name in current_names: |
| parent_module_idx = current_names.index(parent_name) |
| _, parent_module = modules[parent_module_idx] |
| module_children = list(parent_module.named_parameters(recurse=False)) + list( |
| parent_module.named_children() |
| ) |
| modules = ( |
| modules[:parent_module_idx] |
| + [(f"{parent_name}.{n}", v) for n, v in module_children] |
| + modules[parent_module_idx + 1 :] |
| ) |
| current_names = [n for n, _ in modules] |
|
|
| |
| target_idx = current_names.index(name) |
| name, module = modules.pop(target_idx) |
|
|
| return name, module, modules |
|
|
|
|
| def infer_auto_device_map( |
| model: nn.Module, |
| max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None, |
| no_split_module_classes: Optional[list[str]] = None, |
| dtype: Optional[Union[str, torch.dtype]] = None, |
| special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None, |
| verbose: bool = False, |
| clean_result: bool = True, |
| offload_buffers: bool = False, |
| fallback_allocation: bool = False, |
| ): |
| """ |
| Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, |
| such that: |
| - we don't exceed the memory available of any of the GPU. |
| - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that |
| has the largest size. |
| - if offload to the CPU is needed,we don't exceed the RAM available on the CPU. |
| - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk |
| that has the largest size. |
| |
| <Tip> |
| |
| All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the |
| meta device (as it would if initialized within the `init_empty_weights` context manager). |
| |
| </Tip> |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model to analyze. |
| max_memory (`Dict`, *optional*): |
| A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. |
| Example: `max_memory={0: "1GB"}`. |
| no_split_module_classes (`List[str]`, *optional*): |
| A list of layer class names that should never be split across device (for instance any layer that has a |
| residual connection). |
| dtype (`str` or `torch.dtype`, *optional*): |
| If provided, the weights will be converted to that type when loaded. |
| special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): |
| If provided, special dtypes to consider for some specific weights (will override dtype used as default for |
| all weights). |
| verbose (`bool`, *optional*, defaults to `False`): |
| Whether or not to provide debugging statements as the function builds the device_map. |
| clean_result (`bool`, *optional*, defaults to `True`): |
| Clean the resulting device_map by grouping all submodules that go on the same device together. |
| offload_buffers (`bool`, *optional*, defaults to `False`): |
| In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as |
| well as the parameters. |
| fallback_allocation (`bool`, *optional*, defaults to `False`): |
| When regular allocation fails, try to allocate a module that fits in the size limit using BFS. |
| """ |
|
|
| |
| ( |
| devices, |
| max_memory, |
| main_devices, |
| gpus, |
| module_sizes, |
| tied_parameters, |
| no_split_module_classes, |
| modules_to_treat, |
| ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) |
|
|
| device_map = OrderedDict() |
| current_device = 0 |
| device_memory_used = {device: 0 for device in devices} |
| device_buffer_sizes = {} |
| device_minimum_assignment_memory = {} |
|
|
| |
| max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) |
|
|
| |
| while len(modules_to_treat) > 0: |
| name, module = modules_to_treat.pop(0) |
| if verbose: |
| print(f"\nTreating module {name}.") |
| |
| max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] |
| if len(max_layer_names) == 0: |
| max_layer_size, max_layer_names = get_max_layer_size( |
| [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], |
| module_sizes, |
| no_split_module_classes, |
| ) |
| |
| module_size = module_sizes[name] |
|
|
| |
| |
| |
| |
| |
| tied_param_groups = [ |
| tied_group |
| for tied_group in tied_parameters |
| if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) |
| ] |
|
|
| if verbose and len(tied_param_groups) > 0: |
| print(f" Found the relevant tied param groups {tied_param_groups}") |
|
|
| |
| tied_params = sum( |
| [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] |
| ) |
|
|
| if verbose and len(tied_params) > 0: |
| print(f" So those parameters need to be taken into account {tied_params}") |
|
|
| device = devices[current_device] |
| current_max_size = max_memory[device] if device != "disk" else None |
| current_memory_reserved = 0 |
| |
| if devices[current_device] in main_devices: |
| current_max_size = current_max_size - max_layer_size |
| current_memory_reserved = max_layer_size |
|
|
| module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties( |
| tied_params, module_size, module_sizes, modules_to_treat |
| ) |
|
|
| |
| if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: |
| if verbose: |
| output = f"Putting {name}" |
|
|
| if tied_module_names: |
| output += f" and {tied_module_names}" |
| else: |
| output += f" (size={module_size})" |
|
|
| if current_max_size is not None: |
| output += f" (available={current_max_size - device_memory_used[device]})" |
|
|
| output += f" on {device}." |
| print(output) |
|
|
| device_memory_used[device] += module_size_with_ties |
|
|
| |
| device_map[name] = device |
|
|
| |
| for tied_module_name in tied_module_names: |
| if tied_module_name in [m[0] for m in modules_to_treat]: |
| |
| tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name) |
| |
| modules_to_treat.pop(tied_module_index) |
|
|
| |
| device_map[tied_module_name] = device |
|
|
| |
| if not offload_buffers and isinstance(module, nn.Module): |
| |
| current_buffer_size = compute_module_total_buffer_size( |
| module, dtype=dtype, special_dtypes=special_dtypes |
| ) |
| |
| device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size |
|
|
| continue |
|
|
| |
| if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size: |
| |
| if verbose: |
| print( |
| f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " |
| f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." |
| ) |
| split_happened = False |
| for tied_module_name, tied_module in zip(tied_module_names, tied_modules): |
| tied_module_children = list(tied_module.named_children()) |
| if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: |
| |
| continue |
|
|
| if verbose: |
| print(f"Splitting {tied_module_name}.") |
| tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children |
| tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] |
| tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] |
|
|
| modules_to_treat = ( |
| [(name, module)] |
| + modules_to_treat[:tied_module_index] |
| + tied_module_children |
| + modules_to_treat[tied_module_index + 1 :] |
| ) |
| |
| max_layer_size, max_layer_names = get_max_layer_size( |
| [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], |
| module_sizes, |
| no_split_module_classes, |
| ) |
| split_happened = True |
| break |
|
|
| if split_happened: |
| continue |
|
|
| |
| if verbose: |
| print("None of the tied module can be split, going to the next device.") |
|
|
| |
| if device_memory_used[device] + module_size >= current_max_size: |
| |
| modules_children = ( |
| [] |
| if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) |
| else list(module.named_children()) |
| ) |
| if verbose: |
| print( |
| f"Not enough space on {devices[current_device]} to put {name} (space available " |
| f"{current_max_size - device_memory_used[device]}, module size {module_size})." |
| ) |
| if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: |
| |
| if verbose: |
| print("This module cannot be split, going to the next device.") |
|
|
| else: |
| |
| if verbose: |
| print(f"Splitting {name}.") |
| modules_children = list(module.named_parameters(recurse=False)) + modules_children |
| modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat |
| |
| max_layer_size, max_layer_names = get_max_layer_size( |
| [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], |
| module_sizes, |
| no_split_module_classes, |
| ) |
| continue |
|
|
| |
| |
| if device_memory_used[device] == 0 and fallback_allocation and device != "disk": |
| |
| |
| current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties) |
|
|
| fallback_module_name, fallback_module, remaining_modules = fallback_allocate( |
| modules_to_treat, |
| module_sizes, |
| current_max_size - device_memory_used[device], |
| no_split_module_classes, |
| tied_parameters, |
| ) |
| |
| if fallback_module is not None: |
| modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules |
| continue |
|
|
| if device_memory_used[device] == 0: |
| device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved |
|
|
| |
| device_memory_used[device] = device_memory_used[device] + current_memory_reserved |
| current_device += 1 |
| modules_to_treat = [(name, module)] + modules_to_treat |
|
|
| device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} |
|
|
| if clean_result: |
| device_map = clean_device_map(device_map) |
|
|
| non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) |
| if non_gpu_buffer_size > 0 and not offload_buffers: |
| is_buffer_fit_any_gpu = False |
| for gpu_device, gpu_max_memory in max_memory.items(): |
| if gpu_device == "cpu" or gpu_device == "disk": |
| continue |
|
|
| if not is_buffer_fit_any_gpu: |
| gpu_memory_used = device_memory_used.get(gpu_device, 0) |
|
|
| if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: |
| is_buffer_fit_any_gpu = True |
|
|
| if len(gpus) > 0 and not is_buffer_fit_any_gpu: |
| warnings.warn( |
| f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " |
| f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " |
| f"offload_buffers=True." |
| ) |
|
|
| if device_minimum_assignment_memory: |
| devices_info = "\n".join( |
| f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items() |
| ) |
| logger.info( |
| f"Based on the current allocation process, no modules could be assigned to the following devices due to " |
| f"insufficient memory:\n" |
| f"{devices_info}\n" |
| f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing " |
| f"the available memory for these devices to at least the specified minimum, or adjusting the model config." |
| ) |
| return device_map |
|
|
|
|
| def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]): |
| """ |
| Checks a device map covers everything in a given model. |
| |
| Args: |
| model (`torch.nn.Module`): The model to check the device map against. |
| device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check. |
| """ |
| all_module_names = dict(model.named_modules()) |
| invalid_keys = [k for k in device_map if k != "" and k not in all_module_names] |
|
|
| if invalid_keys: |
| warnings.warn( |
| f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning |
| ) |
|
|
| all_model_tensors = [name for name, _ in model.state_dict().items()] |
| for module_name in device_map.keys(): |
| if module_name == "": |
| all_model_tensors.clear() |
| break |
| else: |
| all_model_tensors = [ |
| name |
| for name in all_model_tensors |
| if not name == module_name and not name.startswith(module_name + ".") |
| ] |
| if len(all_model_tensors) > 0: |
| non_covered_params = ", ".join(all_model_tensors) |
| raise ValueError( |
| f"The device_map provided does not give any device for the following parameters: {non_covered_params}" |
| ) |
|
|
|
|
| def load_state_dict(checkpoint_file, device_map=None): |
| """ |
| Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the |
| weights can be fast-loaded directly on the GPU. |
| |
| Args: |
| checkpoint_file (`str`): The path to the checkpoint to load. |
| device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): |
| A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer |
| name, once a given module name is inside, every submodule of it will be sent to the same device. |
| """ |
| if checkpoint_file.endswith(".safetensors"): |
| with safe_open(checkpoint_file, framework="pt") as f: |
| metadata = f.metadata() |
| weight_names = f.keys() |
|
|
| if metadata is None: |
| logger.warning( |
| f"The safetensors archive passed at {checkpoint_file} does not contain metadata. " |
| "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata." |
| ) |
| metadata = {"format": "pt"} |
|
|
| if metadata.get("format") not in ["pt", "tf", "flax"]: |
| raise OSError( |
| f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " |
| "you save your model with the `save_pretrained` method." |
| ) |
| elif metadata["format"] != "pt": |
| raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.") |
| if device_map is None: |
| return safe_load_file(checkpoint_file) |
| else: |
| |
| if len(set(device_map.values())) == 1: |
| device = list(device_map.values())[0] |
| target_device = device |
| if isinstance(device, int): |
| if is_npu_available(): |
| target_device = f"npu:{device}" |
| elif is_hpu_available(): |
| target_device = "hpu" |
|
|
| return safe_load_file(checkpoint_file, device=target_device) |
|
|
| devices = list(set(device_map.values()) - {"disk"}) |
| |
| if "cpu" not in devices: |
| devices.append("cpu") |
|
|
| |
| device_weights = {device: [] for device in devices} |
| for module_name, device in device_map.items(): |
| if device in devices: |
| device_weights[device].extend( |
| [k for k in weight_names if k == module_name or k.startswith(module_name + ".")] |
| ) |
|
|
| |
| device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])]) |
| tensors = {} |
| if is_tqdm_available(): |
| progress_bar = tqdm( |
| main_process_only=False, |
| total=sum([len(device_weights[device]) for device in devices]), |
| unit="w", |
| smoothing=0, |
| leave=False, |
| ) |
| else: |
| progress_bar = None |
| for device in devices: |
| target_device = device |
| if isinstance(device, int): |
| if is_npu_available(): |
| target_device = f"npu:{device}" |
| elif is_hpu_available(): |
| target_device = "hpu" |
|
|
| with safe_open(checkpoint_file, framework="pt", device=target_device) as f: |
| for key in device_weights[device]: |
| if progress_bar is not None: |
| progress_bar.set_postfix(dev=device, refresh=False) |
| progress_bar.set_description(key) |
| tensors[key] = f.get_tensor(key) |
| if progress_bar is not None: |
| progress_bar.update() |
| if progress_bar is not None: |
| progress_bar.close() |
|
|
| return tensors |
| else: |
| return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True) |
|
|
|
|
| def get_state_dict_offloaded_model(model: nn.Module): |
| """ |
| Returns the state dictionary for an offloaded model via iterative onloading |
| |
| Args: |
| model (`torch.nn.Module`): |
| The offloaded model we want to save |
| """ |
|
|
| state_dict = {} |
| placeholders = set() |
| for name, module in model.named_modules(): |
| if name == "": |
| continue |
|
|
| try: |
| with align_module_device(module, "cpu"): |
| module_state_dict = module.state_dict() |
| except MemoryError: |
| raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None |
|
|
| for key in module_state_dict: |
| |
| if module_state_dict[key].device == torch.device("meta"): |
| placeholders.add(name + f".{key}") |
| continue |
| params = module_state_dict[key] |
| state_dict[name + f".{key}"] = params.to("cpu") |
| for key in placeholders.copy(): |
| if key in state_dict: |
| placeholders.remove(key) |
| if placeholders: |
| logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}") |
|
|
| return state_dict |
|
|
|
|
| def get_state_dict_from_offload( |
| module: nn.Module, |
| module_name: str, |
| state_dict: dict[str, Union[str, torch.tensor]], |
| device_to_put_offload: Union[int, str, torch.device] = "cpu", |
| ): |
| """ |
| Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults |
| to cpu). |
| |
| Args: |
| module: (`torch.nn.Module`): |
| The module we want to retrieve a state dictionary from |
| module_name: (`str`): |
| The name of the module of interest |
| state_dict (`Dict[str, Union[int, str, torch.device]]`): |
| Dictionary of {module names: parameters} |
| device_to_put_offload (`Union[int, str, torch.device]`): |
| Device to load offloaded parameters into, defaults to the cpu. |
| """ |
|
|
| root = module_name[: module_name.rfind(".")] |
|
|
| |
| if not has_offloaded_params(module): |
| device_to_put_offload = None |
|
|
| |
| with align_module_device(module, device_to_put_offload): |
| for m_key, params in module.state_dict().items(): |
| if (root + f".{m_key}") in state_dict: |
| state_dict[root + f".{m_key}"] = params |
|
|
| return state_dict |
|
|
|
|
| def load_checkpoint_in_model( |
| model: nn.Module, |
| checkpoint: Union[str, os.PathLike], |
| device_map: Optional[dict[str, Union[int, str, torch.device]]] = None, |
| offload_folder: Optional[Union[str, os.PathLike]] = None, |
| dtype: Optional[Union[str, torch.dtype]] = None, |
| offload_state_dict: bool = False, |
| offload_buffers: bool = False, |
| keep_in_fp32_modules: Optional[list[str]] = None, |
| offload_8bit_bnb: bool = False, |
| strict: bool = False, |
| full_state_dict: bool = True, |
| broadcast_from_rank0: bool = False, |
| ): |
| """ |
| Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are |
| loaded. |
| |
| <Tip warning={true}> |
| |
| Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To |
| group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`]. |
| |
| </Tip> |
| |
| Args: |
| model (`torch.nn.Module`): |
| The model in which we want to load a checkpoint. |
| checkpoint (`str` or `os.PathLike`): |
| The folder checkpoint to load. It can be: |
| - a path to a file containing a whole model state dict |
| - a path to a `.json` file containing the index to a sharded checkpoint |
| - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. |
| - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file. |
| device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): |
| A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer |
| name, once a given module name is inside, every submodule of it will be sent to the same device. |
| offload_folder (`str` or `os.PathLike`, *optional*): |
| If the `device_map` contains any value `"disk"`, the folder where we will offload weights. |
| dtype (`str` or `torch.dtype`, *optional*): |
| If provided, the weights will be converted to that type when loaded. |
| offload_state_dict (`bool`, *optional*, defaults to `False`): |
| If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if |
| the weight of the CPU state dict + the biggest shard does not fit. |
| offload_buffers (`bool`, *optional*, defaults to `False`): |
| Whether or not to include the buffers in the weights offloaded to disk. |
| keep_in_fp32_modules(`List[str]`, *optional*): |
| A list of the modules that we keep in `torch.float32` dtype. |
| offload_8bit_bnb (`bool`, *optional*): |
| Whether or not to enable offload of 8-bit modules on cpu/disk. |
| strict (`bool`, *optional*, defaults to `False`): |
| Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's |
| state_dict. |
| full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the |
| loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict. |
| broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed |
| `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors |
| in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable) |
| according to the local shards in the model. |
| |
| """ |
| if offload_8bit_bnb: |
| from .bnb import quantize_and_offload_8bit |
|
|
| tied_params = find_tied_parameters(model) |
|
|
| if check_tied_parameters_in_config(model) and len(tied_params) == 0: |
| logger.warning( |
| "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." |
| ) |
| if device_map is not None: |
| check_tied_parameters_on_same_device(tied_params, device_map) |
|
|
| if offload_folder is None and device_map is not None and "disk" in device_map.values(): |
| raise ValueError( |
| "At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`." |
| ) |
| elif offload_folder is not None and device_map is not None and "disk" in device_map.values(): |
| os.makedirs(offload_folder, exist_ok=True) |
|
|
| if isinstance(dtype, str): |
| |
| dtype = dtype.replace("torch.", "") |
| dtype = getattr(torch, dtype) |
|
|
| checkpoint_files = None |
| index_filename = None |
| if os.path.isfile(checkpoint): |
| if str(checkpoint).endswith(".json"): |
| index_filename = checkpoint |
| else: |
| checkpoint_files = [checkpoint] |
| elif os.path.isdir(checkpoint): |
| |
| potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME] |
| potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME] |
| if len(potential_state_bin) == 1: |
| checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])] |
| elif len(potential_state_safetensor) == 1: |
| checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] |
| else: |
| |
| potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] |
| if len(potential_index) == 0: |
| raise ValueError( |
| f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" |
| ) |
| elif len(potential_index) == 1: |
| index_filename = os.path.join(checkpoint, potential_index[0]) |
| else: |
| raise ValueError( |
| f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." |
| ) |
| else: |
| raise ValueError( |
| "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " |
| f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." |
| ) |
|
|
| if index_filename is not None: |
| checkpoint_folder = os.path.split(index_filename)[0] |
| with open(index_filename) as f: |
| index = json.loads(f.read()) |
|
|
| if "weight_map" in index: |
| index = index["weight_map"] |
| checkpoint_files = sorted(list(set(index.values()))) |
| checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] |
|
|
| |
|
|
| offload_index = {} |
| if offload_state_dict: |
| state_dict_folder = tempfile.mkdtemp() |
| state_dict_index = {} |
|
|
| unexpected_keys = set() |
| model_keys = set(model.state_dict().keys()) |
| buffer_names = [name for name, _ in model.named_buffers()] |
| model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)} |
| model_physical_devices = model_devices - {torch.device("meta")} |
| for checkpoint_file in checkpoint_files: |
| if device_map is None: |
| |
| |
| |
| if is_torch_version(">=", "2.2.0") and ( |
| (is_torch_version(">=", "2.7.0") and len(model_physical_devices) <= 1) or len(model_devices) <= 1 |
| ): |
| from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict |
|
|
| broadcast_from_rank0 &= is_torch_version(">=", "2.4.0") |
| loaded_checkpoint = ( |
| load_state_dict(checkpoint_file, device_map=device_map) |
| if not broadcast_from_rank0 or dist.get_rank() == 0 |
| else {} |
| ) |
| set_model_state_dict( |
| model, |
| loaded_checkpoint, |
| options=StateDictOptions( |
| full_state_dict=full_state_dict, |
| strict=strict, |
| **({"broadcast_from_rank0": broadcast_from_rank0} if is_torch_version(">=", "2.4.0") else {}), |
| ), |
| ) |
| else: |
| loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map) |
| model.load_state_dict(loaded_checkpoint, strict=strict) |
|
|
| unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys) |
| else: |
| loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map) |
|
|
| for param_name, param in loaded_checkpoint.items(): |
| |
| if "SCB" in param_name: |
| continue |
|
|
| if param_name not in model_keys: |
| unexpected_keys.add(param_name) |
| if not strict: |
| continue |
|
|
| module_name = param_name |
|
|
| while len(module_name) > 0 and module_name not in device_map: |
| module_name = ".".join(module_name.split(".")[:-1]) |
| if module_name == "" and "" not in device_map: |
| |
| raise ValueError(f"{param_name} doesn't have any device set.") |
| param_device = device_map[module_name] |
| new_dtype = dtype |
| if dtype is not None and torch.is_floating_point(param): |
| if keep_in_fp32_modules is not None and dtype == torch.float16: |
| proceed = False |
| for key in keep_in_fp32_modules: |
| if ((key in param_name) and (key + "." in param_name)) or key == param_name: |
| proceed = True |
| break |
| if proceed: |
| new_dtype = torch.float32 |
|
|
| if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys(): |
| if param.dtype == torch.int8: |
| fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")] |
| else: |
| fp16_statistics = None |
|
|
| if param_device == "disk": |
| if offload_buffers or param_name not in buffer_names: |
| if new_dtype is None: |
| new_dtype = param.dtype |
| if offload_8bit_bnb: |
| quantize_and_offload_8bit( |
| model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics |
| ) |
| continue |
| else: |
| set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) |
| offload_weight(param, param_name, offload_folder, index=offload_index) |
| elif param_device == "cpu" and offload_state_dict: |
| if new_dtype is None: |
| new_dtype = param.dtype |
| if offload_8bit_bnb: |
| quantize_and_offload_8bit( |
| model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics |
| ) |
| else: |
| set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) |
| offload_weight(param, param_name, state_dict_folder, index=state_dict_index) |
| else: |
| set_module_tensor_to_device( |
| model, |
| param_name, |
| param_device, |
| value=param, |
| dtype=new_dtype, |
| fp16_statistics=fp16_statistics, |
| ) |
|
|
| |
| del loaded_checkpoint |
| gc.collect() |
|
|
| if not strict and len(unexpected_keys) > 0: |
| logger.warning( |
| f"Some weights of the model checkpoint at {checkpoint} were not used when" |
| f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint." |
| ) |
|
|
| save_offload_index(offload_index, offload_folder) |
|
|
| |
| if offload_state_dict: |
| load_offloaded_weights(model, state_dict_index, state_dict_folder) |
| shutil.rmtree(state_dict_folder) |
|
|
| retie_parameters(model, tied_params) |
|
|
|
|
| def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None): |
| """ |
| Return a context manager for autocasting mixed precision |
| |
| Args: |
| native_amp (`bool`, *optional*, defaults to False): |
| Whether mixed precision is actually enabled. |
| cache_enabled (`bool`, *optional*, defaults to True): |
| Whether the weight cache inside autocast should be enabled. |
| """ |
| state = AcceleratorState() |
| if autocast_kwargs is None: |
| autocast_kwargs = {} |
| else: |
| autocast_kwargs = autocast_kwargs.to_kwargs() |
| if native_amp: |
| device_type = ( |
| "cuda" |
| if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True)) |
| else state.device.type |
| ) |
| if state.mixed_precision == "fp16": |
| return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs) |
| elif state.mixed_precision in ["bf16", "fp8"] and state.distributed_type in [ |
| DistributedType.NO, |
| DistributedType.MULTI_CPU, |
| DistributedType.MULTI_GPU, |
| DistributedType.MULTI_MLU, |
| DistributedType.MULTI_SDAA, |
| DistributedType.MULTI_MUSA, |
| DistributedType.MULTI_NPU, |
| DistributedType.MULTI_XPU, |
| DistributedType.MULTI_HPU, |
| DistributedType.FSDP, |
| DistributedType.XLA, |
| ]: |
| return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs) |
| else: |
| return torch.autocast(device_type=device_type, **autocast_kwargs) |
| else: |
| return contextlib.nullcontext() |
|
|
|
|
| def get_grad_scaler(distributed_type: DistributedType = None, **kwargs): |
| """ |
| A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return |
| it. |
| |
| Args: |
| distributed_type (`DistributedType`, *optional*, defaults to None): |
| The type of distributed environment. |
| kwargs: |
| Additional arguments for the utilized `GradScaler` constructor. |
| """ |
| if distributed_type == DistributedType.FSDP: |
| from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler |
|
|
| return ShardedGradScaler(**kwargs) |
| if is_torch_xla_available(check_is_gpu=True): |
| import torch_xla.amp as xamp |
|
|
| return xamp.GradScaler(**kwargs) |
| elif is_mlu_available(): |
| return torch.mlu.amp.GradScaler(**kwargs) |
| elif is_sdaa_available(): |
| return torch.sdaa.amp.GradScaler(**kwargs) |
| elif is_musa_available(): |
| return torch.musa.amp.GradScaler(**kwargs) |
| elif is_npu_available(): |
| return torch.npu.amp.GradScaler(**kwargs) |
| elif is_hpu_available(): |
| return torch.amp.GradScaler("hpu", **kwargs) |
| elif is_xpu_available(): |
| return torch.amp.GradScaler("xpu", **kwargs) |
| elif is_mps_available(): |
| if not is_torch_version(">=", "2.8.0"): |
| raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0") |
| return torch.amp.GradScaler("mps", **kwargs) |
| else: |
| if is_torch_version(">=", "2.3"): |
| return torch.amp.GradScaler("cuda", **kwargs) |
| else: |
| return torch.cuda.amp.GradScaler(**kwargs) |
|
|
|
|
| def has_offloaded_params(module: torch.nn.Module) -> bool: |
| """ |
| Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with |
| offloading enabled |
| |
| Args: |
| module (`torch.nn.Module`): The module to check for an offload hook. |
| |
| Returns: |
| bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise. |
| """ |
| from ..hooks import AlignDevicesHook |
|
|
| return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload |
|
|
|
|
| @contextlib.contextmanager |
| def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None): |
| """ |
| Context manager that moves a module's parameters to the specified execution device. |
| |
| Args: |
| module (`torch.nn.Module`): |
| Module with parameters to align. |
| execution_device (`torch.device`, *optional*): |
| If provided, overrides the module's execution device within the context. Otherwise, use hook execution |
| device or pass |
| """ |
| if has_offloaded_params(module): |
| if execution_device is not None: |
| original_device = module._hf_hook.execution_device |
| module._hf_hook.execution_device = execution_device |
|
|
| try: |
| module._hf_hook.pre_forward(module) |
| yield |
| finally: |
| module._hf_hook.post_forward(module, None) |
| if execution_device is not None: |
| module._hf_hook.execution_device = original_device |
|
|
| elif execution_device is not None: |
| devices = {name: param.device for name, param in module.named_parameters(recurse=False)} |
| try: |
| for name in devices: |
| set_module_tensor_to_device(module, name, execution_device) |
| yield |
| finally: |
| for name, device in devices.items(): |
| set_module_tensor_to_device(module, name, device) |
|
|
| else: |
| yield |
|
|