|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import copy |
|
|
import functools |
|
|
import inspect |
|
|
import os |
|
|
import re |
|
|
import warnings |
|
|
from collections.abc import Sequence |
|
|
from contextlib import nullcontext |
|
|
from operator import attrgetter |
|
|
from typing import Any, Optional, Union |
|
|
|
|
|
import accelerate |
|
|
import torch |
|
|
import transformers |
|
|
from accelerate import FullyShardedDataParallelPlugin |
|
|
from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
|
|
from accelerate.utils import is_npu_available, is_xpu_available |
|
|
from huggingface_hub import file_exists |
|
|
from huggingface_hub.errors import EntryNotFoundError, HFValidationError |
|
|
from packaging import version |
|
|
from safetensors.torch import storage_ptr, storage_size |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available |
|
|
from .constants import ( |
|
|
CONFIG_NAME, |
|
|
EMBEDDING_LAYER_NAMES, |
|
|
INCLUDE_LINEAR_LAYERS_SHORTHAND, |
|
|
SAFETENSORS_WEIGHTS_NAME, |
|
|
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_MISS_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_OFT_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_POLY_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, |
|
|
TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING, |
|
|
WEIGHTS_NAME, |
|
|
bloom_model_postprocess_past_key_value, |
|
|
starcoder_model_postprocess_past_key_value, |
|
|
) |
|
|
|
|
|
|
|
|
mlu_available = False |
|
|
if version.parse(accelerate.__version__) >= version.parse("0.29.0"): |
|
|
from accelerate.utils import is_mlu_available |
|
|
|
|
|
mlu_available = is_mlu_available() |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"CONFIG_NAME", |
|
|
"EMBEDDING_LAYER_NAMES", |
|
|
"INCLUDE_LINEAR_LAYERS_SHORTHAND", |
|
|
"SAFETENSORS_WEIGHTS_NAME", |
|
|
"TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_MISS_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_OFT_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_POLY_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_ROAD_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", |
|
|
"TRANSFORMERS_MODELS_TO_WAVEFT_TARGET_MODULES_MAPPING", |
|
|
"WEIGHTS_NAME", |
|
|
"bloom_model_postprocess_past_key_value", |
|
|
"starcoder_model_postprocess_past_key_value", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def infer_device() -> str: |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
elif mlu_available: |
|
|
return "mlu" |
|
|
elif is_xpu_available(): |
|
|
return "xpu" |
|
|
elif is_npu_available(): |
|
|
return "npu" |
|
|
return "cpu" |
|
|
|
|
|
|
|
|
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): |
|
|
r""" |
|
|
Note this method only works for `transformers` models. |
|
|
|
|
|
This method wraps the entire protocol for preparing a model before running a training. This includes: |
|
|
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm |
|
|
head to fp32 4- Freezing the base model layers to ensure they are not updated during training |
|
|
|
|
|
|
|
|
Args: |
|
|
model (`transformers.PreTrainedModel`): |
|
|
The loaded model from `transformers` |
|
|
use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): |
|
|
If True, use gradient checkpointing to save memory at the expense of slower backward pass. |
|
|
gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): |
|
|
Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of |
|
|
`torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. |
|
|
Note this is only available in the latest transformers versions (> 4.34.1). |
|
|
""" |
|
|
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) |
|
|
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" |
|
|
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" |
|
|
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" |
|
|
is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao" |
|
|
is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False) |
|
|
|
|
|
if gradient_checkpointing_kwargs is None: |
|
|
gradient_checkpointing_kwargs = {} |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
|
|
param.requires_grad = False |
|
|
|
|
|
if ( |
|
|
not is_gptq_quantized |
|
|
and not is_aqlm_quantized |
|
|
and not is_eetq_quantized |
|
|
and not is_hqq_quantized |
|
|
and not is_torchao_quantized |
|
|
): |
|
|
|
|
|
for param in model.parameters(): |
|
|
if ( |
|
|
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16) |
|
|
) and param.__class__.__name__ != "Params4bit": |
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
if ( |
|
|
loaded_in_kbit |
|
|
or is_gptq_quantized |
|
|
or is_aqlm_quantized |
|
|
or is_eetq_quantized |
|
|
or is_hqq_quantized |
|
|
or is_torchao_quantized |
|
|
) and use_gradient_checkpointing: |
|
|
|
|
|
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: |
|
|
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
|
model.enable_input_require_grads() |
|
|
else: |
|
|
|
|
|
def make_inputs_require_grad(module, input, output): |
|
|
output.requires_grad_(True) |
|
|
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
|
|
|
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( |
|
|
inspect.signature(model.gradient_checkpointing_enable).parameters |
|
|
) |
|
|
|
|
|
if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: |
|
|
warnings.warn( |
|
|
"gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." |
|
|
" if you want to use that feature, please upgrade to the latest version of transformers.", |
|
|
FutureWarning, |
|
|
) |
|
|
|
|
|
gc_enable_kwargs = ( |
|
|
{} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} |
|
|
) |
|
|
|
|
|
|
|
|
model.gradient_checkpointing_enable(**gc_enable_kwargs) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
|
|
""" |
|
|
Shift input ids one token to the right. |
|
|
|
|
|
Args: |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids |
|
|
pad_token_id (`int`): The id of the `padding` token. |
|
|
decoder_start_token_id (`int`): The id of the `start` token. |
|
|
""" |
|
|
shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
|
|
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
|
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
|
|
|
if pad_token_id is None: |
|
|
raise ValueError("self.model.config.pad_token_id has to be defined.") |
|
|
|
|
|
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
|
|
|
return shifted_input_ids |
|
|
|
|
|
|
|
|
class AuxiliaryTrainingWrapper(torch.nn.Module): |
|
|
"""Wrap a specific module so that it can be trained and saved in a way that is tangential to how |
|
|
PEFT normally works, e.g. fully training a classification layer instead of using an adapter. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
adapter_layer_names: tuple[str, ...] = () |
|
|
|
|
|
other_param_names: tuple[str, ...] = () |
|
|
|
|
|
merged_adapters: list[str] = [] |
|
|
|
|
|
def __init__(self, module_to_save, adapter_name, **kwargs): |
|
|
"""Extra kwargs will be passed to `self.init_modules` and `self.update`.""" |
|
|
super().__init__() |
|
|
self.original_module = module_to_save |
|
|
self._active_adapter = [adapter_name] |
|
|
self._disable_adapters = False |
|
|
self._adapters = set() |
|
|
|
|
|
self.init_modules(adapter_name, **kwargs) |
|
|
|
|
|
self.update(adapter_name, **kwargs) |
|
|
self.check_module() |
|
|
|
|
|
def init_modules(self, adapter_name, **kwargs): |
|
|
"""A place to initialize PyTorch modules in `__init__` before the call to `self.update()`.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def _get_available_adapters(self) -> set[str]: |
|
|
"""Return all adapter names that can be found on this module.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def _error_message_name(self): |
|
|
"""Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from |
|
|
`check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is |
|
|
less helpful than "modules_to_save", for example. |
|
|
""" |
|
|
return "training wrapper" |
|
|
|
|
|
def check_module(self): |
|
|
"""Perform some sanity checks on the module to ensure that it works""" |
|
|
|
|
|
|
|
|
|
|
|
forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) |
|
|
if isinstance(self.original_module, forbidden_classes): |
|
|
cls_name = self.original_module.__class__ |
|
|
raise TypeError(f"{self._error_message_name()} cannot be applied to modules of type {cls_name}") |
|
|
|
|
|
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
|
|
|
if isinstance(self.original_module, BaseTunerLayer): |
|
|
|
|
|
cls_name = self.original_module.__class__ |
|
|
raise TypeError(f"{self._error_message_name()} cannot be applied to modules of type {cls_name}") |
|
|
|
|
|
@property |
|
|
def disable_adapters(self) -> bool: |
|
|
|
|
|
return self._disable_adapters |
|
|
|
|
|
@property |
|
|
def active_adapter(self) -> Union[list[str], str]: |
|
|
|
|
|
return self._active_adapter |
|
|
|
|
|
@property |
|
|
def active_adapters(self) -> list[str]: |
|
|
if isinstance(self._active_adapter, str): |
|
|
return [self._active_adapter] |
|
|
return self._active_adapter |
|
|
|
|
|
def _hasattr_wrapped(self, name, modules): |
|
|
"""Infrastructure to enable the implementing class to delegate attributes to other modules. |
|
|
Returns True if the implementing class knows how to handle attribute `name`. |
|
|
|
|
|
Gets passed `modules` which is PyTorch's internal list of assigned modules from `nn.Module`. |
|
|
""" |
|
|
return False |
|
|
|
|
|
def _getattr_wrapped(self, name, modules): |
|
|
"""If `_hasattr_wrapped` returns True for `name`, then this function should return the corresponding |
|
|
value associated with `name`. |
|
|
""" |
|
|
return None |
|
|
|
|
|
def __getattr__(self, name: str): |
|
|
|
|
|
|
|
|
try: |
|
|
return super().__getattr__(name) |
|
|
except AttributeError: |
|
|
pass |
|
|
|
|
|
if "_modules" not in self.__dict__: |
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
|
|
|
|
|
|
|
|
|
modules = self.__dict__["_modules"] |
|
|
if self.disable_adapters: |
|
|
return getattr(self.original_module, name) |
|
|
elif self._hasattr_wrapped(name, modules): |
|
|
return self._getattr_wrapped(name, modules) |
|
|
|
|
|
|
|
|
|
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") |
|
|
|
|
|
def update(self, adapter_name, **kwargs): |
|
|
"""Called when this instance should be part of an adapter's training. |
|
|
Adds the given adapter to the list of adapters that this instance is training along with. |
|
|
|
|
|
Additional kwargs are expected to be the same kwargs that are also passed for initializing this class. |
|
|
""" |
|
|
if adapter_name not in self._adapters: |
|
|
self._adapters.add(adapter_name) |
|
|
|
|
|
def _create_new_hook(self, old_hook): |
|
|
r""" |
|
|
Creates a new hook based on the old hook. Use it only if you know what you are doing ! |
|
|
""" |
|
|
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) |
|
|
old_hook_attr = old_hook.__dict__ |
|
|
filtered_old_hook_attr = {} |
|
|
old_hook_init_signature = inspect.signature(old_hook_cls.__init__) |
|
|
for k in old_hook_attr.keys(): |
|
|
if k in old_hook_init_signature.parameters: |
|
|
filtered_old_hook_attr[k] = old_hook_attr[k] |
|
|
new_hook = old_hook_cls(**filtered_old_hook_attr) |
|
|
return new_hook |
|
|
|
|
|
def _check_forward_args(self, x, *args, **kwargs): |
|
|
"""Check if the arguments are compatible with the configs and state of the model""" |
|
|
adapter_names = kwargs.get("adapter_names", None) |
|
|
if adapter_names is None: |
|
|
return |
|
|
|
|
|
if len(x) != len(adapter_names): |
|
|
msg = ( |
|
|
"Length of `adapter_names` should be the same as the number of inputs, but got " |
|
|
f"{len(adapter_names)} and {len(x)} respectively." |
|
|
) |
|
|
raise ValueError(msg) |
|
|
|
|
|
def _forward_wrapped(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _forward_wrapped_mixed_batch( |
|
|
self, x: torch.Tensor, active_adapter: str, *args: Any, **kwargs: Any |
|
|
) -> torch.Tensor: |
|
|
raise NotImplementedError |
|
|
|
|
|
def _forward_wrapped_passthrough(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
"""The forward call when no adapter is involved in the forward computation, only the base model""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def _mixed_batch_forward( |
|
|
self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
|
|
SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) |
|
|
|
|
|
module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES]) |
|
|
|
|
|
if not isinstance(self.original_module, SUPPORTED_MODULES): |
|
|
raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.") |
|
|
|
|
|
unique_adapters = set(adapter_names) |
|
|
sub_batch_indices_list = [] |
|
|
|
|
|
for adapter in unique_adapters: |
|
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) |
|
|
|
|
|
results = [0 for _ in range(len(input))] |
|
|
|
|
|
for i, active_adapter in enumerate(unique_adapters): |
|
|
sub_batch = input[sub_batch_indices_list[i]] |
|
|
|
|
|
if active_adapter == "__base__": |
|
|
output = self.original_module(sub_batch, *args, **kwargs) |
|
|
else: |
|
|
output = self._forward_wrapped_mixed_batch(sub_batch, active_adapter, *args, **kwargs) |
|
|
|
|
|
for index, j in enumerate(sub_batch_indices_list[i]): |
|
|
results[j] = output[index] |
|
|
|
|
|
return torch.stack(results) |
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
|
self._check_forward_args(x, *args, **kwargs) |
|
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
|
|
|
if self.disable_adapters or any(adapter not in self._adapters for adapter in self.active_adapters): |
|
|
return self._forward_wrapped_passthrough(x, *args, **kwargs) |
|
|
|
|
|
if adapter_names is None: |
|
|
return self._forward_wrapped(x, *args, **kwargs) |
|
|
return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
|
|
|
|
def enable_adapters(self, enabled: bool): |
|
|
"""Toggle the enabling and disabling of adapters |
|
|
|
|
|
Args: |
|
|
enabled (bool): True to enable adapters, False to disable adapters |
|
|
""" |
|
|
if enabled: |
|
|
self._disable_adapters = False |
|
|
else: |
|
|
self._disable_adapters = True |
|
|
|
|
|
def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
|
|
"""Helper function to check if the given adapter(s) can be set. |
|
|
|
|
|
Return the name of the adapter to be set or None if no adapter should be set. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
|
|
"""Set the active adapter |
|
|
|
|
|
Args: |
|
|
adapter_names (str or list[str]): |
|
|
The name(s) of the adapter(s) to set as active |
|
|
inference_mode (bool, optional): |
|
|
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. |
|
|
""" |
|
|
if isinstance(adapter_names, str): |
|
|
self._active_adapter = adapter_names |
|
|
else: |
|
|
self._active_adapter = [] |
|
|
for adapter_name in adapter_names: |
|
|
if adapter_name not in self._adapters: |
|
|
raise ValueError(f"Adapter {adapter_name} not found in {self._adapters}") |
|
|
|
|
|
self._active_adapter.append(adapter_name) |
|
|
|
|
|
def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
|
|
"""Delete an adapter from the layer, set a new active adapter if necessary""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None: |
|
|
""" |
|
|
Enable or disable gradients on the given adapter(s). |
|
|
|
|
|
Args: |
|
|
adapter_name (`str` or `Sequence[str]`): |
|
|
The name of the adapter(s) whose gradients should be enabled/disabled. |
|
|
requires_grad (`bool`, *optional*) |
|
|
Whether to enable (`True`, default) or disable (`False`). |
|
|
""" |
|
|
if isinstance(adapter_names, str): |
|
|
adapter_names_set = {adapter_names} |
|
|
else: |
|
|
adapter_names_set = set(adapter_names) |
|
|
|
|
|
for layer_name in self.adapter_layer_names: |
|
|
|
|
|
module_dict = attrgetter(layer_name)(self) |
|
|
for key, layer in module_dict.items(): |
|
|
if key in adapter_names_set: |
|
|
layer.requires_grad_(requires_grad) |
|
|
|
|
|
def adapter_state_dict(self, adapter_name): |
|
|
"""Return the state dict of this module for a given adapter.""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def adapter_state_dict_load_map(self, adapter_name): |
|
|
"""Return a mapping from the key present in disk-loaded state dict |
|
|
and how it should be represented in the loaded model's state dict. |
|
|
|
|
|
The default should be a 1:1 mapping but it is important to define a mapping as it also serves as the |
|
|
ground-truth for which keys are supposed to be loaded from a saved state dict. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
def unload_and_optionally_merge_module( |
|
|
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
|
|
) -> torch.nn.Module: |
|
|
"""Handles unloading when called from PEFT models. Returns the wrapped module |
|
|
and handles merging onto the wrapped module if requested. |
|
|
""" |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): |
|
|
"""Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training.""" |
|
|
|
|
|
|
|
|
adapter_layer_names: tuple[str, ...] = ("modules_to_save",) |
|
|
|
|
|
def __init__(self, module_to_save, adapter_name): |
|
|
super().__init__(module_to_save, adapter_name) |
|
|
|
|
|
def init_modules(self, adapter_name): |
|
|
|
|
|
self.modules_to_save = torch.nn.ModuleDict({}) |
|
|
|
|
|
def _error_message_name(self): |
|
|
return "modules_to_save" |
|
|
|
|
|
def _forward_wrapped(self, x, *args, **kwargs): |
|
|
if not self.active_adapters: |
|
|
return self._forward_wrapped_passthrough(x, *args, **kwargs) |
|
|
return self.modules_to_save[self.active_adapters[0]](x, *args, **kwargs) |
|
|
|
|
|
def _forward_wrapped_mixed_batch(self, x, active_adapter, *args, **kwargs): |
|
|
return self.modules_to_save[active_adapter](x, *args, **kwargs) |
|
|
|
|
|
def _forward_wrapped_passthrough(self, x, *args, **kwargs): |
|
|
return self.original_module(x, *args, **kwargs) |
|
|
|
|
|
def _hasattr_wrapped(self, name, modules): |
|
|
return self.active_adapters[0] in modules["modules_to_save"] |
|
|
|
|
|
def _getattr_wrapped(self, name, modules): |
|
|
return getattr(modules["modules_to_save"][self.active_adapters[0]], name) |
|
|
|
|
|
def update(self, adapter_name, **kwargs): |
|
|
super().update(adapter_name) |
|
|
|
|
|
context_manager = nullcontext() |
|
|
for _, param in self.original_module.named_parameters(): |
|
|
num_params = param.numel() |
|
|
|
|
|
if num_params == 0 and hasattr(param, "ds_numel"): |
|
|
import deepspeed |
|
|
|
|
|
context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) |
|
|
break |
|
|
|
|
|
if adapter_name not in self.modules_to_save: |
|
|
with context_manager: |
|
|
self.modules_to_save[adapter_name] = copy.deepcopy(self.original_module) |
|
|
|
|
|
if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): |
|
|
old_hook = self.modules_to_save[adapter_name]._hf_hook |
|
|
new_hook = self._create_new_hook(old_hook) |
|
|
remove_hook_from_module(self.modules_to_save[adapter_name]) |
|
|
add_hook_to_module(self.modules_to_save[adapter_name], new_hook) |
|
|
|
|
|
self.original_module.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if adapter_name == self.active_adapter: |
|
|
self.modules_to_save[adapter_name].requires_grad_(True) |
|
|
|
|
|
def enable_adapters(self, enabled: bool): |
|
|
"""Takes care of setting the required_grad flag on the wrapped module. |
|
|
If adapters are enabled, gradients for the module are required as well. |
|
|
""" |
|
|
super().enable_adapters(enabled) |
|
|
|
|
|
if enabled: |
|
|
self.original_module.requires_grad_(False) |
|
|
for adapter_name in self.active_adapters: |
|
|
self.modules_to_save[adapter_name].requires_grad_(True) |
|
|
else: |
|
|
self.original_module.requires_grad_(True) |
|
|
self.modules_to_save.requires_grad_(False) |
|
|
|
|
|
def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
|
|
"""Helper function to check if the given adapter(s) can be set. |
|
|
|
|
|
Return the name of the adapter to be set or None if no adapter should be set. |
|
|
""" |
|
|
if isinstance(adapter_name, str): |
|
|
return adapter_name |
|
|
|
|
|
|
|
|
if len(adapter_name) == 0: |
|
|
raise ValueError("Please specify at least one adapter to set") |
|
|
|
|
|
adapter_names_in_module = [n for n in adapter_name if n in self.modules_to_save] |
|
|
|
|
|
if len(adapter_names_in_module) > 1: |
|
|
raise ValueError(f"Only one adapter can be set at a time for {self}, got {len(adapter_names_in_module)}") |
|
|
|
|
|
adapter_name_to_set: str | None |
|
|
if not adapter_names_in_module: |
|
|
adapter_name_to_set = None |
|
|
else: |
|
|
adapter_name_to_set = adapter_names_in_module[0] |
|
|
|
|
|
return adapter_name_to_set |
|
|
|
|
|
def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
|
|
"""Set the active adapter |
|
|
|
|
|
Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless |
|
|
inference_mode is True. |
|
|
|
|
|
Args: |
|
|
adapter_names (list[str], str): |
|
|
The name(s) of the adapter(s) to set as active. |
|
|
inference_mode (bool, optional): |
|
|
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. |
|
|
""" |
|
|
if isinstance(adapter_names, str): |
|
|
adapter_names = [adapter_names] |
|
|
|
|
|
if len(adapter_names) > 1: |
|
|
raise ValueError(f"Attempted to set multiple ({adapter_names}) adapters at once for modules_to_save.") |
|
|
|
|
|
if len(adapter_names) == 0: |
|
|
|
|
|
self._active_adapter = [] |
|
|
return |
|
|
|
|
|
adapter_name = adapter_names[0] |
|
|
|
|
|
if adapter_name not in self._adapters: |
|
|
raise ValueError(f"Adapter {adapter_name} not found in {self._adapters}") |
|
|
|
|
|
for currently_active_adapter_name in self.active_adapters: |
|
|
self.modules_to_save[currently_active_adapter_name].requires_grad_(False) |
|
|
self.modules_to_save[adapter_name].requires_grad_(not inference_mode) |
|
|
self._active_adapter = adapter_name |
|
|
|
|
|
def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
|
|
""" |
|
|
Delete the adapter if present. |
|
|
|
|
|
This method will also set a new active adapter if the deleted adapter was the active adapter. It is important |
|
|
that the new adapter is chosen by the caller in a deterministic way, so that the same adapter is chosen on all |
|
|
layers. |
|
|
""" |
|
|
if adapter_name not in self.modules_to_save: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(new_active_adapters, (list, tuple)) and len(new_active_adapters) > 1: |
|
|
name = self.__class__.__name__ |
|
|
raise ValueError( |
|
|
f"Attempted to set multiple ({new_active_adapters}) adapters at once for {name}, which is not allowed." |
|
|
) |
|
|
|
|
|
if adapter_name in self._adapters: |
|
|
self._adapters.remove(adapter_name) |
|
|
|
|
|
if not new_active_adapters: |
|
|
|
|
|
del self.modules_to_save[adapter_name] |
|
|
self._active_adapter = [] |
|
|
return |
|
|
|
|
|
new_active_adapter = new_active_adapters[0] |
|
|
if new_active_adapter not in self.modules_to_save: |
|
|
|
|
|
del self.modules_to_save[adapter_name] |
|
|
self._active_adapter = [] |
|
|
return |
|
|
|
|
|
if new_active_adapter != self.active_adapters[0]: |
|
|
self.set_adapter(new_active_adapter) |
|
|
del self.modules_to_save[adapter_name] |
|
|
|
|
|
def adapter_state_dict_load_map(self, adapter_name): |
|
|
|
|
|
|
|
|
if adapter_name not in self._adapters: |
|
|
|
|
|
|
|
|
return {} |
|
|
return {k: f"modules_to_save.{adapter_name}.{k}" for k in self.modules_to_save[adapter_name].state_dict()} |
|
|
|
|
|
def adapter_state_dict(self, adapter_name, state_dict): |
|
|
if adapter_name not in self._adapters: |
|
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
return { |
|
|
k: state_dict[f"modules_to_save.{adapter_name}.{k}"] |
|
|
for k in self.modules_to_save[adapter_name].state_dict() |
|
|
} |
|
|
|
|
|
def unload_and_optionally_merge_module( |
|
|
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
|
|
) -> torch.nn.Module: |
|
|
"""Unloading in case of `ModulesToSave` means to simply return the wrapped module. |
|
|
|
|
|
However, if the wrapped module is itself a tuner, we'll call merge on it before. |
|
|
""" |
|
|
new_module = self.modules_to_save[self.active_adapter] |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(new_module, "base_layer"): |
|
|
|
|
|
if merge: |
|
|
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
|
|
new_module = new_module.get_base_layer() |
|
|
|
|
|
return new_module |
|
|
|
|
|
def _get_available_adapters(self) -> set[str]: |
|
|
"""Return all adapter names that can be found on this module.""" |
|
|
return set(self.modules_to_save.keys()) |
|
|
|
|
|
|
|
|
class TrainableTokensWrapper(AuxiliaryTrainingWrapper): |
|
|
"""Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e. |
|
|
solely updating a few columns) using the `TrainableTokensLayer` PEFT method. |
|
|
|
|
|
Supports weight-tying to another adapter when passed a `tied_adapter` which is expected to be a |
|
|
`TrainableTokensLayer`. |
|
|
""" |
|
|
|
|
|
|
|
|
adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",) |
|
|
other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original") |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
module_to_save: torch.nn.Module, |
|
|
adapter_name: str, |
|
|
token_indices: list[int], |
|
|
tied_adapter=None, |
|
|
) -> None: |
|
|
super().__init__(module_to_save, adapter_name, token_indices=token_indices, tied_adapter=tied_adapter) |
|
|
|
|
|
|
|
|
self.original_module = None |
|
|
|
|
|
@property |
|
|
def original_module(self): |
|
|
|
|
|
|
|
|
return self.token_adapter.base_layer |
|
|
|
|
|
def init_modules(self, adapter_name, token_indices, tied_adapter): |
|
|
|
|
|
from peft.tuners.trainable_tokens import TrainableTokensLayer |
|
|
|
|
|
|
|
|
|
|
|
self.token_adapter = TrainableTokensLayer(self.original_module, adapter_name, token_indices, tied_adapter) |
|
|
|
|
|
def _error_message_name(self): |
|
|
return "trainable_token_indices" |
|
|
|
|
|
def _hasattr_wrapped(self, name, modules): |
|
|
return name == "weight" |
|
|
|
|
|
def _getattr_wrapped(self, name, modules): |
|
|
|
|
|
|
|
|
|
|
|
if name == "weight": |
|
|
return modules["token_adapter"].get_merged_weights(self.token_adapter.active_adapters) |
|
|
|
|
|
raise RuntimeError( |
|
|
f"This code should've never been reached, probably a bad check in `_hasattr_wrapped` for {name}. " |
|
|
"Please file an issue under https://github.com/huggingface/peft/issues." |
|
|
) |
|
|
|
|
|
def _forward_wrapped(self, x, *args, **kwargs): |
|
|
if not self.active_adapters: |
|
|
return self._forward_wrapped_passthrough(x, *args, **kwargs) |
|
|
return self.token_adapter(x) |
|
|
|
|
|
def _forward_wrapped_mixed_batch(self, x, active_adapter, *args, **kwargs): |
|
|
return self.token_adapter.forward_adapters(x, [active_adapter]) |
|
|
|
|
|
def _forward_wrapped_passthrough(self, x, *args, **kwargs): |
|
|
|
|
|
|
|
|
return self.token_adapter(x, *args, **kwargs) |
|
|
|
|
|
def update(self, active_adapter, **kwargs): |
|
|
|
|
|
|
|
|
if active_adapter not in self._adapters: |
|
|
self.token_adapter.update_layer(active_adapter, **kwargs) |
|
|
|
|
|
super().update(active_adapter) |
|
|
|
|
|
def adapter_state_dict_load_map(self, adapter_name): |
|
|
if self.token_adapter.tied_adapter: |
|
|
return {} |
|
|
return {"token_adapter.trainable_tokens_delta": f"token_adapter.trainable_tokens_delta.{adapter_name}"} |
|
|
|
|
|
def adapter_state_dict(self, adapter_name, state_dict): |
|
|
if self.token_adapter.tied_adapter: |
|
|
|
|
|
|
|
|
|
|
|
return {} |
|
|
|
|
|
return { |
|
|
f"token_adapter.{k}": state_dict[f"token_adapter.{k}.{adapter_name}"] for k in ["trainable_tokens_delta"] |
|
|
} |
|
|
|
|
|
def enable_adapters(self, enabled: bool): |
|
|
"""Enables/disables the underlying `TrainableTokens` adapter. |
|
|
Also handles the internal adapter disable flag. |
|
|
""" |
|
|
super().enable_adapters(enabled) |
|
|
|
|
|
self.token_adapter.enable_adapters(enabled) |
|
|
|
|
|
def check_set_adapter(self, adapter_name: str | list[str]) -> str | None: |
|
|
"""Helper function to check if the given adapter(s) can be set. |
|
|
|
|
|
Return the name of the adapter to be set or None if no adapter should be set. |
|
|
""" |
|
|
if isinstance(adapter_name, str): |
|
|
return adapter_name |
|
|
|
|
|
|
|
|
if len(adapter_name) == 0: |
|
|
raise ValueError("Please specify at least one adapter to set") |
|
|
|
|
|
|
|
|
adapter_names_in_module = [n for n in adapter_name if n in self.token_adapter.trainable_tokens_delta] |
|
|
|
|
|
if len(adapter_names_in_module) > 1: |
|
|
raise ValueError(f"Only one adapter can be set at a time for {self}, got {len(adapter_names_in_module)}") |
|
|
|
|
|
adapter_name_to_set: str | None |
|
|
if not adapter_names_in_module: |
|
|
adapter_name_to_set = None |
|
|
else: |
|
|
adapter_name_to_set = adapter_names_in_module[0] |
|
|
|
|
|
return adapter_name_to_set |
|
|
|
|
|
def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: |
|
|
super().set_adapter(adapter_names, inference_mode=inference_mode) |
|
|
self.token_adapter.set_adapter(adapter_names, inference_mode=inference_mode) |
|
|
|
|
|
def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: |
|
|
""" |
|
|
Delete the adapter if present. |
|
|
|
|
|
This method will also set a new active adapter if the deleted adapter was the active adapter. It is important |
|
|
that the new adapter is chosen by the caller in a deterministic way, so that the same adapter is chosen on all |
|
|
layers. |
|
|
""" |
|
|
self.token_adapter.delete_adapter(adapter_name) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(new_active_adapters, (list, tuple)) and len(new_active_adapters) > 1: |
|
|
name = self.__class__.__name__ |
|
|
raise ValueError( |
|
|
f"Attempted to set multiple ({new_active_adapters}) adapters at once for {name}, which is not allowed." |
|
|
) |
|
|
|
|
|
if adapter_name in self._adapters: |
|
|
self._adapters.remove(adapter_name) |
|
|
|
|
|
if not new_active_adapters: |
|
|
self._active_adapter = [] |
|
|
return |
|
|
|
|
|
if new_active_adapters[0] not in self.token_adapter.trainable_tokens_delta: |
|
|
|
|
|
self._active_adapter = [] |
|
|
return |
|
|
|
|
|
new_active_adapter = new_active_adapters[0] |
|
|
self.set_adapter(new_active_adapter) |
|
|
|
|
|
def unload_and_optionally_merge_module( |
|
|
self, merge: bool, safe_merge: bool, adapter_names: Optional[list[str]] |
|
|
) -> torch.nn.Module: |
|
|
"""Unloading for `TrainableTokensWrapper` means to return the wrapped module, e.g. the embedding layer and, |
|
|
if requested, merging the `TrainableTokens` adapter onto the wrapped module. |
|
|
""" |
|
|
if merge: |
|
|
self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names) |
|
|
return self.token_adapter.get_base_layer() |
|
|
|
|
|
def _get_available_adapters(self) -> set[str]: |
|
|
"""Return all adapter names that can be found on this module.""" |
|
|
return set(self.token_adapter.trainable_tokens_delta.keys()) |
|
|
|
|
|
|
|
|
def _get_input_embeddings_name(model, default=None): |
|
|
if not hasattr(model, "get_input_embeddings"): |
|
|
return default |
|
|
|
|
|
input_embeddings = model.get_input_embeddings() |
|
|
for name, module in model.named_modules(): |
|
|
if module is input_embeddings: |
|
|
return name |
|
|
|
|
|
return default |
|
|
|
|
|
|
|
|
def _get_submodules(model, key): |
|
|
parent = model.get_submodule(".".join(key.split(".")[:-1])) |
|
|
target_name = key.split(".")[-1] |
|
|
target = model.get_submodule(key) |
|
|
return parent, target, target_name |
|
|
|
|
|
|
|
|
def _get_submodules_with_grandparent(model, key): |
|
|
parent = model.get_submodule(".".join(key.split(".")[:-1])) |
|
|
try: |
|
|
grandparent = model.get_submodule(".".join(key.split(".")[:-2])) |
|
|
except AttributeError: |
|
|
|
|
|
grandparent = None |
|
|
target_name = key.split(".")[-1] |
|
|
target = model.get_submodule(key) |
|
|
return parent, grandparent, target, target_name |
|
|
|
|
|
|
|
|
def _freeze_adapter(model, adapter_name): |
|
|
for n, p in model.named_parameters(): |
|
|
if adapter_name in n: |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
def _set_trainable( |
|
|
model, |
|
|
adapter_name, |
|
|
module_names, |
|
|
inference_mode: bool, |
|
|
strict_module_check: bool = False, |
|
|
wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None, |
|
|
activate_adapter: bool = True, |
|
|
**wrapper_kwargs, |
|
|
): |
|
|
"""Wraps modules that are supposed to be re-trained either normally, i.e. marking them to require gradients and |
|
|
saving them alongside other modules, or with certain methods that go alongside PEFT methods, such as retraining |
|
|
specific token indices using selective read/write. |
|
|
|
|
|
Note that you need to validate beforehand if there are layers targeted by multiple wrappers, e.g. if the |
|
|
'embedding' layer is configured for both `ModulesToSaveWrapper` and `TrainableTokensWrapper` there would be |
|
|
conflicts down the line. |
|
|
|
|
|
The default is to wrap the module in a `ModulesToSaveWrapper` wrapper. |
|
|
|
|
|
If `strict_module_check` is set, this method raises an ValueError, similar to BaseTuner.inject_adapter when none of |
|
|
the requested modules in `module_names` is not found in the model. |
|
|
|
|
|
The `active_adapter` flag indicates if this new adapter should be activated. |
|
|
""" |
|
|
from peft.tuners.tuners_utils import BaseTunerLayer |
|
|
|
|
|
if wrapper_cls is None: |
|
|
wrapper_cls = ModulesToSaveWrapper |
|
|
|
|
|
if not module_names: |
|
|
|
|
|
|
|
|
return |
|
|
|
|
|
trainable_modules = [] |
|
|
found_modules = set() |
|
|
|
|
|
key_list = [key for key, _ in model.named_modules(remove_duplicate=False)] |
|
|
|
|
|
for key in key_list: |
|
|
target_module_found = any(key.endswith(target_key) for target_key in module_names) |
|
|
if target_module_found: |
|
|
parent, grandparent, target, target_name = _get_submodules_with_grandparent(model, key) |
|
|
if isinstance(grandparent, BaseTunerLayer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raise ValueError( |
|
|
f"You are trying to target a module with {wrapper_cls} that is a child of {type(grandparent)}. " |
|
|
"This is almost certainly not the intended behavior. Please ensure that the adapter name, " |
|
|
f"'{adapter_name}', does not conflict with any of the targeted modules." |
|
|
) |
|
|
|
|
|
if isinstance(target, wrapper_cls): |
|
|
target.update(adapter_name, **wrapper_kwargs) |
|
|
target.set_adapter(target.active_adapter, inference_mode=inference_mode) |
|
|
else: |
|
|
new_module = wrapper_cls(target, adapter_name, **wrapper_kwargs) |
|
|
if activate_adapter: |
|
|
new_module.set_adapter(adapter_name, inference_mode=inference_mode) |
|
|
else: |
|
|
new_module.set_adapter([], inference_mode=inference_mode) |
|
|
setattr(parent, target_name, new_module) |
|
|
trainable_modules.append(new_module) |
|
|
found_modules.add(target_name) |
|
|
|
|
|
not_found = set(module_names).difference(found_modules) |
|
|
if strict_module_check and not found_modules: |
|
|
raise ValueError( |
|
|
f"Target modules {not_found} not found in the base model. Please check the target modules and try again." |
|
|
) |
|
|
|
|
|
return trainable_modules |
|
|
|
|
|
|
|
|
def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False): |
|
|
for module in model.modules(): |
|
|
if isinstance(module, AuxiliaryTrainingWrapper): |
|
|
|
|
|
adapter_name_to_set = module.check_set_adapter(adapter_name) |
|
|
|
|
|
|
|
|
|
|
|
if adapter_name_to_set in module._adapters: |
|
|
module.enable_adapters(True) |
|
|
module.set_adapter(adapter_name_to_set, inference_mode=inference_mode) |
|
|
else: |
|
|
module.enable_adapters(False) |
|
|
module.set_adapter([], inference_mode=inference_mode) |
|
|
|
|
|
|
|
|
def _prepare_prompt_learning_config(peft_config, model_config): |
|
|
|
|
|
if "text_config" in model_config: |
|
|
model_config = model_config["text_config"] |
|
|
|
|
|
if peft_config.num_layers is None: |
|
|
if "num_hidden_layers" in model_config: |
|
|
num_layers = model_config["num_hidden_layers"] |
|
|
elif "num_layers" in model_config: |
|
|
num_layers = model_config["num_layers"] |
|
|
elif "n_layer" in model_config: |
|
|
num_layers = model_config["n_layer"] |
|
|
else: |
|
|
raise ValueError("Please specify `num_layers` in `peft_config`") |
|
|
peft_config.num_layers = num_layers |
|
|
|
|
|
if peft_config.token_dim is None: |
|
|
if "hidden_size" in model_config: |
|
|
token_dim = model_config["hidden_size"] |
|
|
elif "n_embd" in model_config: |
|
|
token_dim = model_config["n_embd"] |
|
|
elif "d_model" in model_config: |
|
|
token_dim = model_config["d_model"] |
|
|
else: |
|
|
raise ValueError("Please specify `token_dim` in `peft_config`") |
|
|
peft_config.token_dim = token_dim |
|
|
|
|
|
if peft_config.num_attention_heads is None: |
|
|
if "num_attention_heads" in model_config: |
|
|
num_attention_heads = model_config["num_attention_heads"] |
|
|
elif "n_head" in model_config: |
|
|
num_attention_heads = model_config["n_head"] |
|
|
elif "num_heads" in model_config: |
|
|
num_attention_heads = model_config["num_heads"] |
|
|
elif "encoder_attention_heads" in model_config: |
|
|
num_attention_heads = model_config["encoder_attention_heads"] |
|
|
else: |
|
|
raise ValueError("Please specify `num_attention_heads` in `peft_config`") |
|
|
peft_config.num_attention_heads = num_attention_heads |
|
|
|
|
|
|
|
|
if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config: |
|
|
num_key_value_heads = model_config["num_key_value_heads"] |
|
|
peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads |
|
|
peft_config.num_attention_heads = num_key_value_heads |
|
|
|
|
|
if getattr(peft_config, "encoder_hidden_size", None) is None: |
|
|
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) |
|
|
|
|
|
return peft_config |
|
|
|
|
|
|
|
|
def _get_no_split_modules(model) -> set[str]: |
|
|
""" |
|
|
Get the modules of the model that should not be split when using device_map. We iterate through the modules to get |
|
|
the underlying `_no_split_modules`. |
|
|
|
|
|
Returns: |
|
|
`List[str]`: List of modules that should not be split |
|
|
""" |
|
|
|
|
|
|
|
|
_no_split_modules: set[str] = set() |
|
|
if not hasattr(model, "_no_split_modules"): |
|
|
return _no_split_modules |
|
|
|
|
|
modules_to_check = [model] |
|
|
while len(modules_to_check) > 0: |
|
|
module = modules_to_check.pop(-1) |
|
|
|
|
|
if module.__class__.__name__ not in _no_split_modules: |
|
|
if isinstance(module, PreTrainedModel): |
|
|
if module._no_split_modules is not None: |
|
|
_no_split_modules = _no_split_modules | set(module._no_split_modules) |
|
|
modules_to_check += list(module.children()) |
|
|
return _no_split_modules |
|
|
|
|
|
|
|
|
def fsdp_auto_wrap_policy(model): |
|
|
if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): |
|
|
get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name |
|
|
else: |
|
|
from accelerate.utils.dataclasses import get_module_class_from_name |
|
|
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
|
|
|
from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder |
|
|
|
|
|
default_transformer_cls_names_to_wrap = ",".join(_get_no_split_modules(model)) |
|
|
transformer_cls_names_to_wrap = os.environ.get( |
|
|
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap |
|
|
).split(",") |
|
|
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} |
|
|
for layer_class in transformer_cls_names_to_wrap: |
|
|
if len(layer_class) == 0: |
|
|
continue |
|
|
transformer_cls = get_module_class_from_name(model, layer_class) |
|
|
if transformer_cls is None: |
|
|
raise Exception("Could not find the transformer layer class to wrap in the model.") |
|
|
else: |
|
|
transformer_cls_to_wrap.add(transformer_cls) |
|
|
|
|
|
def lambda_policy_fn(module): |
|
|
if ( |
|
|
len(list(module.named_children())) == 0 |
|
|
and getattr(module, "weight", None) is not None |
|
|
and module.weight.requires_grad |
|
|
): |
|
|
return True |
|
|
return False |
|
|
|
|
|
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) |
|
|
transformer_wrap_policy = functools.partial( |
|
|
transformer_auto_wrap_policy, |
|
|
transformer_layer_cls=transformer_cls_to_wrap, |
|
|
) |
|
|
|
|
|
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) |
|
|
return auto_wrap_policy |
|
|
|
|
|
|
|
|
def transpose(weight, fan_in_fan_out): |
|
|
if not fan_in_fan_out: |
|
|
return weight |
|
|
|
|
|
if isinstance(weight, torch.nn.Parameter): |
|
|
return torch.nn.Parameter(weight.T) |
|
|
return weight.T |
|
|
|
|
|
|
|
|
def _is_valid_match(key: str, target_key: str): |
|
|
""" |
|
|
Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key |
|
|
or the target_key is a submodule of key |
|
|
""" |
|
|
if key.endswith(target_key): |
|
|
if len(key) > len(target_key): |
|
|
return key.endswith("." + target_key) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: |
|
|
"""Get the batch size based on either input_ids or input_embeds |
|
|
|
|
|
Raises an ValueError if both are None. |
|
|
|
|
|
""" |
|
|
if (input_ids is None) and (inputs_embeds is None): |
|
|
raise ValueError("You have to provide either input_ids or inputs_embeds") |
|
|
|
|
|
if input_ids is not None: |
|
|
batch_size = input_ids.shape[0] |
|
|
else: |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
return batch_size |
|
|
|
|
|
|
|
|
def get_quantization_config(model: torch.nn.Module, method: str): |
|
|
""" |
|
|
Get the quantization config of the related quantization method |
|
|
""" |
|
|
if ( |
|
|
hasattr(model, "config") |
|
|
and hasattr(model.config, "quantization_config") |
|
|
and (getattr(model, "quantization_method", None) == method) |
|
|
): |
|
|
return model.config.quantization_config |
|
|
return None |
|
|
|
|
|
|
|
|
def get_auto_gptq_quant_linear(gptq_quantization_config): |
|
|
""" |
|
|
Get the right AutoGPTQQuantLinear class based on the quantization config file |
|
|
""" |
|
|
if gptq_quantization_config is None: |
|
|
return None |
|
|
|
|
|
if is_auto_gptq_available(): |
|
|
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear |
|
|
else: |
|
|
return None |
|
|
|
|
|
desc_act = gptq_quantization_config.desc_act |
|
|
group_size = gptq_quantization_config.group_size |
|
|
bits = gptq_quantization_config.bits |
|
|
if hasattr(gptq_quantization_config, "use_exllama"): |
|
|
use_exllama = gptq_quantization_config.use_exllama |
|
|
else: |
|
|
use_exllama = not gptq_quantization_config.disable_exllama |
|
|
if hasattr(gptq_quantization_config, "exllama_config"): |
|
|
exllama_version = gptq_quantization_config.exllama_config["version"] |
|
|
else: |
|
|
exllama_version = 1 |
|
|
|
|
|
QuantLinear = dynamically_import_QuantLinear( |
|
|
use_triton=False, |
|
|
desc_act=desc_act, |
|
|
group_size=group_size, |
|
|
bits=bits, |
|
|
disable_exllama=not (use_exllama and exllama_version == 1), |
|
|
disable_exllamav2=not (use_exllama and exllama_version == 2), |
|
|
) |
|
|
|
|
|
return QuantLinear |
|
|
|
|
|
|
|
|
def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): |
|
|
""" |
|
|
Get the right GPTQQuantLinear class based on the quantization config file |
|
|
""" |
|
|
if gptq_quantization_config is None: |
|
|
return None |
|
|
|
|
|
if not is_gptqmodel_available(): |
|
|
return None |
|
|
|
|
|
from gptqmodel.utils.importer import hf_select_quant_linear |
|
|
|
|
|
desc_act = gptq_quantization_config.desc_act |
|
|
group_size = gptq_quantization_config.group_size |
|
|
bits = gptq_quantization_config.bits |
|
|
checkpoint_format = ( |
|
|
gptq_quantization_config.checkpoint_format |
|
|
if hasattr(gptq_quantization_config, "checkpoint_format") |
|
|
else "gptq" |
|
|
) |
|
|
sym = gptq_quantization_config.sym |
|
|
meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None |
|
|
|
|
|
QuantLinear = hf_select_quant_linear( |
|
|
bits=bits, |
|
|
group_size=group_size, |
|
|
desc_act=desc_act, |
|
|
sym=sym, |
|
|
device_map=device_map, |
|
|
checkpoint_format=checkpoint_format, |
|
|
meta=meta, |
|
|
backend="auto_trainable", |
|
|
) |
|
|
|
|
|
return QuantLinear |
|
|
|
|
|
|
|
|
def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: |
|
|
""" |
|
|
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For |
|
|
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is |
|
|
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with |
|
|
non-overlapping lifetimes may have the same id. |
|
|
|
|
|
This method is the exact same copy of |
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added |
|
|
it here manually to avoid import issue with old versions of transformers. |
|
|
""" |
|
|
if tensor.device.type == "xla" and is_torch_tpu_available(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch_xla |
|
|
|
|
|
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) |
|
|
else: |
|
|
unique_id = storage_ptr(tensor) |
|
|
|
|
|
return tensor.device, unique_id, storage_size(tensor) |
|
|
|
|
|
|
|
|
def cast_mixed_precision_params(model, dtype): |
|
|
""" |
|
|
Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or |
|
|
`torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full |
|
|
precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for |
|
|
non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using |
|
|
automatic mixed-precision training. |
|
|
|
|
|
Args: |
|
|
model (`torch.nn.Module`): |
|
|
The model to cast the non-trainable parameters of. |
|
|
dtype (`torch.dtype`): |
|
|
The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or |
|
|
`torch.bfloat16` as per the mixed-precision training you are performing. |
|
|
""" |
|
|
for p in model.parameters(): |
|
|
if not p.requires_grad: |
|
|
p.data = p.to(dtype) |
|
|
else: |
|
|
p.data = p.to(torch.float32) |
|
|
|
|
|
|
|
|
def str_to_bool(value: str) -> int: |
|
|
""" |
|
|
Converts a string representation of truth to `True` (1) or `False` (0). |
|
|
|
|
|
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; |
|
|
""" |
|
|
|
|
|
value = value.lower() |
|
|
if value in ("y", "yes", "t", "true", "on", "1"): |
|
|
return 1 |
|
|
elif value in ("n", "no", "f", "false", "off", "0"): |
|
|
return 0 |
|
|
else: |
|
|
raise ValueError(f"invalid truth value {value}") |
|
|
|
|
|
|
|
|
def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]: |
|
|
"""Check if a file exists on HF Hub, if check was not successful returns None instead of erroring. |
|
|
|
|
|
Respect offline mode if set. |
|
|
|
|
|
""" |
|
|
exists: Optional[bool] = None |
|
|
if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")): |
|
|
|
|
|
return exists |
|
|
|
|
|
try: |
|
|
exists = file_exists(repo_id, filename, **kwargs) |
|
|
except (HFValidationError, EntryNotFoundError): |
|
|
|
|
|
pass |
|
|
except Exception as e: |
|
|
warnings.warn( |
|
|
f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup" |
|
|
f" for the file {filename} in {repo_id}." |
|
|
) |
|
|
|
|
|
return exists |
|
|
|
|
|
|
|
|
def match_target_against_key(target_pattern: str, key: str): |
|
|
"""Backing function for `target_modules` config parameter. |
|
|
|
|
|
Having this as its own function ensures that target key matching can be implemented in the same way everywhere. |
|
|
""" |
|
|
return re.fullmatch(target_pattern, key) |
|
|
|
|
|
|
|
|
def get_pattern_key(pattern_keys: Sequence[str], key_to_match: str) -> str: |
|
|
"""Match a substring of key_to_match in pattern keys""" |
|
|
for key in pattern_keys: |
|
|
match = re.match(rf"(.*\.)?({key})$", key_to_match) |
|
|
if not match: |
|
|
continue |
|
|
return key |
|
|
|
|
|
return key_to_match |
|
|
|
|
|
|
|
|
def set_additional_trainable_modules(model, peft_config, model_config, adapter_name, activate_adapter: bool = True): |
|
|
"""Handle the resolution of additional trainable modules (also called AuxiliaryTrainingWrapper) |
|
|
by checking the config if such modules are requested and adding them to the model. |
|
|
|
|
|
Currently trainable tokens and modules to save are considered additional trainable modules. |
|
|
|
|
|
If `activate_adapter` is set to `False`, the adapter won't be activated. This is typically the case when |
|
|
`model.add_adapter` or `model.load_adapter` are being called. |
|
|
""" |
|
|
if getattr(peft_config, "modules_to_save", None) is not None: |
|
|
|
|
|
_set_trainable( |
|
|
model, |
|
|
adapter_name, |
|
|
inference_mode=peft_config.inference_mode, |
|
|
module_names=getattr(peft_config, "modules_to_save", None), |
|
|
activate_adapter=activate_adapter, |
|
|
) |
|
|
|
|
|
if getattr(peft_config, "trainable_token_indices", None) is not None: |
|
|
if isinstance(peft_config.trainable_token_indices, dict): |
|
|
target_layers = peft_config.trainable_token_indices |
|
|
else: |
|
|
layer_name = _get_input_embeddings_name(model, "embed_tokens") |
|
|
target_layers = {layer_name: peft_config.trainable_token_indices} |
|
|
|
|
|
modules_to_save = getattr(peft_config, "modules_to_save", None) |
|
|
if modules_to_save is not None: |
|
|
for target_layer in target_layers: |
|
|
if target_layer in modules_to_save: |
|
|
raise ValueError( |
|
|
"The embedding layer is already marked to be trained fully, either specify " |
|
|
f'`modules_to_save=[..., "{target_layer}", ...]` or ' |
|
|
f"`trainable_tokens={{'{target_layer}': x}}` but not both." |
|
|
) |
|
|
|
|
|
for target_layer, token_indices in target_layers.items(): |
|
|
_set_trainable( |
|
|
model, |
|
|
adapter_name, |
|
|
inference_mode=peft_config.inference_mode, |
|
|
module_names=[target_layer], |
|
|
strict_module_check=True, |
|
|
wrapper_cls=TrainableTokensWrapper, |
|
|
token_indices=token_indices, |
|
|
activate_adapter=activate_adapter, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
model_config.get("tie_word_embeddings", False) |
|
|
|
|
|
and model._tied_weights_keys is not None |
|
|
and isinstance(model.get_input_embeddings(), TrainableTokensWrapper) |
|
|
): |
|
|
|
|
|
module_keys = [".".join(n.split(".")[:-1]) for n in model._tied_weights_keys] |
|
|
|
|
|
token_adapter = model.get_input_embeddings().token_adapter |
|
|
_set_trainable( |
|
|
model, |
|
|
adapter_name, |
|
|
inference_mode=peft_config.inference_mode, |
|
|
module_names=module_keys, |
|
|
strict_module_check=True, |
|
|
wrapper_cls=TrainableTokensWrapper, |
|
|
token_indices=token_adapter.token_indices[adapter_name], |
|
|
tied_adapter=model.get_input_embeddings().token_adapter, |
|
|
) |
|
|
|
|
|
|
|
|
def create_attention_mask( |
|
|
model, *, model_input, attention_mask, past_key_values, cache_position, batch_size, sequence_length, position_ids |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformers_ge_4_53_1 = version.parse(transformers.__version__) >= version.parse("4.53.1") |
|
|
if transformers_ge_4_53_1: |
|
|
|
|
|
from transformers.masking_utils import create_masks_for_generate |
|
|
else: |
|
|
raise ImportError("Your transformers version is too old, please upgrade it to >= 4.53.1") |
|
|
|
|
|
|
|
|
|
|
|
base_model = getattr(model, model.base_model_prefix, model) |
|
|
decoder = base_model.get_decoder() if hasattr(base_model, "get_decoder") else None |
|
|
causal_mask_creation_function = getattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None) |
|
|
if causal_mask_creation_function is None and decoder is not None: |
|
|
causal_mask_creation_function = getattr(decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None) |
|
|
|
|
|
|
|
|
if causal_mask_creation_function is None: |
|
|
token_type_ids = getattr(model_input, "token_type_ids", None) |
|
|
|
|
|
causal_mask_creation_function = getattr(model, "create_masks_for_generate", create_masks_for_generate) |
|
|
attention_mask = causal_mask_creation_function( |
|
|
config=model.config, |
|
|
|
|
|
input_embeds=torch.empty((batch_size, sequence_length), dtype=model.dtype), |
|
|
attention_mask=attention_mask, |
|
|
cache_position=cache_position, |
|
|
past_key_values=past_key_values, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
else: |
|
|
attention_mask = causal_mask_creation_function( |
|
|
attention_mask, |
|
|
sequence_length=sequence_length, |
|
|
target_length=past_key_values.get_max_cache_shape(), |
|
|
dtype=model.dtype, |
|
|
cache_position=cache_position, |
|
|
batch_size=batch_size, |
|
|
config=model.config, |
|
|
past_key_values=past_key_values, |
|
|
position_ids=position_ids, |
|
|
) |
|
|
return attention_mask |
|
|
|