Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import HfApi | |
import spaces | |
import shutil | |
import logging | |
import subprocess | |
from pathlib import Path | |
def write_repo(base_model, model_to_merge): | |
with open("repo.txt", "w") as repo: | |
repo.write(base_model + "\n" + model_to_merge) | |
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token): | |
# Define a fixed output path | |
outpath = Path('/tmp/output') | |
if outpath.exists() and outpath.is_dir(): | |
shutil.rmtree(outpath) | |
write_repo(base_model, model_to_merge) | |
# Construct the command to run hf_merge.py | |
command = [ | |
"python3", "hf_merge.py", | |
"-p", str(weight_drop_prob), | |
"-lambda", str(scaling_factor), | |
"repo.txt", str(outpath) | |
] | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
log_output = "" | |
# Run the command and capture the output | |
result = subprocess.run(command, capture_output=True, text=True) | |
# Log the output | |
log_output += result.stdout + "\n" | |
log_output += result.stderr + "\n" | |
logging.info(result.stdout) | |
logging.error(result.stderr) | |
# Check if the merge was successful | |
if result.returncode != 0: | |
return None, f"Error in merging models: {result.stderr}", log_output | |
# Update progress bar | |
yield 0.5, "Merging completed. Uploading to Hugging Face Hub..." | |
# Upload the result to Hugging Face Hub | |
api = HfApi(token=token) | |
try: | |
# Get the username of the user who is logged in | |
user = api.whoami(token=token)["name"] | |
# Autofill the repo name if none is provided | |
if not repo_name: | |
repo_name = f"{user}/default-repo" | |
# Create a new repo or update an existing one | |
api.create_repo(repo_id=repo_name, token=token, exist_ok=True) | |
# Upload the file | |
api.upload_folder( | |
folder_path=str(outpath), | |
repo_id=repo_name, | |
repo_type="model", | |
token=token | |
) | |
repo_url = f"https://huggingface.co/{repo_name}" | |
yield 1.0, "Upload completed." | |
return repo_url, "Model merged and uploaded successfully!", log_output | |
except Exception as e: | |
return None, f"Error uploading to Hugging Face Hub: {str(e)}", log_output | |
# Define the Gradio interface | |
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo: | |
gr.Markdown("# Model Merger and Uploader") | |
gr.Markdown("Combine any two models using a Super Mario merge(DARE) as described in the linked whitepaper.") | |
gr.Markdown("Works with:") | |
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)") | |
gr.Markdown("* LLMs (Mistral, Llama, etc)") | |
gr.Markdown("* LoRas (must be same size)") | |
gr.Markdown("* Any two homologous models") | |
with gr.Column(): | |
with gr.Row(): | |
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1) | |
with gr.Row(): | |
base_model = gr.Textbox(label="Base Model", placeholder=".safetensors") | |
with gr.Row(): | |
model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors") | |
with gr.Row(): | |
repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1) | |
with gr.Row(): | |
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor") | |
with gr.Row(): | |
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability") | |
repo_url = gr.Markdown(label="Repository URL") | |
gr.Button("Merge").click( | |
merge_and_upload, | |
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token], | |
outputs=[repo_url] | |
) | |
demo.launch() | |