Edit model card

X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models

X-LoRA works by learning scaling values for LoRA adapters. These learned scalings values are used to gate the LoRA experts in a dense fashion. Additionally, all LoRA adapters and the base model are frozen, allowing efficient fine tuning due to a low parameter count.

X-LoRA is easily applied to any HuggingFace Transformers model.

Features

  • Effective: Dense gating of experts allows effective mixing
  • Efficient fine-tuning: low trainable parameter count
  • Hierarchical encapsulated strategy: Re-use existing trained models or model section and re-use them to address complex tasks that cut across experts, following a bio-inspired strategy
  • Easy-to-use API: add_xlora_to_model, broad compatibility
  • Dynamically mix LoRA adapters: Deep layer-wise combinations of adapters.

X-LoRA source code

Install directly from source

pip install git+https://github.com/EricLBuehler/xlora.git -U

image/png

Further details on installation, packages with source code, API details and more examples:

https://github.com/EricLBuehler/xlora

Converting and loading a model

Example for model conversation:

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

### Convert the model to X-LoRA
model_created = xlora.add_xlora_to_model(
    model=model,
    xlora_config=xlora.xLoRAConfig(config.hidden_size, xlora_depth=8, device=torch.device("cuda")),
    verbose=True,
    adapters={
        "adapter_1": "./path/to/the/checkpoint_adapter_1/",
        "adapter_2": "./path/to/the/checkpoint_adapter_2/",
        "adapter_n": "./path/to/the/checkpoint_adapter_3/",
    },
)

Loading a trained X-LoRA model from scratch

import torch
import xlora
from transformers import AutoConfig, AutoModelForCausalLM # type: ignore

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

config = AutoConfig.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    trust_remote_code=True,
    use_flash_attention_2=False,
    device_map="auto",
)

model = xlora.from_pretrained(
    "./path/to/saved/model",
    model,
    {
        "adapter_1": "./path/to/the/checkpoint/",
        "adapter_2": "./path/to/the/checkpoint/",
        "adapter_n": "./path/to/the/checkpoint/",
    },
    "cuda",
)

Loading pre-trained X-LoRA model directly from Hugging Face Hub

import torch
from xlora.xlora_utils import load_model  

XLoRa_model_name = 'lamm-mit/x-lora'

model,tokenizer=load_model(model_name = XLoRa_model_name, 
                           device='cuda:0',
                           use_flash_attention_2=True, 
                           dtype=torch.bfloat16,
                            )
)

Inference:

def generate_response (model, tokenizer, 
                      text_input="What is the best biomaterial for superior strength?",
                      num_return_sequences = 1,
                      temperature = 0.75,  
                      max_new_tokens = 127,
                      num_beams = 1,
                      top_k = 50,
                      top_p = 0.9,
                      repetition_penalty=1.,
                      eos_token_id=2, 
                      add_special_tokens=True,  
                      ):
    inputs = tokenizer(text_input,  add_special_tokens=add_special_tokens)
    with torch.no_grad():
          outputs = model.generate(input_ids = inputs["input_ids"],
                                    attention_mask = inputs["attention_mask"] ,  
                                    max_new_tokens=max_new_tokens,
                                    temperature=temperature, 
                                    num_beams=num_beams,
                                    top_k = top_k,
                                    top_p = top_p,
                                    num_return_sequences = num_return_sequences,
                                    eos_token_id=eos_token_id,
                                    pad_token_id = eos_token_id,
                                    do_sample =True, 
                                    repetition_penalty=repetition_penalty,
                                  )
    return tokenizer.batch_decode(outputs[:,inputs["input_ids"].shape[1]:].detach().cpu().numpy(), skip_special_tokens=True)

output_text=generate_response (model, tokenizer, text_input=txt,eos_token_id=eos_token,
                                           num_return_sequences=1, repetition_penalty=1.1,
                                           top_p=0.9, top_k=512, 
                                           temperature=0.5,
                                           max_new_tokens=256)

print (output_text[0])

Dataset

See lamm-mit/x-lora-dataset for the dataset used to train the X-LoRA model. Details on the datasets used to train the original adapters are included in the paper (see reference below).

Sample results

image/png

Acknowledgements

This work is built on the Hugging Face PEFT library and other components in the Hugging Face ecosystem. We acknowledge the authors of this excellent library and related methods.

Original paper and citation

Cite this work as:

@article{Buehler_XLoRA_2024,
    title   = {X-LoRA: Mixture of Low-Rank Adapter Experts, a Flexible Framework for Large Language Models with Applications in Protein Mechanics and Design},
    author  = {E.L. Buehler, M.J. Buehler},
    journal = {},
    year    = {2024},
    volume  = {},
    pages   = {},
    url     = {https://arxiv.org/abs/2402.07148}
}
Downloads last month
0
Unable to determine this model's library. Check the docs .