Soon_Merger / app2.py
AlekseyCalvin's picture
Rename app.py to app2.py
f578122 verified
import gradio as gr
import torch
import os
import gc
import json
import shutil
import requests
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
# --- Constants & Setup ---
TempDir = Path("./temp_merge")
os.makedirs(TempDir, exist_ok=True)
api = HfApi()
def info_log(msg, progress=None):
print(msg)
if progress:
return msg
return msg
def cleanup_temp():
if TempDir.exists():
shutil.rmtree(TempDir)
os.makedirs(TempDir, exist_ok=True)
gc.collect()
# --- Core Logic ---
def download_lora(lora_input, hf_token):
"""Downloads LoRA from a Repo ID or a direct URL."""
local_path = TempDir / "adapter.safetensors"
if lora_input.startswith("http"):
# Direct URL download
print(f"Downloading LoRA from URL: {lora_input}")
response = requests.get(lora_input, stream=True)
response.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return local_path
else:
# Repo ID download
print(f"Downloading LoRA from Repo: {lora_input}")
try:
return hf_hub_download(repo_id=lora_input, filename="adapter_model.safetensors", token=hf_token, local_dir=TempDir)
except:
files = list_repo_files(repo_id=lora_input, token=hf_token)
safe_files = [f for f in files if f.endswith(".safetensors") and "adapter" in f]
if not safe_files:
safe_files = [f for f in files if f.endswith(".safetensors")]
if not safe_files:
raise ValueError("Could not find a .safetensors file in the LoRA repo.")
return hf_hub_download(repo_id=lora_input, filename=safe_files[0], token=hf_token, local_dir=TempDir)
def load_lora_weights(path):
tensors = load_file(path, device="cpu")
return tensors
def match_keys(base_key, lora_keys):
matches = {}
candidates = [k for k in lora_keys if base_key in k]
pair_A = None
pair_B = None
for k in candidates:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
return pair_A, pair_B
def copy_auxiliary_files(src_repo, tgt_repo, token, subfolder=""):
print(f"Copying infrastructure from {src_repo} to {tgt_repo}...")
files = list_repo_files(repo_id=src_repo, token=token)
files_to_copy = [
f for f in files
if not f.endswith(".safetensors")
and not f.endswith(".bin")
and not f.endswith(".pt")
and not f.endswith(".pth")
and not f.endswith(".msgpack")
and not f.endswith(".h5")
]
for f in tqdm(files_to_copy, desc="Copying configs"):
try:
local = hf_hub_download(repo_id=src_repo, filename=f, token=token)
api.upload_file(
path_or_fileobj=local,
path_in_repo=f,
repo_id=tgt_repo,
repo_type="model",
token=token
)
os.remove(local)
except Exception as e:
print(f"Skipped {f}: {e}")
def run_merge(
hf_token,
base_repo,
base_subfolder,
structure_repo,
lora_input,
scale,
output_repo,
is_private,
progress=gr.Progress()
):
cleanup_temp()
logs = []
try:
login(hf_token)
logs.append(f"Logged in. Target: {output_repo}")
# 1. Create Output Repo
try:
api.create_repo(repo_id=output_repo, private=is_private, exist_ok=True, token=hf_token)
logs.append("Output repository ready.")
except Exception as e:
return "\n".join(logs) + f"\nError creating repo: {e}"
# 2. Replicate Structure
if structure_repo.strip():
progress(0.1, desc="Cloning Model Structure...")
logs.append(f"Cloning configuration from {structure_repo}...")
copy_auxiliary_files(structure_repo, output_repo, hf_token)
logs.append("Configuration files copied.")
# 3. Load LoRA
progress(0.2, desc="Downloading LoRA...")
logs.append(f"Fetching LoRA: {lora_input}")
lora_path = download_lora(lora_input, hf_token)
lora_state = load_lora_weights(lora_path)
lora_keys = list(lora_state.keys())
logs.append(f"LoRA loaded. Found {len(lora_keys)} tensors.")
# 4. Identify Base Shards
progress(0.3, desc="Analyzing Base Model...")
all_files = list_repo_files(repo_id=base_repo, token=hf_token)
target_shards = []
for f in all_files:
if not f.endswith(".safetensors"):
continue
if base_subfolder.strip() and not f.startswith(base_subfolder.strip("/")):
continue
target_shards.append(f)
logs.append(f"Found {len(target_shards)} matching safetensors shards in base.")
if not target_shards:
raise ValueError("No safetensors found in the specified base repo/subfolder.")
# 5. Process Shards
total_shards = len(target_shards)
merged_count = 0
for idx, shard_file in enumerate(target_shards):
progress(0.3 + (0.6 * (idx / total_shards)), desc=f"Processing Shard {idx+1}/{total_shards}")
logs.append(f"--- Processing {shard_file} ---")
local_shard = hf_hub_download(repo_id=base_repo, filename=shard_file, token=hf_token, local_dir=TempDir)
base_tensors = load_file(local_shard, device="cpu")
modified_tensors = {}
has_changes = False
for key, tensor in base_tensors.items():
pair_A, pair_B = match_keys(key, lora_keys)
if not pair_A:
matches = [k for k in lora_keys if key in k]
for k in matches:
if "lora_A" in k or "lora_down" in k:
pair_A = k
elif "lora_B" in k or "lora_up" in k:
pair_B = k
if pair_A and pair_B:
w_a = lora_state[pair_A].float()
w_b = lora_state[pair_B].float()
current_tensor = tensor.float()
delta = (w_b @ w_a) * scale
if delta.shape != current_tensor.shape:
if delta.T.shape == current_tensor.shape:
delta = delta.T
else:
logs.append(f"Warning: Shape mismatch for {key}. Skipping.")
modified_tensors[key] = tensor
continue
modified_tensors[key] = (current_tensor + delta).to(tensor.dtype)
merged_count += 1
has_changes = True
else:
modified_tensors[key] = tensor
if has_changes:
logs.append(f"Merging complete for shard. Saving...")
output_path = TempDir / "processed.safetensors"
save_file(modified_tensors, output_path)
api.upload_file(path_or_fileobj=output_path, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
logs.append(f"Uploaded {shard_file}")
else:
logs.append(f"No LoRA matches in this shard. Copying original...")
api.upload_file(path_or_fileobj=local_shard, path_in_repo=shard_file, repo_id=output_repo, repo_type="model", token=hf_token)
del base_tensors
del modified_tensors
if 'delta' in locals(): del delta
gc.collect()
os.remove(local_shard)
if os.path.exists(TempDir / "processed.safetensors"):
os.remove(TempDir / "processed.safetensors")
progress(1.0, desc="Done!")
logs.append(f"\nSUCCESS. Merged {merged_count} layers total.")
logs.append(f"New model available at: https://huggingface.co/{output_repo}")
except Exception as e:
import traceback
logs.append(f"\nCRITICAL ERROR: {str(e)}")
logs.append(traceback.format_exc())
finally:
cleanup_temp()
return "\n".join(logs)
# --- UI ---
css = """
.container { max-width: 900px; margin: auto; }
.header { text-align: center; margin-bottom: 20px; }
"""
# NOTE: Removed 'css' and 'theme' from gr.Blocks() to be compatible with latest Gradio versions.
with gr.Blocks() as demo:
gr.Markdown(
"""
# ⚡ Universal LoRA Merger & Reconstructor
Merge LoRA adapters into **any** base model (LLM, Diffusion, Audio) and reconstruct the repository structure.
Optimized for CPU-only execution on Hugging Face Spaces.
"""
)
with gr.Group():
gr.Markdown("### 1. Authentication & Output")
with gr.Row():
hf_token = gr.Textbox(label="HF Write Token", type="password", placeholder="hf_...")
output_repo = gr.Textbox(label="Target Output Repo", placeholder="username/Z-Image-Turbo-Custom")
is_private = gr.Checkbox(label="Private Repo", value=True)
with gr.Group():
gr.Markdown("### 2. Base Weights (The Target)")
with gr.Row():
base_repo = gr.Textbox(label="Base Model Repo", placeholder="e.g. ostris/Z-Image-De-Turbo")
base_subfolder = gr.Textbox(label="Subfolder (Optional)", placeholder="e.g. transformer", info="Only merge weights found inside this folder.")
with gr.Group():
gr.Markdown("### 3. LoRA Configuration")
with gr.Row():
lora_input = gr.Textbox(label="LoRA Source", placeholder="Repo ID OR Direct URL (http...)", info="Accepts direct .safetensors resolve links.")
scale = gr.Slider(label="Scale", minimum=-2.0, maximum=2.0, value=1.0, step=0.1)
with gr.Group():
gr.Markdown("### 4. Repository Reconstruction (Optional)")
gr.Markdown("*Use this to fill in missing files (Scheduler, VAE, Tokenizer, model_index.json) from a different source repo.*")
structure_repo = gr.Textbox(label="Structure Source Repo", placeholder="e.g. Tongyi-MAI/Z-Image-Turbo", info="Copies all NON-weight files from here to output.")
submit_btn = gr.Button("🚀 Start Merge & Upload", variant="primary")
output_log = gr.Textbox(label="Process Log", lines=20, interactive=False)
submit_btn.click(
fn=run_merge,
inputs=[hf_token, base_repo, base_subfolder, structure_repo, lora_input, scale, output_repo, is_private],
outputs=output_log
)
if __name__ == "__main__":
# CSS is now passed here in the launch method
demo.queue(max_size=1).launch(css=css)