merge-lora / app.py
Weyaxi's picture
bfloat16
13a5516
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import gc
import gradio as gr
import torch
from huggingface_hub import snapshot_download, HfApi, notebook_login, create_repo, whoami, login
api = HfApi()
def info_fn(text):
gr.Info(text)
def warning_fn(text):
gr.Warning(text)
def upload(hf_token, base_model_name_or_path, peft_model_path, output_dir):
try:
login(hf_token)
repo_name = output_dir
device_arg = {'device_map': "cpu"}
info_fn(f"Loading base model: {base_model_name_or_path}")
base_model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.bfloat16, **device_arg)
info_fn(f"Loading PEFT: {peft_model_path}")
model = PeftModel.from_pretrained(base_model, peft_model_path, **device_arg)
info_fn(f"Running merge_and_unload")
model = model.merge_and_unload()
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
info_fn("Saving model..")
model.save_pretrained(output_dir, safe_serialization=True)
info_fn("Saving tokenizer...")
tokenizer.save_pretrained(output_dir)
info_fn(f"Model saved to {output_dir}")
del model
gc.collect()
try:
info_fn("Creating Repo...")
info_fn(api.create_repo(repo_id=repo_name).__dict__['url'])
except Exception as e:
warning_fn(f"Model already exists: {e}")
info_fn("Uploading to hub...")
uploading = api.upload_folder(
folder_path=output_dir,
repo_id=output_dir,
repo_type="model")
return uploading
except Exception as e:
gc.collect()
gr.Error(e)
return e
INTRODUCTION_TEXT = f"""
🎯 This space allows you to merge your Lora adapters.
## ❓ What is Lora?
LoRA: Low-Rank Adaptation of Large Language Models allows you to train LLM's with a low cost. Lora freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.
You can learn more about LoRa here:
[📝 LoRA: Low-Rank Adaptation of Large Language Models Arxiv](https://arxiv.org/abs/2106.09685)
## 🛠️ How does this space work?
🛠️ The leaderboard's backend mainly runs the transformers and PEFT library.
🤖 The code first loads your original model and then your adapter models.
📚 The code merges your adapter weights using the `merge_and_unload` function from the PEFT library.
📤 The code saves your resulting model temporarily and then pushes the resulting model to the hub.
## 🧮 Required RAM
This space is loading the model to RAM without performing any quantization, so the required RAM is high.
You can merge models up to 7B. (If your adapter weights are too large, it might not work.)
"""
with gr.Blocks() as demo:
gr.Markdown("""<h1 align="center" id="space-title">🚀 Lora Merge</h1>""")
gr.Markdown(INTRODUCTION_TEXT)
with gr.Row():
with gr.Column(scale=1):
hf_token = gr.Textbox(label="Huggingface Write Access Token")
base_model_name_or_path = gr.Textbox(label="Base Model")
peft_model_path = gr.Textbox(label="Adapter Model")
output_dir = gr.Textbox(label="Output Model Name")
with gr.Column(scale=1):
text = gr.Textbox(label="Output Model Name", lines=14)
submit = gr.Button("Merge lora with adapters")
submit.click(fn=upload, inputs=[hf_token, base_model_name_or_path, peft_model_path, output_dir], outputs=text)
demo.queue()
demo.launch(show_error=True)