Weyaxi commited on
Commit
26e58a0
1 Parent(s): 7b15cb8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from peft import PeftModel
3
+ import gc
4
+ import gradio as gr
5
+ import torch
6
+ from huggingface_hub import snapshot_download, HfApi, notebook_login, create_repo, whoami, login
7
+
8
+ api = HfApi()
9
+
10
+
11
+ def info_fn(text):
12
+ gr.Info(text)
13
+
14
+
15
+ def warning_fn(text):
16
+ gr.Warning(text)
17
+
18
+
19
+ def upload(hf_token, base_model_name_or_path, peft_model_path, output_dir):
20
+ try:
21
+ login(hf_token)
22
+ repo_name = output_dir
23
+
24
+ device_arg = {'device_map': "cpu"}
25
+
26
+ info_fn(f"Loading base model: {base_model_name_or_path}")
27
+
28
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, torch_dtype=torch.float16, **device_arg)
29
+
30
+ info_fn(f"Loading PEFT: {peft_model_path}")
31
+
32
+ model = PeftModel.from_pretrained(base_model, peft_model_path, **device_arg)
33
+
34
+ info_fn(f"Running merge_and_unload")
35
+
36
+ model = model.merge_and_unload()
37
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
38
+
39
+ info_fn("Saving model..")
40
+ model.save_pretrained(output_dir, safe_serialization=True)
41
+
42
+ info_fn("Saving tokenizer...")
43
+ tokenizer.save_pretrained(output_dir)
44
+
45
+ info_fn(f"Model saved to {output_dir}")
46
+
47
+ del model
48
+ gc.collect()
49
+
50
+ try:
51
+ info_fn("Creating Repo...")
52
+ info_fn(api.create_repo(repo_id=repo_name).__dict__['url'])
53
+ except Exception as e:
54
+ warning_fn(f"Model already exists: {e}")
55
+
56
+ info_fn("Uploading to hub...")
57
+ uploading = api.upload_folder(
58
+ folder_path=output_dir,
59
+ repo_id=output_dir,
60
+ repo_type="model")
61
+
62
+ return uploading
63
+
64
+ except Exception as e:
65
+ gc.collect()
66
+ gr.Error(e)
67
+
68
+ return e
69
+
70
+
71
+ INTRODUCTION_TEXT = f"""
72
+ 🎯 The Leaderboard allows you to merge your Lora adapters.
73
+
74
+ ## ❓ What is Lora?
75
+
76
+ LoRA: Low-Rank Adaptation of Large Language Models allows you to train LLM's with a low cost. Lora freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks.
77
+ You can learn more about LoRa here:
78
+
79
+ [📝 LoRA: Low-Rank Adaptation of Large Language Models Arxiv](https://arxiv.org/abs/2106.09685)
80
+
81
+ ## 🛠️ How does this space work?
82
+
83
+ 🛠️ The leaderboard's backend mainly runs the transformers and PEFT library.
84
+
85
+ 🤖 The code first loads your original model and then your adapter models.
86
+
87
+ 📚 The code merges your adapter weights using the `merge_and_unload` function from the PEFT library.
88
+
89
+ 📤 The code saves your resulting model temporarily and then pushes the resulting model to the hub.
90
+
91
+ ## 🧮 Required RAM
92
+
93
+ This space is loading the model to RAM without performing any quantization, so the required RAM is high.
94
+
95
+ You can merge models up to 13B. (If your adapter weights are too large, it might not work.)
96
+ """
97
+
98
+
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("""<h1 align="center" id="space-title">🚀 Lora Merge</h1>""")
101
+ gr.Markdown(INTRODUCTION_TEXT)
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=1):
105
+ hf_token = gr.Textbox(label="Huggingface Write Access Token")
106
+ base_model_name_or_path = gr.Textbox(label="Base Model")
107
+ peft_model_path = gr.Textbox(label="Adapter Model")
108
+ output_dir = gr.Textbox(label="Output Model Name")
109
+
110
+ with gr.Column(scale=1):
111
+ text = gr.Textbox(label="Output Model Name", lines=14)
112
+
113
+
114
+ submit = gr.Button("Merge lora with adapters")
115
+ submit.click(fn=upload, inputs=[hf_token, base_model_name_or_path, peft_model_path, output_dir], outputs=text)
116
+
117
+
118
+ demo.queue()
119
+ demo.launch(show_error=True)