|
<! |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
|
|
|
|
# Quick tour |
|
|
|
Let's have a look at π€ PEFT's main features and learn how to set up a `PeftModel` and train it with π€ Accelerate's DeepSpeed integration and use it for inference. |
|
|
|
## Main use |
|
|
|
To use π€ PEFT in your script: |
|
|
|
1. Each PEFT method is defined by a `PeftConfig` object. |
|
|
|
Create a `PeftConfig` object corresponding to your PEFT method (see the [Configuration](package_reference/config) reference for more details) and [`TaskType`], the type of task you're training your model for. |
|
This example trains the [`bigscience/mt0-large`](https://huggingface.co/bigscience/mt0-large) model with the Low-Rank Adaptation of Large Language Models (LoRA) method. Load the `LoRAConfig`, and specify the `task_type` for sequence-to-sequence language modeling. |
|
|
|
```python |
|
from peft import LoraConfig, TaskType |
|
|
|
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) |
|
``` |
|
|
|
2. Load the base model you want to fine-tune. |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM |
|
|
|
model_name_or_path = "bigscience/mt0-large" |
|
tokenizer_name_or_path = "bigscience/mt0-large" |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path) |
|
``` |
|
|
|
3. Preprocess your model if you use [`bitsandbytes`](https://github.com/TimDettmers/bitsandbytes) for `int8` quantized training; otherwise, skip this step. |
|
|
|
```python |
|
from peft import prepare_model_for_int8_training |
|
|
|
model = prepare_model_for_int8_training(model) |
|
``` |
|
|
|
4. Wrap your model in the `PeftModel` object using the `get_peft_model` function. Also, check the number of trainable parameters of your model. |
|
|
|
```python |
|
from peft import get_peft_model |
|
|
|
model = get_peft_model(model, peft_config) |
|
model.print_trainable_parameters() |
|
# output: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282 |
|
``` |
|
|
|
5. Voila π! Now, train the model using the π€ Transformers Trainer API, π€ Accelerate, or any custom PyTroch training loop (take a look at the end-to-end [example](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq.ipynb) of training [`bigscience/mt0-large`](https://huggingface.co/bigscience/mt0-large)). |
|
|
|
### Saving/loading a model |
|
|
|
1. Save your model using the `save_pretrained` function. |
|
|
|
```python |
|
model.save_pretrained("output_dir") |
|
# model.push_to_hub("my_awesome_peft_model") also works |
|
``` |
|
|
|
This only saves the incremental PEFT weights that were trained. |
|
For example, [smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM](https://huggingface.co/smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM) is a `bigscience/T0_3B`model finetuned with LoRA on the [`twitter_complaints`](https://huggingface.co/datasets/ought/raft/viewer/twitter_complaints/train) RAFT dataset. |
|
Notice that it only contains 2 files: `adapter_config.json` and `adapter_model.bin`, with the latter being just 19MB. |
|
|
|
2. Load your model using the `from_pretrained` function. |
|
|
|
```diff |
|
from transformers import AutoModelForSeq2SeqLM |
|
+ from peft import PeftModel, PeftConfig |
|
|
|
+ peft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM" |
|
+ config = PeftConfig.from_pretrained(peft_model_id) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path) |
|
+ model = PeftModel.from_pretrained(model, peft_model_id) |
|
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
|
|
model = model.to(device) |
|
model.eval() |
|
inputs = tokenizer("Tweet text : @HondaCustSvc Your customer service has been horrible during the recall process. I will never purchase a Honda again. Label :", return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), max_new_tokens=10) |
|
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]) |
|
# 'complaint' |
|
``` |
|
|
|
## Launching your distributed script |
|
|
|
PEFT models work with π€ Accelerate out of the box. |
|
You can use π€ Accelerate for distributed training on various hardware such as GPUs, or Apple Silicon devices during training, and for inference on consumer hardware with fewer resources. |
|
|
|
### Train with π€ Accelerate's DeepSpeed integration |
|
|
|
You'll need DeepSpeed version `v0.8.0` for this example. Feel free to check out the full example [script](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq_accelerate_ds_zero3_offload.py) for more details! |
|
|
|
1. Run `accelerate config |
|
Below are the contents of the config file. |
|
```yaml |
|
compute_environment: LOCAL_MACHINE |
|
deepspeed_config: |
|
gradient_accumulation_steps: 1 |
|
gradient_clipping: 1.0 |
|
offload_optimizer_device: cpu |
|
offload_param_device: cpu |
|
zero3_init_flag: true |
|
zero3_save_16bit_model: true |
|
zero_stage: 3 |
|
distributed_type: DEEPSPEED |
|
downcast_bf16: 'no' |
|
dynamo_backend: 'NO' |
|
fsdp_config: {} |
|
machine_rank: 0 |
|
main_training_function: main |
|
megatron_lm_config: {} |
|
mixed_precision: 'no' |
|
num_machines: 1 |
|
num_processes: 1 |
|
rdzv_backend: static |
|
same_network: true |
|
use_cpu: false |
|
``` |
|
2. Run the following command to launch the example script: |
|
```bash |
|
accelerate launch |
|
``` |
|
|
|
You'll see some output logs that look like this: |
|
```bash |
|
GPU Memory before entering the train : 1916 |
|
GPU Memory consumed at the end of the train (end-begin): 66 |
|
GPU Peak Memory consumed during the train (max-begin): 7488 |
|
GPU Total Peak Memory consumed during the train (max): 9404 |
|
CPU Memory before entering the train : 19411 |
|
CPU Memory consumed at the end of the train (end-begin): 0 |
|
CPU Peak Memory consumed during the train (max-begin): 0 |
|
CPU Total Peak Memory consumed during the train (max): 19411 |
|
epoch=4: train_ppl=tensor(1.0705, device='cuda:0') train_epoch_loss=tensor(0.0681, device='cuda:0') |
|
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 7/7 [00:27<00:00, 3.92s/it] |
|
GPU Memory before entering the eval : 1982 |
|
GPU Memory consumed at the end of the eval (end-begin): -66 |
|
GPU Peak Memory consumed during the eval (max-begin): 672 |
|
GPU Total Peak Memory consumed during the eval (max): 2654 |
|
CPU Memory before entering the eval : 19411 |
|
CPU Memory consumed at the end of the eval (end-begin): 0 |
|
CPU Peak Memory consumed during the eval (max-begin): 0 |
|
CPU Total Peak Memory consumed during the eval (max): 19411 |
|
accuracy=100.0 |
|
eval_preds[:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint'] |
|
dataset['train'][label_column][:10]=['no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint', 'no complaint', 'no complaint', 'complaint', 'complaint', 'no complaint'] |
|
``` |
|
|
|
### Inference with π€ Accelerate's Big Model Inference |
|
An example is provided in `~examples/causal_language_modeling/peft_lora_clm_accelerate_big_model_inference.ipynb`. |
|
|
|
## Model Support matrix |
|
|
|
### Causal Language Modeling |
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| GPT-2 | β
| β
| β
| β
| |
|
| Bloom | β
| β
| β
| β
| |
|
| OPT | β
| β
| β
| β
| |
|
| GPT-Neo | β
| β
| β
| β
| |
|
| GPT-J | β
| β
| β
| β
| |
|
| GPT-NeoX-20B | β
| β
| β
| β
| |
|
| LLaMA | β
| β
| β
| β
| |
|
| ChatGLM | β
| β
| β
| β
| |
|
|
|
### Conditional Generation |
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| T5 | β
| β
| β
| β
| |
|
| BART | β
| β
| β
| β
| |
|
|
|
### Sequence Classification |
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| BERT | β
| β
| β
| β
| |
|
| RoBERTa | β
| β
| β
| β
| |
|
| GPT-2 | β
| β
| β
| β
| |
|
| Bloom | β
| β
| β
| β
| |
|
| OPT | β
| β
| β
| β
| |
|
| GPT-Neo | β
| β
| β
| β
| |
|
| GPT-J | β
| β
| β
| β
| |
|
| Deberta | β
| | β
| β
| |
|
| Deberta-v2 | β
| | β
| β
| |
|
|
|
### Token Classification |
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| BERT | β
| β
| | | |
|
| RoBERTa | β
| β
| | | |
|
| GPT-2 | β
| β
| | | |
|
| Bloom | β
| β
| | | |
|
| OPT | β
| β
| | | |
|
| GPT-Neo | β
| β
| | | |
|
| GPT-J | β
| β
| | | |
|
| Deberta | β
| | | | |
|
| Deberta-v2 | β
| | | | |
|
|
|
### Text-to-Image Generation |
|
|
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| Stable Diffusion | β
| | | | |
|
|
|
|
|
### Image Classification |
|
|
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| ViT | β
| | | | |
|
| Swin | β
| | | | |
|
|
|
___Note that we have tested LoRA for [ViT](https://huggingface.co/docs/transformers/model_doc/vit) and [Swin](https://huggingface.co/docs/transformers/model_doc/swin) for fine-tuning on image classification. However, it should be possible to use LoRA for any compatible model [provided](https://huggingface.co/models?pipeline_tag=image-classification&sort=downloads&search=vit) by π€ Transformers. Check out the respective |
|
examples to learn more. If you run into problems, please open an issue.___ |
|
|
|
The same principle applies to our [segmentation models](https://huggingface.co/models?pipeline_tag=image-segmentation&sort=downloads) as well. |
|
|
|
### Semantic Segmentation |
|
|
|
| Model | LoRA | Prefix Tuning | P-Tuning | Prompt Tuning | |
|
| |
|
| SegFormer | β
| | | | |
|
|
|
|
|
## Other caveats |
|
|
|
1. Below is an example of using PyTorch FSDP for training. However, it doesn't lead to |
|
any GPU memory savings. Please refer to issue [[FSDP] FSDP with CPU offload consumes 1.65X more GPU memory when training models with most of the params frozen](https://github.com/pytorch/pytorch/issues/91165). |
|
|
|
```python |
|
from peft.utils.other import fsdp_auto_wrap_policy |
|
|
|
|
|
if os.environ.get("ACCELERATE_USE_FSDP", None) is not None: |
|
accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model) |
|
|
|
model = accelerator.prepare(model) |
|
``` |
|
|
|
Example of parameter efficient tuning with [`mt0-xxl`](https://huggingface.co/bigscience/mt0-xxl) base model using π€ Accelerate is provided in `~examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py`. |
|
a. First, run `accelerate config --config_file fsdp_config.yaml` and answer the questionnaire. |
|
Below are the contents of the config file. |
|
```yaml |
|
command_file: null |
|
commands: null |
|
compute_environment: LOCAL_MACHINE |
|
deepspeed_config: {} |
|
distributed_type: FSDP |
|
downcast_bf16: 'no' |
|
dynamo_backend: 'NO' |
|
fsdp_config: |
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP |
|
fsdp_backward_prefetch_policy: BACKWARD_PRE |
|
fsdp_offload_params: true |
|
fsdp_sharding_strategy: 1 |
|
fsdp_state_dict_type: FULL_STATE_DICT |
|
fsdp_transformer_layer_cls_to_wrap: T5Block |
|
gpu_ids: null |
|
machine_rank: 0 |
|
main_process_ip: null |
|
main_process_port: null |
|
main_training_function: main |
|
megatron_lm_config: {} |
|
mixed_precision: 'no' |
|
num_machines: 1 |
|
num_processes: 2 |
|
rdzv_backend: static |
|
same_network: true |
|
tpu_name: null |
|
tpu_zone: null |
|
use_cpu: false |
|
``` |
|
b. run the below command to launch the example script |
|
```bash |
|
accelerate launch --config_file fsdp_config.yaml examples/peft_lora_seq2seq_accelerate_fsdp.py |
|
``` |
|
|
|
2. When using `P_TUNING` or `PROMPT_TUNING` with `SEQ_2_SEQ` task, remember to remove the `num_virtual_token` virtual prompt predictions from the left side of the model outputs during evaluations. |
|
|
|
3. For encoder-decoder models, `P_TUNING` or `PROMPT_TUNING` doesn't support the `generate` functionality of transformers because `generate` strictly requires `decoder_input_ids` but |
|
`P_TUNING`/`PROMPT_TUNING` append soft prompt embeddings to `input_embeds` to create |
|
new `input_embeds` to be given to the model. Therefore, `generate` doesn't support this yet. |
|
|
|
4. When using ZeRO3 with zero3_init_flag=True, if you find the GPU memory increase with training steps. we might need to set zero3_init_flag=false in accelerate config.yaml. The related issue is [[BUG] memory leak under zero.Init](https://github.com/microsoft/DeepSpeed/issues/2637) |