Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,552 Bytes
d711508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Optional
from torch import nn
from torch.nn.modules import Module
from tqdm import tqdm
from peft.config import PeftConfig
from peft.tuners.tuners_utils import BaseTuner, _get_submodules, check_target_module_exists
from peft.utils import TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, ModulesToSaveWrapper
from .layer import LNTuningLayer
class LNTuningModel(BaseTuner):
"""
Creates LayerNorm tuning from a pretrained transformer model.
The method is described in detail in https://arxiv.org/abs/2312.11420.
Args:
model ([`torch.nn.Module`]): The model to be adapted.
config ([`LNTuningConfig`]): The configuration of the Lora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
'torch.nn.Module': The adapted model with LayerNorm tuned on.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, TaskType, LNTuningConfig
>>> peft_config = LNTuningConfig(
... task_type=TaskType.CAUSAL_LM,
... )
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> model = get_peft_model(model, peft_config)
>>> model.print_trainable_parameters()
```
**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LNTuningConfig`]): The configuration of the Lora model.
"""
prefix: str = "ln_tuning_"
def __init__(self, model, config, adapter_name) -> None:
# self.adapter_name = adapter_name
super().__init__(model, config, adapter_name)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
# TODO: here need to handle the modules_to_save rather than the target_modules
@staticmethod
def _prepare_adapter_config(peft_config: PeftConfig, model_config: dict) -> PeftConfig:
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _create_and_replace(
self,
peft_config: PeftConfig,
adapter_name: str,
target: Module,
target_name: str,
parent: Module,
current_key: str,
) -> None:
# replace the original module with a same new module
new_module = self._create_new_module(peft_config, target, adapter_name)
if adapter_name != self.active_adapter:
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
def _create_new_module(
self,
peft_config: PeftConfig,
target: Module,
adapter_name: str,
) -> Module:
if not isinstance(target, LNTuningLayer):
new_module = LNTuningLayer(target, adapter_name)
else:
new_module = target
new_module.update_layer(target.base_layer, adapter_name)
return new_module
def _replace_module(self, parent: Module, child_name: str, new_module: Module, child: Module) -> None:
setattr(parent, child_name, new_module)
if hasattr(child, "base_layer"):
child = child.base_layer
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
for name, module in new_module.named_modules():
weight = child.qweight if hasattr(child, "qweight") else child.weight
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model: Module):
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
else:
p.requires_grad = True
def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool:
return check_target_module_exists(peft_config, key)
def _set_adapter_layers(self, enabled: bool) -> None:
for module in self.model.modules():
if isinstance(module, (LNTuningLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self) -> None:
"""Disable all adapters.
When disabling all adapters, the model output corresponds to the output of the base model.
"""
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str) -> None:
for module in self.model.modules():
if isinstance(module, LNTuningLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
self._unloading_checks(adapter_names)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading adapters " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
if merge:
target.merge(adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
return self.model
def unload(self):
return self._unload_and_optionally_merge(merge=False)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> nn.Module:
return self._unload_and_optionally_merge(merge=True)
|