mrcuddle commited on
Commit
a26bb9f
·
verified ·
1 Parent(s): dbb2714

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -35
app.py CHANGED
@@ -1,40 +1,84 @@
1
  import gradio as gr
2
  import subprocess
 
 
 
3
  import spaces
4
 
5
  @spaces.GPU
6
- def merge_models(base_model, model_to_merge, p, lambda_value, token, repo, commit_message, upload):
7
- command = f"python model_merger.py {base_model} {model_to_merge} -p {p} -lambda {lambda_value} --repo {repo} --commit-message '{commit_message}'"
8
- if upload:
9
- command += f" --token {token} --upload"
10
- result = subprocess.run(command, shell=True, capture_output=True, text=True)
11
- return result.stdout
12
-
13
- iface = gr.Interface(
14
- fn=merge_models,
15
- inputs=[
16
- gr.Textbox(label="Base Model"),
17
- gr.Textbox(label="Model to Merge"),
18
- gr.Slider(minimum=0, maximum=1, value=0.5, label="Dropout Probability"),
19
- gr.Slider(minimum=0, maximum=10, value=3, label="Scaling Factor (Lambda)"),
20
- gr.Textbox(label="HuggingFace Token (optional)"),
21
- gr.Textbox(label="New Model Name (without your username) (optional)"),
22
- gr.Textbox(label="Commit Message (optional)", value="Upload merged model"),
23
- gr.Checkbox(label="Upload to HuggingFace Hub"),
24
- ],
25
- outputs="text",
26
- title="Safetensors Model Merger",
27
- description="""
28
- - Combine any two models using a Super Mario merge(DARE).
29
-
30
- - Based on: https://github.com/martyn/safetensors-merge-supermario.
31
-
32
- - Works with:
33
- - Stable Diffusion (1.5, XL/XL Turbo)
34
- - LLMs(Mistral, Llama, etc)
35
- - LoRas(must be same size)
36
- - Any two homologous models
37
- """,
38
- )
39
-
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import subprocess
3
+ import os
4
+ import logging
5
+ from pathlib import Path
6
  import spaces
7
 
8
  @spaces.GPU
9
+ def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
10
+ # Define a fixed output path
11
+ outpath = Path('/tmp/output')
12
+
13
+ # Construct the command to run hf_merge.py
14
+ command = [
15
+ "python3", "hf_merge.py",
16
+ base_model,
17
+ model_to_merge,
18
+ "-p", str(weight_drop_prob),
19
+ "-lambda", str(scaling_factor),
20
+ "--token", token,
21
+ "--repo", repo_name,
22
+ "--commit-message", commit_message,
23
+ "-U"
24
+ ]
25
+
26
+ # Set up logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ log_output = ""
29
+
30
+ # Run the command and capture the output
31
+ result = subprocess.run(command, capture_output=True, text=True)
32
+
33
+ # Log the output
34
+ log_output += result.stdout + "\n"
35
+ log_output += result.stderr + "\n"
36
+ logging.info(result.stdout)
37
+ logging.error(result.stderr)
38
+
39
+ # Check if the merge was successful
40
+ if result.returncode != 0:
41
+ return None, f"Error in merging models: {result.stderr}", log_output
42
+
43
+ # Assuming the script handles the upload and returns the repo URL
44
+ repo_url = f"https://huggingface.co/{repo_name}"
45
+ return repo_url, "Model merged and uploaded successfully!", log_output
46
+
47
+ # Define the Gradio interface
48
+ with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
49
+ gr.Markdown("# SuperMario Safetensors Merger")
50
+ gr.Markdown("Combine any two models using a Super Mario merge(DARE)")
51
+ gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario")
52
+ gr.Markdown("Works with:")
53
+ gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
54
+ gr.Markdown("* LLMs (Mistral, Llama, etc)")
55
+ gr.Markdown("* LoRas (must be same size)")
56
+ gr.Markdown("* Any two homologous models")
57
+
58
+ with gr.Column():
59
+ with gr.Row():
60
+ token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
61
+ with gr.Row():
62
+ base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
63
+ with gr.Row():
64
+ model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
65
+ with gr.Row():
66
+ repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
67
+ with gr.Row():
68
+ scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
69
+ with gr.Row():
70
+ weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
71
+ with gr.Row():
72
+ commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1)
73
+
74
+ progress = gr.Progress()
75
+ repo_url = gr.Markdown(label="Repository URL")
76
+ output = gr.Textbox(label="Output")
77
+
78
+ gr.Button("Merge").click(
79
+ merge_and_upload,
80
+ inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
81
+ outputs=[repo_url, output]
82
+ )
83
+
84
+ demo.launch()