Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- MoRA/README.md +68 -0
- MoRA/config.py +4 -0
- MoRA/model.py +4 -0
- MoRA/peft_mora/__init__.py +90 -0
- MoRA/peft_mora/__pycache__/__init__.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/auto.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/config.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/import_utils.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/mapping.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/mixed_model.cpython-312.pyc +0 -0
- MoRA/peft_mora/__pycache__/peft_model.cpython-312.pyc +0 -0
- MoRA/peft_mora/auto.py +170 -0
- MoRA/peft_mora/config.py +270 -0
- MoRA/peft_mora/helpers.py +113 -0
- MoRA/peft_mora/import_utils.py +73 -0
- MoRA/peft_mora/mapping.py +168 -0
- MoRA/peft_mora/mixed_model.py +402 -0
- MoRA/peft_mora/peft_model.py +1929 -0
- MoRA/peft_mora/py.typed +0 -0
- MoRA/peft_mora/tuners/__init__.py +32 -0
- MoRA/peft_mora/tuners/__pycache__/__init__.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/__pycache__/lycoris_utils.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/__pycache__/tuners_utils.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/__init__.py +37 -0
- MoRA/peft_mora/tuners/adalora/__pycache__/__init__.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/__pycache__/config.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/__pycache__/gptq.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/__pycache__/layer.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/__pycache__/model.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adalora/bnb.py +145 -0
- MoRA/peft_mora/tuners/adalora/config.py +52 -0
- MoRA/peft_mora/tuners/adalora/gptq.py +72 -0
- MoRA/peft_mora/tuners/adalora/layer.py +346 -0
- MoRA/peft_mora/tuners/adalora/model.py +346 -0
- MoRA/peft_mora/tuners/adaption_prompt/__init__.py +19 -0
- MoRA/peft_mora/tuners/adaption_prompt/__pycache__/__init__.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adaption_prompt/__pycache__/config.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adaption_prompt/__pycache__/layer.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adaption_prompt/__pycache__/model.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adaption_prompt/__pycache__/utils.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/adaption_prompt/config.py +73 -0
- MoRA/peft_mora/tuners/adaption_prompt/layer.py +120 -0
- MoRA/peft_mora/tuners/adaption_prompt/model.py +161 -0
- MoRA/peft_mora/tuners/adaption_prompt/utils.py +111 -0
- MoRA/peft_mora/tuners/ia3/__init__.py +36 -0
- MoRA/peft_mora/tuners/ia3/__pycache__/__init__.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/ia3/__pycache__/config.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/ia3/__pycache__/layer.cpython-312.pyc +0 -0
- MoRA/peft_mora/tuners/ia3/__pycache__/model.cpython-312.pyc +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
MoRA/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [MoRA: High-Rank Updating for Parameter-Efficient Fine-Tuning](https://arxiv.org/abs/2405.12130)
|
2 |
+
|
3 |
+
## Setup
|
4 |
+
|
5 |
+
We implement MoRA in peft-mora based on HF peft in the [`apply_mora`](https://github.com/kongds/MoRA/blob/main/peft-mora/src/peft/tuners/lora/layer.py#L229) and [`get_delta_weight`](https://github.com/kongds/MoRA/blob/main/peft-mora/src/peft/tuners/lora/layer.py#L514).
|
6 |
+
``` sh
|
7 |
+
pip install -e ./peft-mora
|
8 |
+
```
|
9 |
+
|
10 |
+
After installation, it can be used like
|
11 |
+
|
12 |
+
``` python
|
13 |
+
from peft import LoraConfig, get_peft_model
|
14 |
+
config = LoraConfig(
|
15 |
+
# enable MoRA
|
16 |
+
use_mora=True,
|
17 |
+
# type 1 (Sharing) for large lora ranks, Eq. 6 in paper
|
18 |
+
# type 6 (RoPE based) for small lora ranks, Eq. 9 in paper
|
19 |
+
mora_type=6,
|
20 |
+
# lora rank here, we will calculate corresponding $\hat{r}$ in MoRA
|
21 |
+
r=lora_r,
|
22 |
+
# MoRA does not use lora_alpha
|
23 |
+
# lora_alpha=lora_alpha,
|
24 |
+
target_modules=lora_target_modules,
|
25 |
+
lora_dropout=lora_dropout,
|
26 |
+
task_type="CAUSAL_LM",
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
model = get_peft_model(model, config)
|
30 |
+
|
31 |
+
# training here...
|
32 |
+
|
33 |
+
# can be merged into model via `merge_and_unload` like LoRA
|
34 |
+
model = model.merge_and_unload()
|
35 |
+
```
|
36 |
+
|
37 |
+
## Examples
|
38 |
+
### fine-tuning MetaMath with MoRA
|
39 |
+
|
40 |
+
``` sh
|
41 |
+
RANK=8
|
42 |
+
deepspeed --num_gpus=8 --num_nodes=2 train.py \
|
43 |
+
--base_model <LLAMA-2> --micro_batch_size 4\
|
44 |
+
--wandb_run_name mora_math_r8 --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
|
45 |
+
--num_epochs 3 --deepspeed ds.config --wandb_project lora-math --lora_r $RANK --batch_size 128 \
|
46 |
+
--data_path meta-math/MetaMath \
|
47 |
+
--save_steps 3000 \
|
48 |
+
--learning_rate 3e-4 --mora_type 6 \
|
49 |
+
--logging_steps 5 --use_bf16 --use_16bit --use_mora
|
50 |
+
```
|
51 |
+
|
52 |
+
### pretraining
|
53 |
+
|
54 |
+
``` sh
|
55 |
+
deepspeed --num_gpus=8 --num_nodes=4 train.py \
|
56 |
+
--micro_batch_size 16 --wandb_run_name mora-pretrain250m-r128 \
|
57 |
+
--num_epochs 1 --wandb_project lora-pretrain --batch_size 1024 \
|
58 |
+
--data_path <processed C4> --logging_steps 1 \
|
59 |
+
--lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
|
60 |
+
--lora_r 128 --lora_alpha 64 --warmup_steps 1000 \
|
61 |
+
--force_tqdm_update --lr_scheduler_type cosine \
|
62 |
+
--max_steps 10000 --pretrain 250m \
|
63 |
+
--train_embhead --learning_rate 5e-4 \
|
64 |
+
--use_mora --use_relora --use_relora_step 2000 # ReMoRA merge per 2000 steps
|
65 |
+
```
|
66 |
+
|
67 |
+
## Acknowledgement
|
68 |
+
Our Code is based on peft, alpaca-lora and ReLoRA
|
MoRA/config.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft_mora import LoraConfig
|
2 |
+
|
3 |
+
class MoRAModelForCausalLM(LoraConfig):
|
4 |
+
pass
|
MoRA/model.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft_mora import PeftModelForCausalLM
|
2 |
+
|
3 |
+
class MoRAModelForCausalLM(PeftModelForCausalLM):
|
4 |
+
pass
|
MoRA/peft_mora/__init__.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
4 |
+
|
5 |
+
# coding=utf-8
|
6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
__version__ = "0.9.0"
|
21 |
+
|
22 |
+
from .auto import (
|
23 |
+
AutoPeftModel,
|
24 |
+
AutoPeftModelForCausalLM,
|
25 |
+
AutoPeftModelForSequenceClassification,
|
26 |
+
AutoPeftModelForSeq2SeqLM,
|
27 |
+
AutoPeftModelForTokenClassification,
|
28 |
+
AutoPeftModelForQuestionAnswering,
|
29 |
+
AutoPeftModelForFeatureExtraction,
|
30 |
+
)
|
31 |
+
from .mapping import (
|
32 |
+
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
|
33 |
+
PEFT_TYPE_TO_CONFIG_MAPPING,
|
34 |
+
get_peft_config,
|
35 |
+
get_peft_model,
|
36 |
+
inject_adapter_in_model,
|
37 |
+
)
|
38 |
+
from .mixed_model import PeftMixedModel
|
39 |
+
from .peft_model import (
|
40 |
+
PeftModel,
|
41 |
+
PeftModelForCausalLM,
|
42 |
+
PeftModelForSeq2SeqLM,
|
43 |
+
PeftModelForSequenceClassification,
|
44 |
+
PeftModelForTokenClassification,
|
45 |
+
PeftModelForQuestionAnswering,
|
46 |
+
PeftModelForFeatureExtraction,
|
47 |
+
)
|
48 |
+
from .tuners import (
|
49 |
+
AdaptionPromptConfig,
|
50 |
+
AdaptionPromptModel,
|
51 |
+
LoraConfig,
|
52 |
+
LoftQConfig,
|
53 |
+
LoraModel,
|
54 |
+
LoHaConfig,
|
55 |
+
LoHaModel,
|
56 |
+
LoKrConfig,
|
57 |
+
LoKrModel,
|
58 |
+
IA3Config,
|
59 |
+
IA3Model,
|
60 |
+
AdaLoraConfig,
|
61 |
+
AdaLoraModel,
|
62 |
+
PrefixEncoder,
|
63 |
+
PrefixTuningConfig,
|
64 |
+
PromptEmbedding,
|
65 |
+
PromptEncoder,
|
66 |
+
PromptEncoderConfig,
|
67 |
+
PromptEncoderReparameterizationType,
|
68 |
+
PromptTuningConfig,
|
69 |
+
PromptTuningInit,
|
70 |
+
MultitaskPromptTuningConfig,
|
71 |
+
MultitaskPromptTuningInit,
|
72 |
+
OFTConfig,
|
73 |
+
OFTModel,
|
74 |
+
PolyConfig,
|
75 |
+
PolyModel,
|
76 |
+
)
|
77 |
+
from .utils import (
|
78 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
79 |
+
PeftType,
|
80 |
+
TaskType,
|
81 |
+
bloom_model_postprocess_past_key_value,
|
82 |
+
get_peft_model_state_dict,
|
83 |
+
prepare_model_for_int8_training,
|
84 |
+
prepare_model_for_kbit_training,
|
85 |
+
set_peft_model_state_dict,
|
86 |
+
shift_tokens_right,
|
87 |
+
load_peft_weights,
|
88 |
+
cast_mixed_precision_params,
|
89 |
+
)
|
90 |
+
from .config import PeftConfig, PromptLearningConfig
|
MoRA/peft_mora/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (2.31 kB). View file
|
|
MoRA/peft_mora/__pycache__/auto.cpython-312.pyc
ADDED
Binary file (6.68 kB). View file
|
|
MoRA/peft_mora/__pycache__/config.cpython-312.pyc
ADDED
Binary file (11.7 kB). View file
|
|
MoRA/peft_mora/__pycache__/import_utils.cpython-312.pyc
ADDED
Binary file (2.85 kB). View file
|
|
MoRA/peft_mora/__pycache__/mapping.cpython-312.pyc
ADDED
Binary file (5.64 kB). View file
|
|
MoRA/peft_mora/__pycache__/mixed_model.cpython-312.pyc
ADDED
Binary file (18.5 kB). View file
|
|
MoRA/peft_mora/__pycache__/peft_model.cpython-312.pyc
ADDED
Binary file (82.2 kB). View file
|
|
MoRA/peft_mora/auto.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import importlib
|
18 |
+
import os
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
from transformers import (
|
22 |
+
AutoModel,
|
23 |
+
AutoModelForCausalLM,
|
24 |
+
AutoModelForQuestionAnswering,
|
25 |
+
AutoModelForSeq2SeqLM,
|
26 |
+
AutoModelForSequenceClassification,
|
27 |
+
AutoModelForTokenClassification,
|
28 |
+
AutoTokenizer,
|
29 |
+
)
|
30 |
+
|
31 |
+
from .config import PeftConfig
|
32 |
+
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
|
33 |
+
from .peft_model import (
|
34 |
+
PeftModel,
|
35 |
+
PeftModelForCausalLM,
|
36 |
+
PeftModelForFeatureExtraction,
|
37 |
+
PeftModelForQuestionAnswering,
|
38 |
+
PeftModelForSeq2SeqLM,
|
39 |
+
PeftModelForSequenceClassification,
|
40 |
+
PeftModelForTokenClassification,
|
41 |
+
)
|
42 |
+
from .utils.constants import TOKENIZER_CONFIG_NAME
|
43 |
+
from .utils.other import check_file_exists_on_hf_hub
|
44 |
+
|
45 |
+
|
46 |
+
class _BaseAutoPeftModel:
|
47 |
+
_target_class = None
|
48 |
+
_target_peft_class = None
|
49 |
+
|
50 |
+
def __init__(self, *args, **kwargs):
|
51 |
+
# For consistency with transformers: https://github.com/huggingface/transformers/blob/91d7df58b6537d385e90578dac40204cb550f706/src/transformers/models/auto/auto_factory.py#L400
|
52 |
+
raise EnvironmentError( # noqa: UP024
|
53 |
+
f"{self.__class__.__name__} is designed to be instantiated "
|
54 |
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
55 |
+
f"`{self.__class__.__name__}.from_config(config)` methods."
|
56 |
+
)
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def from_pretrained(
|
60 |
+
cls,
|
61 |
+
pretrained_model_name_or_path,
|
62 |
+
adapter_name: str = "default",
|
63 |
+
is_trainable: bool = False,
|
64 |
+
config: Optional[PeftConfig] = None,
|
65 |
+
**kwargs,
|
66 |
+
):
|
67 |
+
r"""
|
68 |
+
A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs
|
69 |
+
are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and
|
70 |
+
the config object init.
|
71 |
+
"""
|
72 |
+
peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
73 |
+
base_model_path = peft_config.base_model_name_or_path
|
74 |
+
|
75 |
+
task_type = getattr(peft_config, "task_type", None)
|
76 |
+
|
77 |
+
if cls._target_class is not None:
|
78 |
+
target_class = cls._target_class
|
79 |
+
elif cls._target_class is None and task_type is not None:
|
80 |
+
# this is only in the case where we use `AutoPeftModel`
|
81 |
+
raise ValueError(
|
82 |
+
"Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)"
|
83 |
+
)
|
84 |
+
|
85 |
+
if task_type is not None:
|
86 |
+
expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type]
|
87 |
+
if cls._target_peft_class.__name__ != expected_target_class.__name__:
|
88 |
+
raise ValueError(
|
89 |
+
f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }"
|
90 |
+
" make sure that you are loading the correct model for your task type."
|
91 |
+
)
|
92 |
+
elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None:
|
93 |
+
auto_mapping = getattr(peft_config, "auto_mapping", None)
|
94 |
+
base_model_class = auto_mapping["base_model_class"]
|
95 |
+
parent_library_name = auto_mapping["parent_library"]
|
96 |
+
|
97 |
+
parent_library = importlib.import_module(parent_library_name)
|
98 |
+
target_class = getattr(parent_library, base_model_class)
|
99 |
+
else:
|
100 |
+
raise ValueError(
|
101 |
+
"Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type."
|
102 |
+
)
|
103 |
+
|
104 |
+
base_model = target_class.from_pretrained(base_model_path, **kwargs)
|
105 |
+
|
106 |
+
tokenizer_exists = False
|
107 |
+
if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)):
|
108 |
+
tokenizer_exists = True
|
109 |
+
else:
|
110 |
+
token = kwargs.get("token", None)
|
111 |
+
if token is None:
|
112 |
+
token = kwargs.get("use_auth_token", None)
|
113 |
+
|
114 |
+
tokenizer_exists = check_file_exists_on_hf_hub(
|
115 |
+
repo_id=pretrained_model_name_or_path,
|
116 |
+
filename=TOKENIZER_CONFIG_NAME,
|
117 |
+
revision=kwargs.get("revision", None),
|
118 |
+
repo_type=kwargs.get("repo_type", None),
|
119 |
+
token=token,
|
120 |
+
)
|
121 |
+
|
122 |
+
if tokenizer_exists:
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
124 |
+
pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False)
|
125 |
+
)
|
126 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
127 |
+
|
128 |
+
return cls._target_peft_class.from_pretrained(
|
129 |
+
base_model,
|
130 |
+
pretrained_model_name_or_path,
|
131 |
+
adapter_name=adapter_name,
|
132 |
+
is_trainable=is_trainable,
|
133 |
+
config=config,
|
134 |
+
**kwargs,
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
class AutoPeftModel(_BaseAutoPeftModel):
|
139 |
+
_target_class = None
|
140 |
+
_target_peft_class = PeftModel
|
141 |
+
|
142 |
+
|
143 |
+
class AutoPeftModelForCausalLM(_BaseAutoPeftModel):
|
144 |
+
_target_class = AutoModelForCausalLM
|
145 |
+
_target_peft_class = PeftModelForCausalLM
|
146 |
+
|
147 |
+
|
148 |
+
class AutoPeftModelForSeq2SeqLM(_BaseAutoPeftModel):
|
149 |
+
_target_class = AutoModelForSeq2SeqLM
|
150 |
+
_target_peft_class = PeftModelForSeq2SeqLM
|
151 |
+
|
152 |
+
|
153 |
+
class AutoPeftModelForSequenceClassification(_BaseAutoPeftModel):
|
154 |
+
_target_class = AutoModelForSequenceClassification
|
155 |
+
_target_peft_class = PeftModelForSequenceClassification
|
156 |
+
|
157 |
+
|
158 |
+
class AutoPeftModelForTokenClassification(_BaseAutoPeftModel):
|
159 |
+
_target_class = AutoModelForTokenClassification
|
160 |
+
_target_peft_class = PeftModelForTokenClassification
|
161 |
+
|
162 |
+
|
163 |
+
class AutoPeftModelForQuestionAnswering(_BaseAutoPeftModel):
|
164 |
+
_target_class = AutoModelForQuestionAnswering
|
165 |
+
_target_peft_class = PeftModelForQuestionAnswering
|
166 |
+
|
167 |
+
|
168 |
+
class AutoPeftModelForFeatureExtraction(_BaseAutoPeftModel):
|
169 |
+
_target_class = AutoModel
|
170 |
+
_target_peft_class = PeftModelForFeatureExtraction
|
MoRA/peft_mora/config.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import asdict, dataclass, field
|
18 |
+
from typing import Dict, Optional, Union
|
19 |
+
|
20 |
+
from huggingface_hub import hf_hub_download
|
21 |
+
from transformers.utils import PushToHubMixin
|
22 |
+
|
23 |
+
from .utils import CONFIG_NAME, PeftType, TaskType
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class PeftConfigMixin(PushToHubMixin):
|
28 |
+
r"""
|
29 |
+
This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all
|
30 |
+
PEFT adapter models. This class inherits from [`~transformers.utils.PushToHubMixin`] which contains the methods to
|
31 |
+
push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a
|
32 |
+
directory. The method `from_pretrained` will load the configuration of your adapter model from a directory.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
36 |
+
"""
|
37 |
+
|
38 |
+
peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."})
|
39 |
+
auto_mapping: Optional[dict] = field(
|
40 |
+
default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
|
41 |
+
)
|
42 |
+
|
43 |
+
def to_dict(self) -> Dict:
|
44 |
+
r"""
|
45 |
+
Returns the configuration for your adapter model as a dictionary.
|
46 |
+
"""
|
47 |
+
return asdict(self)
|
48 |
+
|
49 |
+
def save_pretrained(self, save_directory: str, **kwargs) -> None:
|
50 |
+
r"""
|
51 |
+
This method saves the configuration of your adapter model in a directory.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
save_directory (`str`):
|
55 |
+
The directory where the configuration will be saved.
|
56 |
+
kwargs (additional keyword arguments, *optional*):
|
57 |
+
Additional keyword arguments passed along to the [`~transformers.utils.PushToHubMixin.push_to_hub`]
|
58 |
+
method.
|
59 |
+
"""
|
60 |
+
if os.path.isfile(save_directory):
|
61 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
62 |
+
|
63 |
+
os.makedirs(save_directory, exist_ok=True)
|
64 |
+
auto_mapping_dict = kwargs.pop("auto_mapping_dict", None)
|
65 |
+
|
66 |
+
output_dict = asdict(self)
|
67 |
+
# converting set type to list
|
68 |
+
for key, value in output_dict.items():
|
69 |
+
if isinstance(value, set):
|
70 |
+
output_dict[key] = list(value)
|
71 |
+
|
72 |
+
output_path = os.path.join(save_directory, CONFIG_NAME)
|
73 |
+
|
74 |
+
# Add auto mapping details for custom models.
|
75 |
+
if auto_mapping_dict is not None:
|
76 |
+
output_dict["auto_mapping"] = auto_mapping_dict
|
77 |
+
|
78 |
+
# save it
|
79 |
+
with open(output_path, "w") as writer:
|
80 |
+
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_peft_type(cls, **kwargs):
|
84 |
+
r"""
|
85 |
+
This method loads the configuration of your adapter model from a set of kwargs.
|
86 |
+
|
87 |
+
The appropriate configuration type is determined by the `peft_type` argument. If `peft_type` is not provided,
|
88 |
+
the calling class type is instantiated.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
kwargs (configuration keyword arguments):
|
92 |
+
Keyword arguments passed along to the configuration initialization.
|
93 |
+
"""
|
94 |
+
# Avoid circular dependency .. TODO: fix this with a larger refactor
|
95 |
+
from peft_mora.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
96 |
+
|
97 |
+
# TODO: this hack is needed to fix the following issue (on commit 702f937):
|
98 |
+
# if someone saves a default config and loads it back with `PeftConfig` class it yields to
|
99 |
+
# not loading the correct config class.
|
100 |
+
|
101 |
+
# from peft import AdaLoraConfig, PeftConfig
|
102 |
+
# peft_config = AdaLoraConfig()
|
103 |
+
# print(peft_config)
|
104 |
+
# >>> AdaLoraConfig(peft_type=<PeftType.ADALORA: 'ADALORA'>, auto_mapping=None, base_model_name_or_path=None,
|
105 |
+
# revision=None, task_type=None, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, ...
|
106 |
+
#
|
107 |
+
# peft_config.save_pretrained("./test_config")
|
108 |
+
# peft_config = PeftConfig.from_pretrained("./test_config")
|
109 |
+
# print(peft_config)
|
110 |
+
# >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False)
|
111 |
+
|
112 |
+
if "peft_type" in kwargs:
|
113 |
+
peft_type = kwargs["peft_type"]
|
114 |
+
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
|
115 |
+
else:
|
116 |
+
config_cls = cls
|
117 |
+
|
118 |
+
return config_cls(**kwargs)
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional[str] = None, **kwargs):
|
122 |
+
r"""
|
123 |
+
This method loads the configuration of your adapter model from a directory.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
pretrained_model_name_or_path (`str`):
|
127 |
+
The directory or the Hub repository id where the configuration is saved.
|
128 |
+
kwargs (additional keyword arguments, *optional*):
|
129 |
+
Additional keyword arguments passed along to the child class initialization.
|
130 |
+
"""
|
131 |
+
path = (
|
132 |
+
os.path.join(pretrained_model_name_or_path, subfolder)
|
133 |
+
if subfolder is not None
|
134 |
+
else pretrained_model_name_or_path
|
135 |
+
)
|
136 |
+
|
137 |
+
hf_hub_download_kwargs, class_kwargs, _ = cls._split_kwargs(kwargs)
|
138 |
+
|
139 |
+
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
140 |
+
config_file = os.path.join(path, CONFIG_NAME)
|
141 |
+
else:
|
142 |
+
try:
|
143 |
+
config_file = hf_hub_download(
|
144 |
+
pretrained_model_name_or_path, CONFIG_NAME, subfolder=subfolder, **hf_hub_download_kwargs
|
145 |
+
)
|
146 |
+
except Exception:
|
147 |
+
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{pretrained_model_name_or_path}'")
|
148 |
+
|
149 |
+
loaded_attributes = cls.from_json_file(config_file)
|
150 |
+
kwargs = {**class_kwargs, **loaded_attributes}
|
151 |
+
return cls.from_peft_type(**kwargs)
|
152 |
+
|
153 |
+
@classmethod
|
154 |
+
def from_json_file(cls, path_json_file: str, **kwargs):
|
155 |
+
r"""
|
156 |
+
Loads a configuration file from a json file.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
path_json_file (`str`):
|
160 |
+
The path to the json file.
|
161 |
+
"""
|
162 |
+
with open(path_json_file) as file:
|
163 |
+
json_object = json.load(file)
|
164 |
+
|
165 |
+
return json_object
|
166 |
+
|
167 |
+
@classmethod
|
168 |
+
def _split_kwargs(cls, kwargs):
|
169 |
+
hf_hub_download_kwargs = {}
|
170 |
+
class_kwargs = {}
|
171 |
+
other_kwargs = {}
|
172 |
+
|
173 |
+
for key, value in kwargs.items():
|
174 |
+
if key in inspect.signature(hf_hub_download).parameters:
|
175 |
+
hf_hub_download_kwargs[key] = value
|
176 |
+
elif key in list(cls.__annotations__):
|
177 |
+
class_kwargs[key] = value
|
178 |
+
else:
|
179 |
+
other_kwargs[key] = value
|
180 |
+
|
181 |
+
return hf_hub_download_kwargs, class_kwargs, other_kwargs
|
182 |
+
|
183 |
+
@classmethod
|
184 |
+
def _get_peft_type(
|
185 |
+
cls,
|
186 |
+
model_id: str,
|
187 |
+
**hf_hub_download_kwargs,
|
188 |
+
):
|
189 |
+
subfolder = hf_hub_download_kwargs.get("subfolder", None)
|
190 |
+
|
191 |
+
path = os.path.join(model_id, subfolder) if subfolder is not None else model_id
|
192 |
+
|
193 |
+
if os.path.isfile(os.path.join(path, CONFIG_NAME)):
|
194 |
+
config_file = os.path.join(path, CONFIG_NAME)
|
195 |
+
else:
|
196 |
+
try:
|
197 |
+
config_file = hf_hub_download(
|
198 |
+
model_id,
|
199 |
+
CONFIG_NAME,
|
200 |
+
**hf_hub_download_kwargs,
|
201 |
+
)
|
202 |
+
except Exception:
|
203 |
+
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
|
204 |
+
|
205 |
+
loaded_attributes = cls.from_json_file(config_file)
|
206 |
+
return loaded_attributes["peft_type"]
|
207 |
+
|
208 |
+
@property
|
209 |
+
def is_prompt_learning(self) -> bool:
|
210 |
+
r"""
|
211 |
+
Utility method to check if the configuration is for prompt learning.
|
212 |
+
"""
|
213 |
+
return False
|
214 |
+
|
215 |
+
@property
|
216 |
+
def is_adaption_prompt(self) -> bool:
|
217 |
+
"""Return True if this is an adaption prompt config."""
|
218 |
+
return False
|
219 |
+
|
220 |
+
|
221 |
+
@dataclass
|
222 |
+
class PeftConfig(PeftConfigMixin):
|
223 |
+
"""
|
224 |
+
This is the base configuration class to store the configuration of a [`PeftModel`].
|
225 |
+
|
226 |
+
Args:
|
227 |
+
peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
|
228 |
+
task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
|
229 |
+
inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.
|
230 |
+
"""
|
231 |
+
|
232 |
+
base_model_name_or_path: Optional[str] = field(
|
233 |
+
default=None, metadata={"help": "The name of the base model to use."}
|
234 |
+
)
|
235 |
+
revision: Optional[str] = field(default=None, metadata={"help": "The specific model version to use."})
|
236 |
+
peft_type: Optional[Union[str, PeftType]] = field(default=None, metadata={"help": "Peft type"})
|
237 |
+
task_type: Optional[Union[str, TaskType]] = field(default=None, metadata={"help": "Task type"})
|
238 |
+
inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"})
|
239 |
+
|
240 |
+
|
241 |
+
@dataclass
|
242 |
+
class PromptLearningConfig(PeftConfig):
|
243 |
+
"""
|
244 |
+
This is the base configuration class to store the configuration of [`PrefixTuning`], [`PromptEncoder`], or
|
245 |
+
[`PromptTuning`].
|
246 |
+
|
247 |
+
Args:
|
248 |
+
num_virtual_tokens (`int`): The number of virtual tokens to use.
|
249 |
+
token_dim (`int`): The hidden embedding dimension of the base transformer model.
|
250 |
+
num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model.
|
251 |
+
num_attention_heads (`int`): The number of attention heads in the base transformer model.
|
252 |
+
num_layers (`int`): The number of layers in the base transformer model.
|
253 |
+
"""
|
254 |
+
|
255 |
+
num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"})
|
256 |
+
token_dim: int = field(
|
257 |
+
default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"}
|
258 |
+
)
|
259 |
+
num_transformer_submodules: Optional[int] = field(
|
260 |
+
default=None, metadata={"help": "Number of transformer submodules"}
|
261 |
+
)
|
262 |
+
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
|
263 |
+
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})
|
264 |
+
|
265 |
+
@property
|
266 |
+
def is_prompt_learning(self) -> bool:
|
267 |
+
r"""
|
268 |
+
Utility method to check if the configuration is for prompt learning.
|
269 |
+
"""
|
270 |
+
return True
|
MoRA/peft_mora/helpers.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from copy import deepcopy
|
3 |
+
from functools import update_wrapper
|
4 |
+
from types import MethodType
|
5 |
+
|
6 |
+
from .peft_model import PeftModel
|
7 |
+
|
8 |
+
|
9 |
+
def update_forward_signature(model: PeftModel) -> None:
|
10 |
+
"""
|
11 |
+
Args:
|
12 |
+
Updates the forward signature of the PeftModel to include parents class signature
|
13 |
+
model (`PeftModel`): Peft model to update the forward signature
|
14 |
+
Example:
|
15 |
+
|
16 |
+
```python
|
17 |
+
>>> from transformers import WhisperForConditionalGeneration
|
18 |
+
>>> from peft import get_peft_model, LoraConfig, update_forward_signature
|
19 |
+
|
20 |
+
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
21 |
+
>>> peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj"])
|
22 |
+
|
23 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
24 |
+
>>> update_forward_signature(peft_model)
|
25 |
+
```
|
26 |
+
"""
|
27 |
+
|
28 |
+
# Only update signature when the current forward signature only has *args and **kwargs
|
29 |
+
current_signature = inspect.signature(model.forward)
|
30 |
+
if (
|
31 |
+
len(current_signature.parameters) == 2
|
32 |
+
and "args" in current_signature.parameters
|
33 |
+
and "kwargs" in current_signature.parameters
|
34 |
+
):
|
35 |
+
forward = deepcopy(model.forward.__func__)
|
36 |
+
update_wrapper(
|
37 |
+
forward, type(model.get_base_model()).forward, assigned=("__doc__", "__name__", "__annotations__")
|
38 |
+
)
|
39 |
+
model.forward = MethodType(forward, model)
|
40 |
+
|
41 |
+
|
42 |
+
def update_generate_signature(model: PeftModel) -> None:
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
Updates the generate signature of a PeftModel with overriding generate to include parents class signature
|
46 |
+
model (`PeftModel`): Peft model to update the generate signature
|
47 |
+
Example:
|
48 |
+
|
49 |
+
```python
|
50 |
+
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
51 |
+
>>> from peft import get_peft_model, LoraConfig, TaskType, update_generate_signature
|
52 |
+
|
53 |
+
>>> model_name_or_path = "bigscience/mt0-large"
|
54 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
55 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
56 |
+
|
57 |
+
>>> peft_config = LoraConfig(
|
58 |
+
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
59 |
+
... )
|
60 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
61 |
+
>>> update_generate_signature(peft_model)
|
62 |
+
>>> help(peft_model.generate)
|
63 |
+
```
|
64 |
+
"""
|
65 |
+
if not hasattr(model, "generate"):
|
66 |
+
return
|
67 |
+
current_signature = inspect.signature(model.generate)
|
68 |
+
if (
|
69 |
+
len(current_signature.parameters) == 2
|
70 |
+
and "args" in current_signature.parameters
|
71 |
+
and "kwargs" in current_signature.parameters
|
72 |
+
) or (len(current_signature.parameters) == 1 and "kwargs" in current_signature.parameters):
|
73 |
+
generate = deepcopy(model.generate.__func__)
|
74 |
+
update_wrapper(
|
75 |
+
generate,
|
76 |
+
type(model.get_base_model()).generate,
|
77 |
+
assigned=("__doc__", "__name__", "__annotations__"),
|
78 |
+
)
|
79 |
+
model.generate = MethodType(generate, model)
|
80 |
+
|
81 |
+
|
82 |
+
def update_signature(model: PeftModel, method: str = "all") -> None:
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
Updates the signature of a PeftModel include parents class signature for forward or generate method
|
86 |
+
model (`PeftModel`): Peft model to update generate or forward signature method (`str`): method to update
|
87 |
+
signature choose one of "forward", "generate", "all"
|
88 |
+
Example:
|
89 |
+
```python
|
90 |
+
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
91 |
+
>>> from peft import get_peft_model, LoraConfig, TaskType, update_signature
|
92 |
+
|
93 |
+
>>> model_name_or_path = "bigscience/mt0-large"
|
94 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
95 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
|
96 |
+
|
97 |
+
>>> peft_config = LoraConfig(
|
98 |
+
... task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
|
99 |
+
... )
|
100 |
+
>>> peft_model = get_peft_model(model, peft_config)
|
101 |
+
>>> update_signature(peft_model)
|
102 |
+
>>> help(peft_model.generate)
|
103 |
+
```
|
104 |
+
"""
|
105 |
+
if method == "forward":
|
106 |
+
update_forward_signature(model)
|
107 |
+
elif method == "generate":
|
108 |
+
update_generate_signature(model)
|
109 |
+
elif method == "all":
|
110 |
+
update_forward_signature(model)
|
111 |
+
update_generate_signature(model)
|
112 |
+
else:
|
113 |
+
raise ValueError(f"method {method} is not supported please choose one of ['forward', 'generate', 'all']")
|
MoRA/peft_mora/import_utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
import importlib.metadata as importlib_metadata
|
16 |
+
from functools import lru_cache
|
17 |
+
|
18 |
+
import packaging.version
|
19 |
+
|
20 |
+
|
21 |
+
def is_bnb_available() -> bool:
|
22 |
+
return importlib.util.find_spec("bitsandbytes") is not None
|
23 |
+
|
24 |
+
|
25 |
+
def is_bnb_4bit_available() -> bool:
|
26 |
+
if not is_bnb_available():
|
27 |
+
return False
|
28 |
+
|
29 |
+
import bitsandbytes as bnb
|
30 |
+
|
31 |
+
return hasattr(bnb.nn, "Linear4bit")
|
32 |
+
|
33 |
+
|
34 |
+
def is_auto_gptq_available():
|
35 |
+
if importlib.util.find_spec("auto_gptq") is not None:
|
36 |
+
AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0")
|
37 |
+
version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
|
38 |
+
if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
|
39 |
+
return True
|
40 |
+
else:
|
41 |
+
raise ImportError(
|
42 |
+
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, "
|
43 |
+
f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported"
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def is_optimum_available() -> bool:
|
48 |
+
return importlib.util.find_spec("optimum") is not None
|
49 |
+
|
50 |
+
|
51 |
+
@lru_cache
|
52 |
+
def is_torch_tpu_available(check_device=True):
|
53 |
+
"Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
|
54 |
+
if importlib.util.find_spec("torch_xla") is not None:
|
55 |
+
if check_device:
|
56 |
+
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
|
57 |
+
try:
|
58 |
+
import torch_xla.core.xla_model as xm
|
59 |
+
|
60 |
+
_ = xm.xla_device()
|
61 |
+
return True
|
62 |
+
except RuntimeError:
|
63 |
+
return False
|
64 |
+
return True
|
65 |
+
return False
|
66 |
+
|
67 |
+
|
68 |
+
def is_aqlm_available():
|
69 |
+
return importlib.util.find_spec("aqlm") is not None
|
70 |
+
|
71 |
+
|
72 |
+
def is_auto_awq_available():
|
73 |
+
return importlib.util.find_spec("awq") is not None
|
MoRA/peft_mora/mapping.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
from typing import TYPE_CHECKING, Any
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from .config import PeftConfig
|
22 |
+
from .mixed_model import PeftMixedModel
|
23 |
+
from .peft_model import (
|
24 |
+
PeftModel,
|
25 |
+
PeftModelForCausalLM,
|
26 |
+
PeftModelForFeatureExtraction,
|
27 |
+
PeftModelForQuestionAnswering,
|
28 |
+
PeftModelForSeq2SeqLM,
|
29 |
+
PeftModelForSequenceClassification,
|
30 |
+
PeftModelForTokenClassification,
|
31 |
+
)
|
32 |
+
from .tuners import (
|
33 |
+
AdaLoraConfig,
|
34 |
+
AdaLoraModel,
|
35 |
+
AdaptionPromptConfig,
|
36 |
+
IA3Config,
|
37 |
+
IA3Model,
|
38 |
+
LoHaConfig,
|
39 |
+
LoHaModel,
|
40 |
+
LoKrConfig,
|
41 |
+
LoKrModel,
|
42 |
+
LoraConfig,
|
43 |
+
LoraModel,
|
44 |
+
MultitaskPromptTuningConfig,
|
45 |
+
OFTConfig,
|
46 |
+
OFTModel,
|
47 |
+
PolyConfig,
|
48 |
+
PolyModel,
|
49 |
+
PrefixTuningConfig,
|
50 |
+
PromptEncoderConfig,
|
51 |
+
PromptTuningConfig,
|
52 |
+
)
|
53 |
+
from .utils import _prepare_prompt_learning_config
|
54 |
+
|
55 |
+
|
56 |
+
if TYPE_CHECKING:
|
57 |
+
from transformers import PreTrainedModel
|
58 |
+
|
59 |
+
|
60 |
+
MODEL_TYPE_TO_PEFT_MODEL_MAPPING: dict[str, PeftModel] = {
|
61 |
+
"SEQ_CLS": PeftModelForSequenceClassification,
|
62 |
+
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
|
63 |
+
"CAUSAL_LM": PeftModelForCausalLM,
|
64 |
+
"TOKEN_CLS": PeftModelForTokenClassification,
|
65 |
+
"QUESTION_ANS": PeftModelForQuestionAnswering,
|
66 |
+
"FEATURE_EXTRACTION": PeftModelForFeatureExtraction,
|
67 |
+
}
|
68 |
+
|
69 |
+
PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, PeftConfig] = {
|
70 |
+
"ADAPTION_PROMPT": AdaptionPromptConfig,
|
71 |
+
"PROMPT_TUNING": PromptTuningConfig,
|
72 |
+
"PREFIX_TUNING": PrefixTuningConfig,
|
73 |
+
"P_TUNING": PromptEncoderConfig,
|
74 |
+
"LORA": LoraConfig,
|
75 |
+
"LOHA": LoHaConfig,
|
76 |
+
"LOKR": LoKrConfig,
|
77 |
+
"ADALORA": AdaLoraConfig,
|
78 |
+
"IA3": IA3Config,
|
79 |
+
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
|
80 |
+
"OFT": OFTConfig,
|
81 |
+
"POLY": PolyConfig,
|
82 |
+
}
|
83 |
+
|
84 |
+
PEFT_TYPE_TO_TUNER_MAPPING = {
|
85 |
+
"LORA": LoraModel,
|
86 |
+
"LOHA": LoHaModel,
|
87 |
+
"LOKR": LoKrModel,
|
88 |
+
"ADALORA": AdaLoraModel,
|
89 |
+
"IA3": IA3Model,
|
90 |
+
"OFT": OFTModel,
|
91 |
+
"POLY": PolyModel,
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
def get_peft_config(config_dict: dict[str, Any]) -> PeftConfig:
|
96 |
+
"""
|
97 |
+
Returns a Peft config object from a dictionary.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters.
|
101 |
+
"""
|
102 |
+
|
103 |
+
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)
|
104 |
+
|
105 |
+
|
106 |
+
def get_peft_model(
|
107 |
+
model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
|
108 |
+
) -> PeftModel | PeftMixedModel:
|
109 |
+
"""
|
110 |
+
Returns a Peft model object from a model and a config.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
model ([`transformers.PreTrainedModel`]):
|
114 |
+
Model to be wrapped.
|
115 |
+
peft_config ([`PeftConfig`]):
|
116 |
+
Configuration object containing the parameters of the Peft model.
|
117 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
118 |
+
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
119 |
+
mixed (`bool`, `optional`, defaults to `False`):
|
120 |
+
Whether to allow mixing different (compatible) adapter types.
|
121 |
+
"""
|
122 |
+
model_config = getattr(model, "config", {"model_type": "custom"})
|
123 |
+
if hasattr(model_config, "to_dict"):
|
124 |
+
model_config = model_config.to_dict()
|
125 |
+
|
126 |
+
peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)
|
127 |
+
|
128 |
+
if mixed:
|
129 |
+
return PeftMixedModel(model, peft_config, adapter_name=adapter_name)
|
130 |
+
|
131 |
+
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
|
132 |
+
return PeftModel(model, peft_config, adapter_name=adapter_name)
|
133 |
+
|
134 |
+
if peft_config.is_prompt_learning:
|
135 |
+
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
|
136 |
+
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)
|
137 |
+
|
138 |
+
|
139 |
+
def inject_adapter_in_model(
|
140 |
+
peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str = "default"
|
141 |
+
) -> torch.nn.Module:
|
142 |
+
r"""
|
143 |
+
A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning
|
144 |
+
methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API
|
145 |
+
calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
peft_config (`PeftConfig`):
|
149 |
+
Configuration object containing the parameters of the Peft model.
|
150 |
+
model (`torch.nn.Module`):
|
151 |
+
The input model where the adapter will be injected.
|
152 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
153 |
+
The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
|
154 |
+
"""
|
155 |
+
if peft_config.is_prompt_learning or peft_config.is_adaption_prompt:
|
156 |
+
raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.")
|
157 |
+
|
158 |
+
if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys():
|
159 |
+
raise ValueError(
|
160 |
+
f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`."
|
161 |
+
)
|
162 |
+
|
163 |
+
tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type]
|
164 |
+
|
165 |
+
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
166 |
+
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name)
|
167 |
+
|
168 |
+
return peft_model.model
|
MoRA/peft_mora/mixed_model.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import os
|
18 |
+
from contextlib import contextmanager
|
19 |
+
from typing import Any, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from accelerate.hooks import remove_hook_from_submodules
|
23 |
+
from torch import nn
|
24 |
+
from transformers.utils import PushToHubMixin
|
25 |
+
|
26 |
+
from peft_mora.tuners.mixed import COMPATIBLE_TUNER_TYPES
|
27 |
+
|
28 |
+
from .config import PeftConfig
|
29 |
+
from .peft_model import PeftModel
|
30 |
+
from .tuners import (
|
31 |
+
AdaLoraModel,
|
32 |
+
IA3Model,
|
33 |
+
LoHaModel,
|
34 |
+
LoKrModel,
|
35 |
+
LoraModel,
|
36 |
+
MixedModel,
|
37 |
+
OFTModel,
|
38 |
+
)
|
39 |
+
from .utils import PeftType, _set_adapter, _set_trainable
|
40 |
+
|
41 |
+
|
42 |
+
PEFT_TYPE_TO_MODEL_MAPPING = {
|
43 |
+
PeftType.LORA: LoraModel,
|
44 |
+
PeftType.LOHA: LoHaModel,
|
45 |
+
PeftType.LOKR: LoKrModel,
|
46 |
+
PeftType.ADALORA: AdaLoraModel,
|
47 |
+
PeftType.IA3: IA3Model,
|
48 |
+
PeftType.OFT: OFTModel,
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None:
|
53 |
+
r"""
|
54 |
+
Prepares the model for gradient checkpointing if necessary
|
55 |
+
"""
|
56 |
+
# Note: same as PeftModel._prepare_model_for_gradient_checkpointing
|
57 |
+
if not getattr(model, "is_gradient_checkpointing", True):
|
58 |
+
return model
|
59 |
+
|
60 |
+
if not (
|
61 |
+
getattr(model, "is_loaded_in_8bit", False)
|
62 |
+
or getattr(model, "is_loaded_in_4bit", False)
|
63 |
+
or getattr(model, "is_quantized", False)
|
64 |
+
):
|
65 |
+
if hasattr(model, "enable_input_require_grads"):
|
66 |
+
model.enable_input_require_grads()
|
67 |
+
elif hasattr(model, "get_input_embeddings"):
|
68 |
+
|
69 |
+
def make_inputs_require_grad(module, input, output):
|
70 |
+
output.requires_grad_(True)
|
71 |
+
|
72 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
73 |
+
|
74 |
+
|
75 |
+
def _check_config_compatible(peft_config: PeftConfig) -> None:
|
76 |
+
if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES:
|
77 |
+
raise ValueError(
|
78 |
+
f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. "
|
79 |
+
f"Compatible types are: {COMPATIBLE_TUNER_TYPES}"
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class PeftMixedModel(PushToHubMixin, torch.nn.Module):
|
84 |
+
"""
|
85 |
+
PeftMixedModel for loading mixing different types of adapters for inference.
|
86 |
+
|
87 |
+
This class does not support loading/saving, and it shouldn't usually be initialized directly. Instead, use
|
88 |
+
`get_peft_model` with the argument `mixed=True`.
|
89 |
+
|
90 |
+
<Tip>
|
91 |
+
|
92 |
+
Read the [Mixed adapter types](https://huggingface.co/docs/peft/en/developer_guides/mixed_models) guide to learn
|
93 |
+
more about using different adapter types.
|
94 |
+
|
95 |
+
</Tip>
|
96 |
+
|
97 |
+
Example:
|
98 |
+
|
99 |
+
```py
|
100 |
+
>>> from peft import get_peft_model
|
101 |
+
|
102 |
+
>>> base_model = ... # load the base model, e.g. from transformers
|
103 |
+
>>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval()
|
104 |
+
>>> peft_model.load_adapter(path_to_adapter2, "adapter2")
|
105 |
+
>>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters
|
106 |
+
>>> peft_model(data) # forward pass using both adapters
|
107 |
+
```
|
108 |
+
|
109 |
+
Args:
|
110 |
+
model (`torch.nn.Module`):
|
111 |
+
The model to be tuned.
|
112 |
+
config (`PeftConfig`):
|
113 |
+
The config of the model to be tuned. The adapter type must be compatible.
|
114 |
+
adapter_name (`str`, `optional`, defaults to `"default"`):
|
115 |
+
The name of the first adapter.
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
119 |
+
super().__init__()
|
120 |
+
_check_config_compatible(peft_config)
|
121 |
+
_prepare_model_for_gradient_checkpointing(model)
|
122 |
+
self.modules_to_save = None
|
123 |
+
self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name)
|
124 |
+
self.set_modules_to_save(peft_config, adapter_name)
|
125 |
+
|
126 |
+
self.config = getattr(model, "config", {"model_type": "custom"})
|
127 |
+
|
128 |
+
# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
|
129 |
+
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
|
130 |
+
# behavior we disable that in this line.
|
131 |
+
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
|
132 |
+
self.base_model.config.pretraining_tp = 1
|
133 |
+
|
134 |
+
@property
|
135 |
+
def peft_config(self) -> dict[str, PeftConfig]:
|
136 |
+
return self.base_model.peft_config
|
137 |
+
|
138 |
+
@property
|
139 |
+
def active_adapter(self) -> str:
|
140 |
+
return self.base_model.active_adapter
|
141 |
+
|
142 |
+
@property
|
143 |
+
def active_adapters(self) -> list[str]:
|
144 |
+
return self.base_model.active_adapters
|
145 |
+
|
146 |
+
def get_nb_trainable_parameters(self):
|
147 |
+
r"""
|
148 |
+
Returns the number of trainable parameters and number of all parameters in the model.
|
149 |
+
"""
|
150 |
+
# note: same as PeftModel.get_nb_trainable_parameters
|
151 |
+
trainable_params = 0
|
152 |
+
all_param = 0
|
153 |
+
for _, param in self.named_parameters():
|
154 |
+
num_params = param.numel()
|
155 |
+
# if using DS Zero 3 and the weights are initialized empty
|
156 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
157 |
+
num_params = param.ds_numel
|
158 |
+
|
159 |
+
# Due to the design of 4bit linear layers from bitsandbytes
|
160 |
+
# one needs to multiply the number of parameters by 2 to get
|
161 |
+
# the correct number of parameters
|
162 |
+
if param.__class__.__name__ == "Params4bit":
|
163 |
+
num_params = num_params * 2
|
164 |
+
|
165 |
+
all_param += num_params
|
166 |
+
if param.requires_grad:
|
167 |
+
trainable_params += num_params
|
168 |
+
|
169 |
+
return trainable_params, all_param
|
170 |
+
|
171 |
+
def print_trainable_parameters(self):
|
172 |
+
"""
|
173 |
+
Prints the number of trainable parameters in the model.
|
174 |
+
"""
|
175 |
+
# note: same as PeftModel.print_trainable_parameters
|
176 |
+
trainable_params, all_param = self.get_nb_trainable_parameters()
|
177 |
+
|
178 |
+
print(
|
179 |
+
f"trainable params: {trainable_params:,d} || "
|
180 |
+
f"all params: {all_param:,d} || "
|
181 |
+
f"trainable%: {100 * trainable_params / all_param:.4f}"
|
182 |
+
)
|
183 |
+
|
184 |
+
def __getattr__(self, name: str):
|
185 |
+
"""Forward missing attributes to the wrapped module."""
|
186 |
+
try:
|
187 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
188 |
+
except AttributeError:
|
189 |
+
return getattr(self.base_model, name)
|
190 |
+
|
191 |
+
def forward(self, *args: Any, **kwargs: Any):
|
192 |
+
"""
|
193 |
+
Forward pass of the model.
|
194 |
+
"""
|
195 |
+
return self.base_model(*args, **kwargs)
|
196 |
+
|
197 |
+
def generate(self, *args: Any, **kwargs: Any):
|
198 |
+
"""
|
199 |
+
Generate output.
|
200 |
+
"""
|
201 |
+
return self.base_model.generate(*args, **kwargs)
|
202 |
+
|
203 |
+
@contextmanager
|
204 |
+
def disable_adapter(self):
|
205 |
+
"""
|
206 |
+
Disables the adapter module.
|
207 |
+
"""
|
208 |
+
try:
|
209 |
+
self.base_model.disable_adapter_layers()
|
210 |
+
yield
|
211 |
+
finally:
|
212 |
+
self.base_model.enable_adapter_layers()
|
213 |
+
|
214 |
+
def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
|
215 |
+
_check_config_compatible(peft_config)
|
216 |
+
|
217 |
+
try:
|
218 |
+
self.peft_config[adapter_name] = peft_config
|
219 |
+
self.base_model.inject_adapter(self, adapter_name)
|
220 |
+
except Exception: # something went wrong, roll back
|
221 |
+
if adapter_name in self.peft_config:
|
222 |
+
del self.peft_config[adapter_name]
|
223 |
+
raise
|
224 |
+
|
225 |
+
self.set_modules_to_save(peft_config, adapter_name)
|
226 |
+
|
227 |
+
def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None:
|
228 |
+
if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None:
|
229 |
+
return
|
230 |
+
|
231 |
+
if self.modules_to_save is None:
|
232 |
+
self.modules_to_save = set(modules_to_save)
|
233 |
+
else:
|
234 |
+
self.modules_to_save.update(modules_to_save)
|
235 |
+
_set_trainable(self, adapter_name)
|
236 |
+
|
237 |
+
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None:
|
238 |
+
"""
|
239 |
+
Sets the active adapter(s) for the model.
|
240 |
+
|
241 |
+
Note that the order in which the adapters are applied during the forward pass may not be the same as the order
|
242 |
+
in which they are passed to this function. Instead, the order during the forward pass is determined by the
|
243 |
+
order in which the adapters were loaded into the model. The active adapters only determine which adapters are
|
244 |
+
active during the forward pass, but not the order in which they are applied.
|
245 |
+
|
246 |
+
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
|
247 |
+
not desired, use the following code.
|
248 |
+
|
249 |
+
```py
|
250 |
+
>>> for name, param in model_peft.named_parameters():
|
251 |
+
... if ...: # some check on name (ex. if 'lora' in name)
|
252 |
+
... param.requires_grad = False
|
253 |
+
```
|
254 |
+
|
255 |
+
Args:
|
256 |
+
adapter_name (`str` or `List[str]`):
|
257 |
+
The name of the adapter(s) to be activated.
|
258 |
+
"""
|
259 |
+
if isinstance(adapter_name, str):
|
260 |
+
adapter_name = [adapter_name]
|
261 |
+
|
262 |
+
mismatched = set(adapter_name) - set(self.peft_config.keys())
|
263 |
+
if mismatched:
|
264 |
+
raise ValueError(
|
265 |
+
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
|
266 |
+
)
|
267 |
+
|
268 |
+
self.base_model.set_adapter(adapter_name)
|
269 |
+
_set_adapter(self, adapter_name)
|
270 |
+
|
271 |
+
def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None:
|
272 |
+
if isinstance(adapter_name, str):
|
273 |
+
adapter_name = [adapter_name]
|
274 |
+
|
275 |
+
mismatched = set(adapter_name) - set(self.peft_config.keys())
|
276 |
+
if mismatched:
|
277 |
+
raise ValueError(
|
278 |
+
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
|
279 |
+
)
|
280 |
+
|
281 |
+
self.base_model.delete_adapter(adapter_name)
|
282 |
+
|
283 |
+
def merge_and_unload(self, *args: Any, **kwargs: Any):
|
284 |
+
r"""
|
285 |
+
This method merges the adapter layers into the base model. This is needed if someone wants to use the base
|
286 |
+
model as a standalone model.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
progressbar (`bool`):
|
290 |
+
whether to show a progressbar indicating the unload and merge process
|
291 |
+
safe_merge (`bool`):
|
292 |
+
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
293 |
+
weights
|
294 |
+
adapter_names (`List[str]`, *optional*):
|
295 |
+
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
296 |
+
to `None`.
|
297 |
+
"""
|
298 |
+
return self.base_model.merge_and_unload(*args, **kwargs)
|
299 |
+
|
300 |
+
def unload(self, *args: Any, **kwargs: Any):
|
301 |
+
"""
|
302 |
+
Gets back the base model by removing all the adapter modules without merging. This gives back the original base
|
303 |
+
model.
|
304 |
+
"""
|
305 |
+
return self.base_model.unload(*args, **kwargs)
|
306 |
+
|
307 |
+
@classmethod
|
308 |
+
def _split_kwargs(cls, kwargs: dict[str, Any]):
|
309 |
+
return PeftModel._split_kwargs(kwargs)
|
310 |
+
|
311 |
+
def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any):
|
312 |
+
output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs)
|
313 |
+
# TODO: not quite clear why this is necessary but tests fail without it
|
314 |
+
self.set_adapter(self.active_adapters)
|
315 |
+
return output
|
316 |
+
|
317 |
+
def create_or_update_model_card(self, output_dir: str):
|
318 |
+
raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).")
|
319 |
+
|
320 |
+
def save_pretrained(
|
321 |
+
self,
|
322 |
+
save_directory: str,
|
323 |
+
safe_serialization: bool = False,
|
324 |
+
selected_adapters: Optional[list[str]] = None,
|
325 |
+
**kwargs: Any,
|
326 |
+
):
|
327 |
+
raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).")
|
328 |
+
|
329 |
+
@classmethod
|
330 |
+
def from_pretrained(
|
331 |
+
cls,
|
332 |
+
model: nn.Module,
|
333 |
+
model_id: str | os.PathLike,
|
334 |
+
adapter_name: str = "default",
|
335 |
+
is_trainable: bool = False,
|
336 |
+
config: Optional[PeftConfig] = None,
|
337 |
+
**kwargs: Any,
|
338 |
+
):
|
339 |
+
r"""
|
340 |
+
Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights.
|
341 |
+
|
342 |
+
Note that the passed `model` may be modified inplace.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
model (`nn.Module`):
|
346 |
+
The model to be adapted.
|
347 |
+
model_id (`str` or `os.PathLike`):
|
348 |
+
The name of the PEFT configuration to use. Can be either:
|
349 |
+
- A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
|
350 |
+
Hub.
|
351 |
+
- A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
|
352 |
+
method (`./my_peft_config_directory/`).
|
353 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
354 |
+
The name of the adapter to be loaded. This is useful for loading multiple adapters.
|
355 |
+
is_trainable (`bool`, *optional*, defaults to `False`):
|
356 |
+
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for
|
357 |
+
inference
|
358 |
+
config ([`~peft.PeftConfig`], *optional*):
|
359 |
+
The configuration object to use instead of an automatically loaded configuration. This configuration
|
360 |
+
object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
|
361 |
+
loaded before calling `from_pretrained`.
|
362 |
+
kwargs: (`optional`):
|
363 |
+
Additional keyword arguments passed along to the specific PEFT configuration class.
|
364 |
+
"""
|
365 |
+
# note: adapted from PeftModel.from_pretrained
|
366 |
+
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
367 |
+
|
368 |
+
# load the config
|
369 |
+
if config is None:
|
370 |
+
config = PEFT_TYPE_TO_CONFIG_MAPPING[
|
371 |
+
PeftConfig._get_peft_type(
|
372 |
+
model_id,
|
373 |
+
subfolder=kwargs.get("subfolder", None),
|
374 |
+
revision=kwargs.get("revision", None),
|
375 |
+
cache_dir=kwargs.get("cache_dir", None),
|
376 |
+
use_auth_token=kwargs.get("use_auth_token", None),
|
377 |
+
)
|
378 |
+
].from_pretrained(model_id, **kwargs)
|
379 |
+
elif isinstance(config, PeftConfig):
|
380 |
+
config.inference_mode = not is_trainable
|
381 |
+
else:
|
382 |
+
raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")
|
383 |
+
|
384 |
+
# note: this is different from PeftModel.from_pretrained
|
385 |
+
if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING:
|
386 |
+
raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.")
|
387 |
+
|
388 |
+
if (getattr(model, "hf_device_map", None) is not None) and len(
|
389 |
+
set(model.hf_device_map.values()).intersection({"cpu", "disk"})
|
390 |
+
) > 0:
|
391 |
+
remove_hook_from_submodules(model)
|
392 |
+
|
393 |
+
if config.is_prompt_learning and is_trainable:
|
394 |
+
# note: should not be possible to reach, but just in case
|
395 |
+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
|
396 |
+
else:
|
397 |
+
config.inference_mode = not is_trainable
|
398 |
+
|
399 |
+
# note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel
|
400 |
+
model = cls(model, config, adapter_name)
|
401 |
+
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
|
402 |
+
return model
|
MoRA/peft_mora/peft_model.py
ADDED
@@ -0,0 +1,1929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from __future__ import annotations
|
16 |
+
|
17 |
+
import collections
|
18 |
+
import inspect
|
19 |
+
import os
|
20 |
+
import warnings
|
21 |
+
from contextlib import contextmanager
|
22 |
+
from copy import deepcopy
|
23 |
+
from typing import Any, Optional, Union
|
24 |
+
|
25 |
+
import packaging.version
|
26 |
+
import torch
|
27 |
+
import transformers
|
28 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
29 |
+
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
|
30 |
+
from accelerate.utils import get_balanced_memory
|
31 |
+
from huggingface_hub import ModelCard, ModelCardData, hf_hub_download
|
32 |
+
from safetensors.torch import save_file as safe_save_file
|
33 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
34 |
+
from transformers import PreTrainedModel
|
35 |
+
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
|
36 |
+
from transformers.utils import PushToHubMixin
|
37 |
+
|
38 |
+
from . import __version__
|
39 |
+
from .config import PeftConfig
|
40 |
+
from .tuners import (
|
41 |
+
AdaLoraModel,
|
42 |
+
AdaptionPromptModel,
|
43 |
+
IA3Model,
|
44 |
+
LoHaModel,
|
45 |
+
LoKrModel,
|
46 |
+
LoraModel,
|
47 |
+
MultitaskPromptEmbedding,
|
48 |
+
OFTModel,
|
49 |
+
PolyModel,
|
50 |
+
PrefixEncoder,
|
51 |
+
PromptEmbedding,
|
52 |
+
PromptEncoder,
|
53 |
+
)
|
54 |
+
from .utils import (
|
55 |
+
SAFETENSORS_WEIGHTS_NAME,
|
56 |
+
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
57 |
+
WEIGHTS_NAME,
|
58 |
+
PeftType,
|
59 |
+
TaskType,
|
60 |
+
_get_batch_size,
|
61 |
+
_prepare_prompt_learning_config,
|
62 |
+
_set_adapter,
|
63 |
+
_set_trainable,
|
64 |
+
get_peft_model_state_dict,
|
65 |
+
id_tensor_storage,
|
66 |
+
infer_device,
|
67 |
+
load_peft_weights,
|
68 |
+
set_peft_model_state_dict,
|
69 |
+
shift_tokens_right,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
PEFT_TYPE_TO_MODEL_MAPPING = {
|
74 |
+
PeftType.LORA: LoraModel,
|
75 |
+
PeftType.LOHA: LoHaModel,
|
76 |
+
PeftType.LOKR: LoKrModel,
|
77 |
+
PeftType.PROMPT_TUNING: PromptEmbedding,
|
78 |
+
PeftType.P_TUNING: PromptEncoder,
|
79 |
+
PeftType.PREFIX_TUNING: PrefixEncoder,
|
80 |
+
PeftType.ADALORA: AdaLoraModel,
|
81 |
+
PeftType.ADAPTION_PROMPT: AdaptionPromptModel,
|
82 |
+
PeftType.IA3: IA3Model,
|
83 |
+
PeftType.OFT: OFTModel,
|
84 |
+
PeftType.POLY: PolyModel,
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
class PeftModel(PushToHubMixin, torch.nn.Module):
|
89 |
+
"""
|
90 |
+
Base model encompassing various Peft methods.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.
|
94 |
+
peft_config ([`PeftConfig`]): The configuration of the Peft model.
|
95 |
+
adapter_name (`str`, *optional*): The name of the adapter, defaults to `"default"`.
|
96 |
+
|
97 |
+
**Attributes**:
|
98 |
+
- **base_model** ([`torch.nn.Module`]) -- The base transformer model used for Peft.
|
99 |
+
- **peft_config** ([`PeftConfig`]) -- The configuration of the Peft model.
|
100 |
+
- **modules_to_save** (`list` of `str`) -- The list of sub-module names to save when
|
101 |
+
saving the model.
|
102 |
+
- **prompt_encoder** ([`PromptEncoder`]) -- The prompt encoder used for Peft if
|
103 |
+
using [`PromptLearningConfig`].
|
104 |
+
- **prompt_tokens** (`torch.Tensor`) -- The virtual prompt tokens used for Peft if
|
105 |
+
using [`PromptLearningConfig`].
|
106 |
+
- **transformer_backbone_name** (`str`) -- The name of the transformer
|
107 |
+
backbone in the base model if using [`PromptLearningConfig`].
|
108 |
+
- **word_embeddings** (`torch.nn.Embedding`) -- The word embeddings of the transformer backbone
|
109 |
+
in the base model if using [`PromptLearningConfig`].
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
113 |
+
super().__init__()
|
114 |
+
self.modules_to_save = None
|
115 |
+
self.active_adapter = adapter_name
|
116 |
+
self.peft_type = peft_config.peft_type
|
117 |
+
|
118 |
+
self._is_prompt_learning = peft_config.is_prompt_learning
|
119 |
+
if self._is_prompt_learning:
|
120 |
+
self._peft_config = {adapter_name: peft_config}
|
121 |
+
self.base_model = model
|
122 |
+
self.add_adapter(adapter_name, peft_config)
|
123 |
+
else:
|
124 |
+
self._peft_config = None
|
125 |
+
cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]
|
126 |
+
self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
|
127 |
+
self.set_additional_trainable_modules(peft_config, adapter_name)
|
128 |
+
|
129 |
+
if getattr(model, "is_gradient_checkpointing", True):
|
130 |
+
model = self._prepare_model_for_gradient_checkpointing(model)
|
131 |
+
|
132 |
+
# the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid
|
133 |
+
# numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected
|
134 |
+
# behavior we disable that in this line.
|
135 |
+
if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
|
136 |
+
self.base_model.config.pretraining_tp = 1
|
137 |
+
|
138 |
+
@property
|
139 |
+
def peft_config(self) -> dict[str, PeftConfig]:
|
140 |
+
if self._is_prompt_learning:
|
141 |
+
return self._peft_config
|
142 |
+
return self.base_model.peft_config
|
143 |
+
|
144 |
+
@property
|
145 |
+
def active_adapters(self) -> list[str]:
|
146 |
+
try:
|
147 |
+
adapters = self.base_model.active_adapters
|
148 |
+
except AttributeError:
|
149 |
+
adapters = self.active_adapter
|
150 |
+
if isinstance(adapters, str):
|
151 |
+
adapters = [adapters]
|
152 |
+
return adapters
|
153 |
+
|
154 |
+
@peft_config.setter
|
155 |
+
def peft_config(self, value: dict[str, PeftConfig]):
|
156 |
+
if self._is_prompt_learning:
|
157 |
+
self._peft_config = value
|
158 |
+
else:
|
159 |
+
self.base_model.peft_config = value
|
160 |
+
|
161 |
+
def save_pretrained(
|
162 |
+
self,
|
163 |
+
save_directory: str,
|
164 |
+
safe_serialization: bool = True,
|
165 |
+
selected_adapters: Optional[list[str]] = None,
|
166 |
+
save_embedding_layers: Union[str, bool] = "auto",
|
167 |
+
is_main_process: bool = True,
|
168 |
+
**kwargs: Any,
|
169 |
+
) -> None:
|
170 |
+
r"""
|
171 |
+
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
|
172 |
+
reloaded using the [`PeftModel.from_pretrained`] class method, and also used by the [`PeftModel.push_to_hub`]
|
173 |
+
method.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
save_directory (`str`):
|
177 |
+
Directory where the adapter model and configuration files will be saved (will be created if it does not
|
178 |
+
exist).
|
179 |
+
safe_serialization (`bool`, *optional*):
|
180 |
+
Whether to save the adapter files in safetensors format, defaults to `True`.
|
181 |
+
selected_adapters (`List[str]`, *optional*):
|
182 |
+
A list of adapters to be saved. If `None`, will default to all adapters.
|
183 |
+
save_embedding_layers (`Union[bool, str]`, *optional*, defaults to `"auto"`):
|
184 |
+
If `True`, save the embedding layers in addition to adapter weights. If `auto`, checks the common
|
185 |
+
embedding layers `peft.utils.other.EMBEDDING_LAYER_NAMES` in config's `target_modules` when available.
|
186 |
+
and automatically sets the boolean flag. This only works for 🤗 transformers models.
|
187 |
+
is_main_process (`bool`, *optional*):
|
188 |
+
Whether the process calling this is the main process or not. Will default to `True`. Will not save the
|
189 |
+
checkpoint if not on the main process, which is important for multi device setups (e.g. DDP).
|
190 |
+
kwargs (additional keyword arguments, *optional*):
|
191 |
+
Additional keyword arguments passed along to the `push_to_hub` method.
|
192 |
+
"""
|
193 |
+
if os.path.isfile(save_directory):
|
194 |
+
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
195 |
+
|
196 |
+
if selected_adapters is None:
|
197 |
+
selected_adapters = list(self.peft_config.keys())
|
198 |
+
else:
|
199 |
+
if any(
|
200 |
+
selected_adapter_name not in list(self.peft_config.keys())
|
201 |
+
for selected_adapter_name in selected_adapters
|
202 |
+
):
|
203 |
+
raise ValueError(
|
204 |
+
f"You passed an invalid `selected_adapters` arguments, current supported adapter names are"
|
205 |
+
f" {list(self.peft_config.keys())} - got {selected_adapters}."
|
206 |
+
)
|
207 |
+
|
208 |
+
if is_main_process:
|
209 |
+
os.makedirs(save_directory, exist_ok=True)
|
210 |
+
self.create_or_update_model_card(save_directory)
|
211 |
+
|
212 |
+
for adapter_name in selected_adapters:
|
213 |
+
peft_config = self.peft_config[adapter_name]
|
214 |
+
# save only the trainable weights
|
215 |
+
output_state_dict = get_peft_model_state_dict(
|
216 |
+
self,
|
217 |
+
state_dict=kwargs.get("state_dict", None),
|
218 |
+
adapter_name=adapter_name,
|
219 |
+
save_embedding_layers=save_embedding_layers,
|
220 |
+
)
|
221 |
+
output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
|
222 |
+
os.makedirs(output_dir, exist_ok=True)
|
223 |
+
|
224 |
+
if is_main_process and safe_serialization:
|
225 |
+
# Section copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2111-L2134
|
226 |
+
# Safetensors does not allow tensor aliasing.
|
227 |
+
# We're going to remove aliases before saving
|
228 |
+
ptrs = collections.defaultdict(list)
|
229 |
+
for name, tensor in output_state_dict.items():
|
230 |
+
# Sometimes in the state_dict we have non-tensor objects.
|
231 |
+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
|
232 |
+
if isinstance(tensor, torch.Tensor):
|
233 |
+
ptrs[id_tensor_storage(tensor)].append(name)
|
234 |
+
else:
|
235 |
+
# In the non-tensor case, fall back to the pointer of the object itself
|
236 |
+
ptrs[id(tensor)].append(name)
|
237 |
+
|
238 |
+
# These are all the pointers of shared tensors.
|
239 |
+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
|
240 |
+
|
241 |
+
for _, names in shared_ptrs.items():
|
242 |
+
# Here we just clone the shared tensors to avoid tensor aliasing which is
|
243 |
+
# not supported in safetensors.
|
244 |
+
for shared_tensor_name in names[1:]:
|
245 |
+
output_state_dict[shared_tensor_name] = output_state_dict[shared_tensor_name].clone()
|
246 |
+
|
247 |
+
safe_save_file(
|
248 |
+
output_state_dict,
|
249 |
+
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
|
250 |
+
metadata={"format": "pt"},
|
251 |
+
)
|
252 |
+
elif is_main_process:
|
253 |
+
torch.save(output_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
254 |
+
|
255 |
+
# save the config and change the inference mode to `True`
|
256 |
+
if peft_config.base_model_name_or_path is None:
|
257 |
+
peft_config.base_model_name_or_path = (
|
258 |
+
self.base_model.__dict__.get("name_or_path", None)
|
259 |
+
if peft_config.is_prompt_learning
|
260 |
+
else self.base_model.model.__dict__.get("name_or_path", None)
|
261 |
+
)
|
262 |
+
inference_mode = peft_config.inference_mode
|
263 |
+
peft_config.inference_mode = True
|
264 |
+
|
265 |
+
if peft_config.task_type is None:
|
266 |
+
# deal with auto mapping
|
267 |
+
base_model_class = self._get_base_model_class(
|
268 |
+
is_prompt_tuning=peft_config.is_prompt_learning,
|
269 |
+
)
|
270 |
+
parent_library = base_model_class.__module__
|
271 |
+
|
272 |
+
auto_mapping_dict = {
|
273 |
+
"base_model_class": base_model_class.__name__,
|
274 |
+
"parent_library": parent_library,
|
275 |
+
}
|
276 |
+
else:
|
277 |
+
auto_mapping_dict = None
|
278 |
+
|
279 |
+
if is_main_process:
|
280 |
+
peft_config.save_pretrained(output_dir, auto_mapping_dict=auto_mapping_dict)
|
281 |
+
peft_config.inference_mode = inference_mode
|
282 |
+
|
283 |
+
@classmethod
|
284 |
+
def from_pretrained(
|
285 |
+
cls,
|
286 |
+
model: torch.nn.Module,
|
287 |
+
model_id: Union[str, os.PathLike],
|
288 |
+
adapter_name: str = "default",
|
289 |
+
is_trainable: bool = False,
|
290 |
+
config: Optional[PeftConfig] = None,
|
291 |
+
**kwargs: Any,
|
292 |
+
) -> PeftModel:
|
293 |
+
r"""
|
294 |
+
Instantiate a PEFT model from a pretrained model and loaded PEFT weights.
|
295 |
+
|
296 |
+
Note that the passed `model` may be modified inplace.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
model ([`torch.nn.Module`]):
|
300 |
+
The model to be adapted. For 🤗 Transformers models, the model should be initialized with the
|
301 |
+
[`~transformers.PreTrainedModel.from_pretrained`].
|
302 |
+
model_id (`str` or `os.PathLike`):
|
303 |
+
The name of the PEFT configuration to use. Can be either:
|
304 |
+
- A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face
|
305 |
+
Hub.
|
306 |
+
- A path to a directory containing a PEFT configuration file saved using the `save_pretrained`
|
307 |
+
method (`./my_peft_config_directory/`).
|
308 |
+
adapter_name (`str`, *optional*, defaults to `"default"`):
|
309 |
+
The name of the adapter to be loaded. This is useful for loading multiple adapters.
|
310 |
+
is_trainable (`bool`, *optional*, defaults to `False`):
|
311 |
+
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
|
312 |
+
used for inference.
|
313 |
+
config ([`~peft.PeftConfig`], *optional*):
|
314 |
+
The configuration object to use instead of an automatically loaded configuration. This configuration
|
315 |
+
object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already
|
316 |
+
loaded before calling `from_pretrained`.
|
317 |
+
kwargs: (`optional`):
|
318 |
+
Additional keyword arguments passed along to the specific PEFT configuration class.
|
319 |
+
"""
|
320 |
+
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING
|
321 |
+
|
322 |
+
# load the config
|
323 |
+
if config is None:
|
324 |
+
config = PEFT_TYPE_TO_CONFIG_MAPPING[
|
325 |
+
PeftConfig._get_peft_type(
|
326 |
+
model_id,
|
327 |
+
subfolder=kwargs.get("subfolder", None),
|
328 |
+
revision=kwargs.get("revision", None),
|
329 |
+
cache_dir=kwargs.get("cache_dir", None),
|
330 |
+
use_auth_token=kwargs.get("use_auth_token", None),
|
331 |
+
token=kwargs.get("token", None),
|
332 |
+
)
|
333 |
+
].from_pretrained(model_id, **kwargs)
|
334 |
+
elif isinstance(config, PeftConfig):
|
335 |
+
config.inference_mode = not is_trainable
|
336 |
+
else:
|
337 |
+
raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}")
|
338 |
+
|
339 |
+
if (getattr(model, "hf_device_map", None) is not None) and len(
|
340 |
+
set(model.hf_device_map.values()).intersection({"cpu", "disk"})
|
341 |
+
) > 0:
|
342 |
+
remove_hook_from_submodules(model)
|
343 |
+
|
344 |
+
if config.is_prompt_learning and is_trainable:
|
345 |
+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
|
346 |
+
else:
|
347 |
+
config.inference_mode = not is_trainable
|
348 |
+
|
349 |
+
if config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys():
|
350 |
+
model = cls(model, config, adapter_name)
|
351 |
+
else:
|
352 |
+
model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config, adapter_name)
|
353 |
+
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
|
354 |
+
return model
|
355 |
+
|
356 |
+
def _setup_prompt_encoder(self, adapter_name: str):
|
357 |
+
config = self.peft_config[adapter_name]
|
358 |
+
if not hasattr(self, "prompt_encoder"):
|
359 |
+
self.prompt_encoder = torch.nn.ModuleDict({})
|
360 |
+
self.prompt_tokens = {}
|
361 |
+
transformer_backbone = None
|
362 |
+
for name, module in self.base_model.named_children():
|
363 |
+
for param in module.parameters():
|
364 |
+
param.requires_grad = False
|
365 |
+
if isinstance(module, PreTrainedModel):
|
366 |
+
# Make sure to freeze Tranformers model
|
367 |
+
if transformer_backbone is None:
|
368 |
+
transformer_backbone = module
|
369 |
+
self.transformer_backbone_name = name
|
370 |
+
if transformer_backbone is None:
|
371 |
+
transformer_backbone = self.base_model
|
372 |
+
|
373 |
+
if config.num_transformer_submodules is None:
|
374 |
+
config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1
|
375 |
+
|
376 |
+
for named_param, value in list(transformer_backbone.named_parameters()):
|
377 |
+
# for ZeRO-3, the tensor is sharded across accelerators and deepspeed modifies it to a tensor with shape [0]
|
378 |
+
# the actual unsharded shape is stored in "ds_shape" attribute
|
379 |
+
# special handling is needed in case the model is initialized in deepspeed.zero.Init() context or HfDeepSpeedConfig
|
380 |
+
# has been called before
|
381 |
+
# For reference refer to issue: https://github.com/huggingface/peft/issues/996
|
382 |
+
deepspeed_distributed_tensor_shape = getattr(value, "ds_shape", None)
|
383 |
+
|
384 |
+
if value.shape[0] == self.base_model.config.vocab_size or (
|
385 |
+
deepspeed_distributed_tensor_shape is not None
|
386 |
+
and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size
|
387 |
+
):
|
388 |
+
self.word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
|
389 |
+
break
|
390 |
+
|
391 |
+
if config.peft_type == PeftType.PROMPT_TUNING:
|
392 |
+
prompt_encoder = PromptEmbedding(config, self.word_embeddings)
|
393 |
+
elif config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
|
394 |
+
prompt_encoder = MultitaskPromptEmbedding(config, self.word_embeddings)
|
395 |
+
elif config.peft_type == PeftType.P_TUNING:
|
396 |
+
prompt_encoder = PromptEncoder(config)
|
397 |
+
elif config.peft_type == PeftType.PREFIX_TUNING:
|
398 |
+
prompt_encoder = PrefixEncoder(config)
|
399 |
+
else:
|
400 |
+
raise ValueError("Not supported")
|
401 |
+
|
402 |
+
prompt_encoder = prompt_encoder.to(self.device)
|
403 |
+
self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder}))
|
404 |
+
self.prompt_tokens[adapter_name] = torch.arange(
|
405 |
+
config.num_virtual_tokens * config.num_transformer_submodules
|
406 |
+
).long()
|
407 |
+
|
408 |
+
def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel):
|
409 |
+
r"""
|
410 |
+
Prepares the model for gradient checkpointing if necessary
|
411 |
+
"""
|
412 |
+
if not (
|
413 |
+
getattr(model, "is_loaded_in_8bit", False)
|
414 |
+
or getattr(model, "is_loaded_in_4bit", False)
|
415 |
+
or getattr(model, "is_quantized", False)
|
416 |
+
):
|
417 |
+
if hasattr(model, "enable_input_require_grads"):
|
418 |
+
model.enable_input_require_grads()
|
419 |
+
elif hasattr(model, "get_input_embeddings"):
|
420 |
+
|
421 |
+
def make_inputs_require_grad(module, input, output):
|
422 |
+
output.requires_grad_(True)
|
423 |
+
|
424 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
425 |
+
return model
|
426 |
+
|
427 |
+
def get_prompt_embedding_to_save(self, adapter_name: str) -> torch.Tensor:
|
428 |
+
"""
|
429 |
+
Returns the prompt embedding to save when saving the model. Only applicable when using a prompt learning
|
430 |
+
method.
|
431 |
+
"""
|
432 |
+
prompt_encoder = self.prompt_encoder[adapter_name]
|
433 |
+
prompt_tokens = (
|
434 |
+
self.prompt_tokens[adapter_name].unsqueeze(0).expand(1, -1).to(prompt_encoder.embedding.weight.device)
|
435 |
+
)
|
436 |
+
if self.peft_config[adapter_name].peft_type == PeftType.PREFIX_TUNING:
|
437 |
+
prompt_tokens = prompt_tokens[:, : self.peft_config[adapter_name].num_virtual_tokens]
|
438 |
+
|
439 |
+
if self.peft_config[adapter_name].peft_type == PeftType.MULTITASK_PROMPT_TUNING:
|
440 |
+
prompt_embeddings = super(MultitaskPromptEmbedding, prompt_encoder).forward(prompt_tokens)
|
441 |
+
else:
|
442 |
+
prompt_embeddings = prompt_encoder(prompt_tokens)
|
443 |
+
|
444 |
+
return prompt_embeddings[0].detach().cpu()
|
445 |
+
|
446 |
+
def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
447 |
+
"""
|
448 |
+
Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.
|
449 |
+
"""
|
450 |
+
peft_config = self.active_peft_config
|
451 |
+
prompt_encoder = self.prompt_encoder[self.active_adapter]
|
452 |
+
prompt_tokens = (
|
453 |
+
self.prompt_tokens[self.active_adapter]
|
454 |
+
.unsqueeze(0)
|
455 |
+
.expand(batch_size, -1)
|
456 |
+
.to(prompt_encoder.embedding.weight.device)
|
457 |
+
)
|
458 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
459 |
+
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
|
460 |
+
if peft_config.inference_mode:
|
461 |
+
past_key_values = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
|
462 |
+
else:
|
463 |
+
past_key_values = prompt_encoder(prompt_tokens)
|
464 |
+
if self.base_model_torch_dtype is not None:
|
465 |
+
past_key_values = past_key_values.to(self.base_model_torch_dtype)
|
466 |
+
past_key_values = past_key_values.view(
|
467 |
+
batch_size,
|
468 |
+
peft_config.num_virtual_tokens,
|
469 |
+
peft_config.num_layers * 2,
|
470 |
+
peft_config.num_attention_heads,
|
471 |
+
peft_config.token_dim // peft_config.num_attention_heads,
|
472 |
+
)
|
473 |
+
if peft_config.num_transformer_submodules == 2:
|
474 |
+
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
|
475 |
+
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
|
476 |
+
peft_config.num_transformer_submodules * 2
|
477 |
+
)
|
478 |
+
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
|
479 |
+
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
|
480 |
+
past_key_values = post_process_fn(past_key_values)
|
481 |
+
return past_key_values
|
482 |
+
else:
|
483 |
+
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
|
484 |
+
prompts = prompt_encoder(prompt_tokens, task_ids)
|
485 |
+
else:
|
486 |
+
if peft_config.inference_mode:
|
487 |
+
prompts = prompt_encoder.embedding.weight.repeat(batch_size, 1, 1)
|
488 |
+
else:
|
489 |
+
prompts = prompt_encoder(prompt_tokens)
|
490 |
+
return prompts
|
491 |
+
|
492 |
+
def get_nb_trainable_parameters(self) -> tuple[int, int]:
|
493 |
+
r"""
|
494 |
+
Returns the number of trainable parameters and the number of all parameters in the model.
|
495 |
+
"""
|
496 |
+
trainable_params = 0
|
497 |
+
all_param = 0
|
498 |
+
for _, param in self.named_parameters():
|
499 |
+
num_params = param.numel()
|
500 |
+
# if using DS Zero 3 and the weights are initialized empty
|
501 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
502 |
+
num_params = param.ds_numel
|
503 |
+
|
504 |
+
# Due to the design of 4bit linear layers from bitsandbytes
|
505 |
+
# one needs to multiply the number of parameters by 2 to get
|
506 |
+
# the correct number of parameters
|
507 |
+
if param.__class__.__name__ == "Params4bit":
|
508 |
+
num_params = num_params * 2
|
509 |
+
|
510 |
+
all_param += num_params
|
511 |
+
if param.requires_grad:
|
512 |
+
trainable_params += num_params
|
513 |
+
|
514 |
+
return trainable_params, all_param
|
515 |
+
|
516 |
+
def print_trainable_parameters(self) -> None:
|
517 |
+
"""
|
518 |
+
Prints the number of trainable parameters in the model.
|
519 |
+
"""
|
520 |
+
trainable_params, all_param = self.get_nb_trainable_parameters()
|
521 |
+
|
522 |
+
print(
|
523 |
+
f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
|
524 |
+
)
|
525 |
+
|
526 |
+
def __getattr__(self, name: str):
|
527 |
+
"""Forward missing attributes to the wrapped module."""
|
528 |
+
try:
|
529 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
530 |
+
except AttributeError:
|
531 |
+
return getattr(self.base_model, name)
|
532 |
+
|
533 |
+
def forward(self, *args: Any, **kwargs: Any):
|
534 |
+
"""
|
535 |
+
Forward pass of the model.
|
536 |
+
"""
|
537 |
+
return self.get_base_model()(*args, **kwargs)
|
538 |
+
|
539 |
+
def _get_base_model_class(self, is_prompt_tuning=False):
|
540 |
+
"""
|
541 |
+
Returns the base model class.
|
542 |
+
"""
|
543 |
+
if not is_prompt_tuning:
|
544 |
+
return self.base_model.model.__class__
|
545 |
+
return self.base_model.__class__
|
546 |
+
|
547 |
+
@contextmanager
|
548 |
+
def disable_adapter(self):
|
549 |
+
"""
|
550 |
+
Context manager that disables the adapter module. Use this to run inference on the base model.
|
551 |
+
|
552 |
+
Example:
|
553 |
+
|
554 |
+
```py
|
555 |
+
>>> with model.disable_adapter():
|
556 |
+
... model(inputs)
|
557 |
+
```
|
558 |
+
"""
|
559 |
+
try:
|
560 |
+
if self.peft_config[self.active_adapter].is_prompt_learning:
|
561 |
+
# TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and
|
562 |
+
# letting the underlying methods deal with it, same as how LoRA does it.
|
563 |
+
old_forward = self.forward
|
564 |
+
self.forward = self.base_model.forward
|
565 |
+
old_prepare_inputs_for_generation = self.prepare_inputs_for_generation
|
566 |
+
self.prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
|
567 |
+
else:
|
568 |
+
self.base_model.disable_adapter_layers()
|
569 |
+
yield
|
570 |
+
finally:
|
571 |
+
if self.peft_config[self.active_adapter].is_prompt_learning:
|
572 |
+
self.forward = old_forward
|
573 |
+
self.prepare_inputs_for_generation = old_prepare_inputs_for_generation
|
574 |
+
else:
|
575 |
+
self.base_model.enable_adapter_layers()
|
576 |
+
|
577 |
+
def get_base_model(self) -> torch.nn.Module:
|
578 |
+
"""
|
579 |
+
Returns the base model.
|
580 |
+
"""
|
581 |
+
return (
|
582 |
+
self.base_model
|
583 |
+
if (self.active_peft_config.is_prompt_learning or self.peft_type == PeftType.POLY)
|
584 |
+
else self.base_model.model
|
585 |
+
)
|
586 |
+
|
587 |
+
def add_adapter(self, adapter_name: str, peft_config: PeftConfig) -> None:
|
588 |
+
"""
|
589 |
+
Add an adapter to the model based on the passed configuration.
|
590 |
+
|
591 |
+
The name for the new adapter should be unique.
|
592 |
+
|
593 |
+
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
|
594 |
+
adapter.
|
595 |
+
|
596 |
+
Args:
|
597 |
+
adapter_name (`str`):
|
598 |
+
The name of the adapter to be added.
|
599 |
+
peft_config ([`PeftConfig`]):
|
600 |
+
The configuration of the adapter to be added.
|
601 |
+
"""
|
602 |
+
if peft_config.peft_type != self.peft_type:
|
603 |
+
raise ValueError(
|
604 |
+
f"Cannot combine adapters with different peft types. "
|
605 |
+
f"Found {self.peft_type} and {peft_config.peft_type}."
|
606 |
+
)
|
607 |
+
|
608 |
+
try:
|
609 |
+
if peft_config.is_prompt_learning:
|
610 |
+
self.peft_config[adapter_name] = peft_config
|
611 |
+
if hasattr(self.config, "to_dict"):
|
612 |
+
dict_config = self.config.to_dict()
|
613 |
+
else:
|
614 |
+
dict_config = self.config
|
615 |
+
|
616 |
+
peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
|
617 |
+
self._setup_prompt_encoder(adapter_name)
|
618 |
+
elif peft_config.is_adaption_prompt:
|
619 |
+
self.base_model.add_adapter(adapter_name, peft_config)
|
620 |
+
else:
|
621 |
+
self.peft_config[adapter_name] = peft_config
|
622 |
+
self.base_model.inject_adapter(self.base_model.model, adapter_name)
|
623 |
+
except Exception: # something went wrong, roll back
|
624 |
+
if adapter_name in self.peft_config:
|
625 |
+
del self.peft_config[adapter_name]
|
626 |
+
raise
|
627 |
+
|
628 |
+
self.set_additional_trainable_modules(peft_config, adapter_name)
|
629 |
+
|
630 |
+
def set_additional_trainable_modules(self, peft_config, adapter_name):
|
631 |
+
if getattr(peft_config, "modules_to_save", None) is not None:
|
632 |
+
if self.modules_to_save is None:
|
633 |
+
self.modules_to_save = set(peft_config.modules_to_save)
|
634 |
+
else:
|
635 |
+
self.modules_to_save.update(peft_config.modules_to_save)
|
636 |
+
_set_trainable(self, adapter_name)
|
637 |
+
|
638 |
+
@classmethod
|
639 |
+
def _split_kwargs(cls, kwargs: dict[str, Any]):
|
640 |
+
_kwargs_not_in_hf_hub_download_signature = ("use_auth_token",)
|
641 |
+
hf_hub_download_kwargs = {}
|
642 |
+
other_kwargs = {}
|
643 |
+
|
644 |
+
for key, value in kwargs.items():
|
645 |
+
if key in inspect.signature(hf_hub_download).parameters or key in _kwargs_not_in_hf_hub_download_signature:
|
646 |
+
hf_hub_download_kwargs[key] = value
|
647 |
+
else:
|
648 |
+
other_kwargs[key] = value
|
649 |
+
|
650 |
+
return hf_hub_download_kwargs, other_kwargs
|
651 |
+
|
652 |
+
def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any):
|
653 |
+
"""
|
654 |
+
Load a trained adapter into the model.
|
655 |
+
|
656 |
+
The name for the new adapter should be unique.
|
657 |
+
|
658 |
+
The new adapter is not automatically set as the active adapter. Use [`PeftModel.set_adapter`] to set the active
|
659 |
+
adapter.
|
660 |
+
|
661 |
+
Args:
|
662 |
+
adapter_name (`str`):
|
663 |
+
The name of the adapter to be added.
|
664 |
+
peft_config ([`PeftConfig`]):
|
665 |
+
The configuration of the adapter to be added.
|
666 |
+
is_trainable (`bool`, *optional*, defaults to `False`):
|
667 |
+
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
|
668 |
+
used for inference.
|
669 |
+
kwargs: (`optional`):
|
670 |
+
Additional arguments to modify the way the adapter is loaded, e.g. the token for Hugging Face Hub.
|
671 |
+
"""
|
672 |
+
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING
|
673 |
+
|
674 |
+
hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
|
675 |
+
torch_device = infer_device()
|
676 |
+
|
677 |
+
if adapter_name not in self.peft_config:
|
678 |
+
# load the config
|
679 |
+
peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[
|
680 |
+
PeftConfig._get_peft_type(
|
681 |
+
model_id,
|
682 |
+
**hf_hub_download_kwargs,
|
683 |
+
)
|
684 |
+
].from_pretrained(
|
685 |
+
model_id,
|
686 |
+
**hf_hub_download_kwargs,
|
687 |
+
)
|
688 |
+
if peft_config.is_prompt_learning and is_trainable:
|
689 |
+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
|
690 |
+
else:
|
691 |
+
peft_config.inference_mode = not is_trainable
|
692 |
+
self.add_adapter(adapter_name, peft_config)
|
693 |
+
|
694 |
+
adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs)
|
695 |
+
|
696 |
+
# load the weights into the model
|
697 |
+
load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
|
698 |
+
if (
|
699 |
+
(getattr(self, "hf_device_map", None) is not None)
|
700 |
+
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
|
701 |
+
and len(self.peft_config) == 1
|
702 |
+
):
|
703 |
+
device_map = kwargs.get("device_map", "auto")
|
704 |
+
max_memory = kwargs.get("max_memory", None)
|
705 |
+
offload_dir = kwargs.get("offload_folder", None)
|
706 |
+
offload_index = kwargs.get("offload_index", None)
|
707 |
+
|
708 |
+
dispatch_model_kwargs = {}
|
709 |
+
# Safety checker for previous `accelerate` versions
|
710 |
+
# `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
|
711 |
+
if "offload_index" in inspect.signature(dispatch_model).parameters:
|
712 |
+
dispatch_model_kwargs["offload_index"] = offload_index
|
713 |
+
|
714 |
+
no_split_module_classes = self._no_split_modules
|
715 |
+
|
716 |
+
if device_map != "sequential":
|
717 |
+
max_memory = get_balanced_memory(
|
718 |
+
self,
|
719 |
+
max_memory=max_memory,
|
720 |
+
no_split_module_classes=no_split_module_classes,
|
721 |
+
low_zero=(device_map == "balanced_low_0"),
|
722 |
+
)
|
723 |
+
if isinstance(device_map, str):
|
724 |
+
device_map = infer_auto_device_map(
|
725 |
+
self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
|
726 |
+
)
|
727 |
+
dispatch_model(
|
728 |
+
self,
|
729 |
+
device_map=device_map,
|
730 |
+
offload_dir=offload_dir,
|
731 |
+
**dispatch_model_kwargs,
|
732 |
+
)
|
733 |
+
hook = AlignDevicesHook(io_same_device=True)
|
734 |
+
if self.peft_config[adapter_name].is_prompt_learning:
|
735 |
+
remove_hook_from_submodules(self.prompt_encoder)
|
736 |
+
add_hook_to_module(self.get_base_model(), hook)
|
737 |
+
|
738 |
+
# Set model in evaluation mode to deactivate Dropout modules by default
|
739 |
+
if not is_trainable:
|
740 |
+
self.eval()
|
741 |
+
return load_result
|
742 |
+
|
743 |
+
def set_adapter(self, adapter_name: str) -> None:
|
744 |
+
"""
|
745 |
+
Sets the active adapter.
|
746 |
+
|
747 |
+
Only one adapter can be active at a time.
|
748 |
+
|
749 |
+
Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is
|
750 |
+
not desired, use the following code.
|
751 |
+
|
752 |
+
```py
|
753 |
+
>>> for name, param in model_peft.named_parameters():
|
754 |
+
... if ...: # some check on name (ex. if 'lora' in name)
|
755 |
+
... param.requires_grad = False
|
756 |
+
```
|
757 |
+
|
758 |
+
Args:
|
759 |
+
adapter_name (`str`):
|
760 |
+
The name of the adapter to be set as active. The adapter must be loaded first.
|
761 |
+
"""
|
762 |
+
if adapter_name not in self.peft_config:
|
763 |
+
raise ValueError(f"Adapter {adapter_name} not found.")
|
764 |
+
self.active_adapter = adapter_name
|
765 |
+
if not self.peft_config[adapter_name].is_prompt_learning:
|
766 |
+
self.base_model.set_adapter(adapter_name)
|
767 |
+
_set_adapter(self, adapter_name)
|
768 |
+
|
769 |
+
@property
|
770 |
+
def base_model_torch_dtype(self):
|
771 |
+
return getattr(self.base_model, "dtype", None)
|
772 |
+
|
773 |
+
@property
|
774 |
+
def active_peft_config(self):
|
775 |
+
return self.peft_config[self.active_adapter]
|
776 |
+
|
777 |
+
def create_or_update_model_card(self, output_dir: str):
|
778 |
+
"""
|
779 |
+
Updates or create model card to include information about peft:
|
780 |
+
1. Adds `peft` library tag
|
781 |
+
2. Adds peft version
|
782 |
+
3. Adds base model info
|
783 |
+
4. Adds quantization information if it was used
|
784 |
+
"""
|
785 |
+
|
786 |
+
filename = os.path.join(output_dir, "README.md")
|
787 |
+
|
788 |
+
card = ModelCard.load(filename) if os.path.exists(filename) else ModelCard.from_template(ModelCardData())
|
789 |
+
|
790 |
+
card.data["library_name"] = "peft"
|
791 |
+
|
792 |
+
model_config = getattr(self, "config", None)
|
793 |
+
if hasattr(model_config, "to_dict"):
|
794 |
+
model_config = model_config.to_dict()
|
795 |
+
if model_config is not None and "_name_or_path" in model_config:
|
796 |
+
card.data["base_model"] = model_config["_name_or_path"]
|
797 |
+
|
798 |
+
lines = card.text.splitlines()
|
799 |
+
|
800 |
+
quantization_config = None
|
801 |
+
if hasattr(model_config, "quantization_config"):
|
802 |
+
quantization_config = self.config.quantization_config.to_dict()
|
803 |
+
training_config_text = ""
|
804 |
+
quantization_prefix = "The following `bitsandbytes` quantization config was used during training:"
|
805 |
+
# Adds quantization information if it was used
|
806 |
+
if quantization_config is not None:
|
807 |
+
training_config_text += f"\n{quantization_prefix}\n"
|
808 |
+
training_config_text += "\n".join([f"- {name}: {value}" for name, value in quantization_config.items()])
|
809 |
+
training_config_text += "\n"
|
810 |
+
|
811 |
+
training_procedure_heading = "## Training procedure"
|
812 |
+
if quantization_prefix not in lines and bool(training_config_text):
|
813 |
+
if training_procedure_heading in lines:
|
814 |
+
lines.insert(lines.index(training_procedure_heading) + 2, training_config_text)
|
815 |
+
else:
|
816 |
+
lines.append(f"{training_procedure_heading}\n{training_config_text}")
|
817 |
+
|
818 |
+
# Adds peft version
|
819 |
+
framework_block_heading = "### Framework versions"
|
820 |
+
if f"- PEFT {__version__}" not in lines:
|
821 |
+
if framework_block_heading in lines:
|
822 |
+
lines.insert(lines.index(framework_block_heading) + 2, f"- PEFT {__version__}")
|
823 |
+
else:
|
824 |
+
lines.append(f"{framework_block_heading}\n\n- PEFT {__version__}")
|
825 |
+
|
826 |
+
card.text = "\n".join(lines)
|
827 |
+
card.save(filename)
|
828 |
+
|
829 |
+
|
830 |
+
class PeftModelForSequenceClassification(PeftModel):
|
831 |
+
"""
|
832 |
+
Peft model for sequence classification tasks.
|
833 |
+
|
834 |
+
Args:
|
835 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
836 |
+
peft_config ([`PeftConfig`]): Peft config.
|
837 |
+
|
838 |
+
**Attributes**:
|
839 |
+
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
|
840 |
+
- **cls_layer_name** (`str`) -- The name of the classification layer.
|
841 |
+
|
842 |
+
Example:
|
843 |
+
|
844 |
+
```py
|
845 |
+
>>> from transformers import AutoModelForSequenceClassification
|
846 |
+
>>> from peft import PeftModelForSequenceClassification, get_peft_config
|
847 |
+
|
848 |
+
>>> config = {
|
849 |
+
... "peft_type": "PREFIX_TUNING",
|
850 |
+
... "task_type": "SEQ_CLS",
|
851 |
+
... "inference_mode": False,
|
852 |
+
... "num_virtual_tokens": 20,
|
853 |
+
... "token_dim": 768,
|
854 |
+
... "num_transformer_submodules": 1,
|
855 |
+
... "num_attention_heads": 12,
|
856 |
+
... "num_layers": 12,
|
857 |
+
... "encoder_hidden_size": 768,
|
858 |
+
... "prefix_projection": False,
|
859 |
+
... "postprocess_past_key_value_function": None,
|
860 |
+
... }
|
861 |
+
|
862 |
+
>>> peft_config = get_peft_config(config)
|
863 |
+
>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
|
864 |
+
>>> peft_model = PeftModelForSequenceClassification(model, peft_config)
|
865 |
+
>>> peft_model.print_trainable_parameters()
|
866 |
+
trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
|
867 |
+
```
|
868 |
+
"""
|
869 |
+
|
870 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
871 |
+
super().__init__(model, peft_config, adapter_name)
|
872 |
+
if self.modules_to_save is None:
|
873 |
+
self.modules_to_save = {"classifier", "score"}
|
874 |
+
else:
|
875 |
+
self.modules_to_save.update({"classifier", "score"})
|
876 |
+
|
877 |
+
for name, _ in self.base_model.named_children():
|
878 |
+
if any(module_name in name for module_name in self.modules_to_save):
|
879 |
+
self.cls_layer_name = name
|
880 |
+
break
|
881 |
+
|
882 |
+
# to make sure classifier layer is trainable
|
883 |
+
_set_trainable(self, adapter_name)
|
884 |
+
|
885 |
+
def forward(
|
886 |
+
self,
|
887 |
+
input_ids=None,
|
888 |
+
attention_mask=None,
|
889 |
+
inputs_embeds=None,
|
890 |
+
labels=None,
|
891 |
+
output_attentions=None,
|
892 |
+
output_hidden_states=None,
|
893 |
+
return_dict=None,
|
894 |
+
task_ids=None,
|
895 |
+
**kwargs,
|
896 |
+
):
|
897 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
898 |
+
peft_config = self.active_peft_config
|
899 |
+
if not peft_config.is_prompt_learning:
|
900 |
+
if peft_config.peft_type == PeftType.POLY:
|
901 |
+
kwargs["task_ids"] = task_ids
|
902 |
+
return self.base_model(
|
903 |
+
input_ids=input_ids,
|
904 |
+
attention_mask=attention_mask,
|
905 |
+
inputs_embeds=inputs_embeds,
|
906 |
+
labels=labels,
|
907 |
+
output_attentions=output_attentions,
|
908 |
+
output_hidden_states=output_hidden_states,
|
909 |
+
return_dict=return_dict,
|
910 |
+
**kwargs,
|
911 |
+
)
|
912 |
+
|
913 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
914 |
+
if attention_mask is not None:
|
915 |
+
# concat prompt attention mask
|
916 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
|
917 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
918 |
+
if kwargs.get("position_ids", None) is not None:
|
919 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
920 |
+
kwargs["position_ids"] = None
|
921 |
+
kwargs.update(
|
922 |
+
{
|
923 |
+
"attention_mask": attention_mask,
|
924 |
+
"labels": labels,
|
925 |
+
"output_attentions": output_attentions,
|
926 |
+
"output_hidden_states": output_hidden_states,
|
927 |
+
"return_dict": return_dict,
|
928 |
+
}
|
929 |
+
)
|
930 |
+
|
931 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
932 |
+
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
|
933 |
+
else:
|
934 |
+
if kwargs.get("token_type_ids", None) is not None:
|
935 |
+
kwargs["token_type_ids"] = torch.cat(
|
936 |
+
(
|
937 |
+
torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
|
938 |
+
kwargs["token_type_ids"],
|
939 |
+
),
|
940 |
+
dim=1,
|
941 |
+
).long()
|
942 |
+
if inputs_embeds is None:
|
943 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
944 |
+
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
|
945 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
946 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
947 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
948 |
+
|
949 |
+
def _prefix_tuning_forward(
|
950 |
+
self,
|
951 |
+
input_ids=None,
|
952 |
+
attention_mask=None,
|
953 |
+
inputs_embeds=None,
|
954 |
+
labels=None,
|
955 |
+
output_attentions=None,
|
956 |
+
output_hidden_states=None,
|
957 |
+
return_dict=None,
|
958 |
+
**kwargs,
|
959 |
+
):
|
960 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
961 |
+
past_key_values = self.get_prompt(batch_size)
|
962 |
+
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
|
963 |
+
kwargs.update(
|
964 |
+
{
|
965 |
+
"input_ids": input_ids,
|
966 |
+
"attention_mask": attention_mask,
|
967 |
+
"inputs_embeds": inputs_embeds,
|
968 |
+
"output_attentions": output_attentions,
|
969 |
+
"output_hidden_states": output_hidden_states,
|
970 |
+
"return_dict": return_dict,
|
971 |
+
"past_key_values": past_key_values,
|
972 |
+
}
|
973 |
+
)
|
974 |
+
if "past_key_values" in fwd_params:
|
975 |
+
return self.base_model(labels=labels, **kwargs)
|
976 |
+
else:
|
977 |
+
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
|
978 |
+
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
|
979 |
+
if "past_key_values" not in fwd_params:
|
980 |
+
raise ValueError("Model does not support past key values which are required for prefix tuning.")
|
981 |
+
outputs = transformer_backbone_name(**kwargs)
|
982 |
+
pooled_output = outputs[1] if len(outputs) > 1 else outputs[0]
|
983 |
+
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
|
984 |
+
pooled_output = self.base_model.dropout(pooled_output)
|
985 |
+
logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output)
|
986 |
+
|
987 |
+
loss = None
|
988 |
+
if labels is not None:
|
989 |
+
if self.config.problem_type is None:
|
990 |
+
if self.base_model.num_labels == 1:
|
991 |
+
self.config.problem_type = "regression"
|
992 |
+
elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
993 |
+
self.config.problem_type = "single_label_classification"
|
994 |
+
else:
|
995 |
+
self.config.problem_type = "multi_label_classification"
|
996 |
+
|
997 |
+
if self.config.problem_type == "regression":
|
998 |
+
loss_fct = MSELoss()
|
999 |
+
if self.base_model.num_labels == 1:
|
1000 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
1001 |
+
else:
|
1002 |
+
loss = loss_fct(logits, labels)
|
1003 |
+
elif self.config.problem_type == "single_label_classification":
|
1004 |
+
loss_fct = CrossEntropyLoss()
|
1005 |
+
loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1))
|
1006 |
+
elif self.config.problem_type == "multi_label_classification":
|
1007 |
+
loss_fct = BCEWithLogitsLoss()
|
1008 |
+
loss = loss_fct(logits, labels)
|
1009 |
+
if not return_dict:
|
1010 |
+
output = (logits,) + outputs[2:]
|
1011 |
+
return ((loss,) + output) if loss is not None else output
|
1012 |
+
|
1013 |
+
return SequenceClassifierOutput(
|
1014 |
+
loss=loss,
|
1015 |
+
logits=logits,
|
1016 |
+
hidden_states=outputs.hidden_states,
|
1017 |
+
attentions=outputs.attentions,
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
|
1021 |
+
class PeftModelForCausalLM(PeftModel):
|
1022 |
+
"""
|
1023 |
+
Peft model for causal language modeling.
|
1024 |
+
|
1025 |
+
Args:
|
1026 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
1027 |
+
peft_config ([`PeftConfig`]): Peft config.
|
1028 |
+
|
1029 |
+
|
1030 |
+
Example:
|
1031 |
+
|
1032 |
+
```py
|
1033 |
+
>>> from transformers import AutoModelForCausalLM
|
1034 |
+
>>> from peft import PeftModelForCausalLM, get_peft_config
|
1035 |
+
|
1036 |
+
>>> config = {
|
1037 |
+
... "peft_type": "PREFIX_TUNING",
|
1038 |
+
... "task_type": "CAUSAL_LM",
|
1039 |
+
... "inference_mode": False,
|
1040 |
+
... "num_virtual_tokens": 20,
|
1041 |
+
... "token_dim": 1280,
|
1042 |
+
... "num_transformer_submodules": 1,
|
1043 |
+
... "num_attention_heads": 20,
|
1044 |
+
... "num_layers": 36,
|
1045 |
+
... "encoder_hidden_size": 1280,
|
1046 |
+
... "prefix_projection": False,
|
1047 |
+
... "postprocess_past_key_value_function": None,
|
1048 |
+
... }
|
1049 |
+
|
1050 |
+
>>> peft_config = get_peft_config(config)
|
1051 |
+
>>> model = AutoModelForCausalLM.from_pretrained("gpt2-large")
|
1052 |
+
>>> peft_model = PeftModelForCausalLM(model, peft_config)
|
1053 |
+
>>> peft_model.print_trainable_parameters()
|
1054 |
+
trainable params: 1843200 || all params: 775873280 || trainable%: 0.23756456724479544
|
1055 |
+
```
|
1056 |
+
"""
|
1057 |
+
|
1058 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
1059 |
+
super().__init__(model, peft_config, adapter_name)
|
1060 |
+
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
|
1061 |
+
|
1062 |
+
def forward(
|
1063 |
+
self,
|
1064 |
+
input_ids=None,
|
1065 |
+
attention_mask=None,
|
1066 |
+
inputs_embeds=None,
|
1067 |
+
labels=None,
|
1068 |
+
output_attentions=None,
|
1069 |
+
output_hidden_states=None,
|
1070 |
+
return_dict=None,
|
1071 |
+
task_ids=None,
|
1072 |
+
**kwargs,
|
1073 |
+
):
|
1074 |
+
peft_config = self.active_peft_config
|
1075 |
+
if not peft_config.is_prompt_learning:
|
1076 |
+
if self.base_model.config.model_type == "mpt":
|
1077 |
+
if inputs_embeds is not None:
|
1078 |
+
raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
|
1079 |
+
return self.base_model(
|
1080 |
+
input_ids=input_ids,
|
1081 |
+
attention_mask=attention_mask,
|
1082 |
+
labels=labels,
|
1083 |
+
output_attentions=output_attentions,
|
1084 |
+
output_hidden_states=output_hidden_states,
|
1085 |
+
return_dict=return_dict,
|
1086 |
+
**kwargs,
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if peft_config.peft_type == PeftType.POLY:
|
1090 |
+
kwargs["task_ids"] = task_ids
|
1091 |
+
return self.base_model(
|
1092 |
+
input_ids=input_ids,
|
1093 |
+
attention_mask=attention_mask,
|
1094 |
+
inputs_embeds=inputs_embeds,
|
1095 |
+
labels=labels,
|
1096 |
+
output_attentions=output_attentions,
|
1097 |
+
output_hidden_states=output_hidden_states,
|
1098 |
+
return_dict=return_dict,
|
1099 |
+
**kwargs,
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1103 |
+
if attention_mask is not None:
|
1104 |
+
# concat prompt attention mask
|
1105 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
|
1106 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1107 |
+
|
1108 |
+
if kwargs.get("position_ids", None) is not None:
|
1109 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1110 |
+
kwargs["position_ids"] = None
|
1111 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1112 |
+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
|
1113 |
+
kwargs["token_type_ids"] = None
|
1114 |
+
kwargs.update(
|
1115 |
+
{
|
1116 |
+
"attention_mask": attention_mask,
|
1117 |
+
"labels": labels,
|
1118 |
+
"output_attentions": output_attentions,
|
1119 |
+
"output_hidden_states": output_hidden_states,
|
1120 |
+
"return_dict": return_dict,
|
1121 |
+
}
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1125 |
+
past_key_values = self.get_prompt(batch_size)
|
1126 |
+
return self.base_model(
|
1127 |
+
input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
|
1128 |
+
)
|
1129 |
+
else:
|
1130 |
+
if inputs_embeds is None:
|
1131 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1132 |
+
# concat prompt labels
|
1133 |
+
if labels is not None:
|
1134 |
+
prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
|
1135 |
+
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
|
1136 |
+
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
|
1137 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1138 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
1139 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
1140 |
+
|
1141 |
+
def generate(self, *args, **kwargs):
|
1142 |
+
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
|
1143 |
+
if hasattr(self.base_model, "model"):
|
1144 |
+
self.base_model.model.generation_config = self.generation_config
|
1145 |
+
else:
|
1146 |
+
self.base_model.generation_config = self.generation_config
|
1147 |
+
try:
|
1148 |
+
outputs = self.base_model.generate(*args, **kwargs)
|
1149 |
+
except:
|
1150 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
1151 |
+
raise
|
1152 |
+
else:
|
1153 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
1154 |
+
return outputs
|
1155 |
+
|
1156 |
+
def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor] = None, **kwargs):
|
1157 |
+
peft_config = self.active_peft_config
|
1158 |
+
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
|
1159 |
+
|
1160 |
+
# https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
|
1161 |
+
# for some architectures which requires a special fix for prompt tuning etc.
|
1162 |
+
# TODO: starting with transformers 4.38, all architectures should support caching.
|
1163 |
+
uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0")
|
1164 |
+
uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
|
1165 |
+
transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
|
1166 |
+
uses_cache = uses_transformers_4_38 or (
|
1167 |
+
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
|
1168 |
+
)
|
1169 |
+
|
1170 |
+
if peft_config.peft_type == PeftType.POLY:
|
1171 |
+
model_kwargs["task_ids"] = task_ids
|
1172 |
+
if peft_config.is_prompt_learning:
|
1173 |
+
if uses_cache and (model_kwargs["past_key_values"] is not None):
|
1174 |
+
# change in the logic of `prepare_inputs_for_generation` makes the below code necessary
|
1175 |
+
# In prompt learning methods, past key values are longer when compared to the `input_ids`.
|
1176 |
+
# As such only consider the last input ids in the autogressive generation phase.
|
1177 |
+
if model_kwargs["past_key_values"][0][0].shape[-2] >= model_kwargs["input_ids"].shape[1]:
|
1178 |
+
model_kwargs["input_ids"] = model_kwargs["input_ids"][:, -1:]
|
1179 |
+
|
1180 |
+
if model_kwargs.get("attention_mask", None) is not None:
|
1181 |
+
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
|
1182 |
+
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
|
1183 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1184 |
+
(prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
if model_kwargs.get("position_ids", None) is not None:
|
1188 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1189 |
+
model_kwargs["position_ids"] = None
|
1190 |
+
|
1191 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1192 |
+
warnings.warn(
|
1193 |
+
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
|
1194 |
+
)
|
1195 |
+
kwargs["token_type_ids"] = None
|
1196 |
+
|
1197 |
+
if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1198 |
+
past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
|
1199 |
+
model_kwargs["past_key_values"] = past_key_values
|
1200 |
+
else:
|
1201 |
+
if model_kwargs["past_key_values"] is None:
|
1202 |
+
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
|
1203 |
+
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
|
1204 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1205 |
+
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
|
1206 |
+
model_kwargs["input_ids"] = None
|
1207 |
+
|
1208 |
+
# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
|
1209 |
+
# passed in the forward pass to keep track of the position ids of the cache. We have to
|
1210 |
+
# pop that from `model_kwargs` as `cache_position` is properly created by the model, using the passed
|
1211 |
+
# `inputs_embeds`: https://github.com/huggingface/transformers/blob/593230f0a1150ea9c0477b9d859f25daf73c8c33/src/transformers/models/llama/modeling_llama.py#L956
|
1212 |
+
_ = model_kwargs.pop("cache_position", None)
|
1213 |
+
|
1214 |
+
return model_kwargs
|
1215 |
+
|
1216 |
+
|
1217 |
+
class PeftModelForSeq2SeqLM(PeftModel):
|
1218 |
+
"""
|
1219 |
+
Peft model for sequence-to-sequence language modeling.
|
1220 |
+
|
1221 |
+
Args:
|
1222 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
1223 |
+
peft_config ([`PeftConfig`]): Peft config.
|
1224 |
+
|
1225 |
+
|
1226 |
+
Example:
|
1227 |
+
|
1228 |
+
```py
|
1229 |
+
>>> from transformers import AutoModelForSeq2SeqLM
|
1230 |
+
>>> from peft import PeftModelForSeq2SeqLM, get_peft_config
|
1231 |
+
|
1232 |
+
>>> config = {
|
1233 |
+
... "peft_type": "LORA",
|
1234 |
+
... "task_type": "SEQ_2_SEQ_LM",
|
1235 |
+
... "inference_mode": False,
|
1236 |
+
... "r": 8,
|
1237 |
+
... "target_modules": ["q", "v"],
|
1238 |
+
... "lora_alpha": 32,
|
1239 |
+
... "lora_dropout": 0.1,
|
1240 |
+
... "fan_in_fan_out": False,
|
1241 |
+
... "enable_lora": None,
|
1242 |
+
... "bias": "none",
|
1243 |
+
... }
|
1244 |
+
|
1245 |
+
>>> peft_config = get_peft_config(config)
|
1246 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
|
1247 |
+
>>> peft_model = PeftModelForSeq2SeqLM(model, peft_config)
|
1248 |
+
>>> peft_model.print_trainable_parameters()
|
1249 |
+
trainable params: 884736 || all params: 223843584 || trainable%: 0.3952474242013566
|
1250 |
+
```
|
1251 |
+
"""
|
1252 |
+
|
1253 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
1254 |
+
super().__init__(model, peft_config, adapter_name)
|
1255 |
+
self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation
|
1256 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation = (
|
1257 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation
|
1258 |
+
)
|
1259 |
+
|
1260 |
+
def forward(
|
1261 |
+
self,
|
1262 |
+
input_ids=None,
|
1263 |
+
attention_mask=None,
|
1264 |
+
inputs_embeds=None,
|
1265 |
+
decoder_input_ids=None,
|
1266 |
+
decoder_attention_mask=None,
|
1267 |
+
decoder_inputs_embeds=None,
|
1268 |
+
labels=None,
|
1269 |
+
output_attentions=None,
|
1270 |
+
output_hidden_states=None,
|
1271 |
+
return_dict=None,
|
1272 |
+
task_ids=None,
|
1273 |
+
**kwargs,
|
1274 |
+
):
|
1275 |
+
peft_config = self.active_peft_config
|
1276 |
+
if not peft_config.is_prompt_learning:
|
1277 |
+
if peft_config.peft_type == PeftType.POLY:
|
1278 |
+
kwargs["task_ids"] = task_ids
|
1279 |
+
return self.base_model(
|
1280 |
+
input_ids=input_ids,
|
1281 |
+
attention_mask=attention_mask,
|
1282 |
+
inputs_embeds=inputs_embeds,
|
1283 |
+
decoder_input_ids=decoder_input_ids,
|
1284 |
+
decoder_attention_mask=decoder_attention_mask,
|
1285 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1286 |
+
labels=labels,
|
1287 |
+
output_attentions=output_attentions,
|
1288 |
+
output_hidden_states=output_hidden_states,
|
1289 |
+
return_dict=return_dict,
|
1290 |
+
**kwargs,
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1294 |
+
if decoder_attention_mask is not None:
|
1295 |
+
# concat prompt attention mask
|
1296 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
|
1297 |
+
decoder_attention_mask.device
|
1298 |
+
)
|
1299 |
+
if peft_config.peft_type not in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
|
1300 |
+
decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)
|
1301 |
+
|
1302 |
+
if kwargs.get("position_ids", None) is not None:
|
1303 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1304 |
+
kwargs["position_ids"] = None
|
1305 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1306 |
+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
|
1307 |
+
kwargs["token_type_ids"] = None
|
1308 |
+
kwargs.update(
|
1309 |
+
{
|
1310 |
+
"attention_mask": attention_mask,
|
1311 |
+
"decoder_attention_mask": decoder_attention_mask,
|
1312 |
+
"labels": labels,
|
1313 |
+
"output_attentions": output_attentions,
|
1314 |
+
"output_hidden_states": output_hidden_states,
|
1315 |
+
"return_dict": return_dict,
|
1316 |
+
}
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1320 |
+
past_key_values = self.get_prompt(batch_size)
|
1321 |
+
return self.base_model(
|
1322 |
+
input_ids=input_ids,
|
1323 |
+
decoder_input_ids=decoder_input_ids,
|
1324 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1325 |
+
past_key_values=past_key_values,
|
1326 |
+
**kwargs,
|
1327 |
+
)
|
1328 |
+
elif peft_config.peft_type in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
|
1329 |
+
if inputs_embeds is None:
|
1330 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1331 |
+
|
1332 |
+
if attention_mask is not None:
|
1333 |
+
# concat prompt attention mask
|
1334 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
|
1335 |
+
attention_mask.device
|
1336 |
+
)
|
1337 |
+
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1338 |
+
|
1339 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
1340 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1341 |
+
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
|
1342 |
+
|
1343 |
+
return self.base_model(
|
1344 |
+
inputs_embeds=inputs_embeds,
|
1345 |
+
decoder_input_ids=decoder_input_ids,
|
1346 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1347 |
+
**kwargs,
|
1348 |
+
)
|
1349 |
+
else:
|
1350 |
+
if inputs_embeds is None:
|
1351 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1352 |
+
if decoder_inputs_embeds is None and decoder_input_ids is None:
|
1353 |
+
decoder_input_ids = shift_tokens_right(
|
1354 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1355 |
+
)
|
1356 |
+
decoder_inputs_embeds = self.word_embeddings(decoder_input_ids)
|
1357 |
+
|
1358 |
+
if attention_mask is not None:
|
1359 |
+
# concat prompt attention mask
|
1360 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
|
1361 |
+
attention_mask.device
|
1362 |
+
)
|
1363 |
+
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1364 |
+
# concat prompt labels
|
1365 |
+
if labels is not None:
|
1366 |
+
if peft_config.num_transformer_submodules == 1:
|
1367 |
+
kwargs["labels"] = labels
|
1368 |
+
elif peft_config.num_transformer_submodules == 2:
|
1369 |
+
prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
|
1370 |
+
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
|
1371 |
+
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
|
1372 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1373 |
+
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
|
1374 |
+
if peft_config.num_transformer_submodules == 1:
|
1375 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
1376 |
+
elif peft_config.num_transformer_submodules == 2:
|
1377 |
+
decoder_inputs_embeds = torch.cat(
|
1378 |
+
(prompts[:, peft_config.num_virtual_tokens :], decoder_inputs_embeds), dim=1
|
1379 |
+
)
|
1380 |
+
return self.base_model(
|
1381 |
+
inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, **kwargs
|
1382 |
+
)
|
1383 |
+
|
1384 |
+
def generate(self, **kwargs):
|
1385 |
+
peft_config = self.active_peft_config
|
1386 |
+
self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
|
1387 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
1388 |
+
self._prepare_encoder_decoder_kwargs_for_generation
|
1389 |
+
)
|
1390 |
+
try:
|
1391 |
+
if not peft_config.is_prompt_learning:
|
1392 |
+
outputs = self.base_model.generate(**kwargs)
|
1393 |
+
else:
|
1394 |
+
if "input_ids" not in kwargs:
|
1395 |
+
raise ValueError("input_ids must be provided for Peft model generation")
|
1396 |
+
if kwargs.get("position_ids", None) is not None:
|
1397 |
+
warnings.warn(
|
1398 |
+
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
|
1399 |
+
)
|
1400 |
+
kwargs["position_ids"] = None
|
1401 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1402 |
+
warnings.warn(
|
1403 |
+
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
|
1404 |
+
)
|
1405 |
+
kwargs["token_type_ids"] = None
|
1406 |
+
|
1407 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1408 |
+
outputs = self.base_model.generate(**kwargs)
|
1409 |
+
elif peft_config.peft_type in [
|
1410 |
+
PeftType.PROMPT_TUNING,
|
1411 |
+
PeftType.P_TUNING,
|
1412 |
+
PeftType.MULTITASK_PROMPT_TUNING,
|
1413 |
+
]:
|
1414 |
+
kwargs = deepcopy(kwargs)
|
1415 |
+
|
1416 |
+
if "encoder_outputs" in kwargs:
|
1417 |
+
del kwargs["encoder_outputs"]
|
1418 |
+
warnings.warn(
|
1419 |
+
"`encoder_outputs` should not be passed to `generate` when using prompt tuning. Ignoring it."
|
1420 |
+
)
|
1421 |
+
|
1422 |
+
input_ids = kwargs.pop("input_ids")
|
1423 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1424 |
+
batch_size = inputs_embeds.shape[0]
|
1425 |
+
prompts = self.get_prompt(batch_size=batch_size, task_ids=kwargs.pop("task_ids", None))
|
1426 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1427 |
+
|
1428 |
+
inputs_embeds = torch.cat((prompts[:, : peft_config.num_virtual_tokens], inputs_embeds), dim=1)
|
1429 |
+
kwargs["inputs_embeds"] = inputs_embeds
|
1430 |
+
|
1431 |
+
if "attention_mask" in kwargs:
|
1432 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
|
1433 |
+
kwargs["attention_mask"].device
|
1434 |
+
)
|
1435 |
+
kwargs["attention_mask"] = torch.cat((prefix_attention_mask, kwargs["attention_mask"]), dim=1)
|
1436 |
+
|
1437 |
+
return self.base_model.generate(**kwargs)
|
1438 |
+
else:
|
1439 |
+
raise NotImplementedError
|
1440 |
+
except:
|
1441 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
1442 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
1443 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation
|
1444 |
+
)
|
1445 |
+
raise
|
1446 |
+
else:
|
1447 |
+
self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
|
1448 |
+
self.base_model._prepare_encoder_decoder_kwargs_for_generation = (
|
1449 |
+
self.base_model_prepare_encoder_decoder_kwargs_for_generation
|
1450 |
+
)
|
1451 |
+
return outputs
|
1452 |
+
|
1453 |
+
def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
|
1454 |
+
peft_config = self.active_peft_config
|
1455 |
+
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
|
1456 |
+
if peft_config.peft_type == PeftType.POLY:
|
1457 |
+
model_kwargs["task_ids"] = task_ids
|
1458 |
+
if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1459 |
+
batch_size = model_kwargs["decoder_input_ids"].shape[0]
|
1460 |
+
past_key_values = self.get_prompt(batch_size)
|
1461 |
+
model_kwargs["past_key_values"] = past_key_values
|
1462 |
+
|
1463 |
+
return model_kwargs
|
1464 |
+
|
1465 |
+
|
1466 |
+
class PeftModelForTokenClassification(PeftModel):
|
1467 |
+
"""
|
1468 |
+
Peft model for token classification tasks.
|
1469 |
+
|
1470 |
+
Args:
|
1471 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
1472 |
+
peft_config ([`PeftConfig`]): Peft config.
|
1473 |
+
|
1474 |
+
**Attributes**:
|
1475 |
+
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
|
1476 |
+
- **cls_layer_name** (`str`) -- The name of the classification layer.
|
1477 |
+
|
1478 |
+
Example:
|
1479 |
+
|
1480 |
+
```py
|
1481 |
+
>>> from transformers import AutoModelForSequenceClassification
|
1482 |
+
>>> from peft import PeftModelForTokenClassification, get_peft_config
|
1483 |
+
|
1484 |
+
>>> config = {
|
1485 |
+
... "peft_type": "PREFIX_TUNING",
|
1486 |
+
... "task_type": "TOKEN_CLS",
|
1487 |
+
... "inference_mode": False,
|
1488 |
+
... "num_virtual_tokens": 20,
|
1489 |
+
... "token_dim": 768,
|
1490 |
+
... "num_transformer_submodules": 1,
|
1491 |
+
... "num_attention_heads": 12,
|
1492 |
+
... "num_layers": 12,
|
1493 |
+
... "encoder_hidden_size": 768,
|
1494 |
+
... "prefix_projection": False,
|
1495 |
+
... "postprocess_past_key_value_function": None,
|
1496 |
+
... }
|
1497 |
+
|
1498 |
+
>>> peft_config = get_peft_config(config)
|
1499 |
+
>>> model = AutoModelForTokenClassification.from_pretrained("bert-base-cased")
|
1500 |
+
>>> peft_model = PeftModelForTokenClassification(model, peft_config)
|
1501 |
+
>>> peft_model.print_trainable_parameters()
|
1502 |
+
trainable params: 370178 || all params: 108680450 || trainable%: 0.3406113979101117
|
1503 |
+
```
|
1504 |
+
"""
|
1505 |
+
|
1506 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig = None, adapter_name: str = "default") -> None:
|
1507 |
+
super().__init__(model, peft_config, adapter_name)
|
1508 |
+
if self.modules_to_save is None:
|
1509 |
+
self.modules_to_save = {"classifier", "score"}
|
1510 |
+
else:
|
1511 |
+
self.modules_to_save.update({"classifier", "score"})
|
1512 |
+
|
1513 |
+
for name, _ in self.base_model.named_children():
|
1514 |
+
if any(module_name in name for module_name in self.modules_to_save):
|
1515 |
+
self.cls_layer_name = name
|
1516 |
+
break
|
1517 |
+
|
1518 |
+
# to make sure classifier layer is trainable
|
1519 |
+
_set_trainable(self, adapter_name)
|
1520 |
+
|
1521 |
+
def forward(
|
1522 |
+
self,
|
1523 |
+
input_ids=None,
|
1524 |
+
attention_mask=None,
|
1525 |
+
inputs_embeds=None,
|
1526 |
+
labels=None,
|
1527 |
+
output_attentions=None,
|
1528 |
+
output_hidden_states=None,
|
1529 |
+
return_dict=None,
|
1530 |
+
task_ids=None,
|
1531 |
+
**kwargs,
|
1532 |
+
):
|
1533 |
+
peft_config = self.active_peft_config
|
1534 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1535 |
+
|
1536 |
+
if not peft_config.is_prompt_learning:
|
1537 |
+
if peft_config.peft_type == PeftType.POLY:
|
1538 |
+
kwargs["task_ids"] = task_ids
|
1539 |
+
return self.base_model(
|
1540 |
+
input_ids=input_ids,
|
1541 |
+
attention_mask=attention_mask,
|
1542 |
+
inputs_embeds=inputs_embeds,
|
1543 |
+
labels=labels,
|
1544 |
+
output_attentions=output_attentions,
|
1545 |
+
output_hidden_states=output_hidden_states,
|
1546 |
+
return_dict=return_dict,
|
1547 |
+
**kwargs,
|
1548 |
+
)
|
1549 |
+
|
1550 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1551 |
+
if attention_mask is not None:
|
1552 |
+
# concat prompt attention mask
|
1553 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
|
1554 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1555 |
+
if kwargs.get("position_ids", None) is not None:
|
1556 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1557 |
+
kwargs["position_ids"] = None
|
1558 |
+
kwargs.update(
|
1559 |
+
{
|
1560 |
+
"attention_mask": attention_mask,
|
1561 |
+
"labels": labels,
|
1562 |
+
"output_attentions": output_attentions,
|
1563 |
+
"output_hidden_states": output_hidden_states,
|
1564 |
+
"return_dict": return_dict,
|
1565 |
+
}
|
1566 |
+
)
|
1567 |
+
|
1568 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1569 |
+
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
|
1570 |
+
else:
|
1571 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1572 |
+
kwargs["token_type_ids"] = torch.cat(
|
1573 |
+
(
|
1574 |
+
torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
|
1575 |
+
kwargs["token_type_ids"],
|
1576 |
+
),
|
1577 |
+
dim=1,
|
1578 |
+
).long()
|
1579 |
+
if inputs_embeds is None:
|
1580 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1581 |
+
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
|
1582 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1583 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
1584 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
1585 |
+
|
1586 |
+
def _prefix_tuning_forward(
|
1587 |
+
self,
|
1588 |
+
input_ids=None,
|
1589 |
+
attention_mask=None,
|
1590 |
+
inputs_embeds=None,
|
1591 |
+
labels=None,
|
1592 |
+
output_attentions=None,
|
1593 |
+
output_hidden_states=None,
|
1594 |
+
return_dict=None,
|
1595 |
+
**kwargs,
|
1596 |
+
):
|
1597 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1598 |
+
past_key_values = self.get_prompt(batch_size)
|
1599 |
+
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
|
1600 |
+
kwargs.update(
|
1601 |
+
{
|
1602 |
+
"input_ids": input_ids,
|
1603 |
+
"attention_mask": attention_mask,
|
1604 |
+
"inputs_embeds": inputs_embeds,
|
1605 |
+
"output_attentions": output_attentions,
|
1606 |
+
"output_hidden_states": output_hidden_states,
|
1607 |
+
"return_dict": return_dict,
|
1608 |
+
"past_key_values": past_key_values,
|
1609 |
+
}
|
1610 |
+
)
|
1611 |
+
if "past_key_values" in fwd_params:
|
1612 |
+
return self.base_model(labels=labels, **kwargs)
|
1613 |
+
else:
|
1614 |
+
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
|
1615 |
+
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
|
1616 |
+
if "past_key_values" not in fwd_params:
|
1617 |
+
raise ValueError("Model does not support past key values which are required for prefix tuning.")
|
1618 |
+
outputs = transformer_backbone_name(**kwargs)
|
1619 |
+
sequence_output = outputs[0]
|
1620 |
+
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
|
1621 |
+
sequence_output = self.base_model.dropout(sequence_output)
|
1622 |
+
logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
|
1623 |
+
|
1624 |
+
loss = None
|
1625 |
+
if labels is not None:
|
1626 |
+
loss_fct = CrossEntropyLoss()
|
1627 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1628 |
+
|
1629 |
+
if not return_dict:
|
1630 |
+
output = (logits,) + outputs[2:]
|
1631 |
+
return ((loss,) + output) if loss is not None else output
|
1632 |
+
|
1633 |
+
return TokenClassifierOutput(
|
1634 |
+
loss=loss,
|
1635 |
+
logits=logits,
|
1636 |
+
hidden_states=outputs.hidden_states,
|
1637 |
+
attentions=outputs.attentions,
|
1638 |
+
)
|
1639 |
+
|
1640 |
+
|
1641 |
+
class PeftModelForQuestionAnswering(PeftModel):
|
1642 |
+
"""
|
1643 |
+
Peft model for extractive question answering.
|
1644 |
+
|
1645 |
+
Args:
|
1646 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
1647 |
+
peft_config ([`PeftConfig`]): Peft config.
|
1648 |
+
|
1649 |
+
**Attributes**:
|
1650 |
+
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
|
1651 |
+
- **cls_layer_name** (`str`) -- The name of the classification layer.
|
1652 |
+
|
1653 |
+
Example:
|
1654 |
+
|
1655 |
+
```py
|
1656 |
+
>>> from transformers import AutoModelForQuestionAnswering
|
1657 |
+
>>> from peft import PeftModelForQuestionAnswering, get_peft_config
|
1658 |
+
|
1659 |
+
>>> config = {
|
1660 |
+
... "peft_type": "LORA",
|
1661 |
+
... "task_type": "QUESTION_ANS",
|
1662 |
+
... "inference_mode": False,
|
1663 |
+
... "r": 16,
|
1664 |
+
... "target_modules": ["query", "value"],
|
1665 |
+
... "lora_alpha": 32,
|
1666 |
+
... "lora_dropout": 0.05,
|
1667 |
+
... "fan_in_fan_out": False,
|
1668 |
+
... "bias": "none",
|
1669 |
+
... }
|
1670 |
+
|
1671 |
+
>>> peft_config = get_peft_config(config)
|
1672 |
+
>>> model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased")
|
1673 |
+
>>> peft_model = PeftModelForQuestionAnswering(model, peft_config)
|
1674 |
+
>>> peft_model.print_trainable_parameters()
|
1675 |
+
trainable params: 592900 || all params: 108312580 || trainable%: 0.5473971721475013
|
1676 |
+
```
|
1677 |
+
"""
|
1678 |
+
|
1679 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
|
1680 |
+
super().__init__(model, peft_config, adapter_name)
|
1681 |
+
if self.modules_to_save is None:
|
1682 |
+
self.modules_to_save = {"qa_outputs"}
|
1683 |
+
else:
|
1684 |
+
self.modules_to_save.update({"qa_outputs"})
|
1685 |
+
|
1686 |
+
for name, _ in self.base_model.named_children():
|
1687 |
+
if any(module_name in name for module_name in self.modules_to_save):
|
1688 |
+
self.cls_layer_name = name
|
1689 |
+
break
|
1690 |
+
|
1691 |
+
# to make sure classifier layer is trainable
|
1692 |
+
_set_trainable(self, adapter_name)
|
1693 |
+
|
1694 |
+
def forward(
|
1695 |
+
self,
|
1696 |
+
input_ids=None,
|
1697 |
+
attention_mask=None,
|
1698 |
+
token_type_ids=None,
|
1699 |
+
position_ids=None,
|
1700 |
+
inputs_embeds=None,
|
1701 |
+
start_positions=None,
|
1702 |
+
end_positions=None,
|
1703 |
+
output_attentions=None,
|
1704 |
+
output_hidden_states=None,
|
1705 |
+
return_dict=None,
|
1706 |
+
task_ids=None,
|
1707 |
+
**kwargs,
|
1708 |
+
):
|
1709 |
+
peft_config = self.active_peft_config
|
1710 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1711 |
+
|
1712 |
+
if not peft_config.is_prompt_learning:
|
1713 |
+
if peft_config.peft_type == PeftType.POLY:
|
1714 |
+
kwargs["task_ids"] = task_ids
|
1715 |
+
return self.base_model(
|
1716 |
+
input_ids=input_ids,
|
1717 |
+
attention_mask=attention_mask,
|
1718 |
+
inputs_embeds=inputs_embeds,
|
1719 |
+
start_positions=start_positions,
|
1720 |
+
end_positions=end_positions,
|
1721 |
+
output_attentions=output_attentions,
|
1722 |
+
output_hidden_states=output_hidden_states,
|
1723 |
+
return_dict=return_dict,
|
1724 |
+
**kwargs,
|
1725 |
+
)
|
1726 |
+
|
1727 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1728 |
+
if attention_mask is not None:
|
1729 |
+
# concat prompt attention mask
|
1730 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
|
1731 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1732 |
+
if kwargs.get("position_ids", None) is not None:
|
1733 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1734 |
+
kwargs["position_ids"] = None
|
1735 |
+
kwargs.update(
|
1736 |
+
{
|
1737 |
+
"attention_mask": attention_mask,
|
1738 |
+
"start_positions": start_positions,
|
1739 |
+
"end_positions": end_positions,
|
1740 |
+
"output_attentions": output_attentions,
|
1741 |
+
"output_hidden_states": output_hidden_states,
|
1742 |
+
"return_dict": return_dict,
|
1743 |
+
}
|
1744 |
+
)
|
1745 |
+
|
1746 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1747 |
+
return self._prefix_tuning_forward(input_ids=input_ids, **kwargs)
|
1748 |
+
else:
|
1749 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1750 |
+
kwargs["token_type_ids"] = torch.cat(
|
1751 |
+
(
|
1752 |
+
torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device),
|
1753 |
+
kwargs["token_type_ids"],
|
1754 |
+
),
|
1755 |
+
dim=1,
|
1756 |
+
).long()
|
1757 |
+
if inputs_embeds is None:
|
1758 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1759 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
1760 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1761 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
1762 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
1763 |
+
|
1764 |
+
def _prefix_tuning_forward(
|
1765 |
+
self,
|
1766 |
+
input_ids=None,
|
1767 |
+
attention_mask=None,
|
1768 |
+
inputs_embeds=None,
|
1769 |
+
start_positions=None,
|
1770 |
+
end_positions=None,
|
1771 |
+
output_attentions=None,
|
1772 |
+
output_hidden_states=None,
|
1773 |
+
return_dict=None,
|
1774 |
+
**kwargs,
|
1775 |
+
):
|
1776 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1777 |
+
past_key_values = self.get_prompt(batch_size)
|
1778 |
+
fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys())
|
1779 |
+
kwargs.update(
|
1780 |
+
{
|
1781 |
+
"input_ids": input_ids,
|
1782 |
+
"attention_mask": attention_mask,
|
1783 |
+
"inputs_embeds": inputs_embeds,
|
1784 |
+
"output_attentions": output_attentions,
|
1785 |
+
"output_hidden_states": output_hidden_states,
|
1786 |
+
"return_dict": return_dict,
|
1787 |
+
"past_key_values": past_key_values,
|
1788 |
+
}
|
1789 |
+
)
|
1790 |
+
if "past_key_values" in fwd_params:
|
1791 |
+
return self.base_model(start_positions=start_positions, end_positions=end_positions, **kwargs)
|
1792 |
+
else:
|
1793 |
+
transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name)
|
1794 |
+
fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys())
|
1795 |
+
if "past_key_values" not in fwd_params:
|
1796 |
+
raise ValueError("Model does not support past key values which are required for prefix tuning.")
|
1797 |
+
outputs = transformer_backbone_name(**kwargs)
|
1798 |
+
sequence_output = outputs[0]
|
1799 |
+
if "dropout" in [name for name, _ in list(self.base_model.named_children())]:
|
1800 |
+
sequence_output = self.base_model.dropout(sequence_output)
|
1801 |
+
logits = self.base_model.get_submodule(self.cls_layer_name)(sequence_output)
|
1802 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1803 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1804 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1805 |
+
|
1806 |
+
total_loss = None
|
1807 |
+
if start_positions is not None and end_positions is not None:
|
1808 |
+
# If we are on multi-GPU, split add a dimension
|
1809 |
+
if len(start_positions.size()) > 1:
|
1810 |
+
start_positions = start_positions.squeeze(-1)
|
1811 |
+
if len(end_positions.size()) > 1:
|
1812 |
+
end_positions = end_positions.squeeze(-1)
|
1813 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1814 |
+
ignored_index = start_logits.size(1)
|
1815 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1816 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1817 |
+
|
1818 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1819 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1820 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1821 |
+
total_loss = (start_loss + end_loss) / 2
|
1822 |
+
|
1823 |
+
if not return_dict:
|
1824 |
+
output = (start_logits, end_logits) + outputs[2:]
|
1825 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1826 |
+
|
1827 |
+
return QuestionAnsweringModelOutput(
|
1828 |
+
loss=total_loss,
|
1829 |
+
start_logits=start_logits,
|
1830 |
+
end_logits=end_logits,
|
1831 |
+
hidden_states=outputs.hidden_states,
|
1832 |
+
attentions=outputs.attentions,
|
1833 |
+
)
|
1834 |
+
|
1835 |
+
|
1836 |
+
class PeftModelForFeatureExtraction(PeftModel):
|
1837 |
+
"""
|
1838 |
+
Peft model for extracting features/embeddings from transformer models
|
1839 |
+
|
1840 |
+
Args:
|
1841 |
+
model ([`~transformers.PreTrainedModel`]): Base transformer model.
|
1842 |
+
peft_config ([`PeftConfig`]): Peft config.
|
1843 |
+
|
1844 |
+
**Attributes**:
|
1845 |
+
- **config** ([`~transformers.PretrainedConfig`]) -- The configuration object of the base model.
|
1846 |
+
|
1847 |
+
Example:
|
1848 |
+
|
1849 |
+
```py
|
1850 |
+
>>> from transformers import AutoModel
|
1851 |
+
>>> from peft import PeftModelForFeatureExtraction, get_peft_config
|
1852 |
+
|
1853 |
+
>>> config = {
|
1854 |
+
... "peft_type": "LORA",
|
1855 |
+
... "task_type": "FEATURE_EXTRACTION",
|
1856 |
+
... "inference_mode": False,
|
1857 |
+
... "r": 16,
|
1858 |
+
... "target_modules": ["query", "value"],
|
1859 |
+
... "lora_alpha": 32,
|
1860 |
+
... "lora_dropout": 0.05,
|
1861 |
+
... "fan_in_fan_out": False,
|
1862 |
+
... "bias": "none",
|
1863 |
+
... }
|
1864 |
+
>>> peft_config = get_peft_config(config)
|
1865 |
+
>>> model = AutoModel.from_pretrained("bert-base-cased")
|
1866 |
+
>>> peft_model = PeftModelForFeatureExtraction(model, peft_config)
|
1867 |
+
>>> peft_model.print_trainable_parameters()
|
1868 |
+
```
|
1869 |
+
"""
|
1870 |
+
|
1871 |
+
def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default"):
|
1872 |
+
super().__init__(model, peft_config, adapter_name)
|
1873 |
+
|
1874 |
+
def forward(
|
1875 |
+
self,
|
1876 |
+
input_ids=None,
|
1877 |
+
attention_mask=None,
|
1878 |
+
inputs_embeds=None,
|
1879 |
+
output_attentions=None,
|
1880 |
+
output_hidden_states=None,
|
1881 |
+
return_dict=None,
|
1882 |
+
task_ids=None,
|
1883 |
+
**kwargs,
|
1884 |
+
):
|
1885 |
+
peft_config = self.active_peft_config
|
1886 |
+
if not peft_config.is_prompt_learning:
|
1887 |
+
if peft_config.peft_type == PeftType.POLY:
|
1888 |
+
kwargs["task_ids"] = task_ids
|
1889 |
+
return self.base_model(
|
1890 |
+
input_ids=input_ids,
|
1891 |
+
attention_mask=attention_mask,
|
1892 |
+
inputs_embeds=inputs_embeds,
|
1893 |
+
output_attentions=output_attentions,
|
1894 |
+
output_hidden_states=output_hidden_states,
|
1895 |
+
return_dict=return_dict,
|
1896 |
+
**kwargs,
|
1897 |
+
)
|
1898 |
+
|
1899 |
+
batch_size = _get_batch_size(input_ids, inputs_embeds)
|
1900 |
+
if attention_mask is not None:
|
1901 |
+
# concat prompt attention mask
|
1902 |
+
prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
|
1903 |
+
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
1904 |
+
|
1905 |
+
if kwargs.get("position_ids", None) is not None:
|
1906 |
+
warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
|
1907 |
+
kwargs["position_ids"] = None
|
1908 |
+
if kwargs.get("token_type_ids", None) is not None:
|
1909 |
+
warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
|
1910 |
+
kwargs["token_type_ids"] = None
|
1911 |
+
kwargs.update(
|
1912 |
+
{
|
1913 |
+
"attention_mask": attention_mask,
|
1914 |
+
"output_attentions": output_attentions,
|
1915 |
+
"output_hidden_states": output_hidden_states,
|
1916 |
+
"return_dict": return_dict,
|
1917 |
+
}
|
1918 |
+
)
|
1919 |
+
|
1920 |
+
if peft_config.peft_type == PeftType.PREFIX_TUNING:
|
1921 |
+
past_key_values = self.get_prompt(batch_size)
|
1922 |
+
return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
|
1923 |
+
else:
|
1924 |
+
if inputs_embeds is None:
|
1925 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
1926 |
+
prompts = self.get_prompt(batch_size=batch_size)
|
1927 |
+
prompts = prompts.to(inputs_embeds.dtype)
|
1928 |
+
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
|
1929 |
+
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
|
MoRA/peft_mora/py.typed
ADDED
File without changes
|
MoRA/peft_mora/tuners/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
3 |
+
# module, but to preserve other warnings. So, don't check this module at all
|
4 |
+
|
5 |
+
# coding=utf-8
|
6 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
from .adaption_prompt import AdaptionPromptConfig, AdaptionPromptModel
|
21 |
+
from .lora import LoraConfig, LoraModel, LoftQConfig
|
22 |
+
from .loha import LoHaConfig, LoHaModel
|
23 |
+
from .lokr import LoKrConfig, LoKrModel
|
24 |
+
from .ia3 import IA3Config, IA3Model
|
25 |
+
from .adalora import AdaLoraConfig, AdaLoraModel
|
26 |
+
from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
|
27 |
+
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
|
28 |
+
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
|
29 |
+
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
|
30 |
+
from .oft import OFTConfig, OFTModel
|
31 |
+
from .mixed import MixedModel
|
32 |
+
from .poly import PolyConfig, PolyModel
|
MoRA/peft_mora/tuners/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.24 kB). View file
|
|
MoRA/peft_mora/tuners/__pycache__/lycoris_utils.cpython-312.pyc
ADDED
Binary file (19.9 kB). View file
|
|
MoRA/peft_mora/tuners/__pycache__/tuners_utils.cpython-312.pyc
ADDED
Binary file (29.6 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from peft_mora.import_utils import is_bnb_4bit_available, is_bnb_available
|
16 |
+
|
17 |
+
from .config import AdaLoraConfig
|
18 |
+
from .gptq import SVDQuantLinear
|
19 |
+
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
|
20 |
+
from .model import AdaLoraModel
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]
|
24 |
+
|
25 |
+
|
26 |
+
def __getattr__(name):
|
27 |
+
if (name == "SVDLinear8bitLt") and is_bnb_available():
|
28 |
+
from .bnb import SVDLinear8bitLt
|
29 |
+
|
30 |
+
return SVDLinear8bitLt
|
31 |
+
|
32 |
+
if (name == "SVDLinear4bit") and is_bnb_4bit_available():
|
33 |
+
from .bnb import SVDLinear4bit
|
34 |
+
|
35 |
+
return SVDLinear4bit
|
36 |
+
|
37 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
MoRA/peft_mora/tuners/adalora/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (1.02 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/__pycache__/config.cpython-312.pyc
ADDED
Binary file (2.82 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/__pycache__/gptq.cpython-312.pyc
ADDED
Binary file (2.66 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/__pycache__/layer.cpython-312.pyc
ADDED
Binary file (19.8 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/__pycache__/model.cpython-312.pyc
ADDED
Binary file (15.8 kB). View file
|
|
MoRA/peft_mora/tuners/adalora/bnb.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Any
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
from peft_mora.import_utils import is_bnb_4bit_available, is_bnb_available
|
20 |
+
|
21 |
+
from .layer import AdaLoraLayer
|
22 |
+
|
23 |
+
|
24 |
+
if is_bnb_available():
|
25 |
+
|
26 |
+
class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
|
27 |
+
# Low-rank matrix for SVD-based adaptation
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
base_layer: torch.nn.Module,
|
31 |
+
adapter_name: str,
|
32 |
+
r: int = 0,
|
33 |
+
lora_alpha: int = 1,
|
34 |
+
lora_dropout: float = 0.0,
|
35 |
+
init_lora_weights: bool = True,
|
36 |
+
**kwargs,
|
37 |
+
) -> None:
|
38 |
+
super().__init__()
|
39 |
+
AdaLoraLayer.__init__(self, base_layer)
|
40 |
+
# Freezing the pre-trained weight matrix
|
41 |
+
self.get_base_layer().weight.requires_grad = False
|
42 |
+
|
43 |
+
self._active_adapter = adapter_name
|
44 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
47 |
+
# note: no check for self.merged because merging is not supported (yet)
|
48 |
+
result = self.base_layer(x)
|
49 |
+
|
50 |
+
if self.disable_adapters:
|
51 |
+
return result
|
52 |
+
|
53 |
+
for active_adapter in self.active_adapters:
|
54 |
+
if active_adapter not in self.lora_A.keys():
|
55 |
+
continue
|
56 |
+
requires_conversion = not torch.is_autocast_enabled()
|
57 |
+
if requires_conversion:
|
58 |
+
expected_dtype = result.dtype
|
59 |
+
if x.dtype != torch.float32:
|
60 |
+
x = x.float()
|
61 |
+
|
62 |
+
lora_A = self.lora_A[active_adapter]
|
63 |
+
lora_B = self.lora_B[active_adapter]
|
64 |
+
lora_E = self.lora_E[active_adapter]
|
65 |
+
dropout = self.lora_dropout[active_adapter]
|
66 |
+
scaling = self.scaling[active_adapter]
|
67 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
68 |
+
|
69 |
+
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
|
70 |
+
if requires_conversion:
|
71 |
+
output = output.to(expected_dtype)
|
72 |
+
output = output * scaling / ranknum
|
73 |
+
# inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
|
74 |
+
result = result + output
|
75 |
+
return result
|
76 |
+
|
77 |
+
def __repr__(self) -> str:
|
78 |
+
rep = super().__repr__()
|
79 |
+
return "adalora." + rep
|
80 |
+
|
81 |
+
|
82 |
+
if is_bnb_4bit_available():
|
83 |
+
|
84 |
+
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
|
85 |
+
# Low-rank matrix for SVD-based adaptation
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
base_layer: torch.nn.Module,
|
89 |
+
adapter_name: str,
|
90 |
+
r: int = 0,
|
91 |
+
lora_alpha: int = 1,
|
92 |
+
lora_dropout: float = 0.0,
|
93 |
+
init_lora_weights: bool = True,
|
94 |
+
**kwargs,
|
95 |
+
) -> None:
|
96 |
+
super().__init__()
|
97 |
+
AdaLoraLayer.__init__(self, base_layer)
|
98 |
+
# Freezing the pre-trained weight matrix
|
99 |
+
self.get_base_layer().weight.requires_grad = False
|
100 |
+
|
101 |
+
self._active_adapter = adapter_name
|
102 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
105 |
+
# note: no check for self.merged because merging is not supported (yet)
|
106 |
+
result = self.base_layer(x, *args, **kwargs)
|
107 |
+
|
108 |
+
if self.disable_adapters:
|
109 |
+
return result
|
110 |
+
|
111 |
+
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
112 |
+
# The reason is that in some cases, an error can occur that backprop
|
113 |
+
# does not work on a manipulated view. This issue may be solved with
|
114 |
+
# newer PyTorch versions but this would need extensive testing to be
|
115 |
+
# sure.
|
116 |
+
result = result.clone()
|
117 |
+
|
118 |
+
for active_adapter in self.active_adapters:
|
119 |
+
if active_adapter not in self.lora_A.keys():
|
120 |
+
continue
|
121 |
+
|
122 |
+
lora_A = self.lora_A[active_adapter]
|
123 |
+
lora_B = self.lora_B[active_adapter]
|
124 |
+
lora_E = self.lora_E[active_adapter]
|
125 |
+
dropout = self.lora_dropout[active_adapter]
|
126 |
+
scaling = self.scaling[active_adapter]
|
127 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
128 |
+
|
129 |
+
requires_conversion = not torch.is_autocast_enabled()
|
130 |
+
if requires_conversion:
|
131 |
+
expected_dtype = result.dtype
|
132 |
+
compute_dtype = lora_A.dtype
|
133 |
+
if x.dtype != compute_dtype:
|
134 |
+
x = x.to(compute_dtype)
|
135 |
+
|
136 |
+
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
|
137 |
+
if requires_conversion:
|
138 |
+
output = output.to(expected_dtype)
|
139 |
+
output = output * scaling / ranknum
|
140 |
+
result += output
|
141 |
+
return result
|
142 |
+
|
143 |
+
def __repr__(self) -> str:
|
144 |
+
rep = super().__repr__()
|
145 |
+
return "adalora." + rep
|
MoRA/peft_mora/tuners/adalora/config.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
from peft_mora.tuners.lora import LoraConfig
|
19 |
+
from peft_mora.utils import PeftType
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class AdaLoraConfig(LoraConfig):
|
24 |
+
"""
|
25 |
+
This is the configuration class to store the configuration of a [`~peft.AdaLora`].
|
26 |
+
|
27 |
+
Args:
|
28 |
+
target_r (`int`): The target average rank of incremental matrix.
|
29 |
+
init_r (`int`): The initial rank for each incremental matrix.
|
30 |
+
tinit (`int`): The steps of initial fine-tuning warmup.
|
31 |
+
tfinal (`int`): The step of final fine-tuning.
|
32 |
+
deltaT (`int`): The time internval between two budget allocations.
|
33 |
+
beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
|
34 |
+
beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
|
35 |
+
orth_reg_weight (`float`): The coefficient of orthogonal regularization.
|
36 |
+
total_step (`int`): The total training steps that should be specified before training.
|
37 |
+
rank_pattern (`list`): The allocated rank for each weight matrix by RankAllocator.
|
38 |
+
"""
|
39 |
+
|
40 |
+
target_r: int = field(default=8, metadata={"help": "Target Lora matrix dimension."})
|
41 |
+
init_r: int = field(default=12, metadata={"help": "Initial Lora matrix dimension."})
|
42 |
+
tinit: int = field(default=0, metadata={"help": "The steps of initial warmup."})
|
43 |
+
tfinal: int = field(default=0, metadata={"help": "The steps of final warmup."})
|
44 |
+
deltaT: int = field(default=1, metadata={"help": "Step interval of rank allocation."})
|
45 |
+
beta1: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
|
46 |
+
beta2: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
|
47 |
+
orth_reg_weight: float = field(default=0.5, metadata={"help": "The orthogonal regularization coefficient."})
|
48 |
+
total_step: Optional[int] = field(default=None, metadata={"help": "The total training steps."})
|
49 |
+
rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."})
|
50 |
+
|
51 |
+
def __post_init__(self):
|
52 |
+
self.peft_type = PeftType.ADALORA
|
MoRA/peft_mora/tuners/adalora/gptq.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from .layer import AdaLoraLayer
|
17 |
+
|
18 |
+
|
19 |
+
class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
base_layer,
|
23 |
+
adapter_name,
|
24 |
+
r: int = 0,
|
25 |
+
lora_alpha: int = 1,
|
26 |
+
lora_dropout: float = 0.0,
|
27 |
+
init_lora_weights: bool = True,
|
28 |
+
**kwargs,
|
29 |
+
) -> None:
|
30 |
+
super().__init__()
|
31 |
+
AdaLoraLayer.__init__(self, base_layer)
|
32 |
+
|
33 |
+
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
34 |
+
# for backwards compatibility
|
35 |
+
self.quant_linear_module = base_layer
|
36 |
+
self._active_adapter = adapter_name
|
37 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
result = self.quant_linear_module(x)
|
41 |
+
|
42 |
+
if self.disable_adapters:
|
43 |
+
return result
|
44 |
+
|
45 |
+
for active_adapter in self.active_adapters:
|
46 |
+
if active_adapter not in self.lora_A.keys():
|
47 |
+
continue
|
48 |
+
lora_A = self.lora_A[active_adapter]
|
49 |
+
lora_B = self.lora_B[active_adapter]
|
50 |
+
lora_E = self.lora_E[active_adapter]
|
51 |
+
dropout = self.lora_dropout[active_adapter]
|
52 |
+
scaling = self.scaling[active_adapter]
|
53 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
54 |
+
|
55 |
+
requires_conversion = not torch.is_autocast_enabled()
|
56 |
+
if requires_conversion:
|
57 |
+
expected_dtype = result.dtype
|
58 |
+
if x.dtype != torch.float32:
|
59 |
+
x = x.float()
|
60 |
+
|
61 |
+
output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
|
62 |
+
# TODO: here, the dtype conversion is applied on the *whole expression*,
|
63 |
+
# not the intermediate result, unlike for SVDLinear8bitLT and
|
64 |
+
# SVDLinear4bit, is that correct?
|
65 |
+
if requires_conversion:
|
66 |
+
output = output.to(expected_dtype)
|
67 |
+
result += output
|
68 |
+
return result
|
69 |
+
|
70 |
+
def __repr__(self) -> str:
|
71 |
+
rep = super().__repr__()
|
72 |
+
return "adalora." + rep
|
MoRA/peft_mora/tuners/adalora/layer.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
from typing import Any, List, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from peft_mora.tuners.lora import LoraLayer
|
22 |
+
from peft_mora.tuners.tuners_utils import check_adapters_to_merge
|
23 |
+
from peft_mora.utils import transpose
|
24 |
+
|
25 |
+
|
26 |
+
class AdaLoraLayer(LoraLayer):
|
27 |
+
# List all names of layers that may contain adapter weights
|
28 |
+
# Note: ranknum doesn't need to be included as it is not an nn.Module
|
29 |
+
adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
|
30 |
+
# other_param_names is defined in LoraLayer
|
31 |
+
|
32 |
+
def __init__(self, base_layer: nn.Module) -> None:
|
33 |
+
super().__init__(base_layer)
|
34 |
+
self.lora_E = nn.ParameterDict({})
|
35 |
+
self.lora_A = nn.ParameterDict({})
|
36 |
+
self.lora_B = nn.ParameterDict({})
|
37 |
+
self.ranknum = nn.ParameterDict({})
|
38 |
+
|
39 |
+
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
|
40 |
+
if r <= 0:
|
41 |
+
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
42 |
+
|
43 |
+
self.r[adapter_name] = r
|
44 |
+
self.lora_alpha[adapter_name] = lora_alpha
|
45 |
+
if lora_dropout > 0.0:
|
46 |
+
lora_dropout_layer = nn.Dropout(p=lora_dropout)
|
47 |
+
else:
|
48 |
+
lora_dropout_layer = nn.Identity()
|
49 |
+
|
50 |
+
self.lora_dropout[adapter_name] = lora_dropout_layer
|
51 |
+
# Actual trainable parameters
|
52 |
+
# Right singular vectors
|
53 |
+
self.lora_A[adapter_name] = nn.Parameter(torch.randn(r, self.in_features))
|
54 |
+
# Singular values
|
55 |
+
self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1))
|
56 |
+
# Left singular vectors
|
57 |
+
self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r))
|
58 |
+
# The current rank
|
59 |
+
self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False)
|
60 |
+
self.ranknum[adapter_name].data.fill_(float(r))
|
61 |
+
self.ranknum[adapter_name].requires_grad = False
|
62 |
+
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
|
63 |
+
if init_lora_weights:
|
64 |
+
self.reset_lora_parameters(adapter_name)
|
65 |
+
|
66 |
+
if hasattr(self.get_base_layer(), "qweight"):
|
67 |
+
# QuantLinear
|
68 |
+
self.to(self.get_base_layer().qweight.device)
|
69 |
+
else:
|
70 |
+
self.to(self.get_base_layer().weight.device)
|
71 |
+
self.set_adapter(self.active_adapters)
|
72 |
+
|
73 |
+
def reset_lora_parameters(self, adapter_name):
|
74 |
+
if adapter_name in self.lora_A.keys():
|
75 |
+
nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02)
|
76 |
+
nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02)
|
77 |
+
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)
|
78 |
+
|
79 |
+
|
80 |
+
class SVDLinear(nn.Module, AdaLoraLayer):
|
81 |
+
# SVD-based adaptation by a dense layer
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
base_layer: nn.Module,
|
85 |
+
adapter_name: str,
|
86 |
+
r: int = 0,
|
87 |
+
lora_alpha: int = 1,
|
88 |
+
lora_dropout: float = 0.0,
|
89 |
+
fan_in_fan_out: bool = False,
|
90 |
+
init_lora_weights: bool = True,
|
91 |
+
**kwargs,
|
92 |
+
) -> None:
|
93 |
+
super().__init__()
|
94 |
+
AdaLoraLayer.__init__(self, base_layer)
|
95 |
+
# Freezing the pre-trained weight matrix
|
96 |
+
self.get_base_layer().weight.requires_grad = False
|
97 |
+
|
98 |
+
self.fan_in_fan_out = fan_in_fan_out
|
99 |
+
self._active_adapter = adapter_name
|
100 |
+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
101 |
+
|
102 |
+
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
103 |
+
"""
|
104 |
+
Merge the active adapter weights into the base weights
|
105 |
+
|
106 |
+
Args:
|
107 |
+
safe_merge (`bool`, *optional*):
|
108 |
+
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
109 |
+
before merging the weights. This is useful if you want to check if the merge operation will produce
|
110 |
+
NaNs. Defaults to `False`.
|
111 |
+
adapter_names (`List[str]`, *optional*):
|
112 |
+
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
113 |
+
to `None`.
|
114 |
+
"""
|
115 |
+
adapter_names = check_adapters_to_merge(self, adapter_names)
|
116 |
+
if not adapter_names:
|
117 |
+
# no adapter to merge
|
118 |
+
return
|
119 |
+
|
120 |
+
for active_adapter in adapter_names:
|
121 |
+
base_layer = self.get_base_layer()
|
122 |
+
if active_adapter in self.lora_A.keys():
|
123 |
+
if safe_merge:
|
124 |
+
# Note that safe_merge will be slower than the normal merge
|
125 |
+
# because of the copy operation.
|
126 |
+
orig_weights = base_layer.weight.data.clone()
|
127 |
+
orig_weights += self.get_delta_weight(active_adapter)
|
128 |
+
|
129 |
+
if not torch.isfinite(orig_weights).all():
|
130 |
+
raise ValueError(
|
131 |
+
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
132 |
+
)
|
133 |
+
|
134 |
+
base_layer.weight.data = orig_weights
|
135 |
+
else:
|
136 |
+
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
137 |
+
self.merged_adapters.append(active_adapter)
|
138 |
+
|
139 |
+
def unmerge(self) -> None:
|
140 |
+
"""
|
141 |
+
This method unmerges all merged adapter layers from the base weights.
|
142 |
+
"""
|
143 |
+
if not self.merged:
|
144 |
+
warnings.warn("Already unmerged. Nothing to do.")
|
145 |
+
return
|
146 |
+
while len(self.merged_adapters) > 0:
|
147 |
+
active_adapter = self.merged_adapters.pop()
|
148 |
+
if active_adapter in self.lora_A.keys():
|
149 |
+
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
150 |
+
|
151 |
+
def get_delta_weight(self, adapter) -> torch.Tensor:
|
152 |
+
return (
|
153 |
+
transpose(self.lora_B[adapter] @ (self.lora_A[adapter] * self.lora_E[adapter]), self.fan_in_fan_out)
|
154 |
+
* self.scaling[adapter]
|
155 |
+
/ (self.ranknum[adapter] + 1e-5)
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
159 |
+
if self.disable_adapters:
|
160 |
+
if self.merged:
|
161 |
+
self.unmerge()
|
162 |
+
result = self.base_layer(x, *args, **kwargs)
|
163 |
+
elif self.merged:
|
164 |
+
result = self.base_layer(x, *args, **kwargs)
|
165 |
+
else:
|
166 |
+
result = self.base_layer(x, *args, **kwargs)
|
167 |
+
for active_adapter in self.active_adapters:
|
168 |
+
if active_adapter not in self.lora_A.keys():
|
169 |
+
continue
|
170 |
+
lora_A = self.lora_A[active_adapter]
|
171 |
+
lora_B = self.lora_B[active_adapter]
|
172 |
+
lora_E = self.lora_E[active_adapter]
|
173 |
+
dropout = self.lora_dropout[active_adapter]
|
174 |
+
scaling = self.scaling[active_adapter]
|
175 |
+
ranknum = self.ranknum[active_adapter] + 1e-5
|
176 |
+
|
177 |
+
x = x.to(lora_A.dtype)
|
178 |
+
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
|
179 |
+
|
180 |
+
return result
|
181 |
+
|
182 |
+
def __repr__(self) -> str:
|
183 |
+
rep = super().__repr__()
|
184 |
+
return "adalora." + rep
|
185 |
+
|
186 |
+
|
187 |
+
class RankAllocator:
|
188 |
+
"""
|
189 |
+
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY
|
190 |
+
|
191 |
+
Args:
|
192 |
+
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
|
193 |
+
model: the model that we apply AdaLoRA to.
|
194 |
+
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, model, peft_config, adapter_name):
|
198 |
+
self.peft_config = peft_config
|
199 |
+
self.adapter_name = adapter_name
|
200 |
+
self.beta1 = peft_config.beta1
|
201 |
+
self.beta2 = peft_config.beta2
|
202 |
+
assert self.beta1 > 0 and self.beta1 < 1
|
203 |
+
assert self.beta2 > 0 and self.beta2 < 1
|
204 |
+
|
205 |
+
self.reset_ipt()
|
206 |
+
self._set_budget_scheduler(model)
|
207 |
+
|
208 |
+
def set_total_step(self, total_step):
|
209 |
+
self.peft_config.total_step = total_step
|
210 |
+
|
211 |
+
def reset_ipt(self):
|
212 |
+
self.ipt = {}
|
213 |
+
self.exp_avg_ipt = {}
|
214 |
+
self.exp_avg_unc = {}
|
215 |
+
|
216 |
+
def _set_budget_scheduler(self, model):
|
217 |
+
self.init_bgt = 0
|
218 |
+
self.name_set = set()
|
219 |
+
for n, p in model.named_parameters():
|
220 |
+
if f"lora_A.{self.adapter_name}" in n:
|
221 |
+
self.init_bgt += p.size(0)
|
222 |
+
self.name_set.add(n.replace("lora_A", "%s"))
|
223 |
+
self.name_set = sorted(self.name_set)
|
224 |
+
# The total final rank budget
|
225 |
+
self.target_bgt = self.peft_config.target_r * len(self.name_set)
|
226 |
+
|
227 |
+
def budget_schedule(self, step: int):
|
228 |
+
tinit = self.peft_config.tinit
|
229 |
+
tfinal = self.peft_config.tfinal
|
230 |
+
total_step = self.peft_config.total_step
|
231 |
+
# Initial warmup
|
232 |
+
if step <= tinit:
|
233 |
+
budget = self.init_bgt
|
234 |
+
mask_ind = False
|
235 |
+
# Final fine-tuning
|
236 |
+
elif step > total_step - tfinal:
|
237 |
+
budget = self.target_bgt
|
238 |
+
mask_ind = True
|
239 |
+
else:
|
240 |
+
# Budget decreasing with a cubic scheduler
|
241 |
+
mul_coeff = 1 - (step - tinit) / (total_step - tfinal - tinit)
|
242 |
+
budget = int((self.init_bgt - self.target_bgt) * (mul_coeff**3) + self.target_bgt)
|
243 |
+
mask_ind = True if step % self.peft_config.deltaT == 0 else False
|
244 |
+
return budget, mask_ind
|
245 |
+
|
246 |
+
def update_ipt(self, model):
|
247 |
+
# Update the sensitivity and uncertainty for every weight
|
248 |
+
for n, p in model.named_parameters():
|
249 |
+
if "lora_" in n and self.adapter_name in n:
|
250 |
+
if n not in self.ipt:
|
251 |
+
self.ipt[n] = torch.zeros_like(p)
|
252 |
+
self.exp_avg_ipt[n] = torch.zeros_like(p)
|
253 |
+
self.exp_avg_unc[n] = torch.zeros_like(p)
|
254 |
+
with torch.no_grad():
|
255 |
+
self.ipt[n] = (p * p.grad).abs().detach()
|
256 |
+
# Sensitivity smoothing
|
257 |
+
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
|
258 |
+
# Uncertainty quantification
|
259 |
+
self.exp_avg_unc[n] = (
|
260 |
+
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
|
261 |
+
)
|
262 |
+
|
263 |
+
def _element_score(self, n):
|
264 |
+
return self.exp_avg_ipt[n] * self.exp_avg_unc[n]
|
265 |
+
|
266 |
+
def _combine_ipt(self, ipt_E, ipt_AB):
|
267 |
+
ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
|
268 |
+
sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1)
|
269 |
+
return sum_ipt
|
270 |
+
|
271 |
+
def mask_to_budget(self, model, budget):
|
272 |
+
value_ipt = {}
|
273 |
+
vector_ipt = {}
|
274 |
+
triplet_ipt = {}
|
275 |
+
# Get the importance score for A, E, B
|
276 |
+
for n, p in model.named_parameters():
|
277 |
+
if f"lora_A.{self.adapter_name}" in n:
|
278 |
+
entry_ipt = self._element_score(n)
|
279 |
+
comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
|
280 |
+
name_m = n.replace("lora_A", "%s")
|
281 |
+
if name_m not in vector_ipt:
|
282 |
+
vector_ipt[name_m] = [comb_ipt]
|
283 |
+
else:
|
284 |
+
vector_ipt[name_m].append(comb_ipt)
|
285 |
+
if f"lora_B.{self.adapter_name}" in n:
|
286 |
+
entry_ipt = self._element_score(n)
|
287 |
+
comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
|
288 |
+
name_m = n.replace("lora_B", "%s")
|
289 |
+
if name_m not in vector_ipt:
|
290 |
+
vector_ipt[name_m] = [comb_ipt]
|
291 |
+
else:
|
292 |
+
vector_ipt[name_m].append(comb_ipt)
|
293 |
+
if f"lora_E.{self.adapter_name}" in n:
|
294 |
+
entry_ipt = self._element_score(n)
|
295 |
+
name_m = n.replace("lora_E", "%s")
|
296 |
+
value_ipt[name_m] = entry_ipt
|
297 |
+
|
298 |
+
all_score = []
|
299 |
+
# Calculate the score for each triplet
|
300 |
+
for name_m in vector_ipt:
|
301 |
+
ipt_E = value_ipt[name_m]
|
302 |
+
ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
|
303 |
+
sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
|
304 |
+
name_E = name_m % "lora_E"
|
305 |
+
triplet_ipt[name_E] = sum_ipt.view(-1, 1)
|
306 |
+
all_score.append(sum_ipt.view(-1))
|
307 |
+
|
308 |
+
# Get the threshold by ranking ipt
|
309 |
+
mask_threshold = torch.kthvalue(
|
310 |
+
torch.cat(all_score),
|
311 |
+
k=self.init_bgt - budget,
|
312 |
+
)[0].item()
|
313 |
+
|
314 |
+
rank_pattern = {}
|
315 |
+
# Mask the unimportant triplets
|
316 |
+
with torch.no_grad():
|
317 |
+
for n, p in model.named_parameters():
|
318 |
+
if f"lora_E.{self.adapter_name}" in n:
|
319 |
+
p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
|
320 |
+
rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
|
321 |
+
return rank_pattern
|
322 |
+
|
323 |
+
def update_and_allocate(self, model, global_step, force_mask=False):
|
324 |
+
# # Update the importance score and allocate the budget
|
325 |
+
if global_step < self.peft_config.total_step - self.peft_config.tfinal:
|
326 |
+
self.update_ipt(model)
|
327 |
+
budget, mask_ind = self.budget_schedule(global_step)
|
328 |
+
# Allocate the budget according to importance scores
|
329 |
+
if mask_ind or force_mask:
|
330 |
+
rank_pattern = self.mask_to_budget(model, budget)
|
331 |
+
else:
|
332 |
+
rank_pattern = None
|
333 |
+
return budget, rank_pattern
|
334 |
+
|
335 |
+
def mask_using_rank_pattern(self, model, rank_pattern):
|
336 |
+
# Mask the unimportant triplets
|
337 |
+
is_adapter_name_truncated = False
|
338 |
+
if self.adapter_name not in next(iter(rank_pattern.keys())):
|
339 |
+
is_adapter_name_truncated = True
|
340 |
+
|
341 |
+
with torch.no_grad():
|
342 |
+
for n, p in model.named_parameters():
|
343 |
+
if f"lora_E.{self.adapter_name}" in n:
|
344 |
+
key = n if not is_adapter_name_truncated else n.replace(f".{self.adapter_name}", "")
|
345 |
+
mask = torch.Tensor(rank_pattern[key]).unsqueeze(-1).to(p.device)
|
346 |
+
p.masked_fill_(~mask.bool(), 0.0)
|
MoRA/peft_mora/tuners/adalora/model.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from transformers.pytorch_utils import Conv1D
|
19 |
+
|
20 |
+
from peft_mora.import_utils import is_bnb_4bit_available, is_bnb_available
|
21 |
+
from peft_mora.tuners.lora import LoraConfig, LoraModel
|
22 |
+
from peft_mora.tuners.tuners_utils import BaseTunerLayer
|
23 |
+
from peft_mora.utils import (
|
24 |
+
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
|
25 |
+
_freeze_adapter,
|
26 |
+
_get_submodules,
|
27 |
+
get_auto_gptq_quant_linear,
|
28 |
+
get_quantization_config,
|
29 |
+
)
|
30 |
+
|
31 |
+
from .gptq import SVDQuantLinear
|
32 |
+
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
|
33 |
+
|
34 |
+
|
35 |
+
class AdaLoraModel(LoraModel):
|
36 |
+
"""
|
37 |
+
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
|
38 |
+
https://openreview.net/forum?id=lq62uWRJjiY
|
39 |
+
|
40 |
+
Args:
|
41 |
+
model ([`transformers.PreTrainedModel`]): The model to be adapted.
|
42 |
+
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
|
43 |
+
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
`torch.nn.Module`: The AdaLora model.
|
47 |
+
|
48 |
+
Example::
|
49 |
+
|
50 |
+
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig
|
51 |
+
>>> config = AdaLoraConfig(
|
52 |
+
peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
|
53 |
+
lora_dropout=0.01,
|
54 |
+
)
|
55 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default")
|
56 |
+
|
57 |
+
**Attributes**:
|
58 |
+
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
|
59 |
+
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
|
60 |
+
"""
|
61 |
+
|
62 |
+
# Note: don't redefine prefix here, it should be inherited from LoraModel
|
63 |
+
|
64 |
+
def __init__(self, model, config, adapter_name):
|
65 |
+
super().__init__(model, config, adapter_name)
|
66 |
+
|
67 |
+
traininable_mode_counter = 0
|
68 |
+
for config in self.peft_config.values():
|
69 |
+
if not config.inference_mode:
|
70 |
+
traininable_mode_counter += 1
|
71 |
+
|
72 |
+
if traininable_mode_counter > 1:
|
73 |
+
raise ValueError(
|
74 |
+
"AdaLoraModel supports only 1 trainable adapter. "
|
75 |
+
"When using multiple adapters, set inference_mode to True for all adapters except the one you want to train."
|
76 |
+
)
|
77 |
+
|
78 |
+
if self.peft_config[adapter_name].inference_mode:
|
79 |
+
_freeze_adapter(self.model, adapter_name)
|
80 |
+
else:
|
81 |
+
self.trainable_adapter_name = adapter_name
|
82 |
+
self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name)
|
83 |
+
|
84 |
+
def _check_new_adapter_config(self, config: LoraConfig) -> None:
|
85 |
+
"""
|
86 |
+
A helper method to check the config when a new adapter is being added.
|
87 |
+
|
88 |
+
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
|
89 |
+
|
90 |
+
"""
|
91 |
+
super()._check_new_adapter_config(config)
|
92 |
+
|
93 |
+
traininable_mode_counter = 0
|
94 |
+
for config_ in self.peft_config.values():
|
95 |
+
if not config_.inference_mode:
|
96 |
+
traininable_mode_counter += 1
|
97 |
+
|
98 |
+
if traininable_mode_counter > 1:
|
99 |
+
raise ValueError(
|
100 |
+
f"{self.__class__.__name__} supports only 1 trainable adapter. "
|
101 |
+
"When using multiple adapters, set inference_mode to True for all adapters except the one "
|
102 |
+
"you want to train."
|
103 |
+
)
|
104 |
+
|
105 |
+
def _create_and_replace(
|
106 |
+
self,
|
107 |
+
lora_config,
|
108 |
+
adapter_name,
|
109 |
+
target,
|
110 |
+
target_name,
|
111 |
+
parent,
|
112 |
+
current_key,
|
113 |
+
):
|
114 |
+
kwargs = {
|
115 |
+
"r": lora_config.init_r,
|
116 |
+
"lora_alpha": lora_config.lora_alpha,
|
117 |
+
"lora_dropout": lora_config.lora_dropout,
|
118 |
+
"fan_in_fan_out": lora_config.fan_in_fan_out,
|
119 |
+
"init_lora_weights": lora_config.init_lora_weights,
|
120 |
+
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
|
121 |
+
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
|
122 |
+
}
|
123 |
+
if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available():
|
124 |
+
raise ImportError(
|
125 |
+
"To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
|
126 |
+
"You can install it with `pip install bitsandbytes`."
|
127 |
+
)
|
128 |
+
|
129 |
+
quantization_config = get_quantization_config(self.model, method="gptq")
|
130 |
+
if quantization_config is not None:
|
131 |
+
kwargs["gptq_quantization_config"] = quantization_config
|
132 |
+
|
133 |
+
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
|
134 |
+
if not isinstance(target, AdaLoraLayer):
|
135 |
+
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
|
136 |
+
if adapter_name != self.active_adapter:
|
137 |
+
# adding an additional adapter: it is not automatically trainable
|
138 |
+
new_module.requires_grad_(False)
|
139 |
+
self._replace_module(parent, target_name, new_module, target)
|
140 |
+
else:
|
141 |
+
target.update_layer(
|
142 |
+
adapter_name,
|
143 |
+
lora_config.init_r,
|
144 |
+
lora_config.lora_alpha,
|
145 |
+
lora_config.lora_dropout,
|
146 |
+
lora_config.init_lora_weights,
|
147 |
+
)
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def _create_new_module(lora_config, adapter_name, target, **kwargs):
|
151 |
+
# avoid eager bnb import
|
152 |
+
if is_bnb_available():
|
153 |
+
import bitsandbytes as bnb
|
154 |
+
|
155 |
+
from .bnb import SVDLinear8bitLt
|
156 |
+
if is_bnb_4bit_available():
|
157 |
+
from .bnb import SVDLinear4bit
|
158 |
+
|
159 |
+
gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
|
160 |
+
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
|
161 |
+
|
162 |
+
loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
|
163 |
+
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
|
164 |
+
|
165 |
+
if isinstance(target, BaseTunerLayer):
|
166 |
+
target_base_layer = target.get_base_layer()
|
167 |
+
else:
|
168 |
+
target_base_layer = target
|
169 |
+
|
170 |
+
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
|
171 |
+
kwargs.update(
|
172 |
+
{
|
173 |
+
"has_fp16_weights": target_base_layer.state.has_fp16_weights,
|
174 |
+
"memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
|
175 |
+
"threshold": target_base_layer.state.threshold,
|
176 |
+
"index": target_base_layer.index,
|
177 |
+
}
|
178 |
+
)
|
179 |
+
new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
|
180 |
+
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
181 |
+
fourbit_kwargs = kwargs.copy()
|
182 |
+
fourbit_kwargs.update(
|
183 |
+
{
|
184 |
+
"compute_dtype": target_base_layer.compute_dtype,
|
185 |
+
"compress_statistics": target_base_layer.weight.compress_statistics,
|
186 |
+
"quant_type": target_base_layer.weight.quant_type,
|
187 |
+
}
|
188 |
+
)
|
189 |
+
new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
|
190 |
+
elif AutoGPTQQuantLinear is not None and isinstance(target, AutoGPTQQuantLinear):
|
191 |
+
new_module = SVDQuantLinear(target, adapter_name, **kwargs)
|
192 |
+
else:
|
193 |
+
if isinstance(target_base_layer, torch.nn.Linear):
|
194 |
+
if kwargs["fan_in_fan_out"]:
|
195 |
+
warnings.warn(
|
196 |
+
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
197 |
+
"Setting fan_in_fan_out to False."
|
198 |
+
)
|
199 |
+
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
|
200 |
+
elif isinstance(target_base_layer, Conv1D):
|
201 |
+
if not kwargs["fan_in_fan_out"]:
|
202 |
+
warnings.warn(
|
203 |
+
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
|
204 |
+
"Setting fan_in_fan_out to True."
|
205 |
+
)
|
206 |
+
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
|
207 |
+
else:
|
208 |
+
raise ValueError(
|
209 |
+
f"Target module {target} is not supported. "
|
210 |
+
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
211 |
+
)
|
212 |
+
new_module = SVDLinear(target, adapter_name, **kwargs)
|
213 |
+
|
214 |
+
return new_module
|
215 |
+
|
216 |
+
@staticmethod
|
217 |
+
def _prepare_adapter_config(peft_config, model_config):
|
218 |
+
if peft_config.target_modules is None:
|
219 |
+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
|
220 |
+
raise ValueError("Please specify `target_modules` in `peft_config`")
|
221 |
+
peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[
|
222 |
+
model_config["model_type"]
|
223 |
+
]
|
224 |
+
return peft_config
|
225 |
+
|
226 |
+
def __getattr__(self, name: str):
|
227 |
+
"""Forward missing attributes to the wrapped module."""
|
228 |
+
try:
|
229 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
230 |
+
except AttributeError:
|
231 |
+
return getattr(self.model, name)
|
232 |
+
|
233 |
+
def forward(self, *args, **kwargs):
|
234 |
+
outputs = self.model.forward(*args, **kwargs)
|
235 |
+
|
236 |
+
if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor):
|
237 |
+
# Calculate the orthogonal regularization
|
238 |
+
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
|
239 |
+
|
240 |
+
if orth_reg_weight <= 0:
|
241 |
+
raise ValueError("orth_reg_weight should be greater than 0. ")
|
242 |
+
|
243 |
+
regu_loss = 0
|
244 |
+
num_param = 0
|
245 |
+
for n, p in self.model.named_parameters():
|
246 |
+
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
|
247 |
+
para_cov = p @ p.T if "lora_A" in n else p.T @ p
|
248 |
+
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
|
249 |
+
I.requires_grad = False
|
250 |
+
num_param += 1
|
251 |
+
regu_loss += torch.norm(para_cov - I, p="fro")
|
252 |
+
if num_param > 0:
|
253 |
+
regu_loss = regu_loss / num_param
|
254 |
+
else:
|
255 |
+
regu_loss = 0
|
256 |
+
outputs.loss += orth_reg_weight * regu_loss
|
257 |
+
return outputs
|
258 |
+
|
259 |
+
def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name):
|
260 |
+
lora_config = self.peft_config[adapter_name]
|
261 |
+
for name, rank_idx in rank_pattern.items():
|
262 |
+
if isinstance(rank_idx, list):
|
263 |
+
rank = sum(rank_idx)
|
264 |
+
elif isinstance(rank_idx, torch.Tensor):
|
265 |
+
rank_idx = rank_idx.view(-1)
|
266 |
+
rank = rank_idx.sum().item()
|
267 |
+
else:
|
268 |
+
raise ValueError("Unexpected type of rank_idx")
|
269 |
+
key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
|
270 |
+
_, target, _ = _get_submodules(self.model, key)
|
271 |
+
lora_E_weights = target.lora_E[adapter_name][rank_idx]
|
272 |
+
lora_A_weights = target.lora_A[adapter_name][rank_idx]
|
273 |
+
lora_B_weights = target.lora_B[adapter_name][:, rank_idx]
|
274 |
+
ranknum = target.ranknum[adapter_name]
|
275 |
+
target.update_layer(
|
276 |
+
adapter_name,
|
277 |
+
rank,
|
278 |
+
lora_config.lora_alpha,
|
279 |
+
lora_config.lora_dropout,
|
280 |
+
lora_config.init_lora_weights,
|
281 |
+
)
|
282 |
+
with torch.no_grad():
|
283 |
+
if rank > 0:
|
284 |
+
target.lora_E[adapter_name].copy_(lora_E_weights)
|
285 |
+
target.lora_A[adapter_name].copy_(lora_A_weights)
|
286 |
+
target.lora_B[adapter_name].copy_(lora_B_weights)
|
287 |
+
# The scaling is exactly as the previous
|
288 |
+
target.ranknum[adapter_name].copy_(ranknum)
|
289 |
+
|
290 |
+
def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name):
|
291 |
+
for name, rank_idx in rank_pattern.items():
|
292 |
+
rank = sum(rank_idx)
|
293 |
+
prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
|
294 |
+
for layer in ["lora_E", "lora_A", "lora_B"]:
|
295 |
+
key = f"base_model.model.{prefix}.{layer}.{adapter_name}"
|
296 |
+
if layer != "lora_B":
|
297 |
+
state_dict[key] = (
|
298 |
+
state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key]
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
state_dict[key] = (
|
302 |
+
state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key]
|
303 |
+
)
|
304 |
+
return state_dict
|
305 |
+
|
306 |
+
def update_and_allocate(self, global_step):
|
307 |
+
"""
|
308 |
+
This method updates Adalora budget and mask.
|
309 |
+
|
310 |
+
This should be called in every training step after `loss.backward()` and before `zero_grad()`.
|
311 |
+
|
312 |
+
`tinit`, `tfinal` and `deltaT` are handled with in the method.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
global_step (`int`): The current training step, it is used to calculate adalora budget.
|
316 |
+
|
317 |
+
Example:
|
318 |
+
|
319 |
+
```python
|
320 |
+
>>> loss = model(**input).loss
|
321 |
+
>>> loss.backward()
|
322 |
+
>>> optimizer.step()
|
323 |
+
>>> model.base_model.update_and_allocate(i_step)
|
324 |
+
>>> optimizer.zero_grad()
|
325 |
+
```
|
326 |
+
"""
|
327 |
+
lora_config = self.peft_config[self.trainable_adapter_name]
|
328 |
+
# Update the importance score and allocate the budget
|
329 |
+
if global_step < lora_config.total_step - lora_config.tfinal:
|
330 |
+
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step)
|
331 |
+
if rank_pattern:
|
332 |
+
lora_config.rank_pattern = rank_pattern
|
333 |
+
# Finalize the budget allocation
|
334 |
+
elif global_step == lora_config.total_step - lora_config.tfinal:
|
335 |
+
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True)
|
336 |
+
# for some reason, this freezes the trainable parameters and nothing gets updates
|
337 |
+
# self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name)
|
338 |
+
lora_config.rank_pattern = rank_pattern
|
339 |
+
self.rankallocator.reset_ipt()
|
340 |
+
# Currently using inefficient way to mask the unimportant weights using the rank pattern
|
341 |
+
# due to problem mentioned above
|
342 |
+
elif global_step > lora_config.total_step - lora_config.tfinal:
|
343 |
+
self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern)
|
344 |
+
# Pass the function and do forward propagation
|
345 |
+
else:
|
346 |
+
return None
|
MoRA/peft_mora/tuners/adaption_prompt/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from .config import AdaptionPromptConfig
|
15 |
+
from .layer import AdaptedAttention
|
16 |
+
from .model import AdaptionPromptModel
|
17 |
+
|
18 |
+
|
19 |
+
__all__ = ["AdaptionPromptConfig", "AdaptedAttention", "AdaptionPromptModel"]
|
MoRA/peft_mora/tuners/adaption_prompt/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (390 Bytes). View file
|
|
MoRA/peft_mora/tuners/adaption_prompt/__pycache__/config.cpython-312.pyc
ADDED
Binary file (2.57 kB). View file
|
|
MoRA/peft_mora/tuners/adaption_prompt/__pycache__/layer.cpython-312.pyc
ADDED
Binary file (5.89 kB). View file
|
|
MoRA/peft_mora/tuners/adaption_prompt/__pycache__/model.cpython-312.pyc
ADDED
Binary file (8.35 kB). View file
|
|
MoRA/peft_mora/tuners/adaption_prompt/__pycache__/utils.cpython-312.pyc
ADDED
Binary file (5.71 kB). View file
|
|
MoRA/peft_mora/tuners/adaption_prompt/config.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import namedtuple
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
|
18 |
+
from peft_mora.config import PeftConfig
|
19 |
+
from peft_mora.utils import PeftType
|
20 |
+
|
21 |
+
from .utils import llama_compute_query_states
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class AdaptionPromptConfig(PeftConfig):
|
26 |
+
"""Stores the configuration of an [`AdaptionPromptModel`]."""
|
27 |
+
|
28 |
+
target_modules: str = field(
|
29 |
+
default=None, metadata={"help": "Name of the attention submodules to insert adaption prompts into."}
|
30 |
+
)
|
31 |
+
adapter_len: int = field(default=None, metadata={"help": "Number of adapter tokens to insert"})
|
32 |
+
adapter_layers: int = field(default=None, metadata={"help": "Number of adapter layers (from the top)"})
|
33 |
+
|
34 |
+
def __post_init__(self):
|
35 |
+
self.peft_type = PeftType.ADAPTION_PROMPT
|
36 |
+
|
37 |
+
@property
|
38 |
+
def is_adaption_prompt(self) -> bool:
|
39 |
+
"""Return True if this is an adaption prompt config."""
|
40 |
+
return True
|
41 |
+
|
42 |
+
|
43 |
+
# Contains the config that is specific to a transformers model type.
|
44 |
+
ModelTypeConfig = namedtuple(
|
45 |
+
"ModelTypeConfig", ["compute_query_states", "target_modules", "k_proj_layer", "v_proj_layer", "o_proj_layer"]
|
46 |
+
)
|
47 |
+
|
48 |
+
# Mapping of transformers model types to their specific configuration.
|
49 |
+
TRANSFORMERS_MODEL_CONFIG = {
|
50 |
+
"llama": ModelTypeConfig(
|
51 |
+
compute_query_states=llama_compute_query_states,
|
52 |
+
target_modules="self_attn",
|
53 |
+
k_proj_layer="k_proj",
|
54 |
+
v_proj_layer="v_proj",
|
55 |
+
o_proj_layer="o_proj",
|
56 |
+
),
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
def prepare_config(
|
61 |
+
peft_config: AdaptionPromptConfig,
|
62 |
+
model,
|
63 |
+
) -> AdaptionPromptConfig:
|
64 |
+
"""Prepare the config based on the llama model type."""
|
65 |
+
if model.config.model_type not in TRANSFORMERS_MODEL_CONFIG:
|
66 |
+
raise ValueError("Unsupported model type for adaption prompt: '{model.config.model_type}'.")
|
67 |
+
|
68 |
+
model_config = TRANSFORMERS_MODEL_CONFIG[model.config.model_type]
|
69 |
+
|
70 |
+
if peft_config.target_modules is None:
|
71 |
+
peft_config.target_modules = model_config.target_modules
|
72 |
+
|
73 |
+
return peft_config
|
MoRA/peft_mora/tuners/adaption_prompt/layer.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
from .config import TRANSFORMERS_MODEL_CONFIG
|
22 |
+
|
23 |
+
|
24 |
+
class AdaptedAttention(nn.Module):
|
25 |
+
"""This module wraps a LLamaAttention module and injects adaption prompts."""
|
26 |
+
|
27 |
+
def __init__(self, model_type: str, adapter_len: int, model):
|
28 |
+
"""
|
29 |
+
Initialize object.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_type: The transformer model type. This is used to retrieve the right method to
|
33 |
+
compute query states.
|
34 |
+
adapter_len: The length of the adaption prompt to insert.
|
35 |
+
model: The original transformer attention module that is being wrapped.
|
36 |
+
"""
|
37 |
+
assert not isinstance(model, AdaptedAttention)
|
38 |
+
super().__init__()
|
39 |
+
self.model_type = model_type
|
40 |
+
self.model = model
|
41 |
+
self.adapter_len = adapter_len
|
42 |
+
# Assume all parameters of the attention model we are wrapping are on the same device.
|
43 |
+
device = next(model.parameters()).device
|
44 |
+
# Don't think this was specified in the paper, but we follow the official repo which used an Embedding
|
45 |
+
# which initializes the tokens with standard normal values.
|
46 |
+
# https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234
|
47 |
+
# (bsz, adapter_len, hidden_size)
|
48 |
+
target_dtype = (
|
49 |
+
model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32
|
50 |
+
)
|
51 |
+
self.adaption_prompt = nn.Parameter(
|
52 |
+
torch.empty(1, adapter_len, self.model.hidden_size, device=device, dtype=target_dtype).normal_()
|
53 |
+
)
|
54 |
+
# Initialize the gate to 0 as this is "zero-init".
|
55 |
+
self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype))
|
56 |
+
|
57 |
+
def forward(self, **kwargs):
|
58 |
+
"""
|
59 |
+
Forward pass for the adapter which wraps the original LlamaAttention module.
|
60 |
+
|
61 |
+
"Official" paper implementation:
|
62 |
+
https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L141
|
63 |
+
|
64 |
+
Args:
|
65 |
+
kwargs: See the original LlamaAttention module.
|
66 |
+
"""
|
67 |
+
if kwargs.get("output_attention", False):
|
68 |
+
raise NotImplementedError("output_attention is not currently supported.")
|
69 |
+
|
70 |
+
output, _, past_key_value = self.model(**kwargs)
|
71 |
+
bsz = output.shape[0]
|
72 |
+
q_len = output.shape[1]
|
73 |
+
embed_dim = output.shape[2]
|
74 |
+
k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
|
75 |
+
v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer
|
76 |
+
o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer
|
77 |
+
|
78 |
+
if k_proj_layer == v_proj_layer:
|
79 |
+
_, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2)
|
80 |
+
else:
|
81 |
+
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
|
82 |
+
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)
|
83 |
+
# (bsz, num_heads, adapter_len, head_dim)
|
84 |
+
adapter_k = (
|
85 |
+
key.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
|
86 |
+
.repeat(bsz, 1, 1, 1)
|
87 |
+
.transpose(1, 2)
|
88 |
+
)
|
89 |
+
# (bsz, num_heads, adapter_len, head_dim)
|
90 |
+
adapter_v = (
|
91 |
+
value.view(1, self.adapter_len, self.model.num_heads, self.model.head_dim)
|
92 |
+
.repeat(bsz, 1, 1, 1)
|
93 |
+
.transpose(1, 2)
|
94 |
+
)
|
95 |
+
|
96 |
+
# Recompute query states.
|
97 |
+
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
|
98 |
+
# (bsz, num_heads, q_len, head_dim)
|
99 |
+
query_states = compute_query_states(model=self.model, **kwargs)
|
100 |
+
|
101 |
+
previous_dtype = query_states.dtype
|
102 |
+
# (bsz, num_heads, q_len, adapter_len)
|
103 |
+
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
|
104 |
+
self.model.head_dim
|
105 |
+
)
|
106 |
+
# Upcast attention to fp32
|
107 |
+
# (bsz, num_heads, q_len, adapter_len)
|
108 |
+
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
|
109 |
+
# (bsz, q_len, num_heads * head_dim)
|
110 |
+
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)
|
111 |
+
# (bsz, q_len, hidden_size)
|
112 |
+
if o_proj_layer is not None:
|
113 |
+
adapter_output = getattr(self.model, o_proj_layer)(adapter_output)
|
114 |
+
|
115 |
+
# Add adaption prompt output to original output.
|
116 |
+
output = output + adapter_output
|
117 |
+
|
118 |
+
# Restore original dtype.
|
119 |
+
output = output.to(previous_dtype)
|
120 |
+
return output, None, past_key_value
|
MoRA/peft_mora/tuners/adaption_prompt/model.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Dict, List
|
16 |
+
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from peft_mora.utils import _freeze_adapter, _get_submodules
|
20 |
+
|
21 |
+
from .config import AdaptionPromptConfig, prepare_config
|
22 |
+
from .layer import AdaptedAttention
|
23 |
+
from .utils import is_adaption_prompt_trainable
|
24 |
+
|
25 |
+
|
26 |
+
class AdaptionPromptModel(nn.Module):
|
27 |
+
"""
|
28 |
+
Implements adaption prompts as described in https://arxiv.org/pdf/2303.16199.pdf.
|
29 |
+
|
30 |
+
The top L attention modules are replaced with AdaptedAttention modules that wrap the original ones, but insert
|
31 |
+
trainable prompts with gates (for zero init).
|
32 |
+
|
33 |
+
Notes on the multi-adapter pattern:
|
34 |
+
- We store the states of different adapters by keeping a dictionary of AdaptedAttention modules indexed by adapter
|
35 |
+
name.
|
36 |
+
- Every time we switch adapters, we remove the modules of the currently active adapter from the model, store them
|
37 |
+
in the dictionary, and replace them with the modules of the new adapter.
|
38 |
+
- To avoid duplicated and potentially inconsistent state, the currently active adapter is always removed from the
|
39 |
+
dictionary.
|
40 |
+
- Disabling the adapter would also result in the modules being removed from the model.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, model, configs: Dict, adapter_name: str):
|
44 |
+
super().__init__()
|
45 |
+
self.model = model
|
46 |
+
# Store adapter configs by name.
|
47 |
+
self.peft_config: Dict[str, AdaptionPromptConfig] = {}
|
48 |
+
# Store lists of the parents of the affected attention modules by adapter name.
|
49 |
+
# We keep references to the parents so we can swap the adapters in-and-out of the model.
|
50 |
+
self._parents: Dict[str, List[nn.Module]] = {}
|
51 |
+
# Store lists of cached AdaptedAttention modules by name.
|
52 |
+
self._cached_adapters: Dict[str, List] = {}
|
53 |
+
# The name of the currently active adapter.
|
54 |
+
self._active_adapter = None
|
55 |
+
# Whether the adapter is enabled.
|
56 |
+
self._enabled = True
|
57 |
+
self.forward = self.model.forward
|
58 |
+
self.add_adapter(adapter_name, configs[adapter_name])
|
59 |
+
self._mark_only_adaption_prompts_as_trainable(self.model)
|
60 |
+
|
61 |
+
def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None:
|
62 |
+
"""Add an adapter with the given name and config."""
|
63 |
+
config = prepare_config(config, self.model)
|
64 |
+
if adapter_name in self.peft_config:
|
65 |
+
raise ValueError(f"Adapter with name '{adapter_name}' already exists.")
|
66 |
+
|
67 |
+
parents = []
|
68 |
+
for name, _ in self.model.named_modules():
|
69 |
+
if name.endswith(config.target_modules):
|
70 |
+
par, _, _ = _get_submodules(self.model, name)
|
71 |
+
parents.append(par)
|
72 |
+
if len(parents) < config.adapter_layers:
|
73 |
+
raise ValueError(
|
74 |
+
f"Config specifies more adapter layers '{config.adapter_layers}'"
|
75 |
+
f" than the model has '{len(parents)}'."
|
76 |
+
)
|
77 |
+
# Note that if the target modules are not in Sequential, ModuleList, or
|
78 |
+
# some other PyTorch ordered container, the behavior is undefined as we
|
79 |
+
# assume here that the order of the modules is the same as the order of
|
80 |
+
# the transformer decoder layers.
|
81 |
+
parents = parents[-config.adapter_layers :]
|
82 |
+
self._parents[adapter_name] = parents
|
83 |
+
|
84 |
+
# It is only None during initialization.
|
85 |
+
# If it is disabled, we don't have to remove the modules.
|
86 |
+
if self._active_adapter is not None and self._enabled:
|
87 |
+
self._remove_adapted_attentions(self._active_adapter)
|
88 |
+
self._active_adapter = adapter_name
|
89 |
+
self.peft_config[adapter_name] = config
|
90 |
+
self._create_adapted_attentions(config, parents)
|
91 |
+
if not self._enabled:
|
92 |
+
self._remove_adapted_attentions(self._active_adapter)
|
93 |
+
|
94 |
+
if config.inference_mode:
|
95 |
+
_freeze_adapter(self.model, adapter_name)
|
96 |
+
|
97 |
+
def set_adapter(self, adapter_name: str) -> None:
|
98 |
+
"""Set the model to use the adapter with the given name."""
|
99 |
+
if self._active_adapter == adapter_name:
|
100 |
+
return
|
101 |
+
if adapter_name not in self.peft_config:
|
102 |
+
raise ValueError(f"Adapter with name '{adapter_name}' does not exist.")
|
103 |
+
|
104 |
+
if self._enabled:
|
105 |
+
self._remove_adapted_attentions(self._active_adapter)
|
106 |
+
self._set_adapted_attentions(adapter_name)
|
107 |
+
|
108 |
+
self._active_adapter = adapter_name
|
109 |
+
|
110 |
+
def enable_adapter_layers(self):
|
111 |
+
"""Enable adapter layers by swapping in cached AdaptedAttention modules."""
|
112 |
+
self._enabled = True
|
113 |
+
self._set_adapted_attentions(self._active_adapter)
|
114 |
+
|
115 |
+
def disable_adapter_layers(self):
|
116 |
+
"""Disable adapter layers by swapping out AdaptedAttention modules."""
|
117 |
+
self._enabled = False
|
118 |
+
self._remove_adapted_attentions(self._active_adapter)
|
119 |
+
|
120 |
+
def _create_adapted_attentions(self, config: AdaptionPromptConfig, parents: List[nn.Module]) -> None:
|
121 |
+
"""Wrap LlamaAttention modules with newly created AdaptedAttention modules."""
|
122 |
+
for par in parents:
|
123 |
+
attn = AdaptedAttention(
|
124 |
+
model_type=self.model.config.model_type,
|
125 |
+
adapter_len=config.adapter_len,
|
126 |
+
model=getattr(par, config.target_modules),
|
127 |
+
)
|
128 |
+
setattr(par, config.target_modules, attn)
|
129 |
+
|
130 |
+
def _set_adapted_attentions(self, adapter_name: str) -> None:
|
131 |
+
"""Replace LlamaAttention modules with cached AdaptedAttention modules."""
|
132 |
+
cached = self._cached_adapters[adapter_name]
|
133 |
+
del self._cached_adapters[adapter_name]
|
134 |
+
config = self.peft_config[adapter_name]
|
135 |
+
for i, par in enumerate(self._parents[adapter_name]):
|
136 |
+
setattr(par, config.target_modules, cached[i])
|
137 |
+
|
138 |
+
def _remove_adapted_attentions(self, adapter_name: str) -> None:
|
139 |
+
"""Remove AdaptedAttention modules from the model and store them in the cache."""
|
140 |
+
config = self.peft_config[adapter_name]
|
141 |
+
adapted_attentions = []
|
142 |
+
for par in self._parents[adapter_name]:
|
143 |
+
attn = getattr(par, config.target_modules)
|
144 |
+
adapted_attentions.append(attn)
|
145 |
+
setattr(par, config.target_modules, attn.model)
|
146 |
+
self._cached_adapters[adapter_name] = adapted_attentions
|
147 |
+
|
148 |
+
def _mark_only_adaption_prompts_as_trainable(self, model: nn.Module) -> None:
|
149 |
+
"""Freeze all parameters of the model except the adaption prompts."""
|
150 |
+
for n, p in model.named_parameters():
|
151 |
+
if not is_adaption_prompt_trainable(n):
|
152 |
+
p.requires_grad = False
|
153 |
+
|
154 |
+
def __getattr__(self, name: str):
|
155 |
+
"""Forward missing attributes to the wrapped module."""
|
156 |
+
try:
|
157 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
158 |
+
except AttributeError:
|
159 |
+
# This is necessary as e.g. causal models have various methods that we
|
160 |
+
# don't want to re-implement here.
|
161 |
+
return getattr(self.model, name)
|
MoRA/peft_mora/tuners/adaption_prompt/utils.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import inspect
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
def llama_rotate_half(x: torch.Tensor) -> torch.Tensor:
|
21 |
+
"""
|
22 |
+
Rotate half the hidden dims of the input.
|
23 |
+
|
24 |
+
This function was duplicated verbatim from:
|
25 |
+
https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L126
|
26 |
+
|
27 |
+
This was done to eliminate the Llama transformers implementation as a dependency of this file. Note that some other
|
28 |
+
functions were also adapted from the transformers implementation but were modified.
|
29 |
+
"""
|
30 |
+
x1 = x[..., : x.shape[-1] // 2]
|
31 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
32 |
+
return torch.cat((-x2, x1), dim=-1)
|
33 |
+
|
34 |
+
|
35 |
+
def llama_apply_rotary_pos_emb(q, cos, sin, position_ids):
|
36 |
+
"""
|
37 |
+
Apply rotary position embedding to query states in the Llama model.
|
38 |
+
|
39 |
+
This function was adapted from:
|
40 |
+
https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L133
|
41 |
+
|
42 |
+
It was modified to remove unnecessary processing of key states. The method is compatible with transformers <=
|
43 |
+
4.34.2 and also with the latest version (>=4.35).
|
44 |
+
"""
|
45 |
+
# In previous transformers version cos/sin cached had a shape of 4D
|
46 |
+
if len(cos.shape) == 4:
|
47 |
+
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
48 |
+
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
49 |
+
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
50 |
+
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
51 |
+
# In the new version, it is 2D so we fall back to the new implementation
|
52 |
+
# https://github.com/huggingface/transformers/blame/eef7ea98c31a333bacdc7ae7a2372bde772be8e4/src/transformers/models/llama/modeling_llama.py#L222-L226
|
53 |
+
else:
|
54 |
+
cos = cos[position_ids].unsqueeze(1)
|
55 |
+
sin = sin[position_ids].unsqueeze(1)
|
56 |
+
q_embed = (q * cos) + (llama_rotate_half(q) * sin)
|
57 |
+
return q_embed
|
58 |
+
|
59 |
+
|
60 |
+
def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
|
61 |
+
"""
|
62 |
+
Compute query states for Llama models specifically. They need to be recomputed as the forward() method of the
|
63 |
+
original LlamaModel in the transformers library does not return them. See the related discussion in the PR:
|
64 |
+
https://github.com/huggingface/peft/pull/268
|
65 |
+
"""
|
66 |
+
hidden_states = kwargs.get("hidden_states")
|
67 |
+
position_ids = kwargs.get("position_ids")
|
68 |
+
past_key_value = kwargs.get("past_key_value")
|
69 |
+
bsz, q_len, _ = hidden_states.size()
|
70 |
+
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
|
71 |
+
value_states = model.v_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
|
72 |
+
seq_len = q_len
|
73 |
+
|
74 |
+
if past_key_value is not None:
|
75 |
+
if isinstance(past_key_value, tuple):
|
76 |
+
# for transformers <= 4.35
|
77 |
+
seq_len += past_key_value[0].shape[-2]
|
78 |
+
else:
|
79 |
+
# since transformers 4.36, this is a DynamicCache instance
|
80 |
+
seq_len += past_key_value.get_seq_length(model.layer_idx)
|
81 |
+
|
82 |
+
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
|
83 |
+
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
|
84 |
+
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
|
85 |
+
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
|
86 |
+
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
|
87 |
+
|
88 |
+
past_seen_tokens = 0
|
89 |
+
if position_ids is None:
|
90 |
+
# Compute position_ids, since they are required for transformers > 4.37.2
|
91 |
+
if past_key_value is None:
|
92 |
+
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
|
93 |
+
else:
|
94 |
+
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
|
95 |
+
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
|
96 |
+
position_ids = new_cache_positions.unsqueeze(0)
|
97 |
+
|
98 |
+
cos, sin = model.rotary_emb(value_states, seq_len=q_len + past_seen_tokens, position_ids=position_ids)
|
99 |
+
|
100 |
+
# For batched inference unsqueeze it on the correct dim
|
101 |
+
# since: https://github.com/huggingface/transformers/pull/29109
|
102 |
+
if len(cos.shape) == 3:
|
103 |
+
cos = cos.unsqueeze(1)
|
104 |
+
sin = sin.unsqueeze(1)
|
105 |
+
|
106 |
+
return (query_states * cos) + (llama_rotate_half(query_states) * sin)
|
107 |
+
|
108 |
+
|
109 |
+
def is_adaption_prompt_trainable(params: str) -> bool:
|
110 |
+
"""Return True if module is trainable under adaption prompt fine-tuning."""
|
111 |
+
return params.split(".")[-1].startswith("adaption_")
|
MoRA/peft_mora/tuners/ia3/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from peft_mora.import_utils import is_bnb_4bit_available, is_bnb_available
|
16 |
+
|
17 |
+
from .config import IA3Config
|
18 |
+
from .layer import Conv2d, IA3Layer, Linear
|
19 |
+
from .model import IA3Model
|
20 |
+
|
21 |
+
|
22 |
+
__all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"]
|
23 |
+
|
24 |
+
|
25 |
+
def __getattr__(name):
|
26 |
+
if (name == "Linear8bitLt") and is_bnb_available():
|
27 |
+
from .bnb import Linear8bitLt
|
28 |
+
|
29 |
+
return Linear8bitLt
|
30 |
+
|
31 |
+
if (name == "Linear4bit") and is_bnb_4bit_available():
|
32 |
+
from .bnb import Linear4bit
|
33 |
+
|
34 |
+
return Linear4bit
|
35 |
+
|
36 |
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
MoRA/peft_mora/tuners/ia3/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (937 Bytes). View file
|
|
MoRA/peft_mora/tuners/ia3/__pycache__/config.cpython-312.pyc
ADDED
Binary file (5.09 kB). View file
|
|
MoRA/peft_mora/tuners/ia3/__pycache__/layer.cpython-312.pyc
ADDED
Binary file (15.8 kB). View file
|
|
MoRA/peft_mora/tuners/ia3/__pycache__/model.cpython-312.pyc
ADDED
Binary file (18.4 kB). View file
|
|