PEFT documentation

Adapter injection

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.14.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Adapter injection

With PEFT, you can inject trainable adapters into any torch module which allows you to use adapter methods without relying on the modeling classes in PEFT. Currently, PEFT supports injecting LoRA, AdaLoRA, and IA3 into models because for these adapters, inplace modification of the model is sufficient for finetuning it.

Check the table below to see when you should inject adapters.

Pros Cons
the model is modified inplace, keeping all the original attributes and methods manually write the from_pretrained and save_pretrained utility functions from Hugging Face to save and load adapters
works for any torch module and modality doesn’t work with any of the utility methods provided by PeftModel such as disabling and merging adapters

Creating a new PEFT model

To perform the adapter injection, use the inject_adapter_in_model() method. This method takes 3 arguments, the PEFT config, the model, and an optional adapter name. You can also attach multiple adapters to the model if you call inject_adapter_in_model() multiple times with different adapter names.

For example, to inject LoRA adapters into the linear submodule of the DummyModel module:

import torch
from peft import inject_adapter_in_model, LoraConfig

class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(10, 10)
        self.linear = torch.nn.Linear(10, 10)
        self.lm_head = torch.nn.Linear(10, 10)

    def forward(self, input_ids):
        x = self.embedding(input_ids)
        x = self.linear(x)
        x = self.lm_head(x)
        return x


lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    target_modules=["linear"],
)

model = DummyModel()
model = inject_adapter_in_model(lora_config, model)

dummy_inputs = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]])
dummy_outputs = model(dummy_inputs)

Print the model to see that the adapters have been correctly injected.

DummyModel(
  (embedding): Embedding(10, 10)
  (linear): Linear(
    in_features=10, out_features=10, bias=True
    (lora_dropout): ModuleDict(
      (default): Dropout(p=0.1, inplace=False)
    )
    (lora_A): ModuleDict(
      (default): Linear(in_features=10, out_features=64, bias=False)
    )
    (lora_B): ModuleDict(
      (default): Linear(in_features=64, out_features=10, bias=False)
    )
    (lora_embedding_A): ParameterDict()
    (lora_embedding_B): ParameterDict()
  )
  (lm_head): Linear(in_features=10, out_features=10, bias=True)
)

Saving the model

To only save the adapter, use the get_peft_model_state_dict() function:

from peft import get_peft_model_state_dict

peft_state_dict = get_peft_model_state_dict(model)
print(peft_state_dict)

Otherwise, model.state_dict() returns the full state dict of the model.

Loading the model

After loading the saved state_dict, it can be applied using the set_peft_model_state_dict() function:

from peft import set_peft_model_state_dict

model = DummyModel()
model = inject_adapter_in_model(lora_config, model)
outcome = set_peft_model_state_dict(model, peft_state_dict)
# check that there were no wrong keys
print(outcome.unexpected_keys)

If injecting the adapter is slow or you need to load a large number of adapters, you may use an optimization that allows to create an “empty” adapter on meta device and only fills the weights with real weights when the set_peft_model_state_dict() is called. To do this, pass low_cpu_mem_usage=True to both inject_adapter_in_model() and set_peft_model_state_dict().

model = DummyModel()
model = inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)

print(model.linear.lora_A["default"].weight.device.type == "meta")  # should be True
set_peft_model_state_dict(model, peft_state_dict, low_cpu_mem_usage=True)
print(model.linear.lora_A["default"].weight.device.type == "cpu")  # should be True
< > Update on GitHub