Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,40 +1,84 @@
|
|
1 |
import gradio as gr
|
2 |
import subprocess
|
|
|
|
|
|
|
3 |
import spaces
|
4 |
|
5 |
@spaces.GPU
|
6 |
-
def
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import subprocess
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
import spaces
|
7 |
|
8 |
@spaces.GPU
|
9 |
+
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
|
10 |
+
# Define a fixed output path
|
11 |
+
outpath = Path('/tmp/output')
|
12 |
+
|
13 |
+
# Construct the command to run hf_merge.py
|
14 |
+
command = [
|
15 |
+
"python3", "hf_merge.py",
|
16 |
+
base_model,
|
17 |
+
model_to_merge,
|
18 |
+
"-p", str(weight_drop_prob),
|
19 |
+
"-lambda", str(scaling_factor),
|
20 |
+
"--token", token,
|
21 |
+
"--repo", repo_name,
|
22 |
+
"--commit-message", commit_message,
|
23 |
+
"-U"
|
24 |
+
]
|
25 |
+
|
26 |
+
# Set up logging
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
log_output = ""
|
29 |
+
|
30 |
+
# Run the command and capture the output
|
31 |
+
result = subprocess.run(command, capture_output=True, text=True)
|
32 |
+
|
33 |
+
# Log the output
|
34 |
+
log_output += result.stdout + "\n"
|
35 |
+
log_output += result.stderr + "\n"
|
36 |
+
logging.info(result.stdout)
|
37 |
+
logging.error(result.stderr)
|
38 |
+
|
39 |
+
# Check if the merge was successful
|
40 |
+
if result.returncode != 0:
|
41 |
+
return None, f"Error in merging models: {result.stderr}", log_output
|
42 |
+
|
43 |
+
# Assuming the script handles the upload and returns the repo URL
|
44 |
+
repo_url = f"https://huggingface.co/{repo_name}"
|
45 |
+
return repo_url, "Model merged and uploaded successfully!", log_output
|
46 |
+
|
47 |
+
# Define the Gradio interface
|
48 |
+
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
|
49 |
+
gr.Markdown("# SuperMario Safetensors Merger")
|
50 |
+
gr.Markdown("Combine any two models using a Super Mario merge(DARE)")
|
51 |
+
gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario")
|
52 |
+
gr.Markdown("Works with:")
|
53 |
+
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
|
54 |
+
gr.Markdown("* LLMs (Mistral, Llama, etc)")
|
55 |
+
gr.Markdown("* LoRas (must be same size)")
|
56 |
+
gr.Markdown("* Any two homologous models")
|
57 |
+
|
58 |
+
with gr.Column():
|
59 |
+
with gr.Row():
|
60 |
+
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
|
61 |
+
with gr.Row():
|
62 |
+
base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
|
63 |
+
with gr.Row():
|
64 |
+
model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
|
65 |
+
with gr.Row():
|
66 |
+
repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
|
67 |
+
with gr.Row():
|
68 |
+
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
|
69 |
+
with gr.Row():
|
70 |
+
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
|
71 |
+
with gr.Row():
|
72 |
+
commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1)
|
73 |
+
|
74 |
+
progress = gr.Progress()
|
75 |
+
repo_url = gr.Markdown(label="Repository URL")
|
76 |
+
output = gr.Textbox(label="Output")
|
77 |
+
|
78 |
+
gr.Button("Merge").click(
|
79 |
+
merge_and_upload,
|
80 |
+
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
|
81 |
+
outputs=[repo_url, output]
|
82 |
+
)
|
83 |
+
|
84 |
+
demo.launch()
|