abhishek HF staff commited on
Commit
c445497
1 Parent(s): 4d0742d

adapter merger

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import torch
5
+
6
+
7
+ def merge(base_model, trained_adapter, token):
8
+ base = AutoModelForCausalLM.from_pretrained(
9
+ base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=token
10
+ )
11
+ model = PeftModel.from_pretrained(base, trained_adapter, token=token)
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained(base_model, token=token)
14
+ except RecursionError:
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ base_model, unk_token="<unk>", token=token
17
+ )
18
+
19
+ model = model.merge_and_unload()
20
+
21
+ print("Saving target model")
22
+ model.push_to_hub(trained_adapter, token=token)
23
+ tokenizer.push_to_hub(trained_adapter, token=token)
24
+ return gr.Markdown.update(
25
+ value="Model successfully merged and pushed! Please shutdown/pause this space"
26
+ )
27
+
28
+
29
+ with gr.Blocks() as demo:
30
+ gr.Markdown("## AutoTrain Merge Adapter")
31
+ gr.Markdown("Please duplicate this space and attach a GPU in order to use it.")
32
+ token = gr.Textbox(
33
+ label="Hugging Face Write Token",
34
+ value="",
35
+ lines=1,
36
+ max_lines=1,
37
+ interactive=True,
38
+ type="password",
39
+ )
40
+ base_model = gr.Textbox(
41
+ label="Base Model (e.g. meta-llama/Llama-2-7b-chat-hf)",
42
+ value="",
43
+ lines=1,
44
+ max_lines=1,
45
+ interactive=True,
46
+ )
47
+ trained_adapter = gr.Textbox(
48
+ label="Trained Adapter Model (e.g. username/autotrain-my-llama)",
49
+ value="",
50
+ lines=1,
51
+ max_lines=1,
52
+ interactive=True,
53
+ )
54
+ submit = gr.Button(value="Merge & Push")
55
+ op = gr.Markdown(interactive=False)
56
+ submit.click(merge, inputs=[base_model, trained_adapter, token], outputs=[op])
57
+
58
+
59
+ if __name__ == "__main__":
60
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ peft