PEFT documentation

LoRA

You are viewing v0.9.0 version. A newer version v0.13.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

LoRA

LoRA is low-rank decomposition method to reduce the number of trainable parameters which speeds up finetuning large models and uses less memory. In PEFT, using LoRA is as easy as setting up a LoraConfig and wrapping it with get_peft_model() to create a trainable PeftModel.

This guide explores in more detail other options and features for using LoRA.

Initialization

The initialization of LoRA weights is controlled by the parameter init_lora_weights in LoraConfig. By default, PEFT initializes LoRA weights with Kaiming-uniform for weight A and zeros for weight B resulting in an identity transform (same as the reference implementation).

It is also possible to pass init_lora_weights="gaussian". As the name suggests, this initializes weight A with a Gaussian distribution and zeros for weight B (this is how Diffusers initializes LoRA weights).

from peft import LoraConfig

config = LoraConfig(init_lora_weights="gaussian", ...)

There is also an option to set init_lora_weights=False which is useful for debugging and testing. This should be the only time you use this option. When choosing this option, the LoRA weights are initialized such that they do not result in an identity transform.

from peft import LoraConfig

config = LoraConfig(init_lora_weights=False, ...)

LoftQ

When quantizing the base model for QLoRA training, consider using the LoftQ initialization, which has been shown to improve performance when training quantized models. The idea is that the LoRA weights are initialized such that the quantization error is minimized. If you’re using LoftQ, do not quantize the base model. You should set up a LoftQConfig instead:

from peft import LoftQConfig, LoraConfig, get_peft_model

base_model = AutoModelForCausalLM.from_pretrained(...)  # don't quantize here
loftq_config = LoftQConfig(loftq_bits=4, ...)           # set 4bit quantization
lora_config = LoraConfig(..., init_lora_weights="loftq", loftq_config=loftq_config)
peft_model = get_peft_model(base_model, lora_config)

Learn more about how PEFT works with quantization in the Quantization guide.

Rank-stabilized LoRA

Another way to initialize LoraConfig is with the rank-stabilized LoRA (rsLoRA) method. The LoRA architecture scales each adapter during every forward pass by a fixed scalar which is set at initialization and depends on the rank r. The scalar is given by lora_alpha/r in the original implementation, but rsLoRA uses lora_alpha/math.sqrt(r) which stabilizes the adapters and increases the performance potential from using a higher r.

from peft import LoraConfig

config = LoraConfig(use_rslora=True, ...)

Weight-Decomposed Low-Rank Adaptation (DoRA)

This technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports non-quantized linear layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for inference, see LoraModel.merge_and_unload(). For more information on DoRA, see https://arxiv.org/abs/2402.09353.

from peft import LoraConfig

config = LoraConfig(use_dora=True, ...)

QLoRA-style training

The default LoRA settings in PEFT add trainable weights to the query and value layers of each attention block. But QLoRA, which adds trainable weights to all the linear layers of a transformer model, can provide performance equal to a fully finetuned model. To apply LoRA to all the linear layers, like in QLoRA, set target_modules="all-linear" (easier than specifying individual modules by name which can vary depending on the architecture).

config = LoraConfig(target_modules="all-linear", ...)

Merge adapters

While LoRA is significantly smaller and faster to train, you may encounter latency issues during inference due to separately loading the base model and the LoRA adapter. To eliminate latency, use the merge_and_unload() function to merge the adapter weights with the base model. This allows you to use the newly merged model as a standalone model. The merge_and_unload() function doesn’t keep the adapter weights in memory.

from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
peft_model_id = "alignment-handbook/zephyr-7b-sft-lora"
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.merge_and_unload()

If you need to keep a copy of the weights so you can unmerge the adapter later or delete and load different ones, you should use the merge_adapter() function instead. Now you have the option to use unmerge_adapter() to return the base model.

from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
peft_model_id = "alignment-handbook/zephyr-7b-sft-lora"
model = PeftModel.from_pretrained(base_model, peft_model_id)
model.merge_adapter()

# unmerge the LoRA layers from the base model
model.unmerge_adapter()

The add_weighted_adapter() function is useful for merging multiple LoRAs into a new adapter based on a user provided weighting scheme in the weights parameter. Below is an end-to-end example.

First load the base model:

from transformers import AutoModelForCausalLM
from peft import PeftModel
import torch

base_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, device_map="auto"
)

Then we load the first adapter:

peft_model_id = "alignment-handbook/zephyr-7b-sft-lora"
model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name="sft")

Then load a different adapter and merge it with the first one:

weighted_adapter_name = "sft-dpo"
model.load_adapter("alignment-handbook/zephyr-7b-dpo-lora", adapter_name="dpo")
model.add_weighted_adapter(
    adapters=["sft", "dpo"],
    weights=[0.7, 0.3],
    adapter_name=weighted_adapter_name,
    combination_type="linear"
)
model.set_adapter(weighted_adapter_name)

There are several supported methods for combination_type. Refer to the documentation for more details. Note that “svd” as the combination_type is not supported when using torch.float16 or torch.bfloat16 as the datatype.

Now, perform inference:

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}

with torch.no_grad():
    generate_ids = model.generate(**inputs, max_length=30)
outputs = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(outputs)

Load adapters

Adapters can be loaded onto a pretrained model with load_adapter(), which is useful for trying out different adapters whose weights aren’t merged. Set the active adapter weights with the set_adapter() function.

from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
peft_model_id = "alignment-handbook/zephyr-7b-sft-lora"
model = PeftModel.from_pretrained(base_model, peft_model_id)

# load different adapter
model.load_adapter("alignment-handbook/zephyr-7b-dpo-lora", adapter_name="dpo")

# set adapter as active
model.set_adapter("dpo")

To return the base model, you could use unload() to unload all of the LoRA modules or delete_adapter() to delete the adapter entirely.

# unload adapter
model.unload()

# delete adapter
model.delete_adapter("dpo")