Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import annotations | |
import re | |
import warnings | |
from dataclasses import asdict | |
from enum import Enum | |
from typing import Optional | |
import torch | |
from torch import nn | |
from transformers.pytorch_utils import Conv1D | |
from peft.import_utils import is_bnb_4bit_available, is_bnb_available | |
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists | |
from peft.utils import ( | |
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, | |
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, | |
ModulesToSaveWrapper, | |
_get_submodules, | |
) | |
from .layer import Conv2d, IA3Layer, Linear | |
class IA3Model(BaseTuner): | |
""" | |
Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained | |
transformers model. The method is described in detail in https://arxiv.org/abs/2205.05638 | |
Args: | |
model ([`~transformers.PreTrainedModel`]): The model to be adapted. | |
config ([`IA3Config`]): The configuration of the (IA)^3 model. | |
adapter_name (`str`): The name of the adapter, defaults to `"default"`. | |
Returns: | |
`torch.nn.Module`: The (IA)^3 model. | |
Example: | |
```py | |
>>> from transformers import AutoModelForSeq2SeqLM, ia3Config | |
>>> from peft import IA3Model, IA3Config | |
>>> config = IA3Config( | |
... peft_type="IA3", | |
... task_type="SEQ_2_SEQ_LM", | |
... target_modules=["k", "v", "w0"], | |
... feedforward_modules=["w0"], | |
... ) | |
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
>>> ia3_model = IA3Model(config, model) | |
``` | |
**Attributes**: | |
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. | |
- **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model. | |
""" | |
prefix: str = "ia3_" | |
def __init__(self, model, config, adapter_name): | |
super().__init__(model, config, adapter_name) | |
def _create_new_module(ia3_config, adapter_name, target, **kwargs): | |
# avoid eager bnb import | |
if is_bnb_available(): | |
import bitsandbytes as bnb | |
from .bnb import Linear8bitLt | |
if is_bnb_4bit_available(): | |
from .bnb import Linear4bit | |
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) | |
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) | |
is_feedforward = kwargs.pop("is_feedforward", False) | |
if isinstance(target, BaseTunerLayer): | |
target_base_layer = target.get_base_layer() | |
else: | |
target_base_layer = target | |
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): | |
eightbit_kwargs = kwargs.copy() | |
eightbit_kwargs.update( | |
{ | |
"has_fp16_weights": target_base_layer.state.has_fp16_weights, | |
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward, | |
"threshold": target_base_layer.state.threshold, | |
"index": target_base_layer.index, | |
} | |
) | |
new_module = Linear8bitLt(target, adapter_name, is_feedforward=is_feedforward, **eightbit_kwargs) | |
elif loaded_in_4bit and isinstance(target_base_layer, bnb.nn.Linear4bit): | |
fourbit_kwargs = kwargs.copy() | |
fourbit_kwargs.update( | |
{ | |
"compute_dtype": target_base_layer.compute_dtype, | |
"compress_statistics": target_base_layer.weight.compress_statistics, | |
"quant_type": target_base_layer.weight.quant_type, | |
} | |
) | |
new_module = Linear4bit(target, adapter_name, is_feedforward=is_feedforward, **fourbit_kwargs) | |
elif isinstance(target, torch.nn.Conv2d): | |
new_module = Conv2d(target, adapter_name, is_feedforward=is_feedforward, **kwargs) | |
elif isinstance(target_base_layer, torch.nn.Linear): | |
if kwargs["fan_in_fan_out"]: | |
warnings.warn( | |
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " | |
"Setting fan_in_fan_out to False." | |
) | |
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = False | |
new_module = Linear(target, adapter_name, is_feedforward=is_feedforward, **kwargs) | |
elif isinstance(target_base_layer, Conv1D): | |
if not kwargs["fan_in_fan_out"]: | |
warnings.warn( | |
"fan_in_fan_out is set to False but the target module is `Conv1D`. " | |
"Setting fan_in_fan_out to True." | |
) | |
kwargs["fan_in_fan_out"] = ia3_config.fan_in_fan_out = True | |
new_module = Linear( | |
target, adapter_name, is_feedforward=is_feedforward, is_target_conv_1d_layer=True, **kwargs | |
) | |
else: | |
raise ValueError( | |
f"Target module {target} is not supported. " | |
f"Currently, only `torch.nn.Linear`, `torch.nn.Conv2d`, and `Conv1D` are supported." | |
) | |
return new_module | |
def _check_target_module_exists(ia3_config, key): | |
return check_target_module_exists(ia3_config, key) | |
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: | |
for n, p in model.named_parameters(): | |
if self.prefix not in n: | |
p.requires_grad = False | |
def _create_and_replace( | |
self, | |
ia3_config, | |
adapter_name, | |
target, | |
target_name, | |
parent, | |
current_key, | |
): | |
# check if target module is in feedforward_modules | |
is_feedforward = self._check_target_module_feedforward(ia3_config, current_key) | |
kwargs = { | |
"fan_in_fan_out": ia3_config.fan_in_fan_out, | |
"init_ia3_weights": ia3_config.init_ia3_weights, | |
"is_feedforward": is_feedforward, | |
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), | |
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), | |
} | |
if isinstance(target, IA3Layer): | |
target.update_layer( | |
adapter_name, | |
ia3_config.init_ia3_weights, | |
) | |
else: | |
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) | |
if adapter_name not in self.active_adapters: | |
# adding an additional adapter: it is not automatically trainable | |
new_module.requires_grad_(False) | |
self._replace_module(parent, target_name, new_module, target) | |
def _check_target_module_feedforward(ia3_config, key) -> bool: | |
""" | |
A helper private method that checks if the target module `key` matches with a feedforward module specified in | |
`ia3_config` | |
""" | |
if isinstance(ia3_config.feedforward_modules, str): | |
is_feedforward = bool(re.fullmatch(ia3_config.feedforward_modules, key)) | |
else: | |
is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules) | |
return is_feedforward | |
def _replace_module(self, parent, child_name, new_module, child): | |
setattr(parent, child_name, new_module) | |
# child layer wraps the original module, unpack it | |
if hasattr(child, "base_layer"): | |
child = child.base_layer | |
# layers with base_layer don't need the weight to be copied, as they have a reference already | |
if not hasattr(new_module, "base_layer"): | |
new_module.weight = child.weight | |
if hasattr(child, "bias"): | |
new_module.bias = child.bias | |
if getattr(child, "state", None) is not None: | |
if hasattr(new_module, "base_layer"): | |
new_module.base_layer.state = child.state | |
else: | |
new_module.state = child.state | |
new_module.to(child.weight.device) | |
# dispatch to correct device | |
for name, module in new_module.named_modules(): | |
if self.prefix in name: | |
module.to(child.weight.device) | |
def __getattr__(self, name: str): | |
"""Forward missing attributes to the wrapped module.""" | |
try: | |
return super().__getattr__(name) # defer to nn.Module's logic | |
except AttributeError: | |
return getattr(self.model, name) | |
def get_peft_config_as_dict(self, inference: bool = False): | |
config_dict = {} | |
for key, value in self.peft_config.items(): | |
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} | |
if inference: | |
config["inference_mode"] = True | |
config_dict[key] = config | |
return config | |
def _set_adapter_layers(self, enabled=True): | |
for module in self.model.modules(): | |
if isinstance(module, (IA3Layer, ModulesToSaveWrapper)): | |
module.enable_adapters(enabled) | |
def enable_adapter_layers(self) -> None: | |
"""Enable all adapters. | |
Call this if you have previously disabled all adapters and want to re-enable them. | |
""" | |
self._set_adapter_layers(enabled=True) | |
def disable_adapter_layers(self) -> None: | |
"""Disable all adapters. | |
When disabling all adapters, the model output corresponds to the output of the base model. | |
""" | |
self._set_adapter_layers(enabled=False) | |
def set_adapter(self, adapter_name: str | list[str]) -> None: | |
"""Set the active adapter(s). | |
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is | |
not desired, use the following code. | |
```py | |
>>> for name, param in model_peft.named_parameters(): | |
... if ...: # some check on name (ex. if 'lora' in name) | |
... param.requires_grad = False | |
``` | |
Args: | |
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. | |
""" | |
for module in self.model.modules(): | |
if isinstance(module, IA3Layer): | |
if module.merged: | |
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") | |
module.unmerge() | |
module.set_adapter(adapter_name) | |
self.active_adapter = adapter_name | |
def _prepare_adapter_config(self, peft_config, model_config): | |
if peft_config.target_modules is None: | |
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: | |
raise ValueError("Please specify `target_modules` in `peft_config`") | |
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]] | |
if peft_config.feedforward_modules is None: | |
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING: | |
raise ValueError("Please specify `feedforward_modules` in `peft_config`") | |
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[ | |
model_config["model_type"] | |
] | |
return peft_config | |
def _unload_and_optionally_merge( | |
self, merge: bool = True, safe_merge: bool = False, adapter_names: Optional[list[str]] = None | |
): | |
r""" | |
This method merges the (IA)^3 layers into the base model. This is needed if someone wants to use the base model | |
as a standalone model. | |
Args: | |
safe_merge (`bool`, `optional`, defaults to `False`): | |
If True, the merge operation will be performed in a copy of the original weights and check for NaNs | |
before merging the weights. This is useful if you want to check if the merge operation will produce | |
NaNs. Defaults to `False`. | |
adapter_names (`List[str]`, *optional*): | |
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults | |
to `None`. | |
""" | |
if getattr(self.model, "is_loaded_in_8bit", False): | |
raise ValueError("Cannot merge ia3 layers when the model is loaded in 8-bit mode") | |
if getattr(self.model, "is_loaded_in_4bit", False): | |
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode") | |
self._unloading_checks(adapter_names) | |
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] | |
for key in key_list: | |
try: | |
parent, target, target_name = _get_submodules(self.model, key) | |
except AttributeError: | |
continue | |
if hasattr(target, "base_layer"): | |
if merge: | |
target.merge(safe_merge=safe_merge, adapter_names=adapter_names) | |
self._replace_module(parent, target_name, target.get_base_layer(), target) | |
elif isinstance(target, ModulesToSaveWrapper): | |
# save any additional trainable modules part of `modules_to_save` | |
new_module = target.modules_to_save[target.active_adapter] | |
if hasattr(new_module, "base_layer"): | |
# check if the module is itself a tuner layer | |
if merge: | |
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) | |
new_module = new_module.get_base_layer() | |
setattr(parent, target_name, new_module) | |
return self.model | |
def merge_and_unload(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> torch.nn.Module: | |
r""" | |
This method merges the IA³ layers into the base model. This is needed if someone wants to use the base model as | |
a standalone model. | |
Args: | |
safe_merge (`bool`): | |
whether to activate the safe merging check to check if there is any potential Nan in the adapter | |
weights | |
adapter_names (`List[str]`, *optional*): | |
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults | |
to `None`. | |
Example: | |
```py | |
>>> from transformers import AutoModelForCausalLM | |
>>> from peft import PeftModel | |
>>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") | |
>>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" | |
>>> model = PeftModel.from_pretrained(base_model, peft_model_id) | |
>>> merged_model = model.merge_and_unload() | |
``` | |
""" | |
return self._unload_and_optionally_merge(safe_merge=safe_merge, adapter_names=adapter_names) | |
def unload(self) -> torch.nn.Module: | |
""" | |
Gets back the base model by removing all the IA³ modules without merging. This gives back the original base | |
model. | |
""" | |
return self._unload_and_optionally_merge(merge=False) | |
def delete_adapter(self, adapter_name: str) -> None: | |
""" | |
Deletes an existing adapter. | |
Args: | |
adapter_name (str): Name of the adapter to be deleted. | |
""" | |
if adapter_name not in self.peft_config: | |
raise ValueError(f"Adapter {adapter_name} does not exist") | |
del self.peft_config[adapter_name] | |
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] | |
new_adapter = None | |
for key in key_list: | |
_, target, _ = _get_submodules(self.model, key) | |
if isinstance(target, IA3Layer): | |
target.delete_adapter(adapter_name) | |
if new_adapter is None: | |
new_adapter = target.active_adapters[:] | |
self.active_adapter = new_adapter or [] | |