Spaces:
Sleeping
Sleeping
Create worker.py
Browse files- backend/worker.py +50 -0
backend/worker.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import tempfile
|
| 4 |
+
import shutil
|
| 5 |
+
import traceback
|
| 6 |
+
from .merge_engines.linear import linear_merge
|
| 7 |
+
from .merge_engines.evolutionary import evolutionary_merge
|
| 8 |
+
from .uploader import upload_to_hf
|
| 9 |
+
from .memory_manager import estimate_memory, check_memory_safe
|
| 10 |
+
|
| 11 |
+
async def run_job_async(job_id: str, params: dict):
|
| 12 |
+
loop = asyncio.get_event_loop()
|
| 13 |
+
return await loop.run_in_executor(None, run_job, job_id, params)
|
| 14 |
+
|
| 15 |
+
def run_job(job_id, params):
|
| 16 |
+
try:
|
| 17 |
+
# memory pre-check
|
| 18 |
+
model_a_id = f"{params['model_a_source']}:{params['model_a_id']}"
|
| 19 |
+
model_b_id = f"{params['model_b_source']}:{params['model_b_id']}"
|
| 20 |
+
if not check_memory_safe(model_a_id, model_b_id, params.get("evo_params")):
|
| 21 |
+
raise RuntimeError("Insufficient memory for this merge. Job rejected.")
|
| 22 |
+
|
| 23 |
+
# download models if needed
|
| 24 |
+
from .model_loader import prepare_model
|
| 25 |
+
path_a = prepare_model(params["model_a_source"], params["model_a_id"], params.get("civitai_key"))
|
| 26 |
+
path_b = prepare_model(params["model_b_source"], params["model_b_id"], params.get("civitai_key"))
|
| 27 |
+
|
| 28 |
+
# output dir
|
| 29 |
+
output_dir = f"/data/output/{job_id}"
|
| 30 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# merge
|
| 33 |
+
if params["method"] == "linear":
|
| 34 |
+
alpha = params.get("linear_alpha", 0.5)
|
| 35 |
+
linear_merge(path_a, path_b, alpha, output_dir)
|
| 36 |
+
elif params["method"] == "evolutionary":
|
| 37 |
+
evo_params = params.get("evo_params", {})
|
| 38 |
+
dataset_path = params.get("dataset")
|
| 39 |
+
evolutionary_merge(path_a, path_b, dataset_path, output_dir, evo_params)
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Unknown method: {params['method']}")
|
| 42 |
+
|
| 43 |
+
# upload
|
| 44 |
+
repo_name = params.get("output_repo_name") or f"merge-{job_id[:8]}"
|
| 45 |
+
upload_to_hf(output_dir, repo_name, params["hf_token"])
|
| 46 |
+
|
| 47 |
+
return {"repo_name": repo_name, "status": "success"}
|
| 48 |
+
except Exception as e:
|
| 49 |
+
traceback.print_exc()
|
| 50 |
+
raise
|