Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import warnings | |
from typing import Optional | |
from torch import nn | |
from torch.nn.modules import Module | |
from tqdm import tqdm | |
from peft.config import PeftConfig | |
from peft.tuners.tuners_utils import BaseTuner, _get_submodules, check_target_module_exists | |
from peft.utils import TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, ModulesToSaveWrapper | |
from .layer import LNTuningLayer | |
class LNTuningModel(BaseTuner): | |
""" | |
Creates LayerNorm tuning from a pretrained transformer model. | |
The method is described in detail in https://arxiv.org/abs/2312.11420. | |
Args: | |
model ([`torch.nn.Module`]): The model to be adapted. | |
config ([`LNTuningConfig`]): The configuration of the Lora model. | |
adapter_name (`str`): The name of the adapter, defaults to `"default"`. | |
Returns: | |
'torch.nn.Module': The adapted model with LayerNorm tuned on. | |
Example: | |
```py | |
>>> from transformers import AutoModelForCausalLM | |
>>> from peft import get_peft_model, TaskType, LNTuningConfig | |
>>> peft_config = LNTuningConfig( | |
... task_type=TaskType.CAUSAL_LM, | |
... ) | |
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") | |
>>> model = get_peft_model(model, peft_config) | |
>>> model.print_trainable_parameters() | |
``` | |
**Attributes**: | |
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. | |
- **peft_config** ([`LNTuningConfig`]): The configuration of the Lora model. | |
""" | |
prefix: str = "ln_tuning_" | |
def __init__(self, model, config, adapter_name) -> None: | |
# self.adapter_name = adapter_name | |
super().__init__(model, config, adapter_name) | |
def __getattr__(self, name: str): | |
"""Forward missing attributes to the wrapped module.""" | |
try: | |
return super().__getattr__(name) # defer to nn.Module's logic | |
except AttributeError: | |
return getattr(self.model, name) | |
# TODO: here need to handle the modules_to_save rather than the target_modules | |
def _prepare_adapter_config(peft_config: PeftConfig, model_config: dict) -> PeftConfig: | |
if peft_config.target_modules is None: | |
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING: | |
raise ValueError("Please specify `target_modules` in `peft_config`") | |
peft_config.target_modules = set( | |
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING[model_config["model_type"]] | |
) | |
return peft_config | |
def _create_and_replace( | |
self, | |
peft_config: PeftConfig, | |
adapter_name: str, | |
target: Module, | |
target_name: str, | |
parent: Module, | |
current_key: str, | |
) -> None: | |
# replace the original module with a same new module | |
new_module = self._create_new_module(peft_config, target, adapter_name) | |
if adapter_name != self.active_adapter: | |
new_module.requires_grad_(False) | |
self._replace_module(parent, target_name, new_module, target) | |
def _create_new_module( | |
self, | |
peft_config: PeftConfig, | |
target: Module, | |
adapter_name: str, | |
) -> Module: | |
if not isinstance(target, LNTuningLayer): | |
new_module = LNTuningLayer(target, adapter_name) | |
else: | |
new_module = target | |
new_module.update_layer(target.base_layer, adapter_name) | |
return new_module | |
def _replace_module(self, parent: Module, child_name: str, new_module: Module, child: Module) -> None: | |
setattr(parent, child_name, new_module) | |
if hasattr(child, "base_layer"): | |
child = child.base_layer | |
if getattr(child, "state", None) is not None: | |
if hasattr(new_module, "base_layer"): | |
new_module.base_layer.state = child.state | |
else: | |
new_module.state = child.state | |
new_module.to(child.weight.device) | |
for name, module in new_module.named_modules(): | |
weight = child.qweight if hasattr(child, "qweight") else child.weight | |
module.to(weight.device) | |
def _mark_only_adapters_as_trainable(self, model: Module): | |
for n, p in model.named_parameters(): | |
if self.prefix not in n: | |
p.requires_grad = False | |
else: | |
p.requires_grad = True | |
def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: | |
return check_target_module_exists(peft_config, key) | |
def _set_adapter_layers(self, enabled: bool) -> None: | |
for module in self.model.modules(): | |
if isinstance(module, (LNTuningLayer, ModulesToSaveWrapper)): | |
module.enable_adapters(enabled) | |
def enable_adapter_layers(self) -> None: | |
"""Enable all adapters. | |
Call this if you have previously disabled all adapters and want to re-enable them. | |
""" | |
self._set_adapter_layers(enabled=True) | |
def disable_adapter_layers(self) -> None: | |
"""Disable all adapters. | |
When disabling all adapters, the model output corresponds to the output of the base model. | |
""" | |
self._set_adapter_layers(enabled=False) | |
def set_adapter(self, adapter_name: str) -> None: | |
for module in self.model.modules(): | |
if isinstance(module, LNTuningLayer): | |
if module.merged: | |
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") | |
module.unmerge() | |
module.set_adapter(adapter_name) | |
self.active_adapter = adapter_name | |
def _unload_and_optionally_merge( | |
self, | |
merge=True, | |
progressbar: bool = False, | |
safe_merge: bool = False, | |
adapter_names: Optional[list[str]] = None, | |
): | |
self._unloading_checks(adapter_names) | |
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] | |
desc = "Unloading adapters " + ("and merging " if merge else "") + "model" | |
for key in tqdm(key_list, disable=not progressbar, desc=desc): | |
try: | |
parent, target, target_name = _get_submodules(self.model, key) | |
except AttributeError: | |
continue | |
if hasattr(target, "base_layer"): | |
if merge: | |
target.merge(adapter_names) | |
self._replace_module(parent, target_name, target.get_base_layer(), target) | |
return self.model | |
def unload(self): | |
return self._unload_and_optionally_merge(merge=False) | |
def merge_and_unload( | |
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None | |
) -> nn.Module: | |
return self._unload_and_optionally_merge(merge=True) | |