rikunarita-2 commited on
Commit
6d2b369
·
verified ·
1 Parent(s): 0775c37

Create worker.py

Browse files
Files changed (1) hide show
  1. backend/worker.py +50 -0
backend/worker.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tempfile
4
+ import shutil
5
+ import traceback
6
+ from .merge_engines.linear import linear_merge
7
+ from .merge_engines.evolutionary import evolutionary_merge
8
+ from .uploader import upload_to_hf
9
+ from .memory_manager import estimate_memory, check_memory_safe
10
+
11
+ async def run_job_async(job_id: str, params: dict):
12
+ loop = asyncio.get_event_loop()
13
+ return await loop.run_in_executor(None, run_job, job_id, params)
14
+
15
+ def run_job(job_id, params):
16
+ try:
17
+ # memory pre-check
18
+ model_a_id = f"{params['model_a_source']}:{params['model_a_id']}"
19
+ model_b_id = f"{params['model_b_source']}:{params['model_b_id']}"
20
+ if not check_memory_safe(model_a_id, model_b_id, params.get("evo_params")):
21
+ raise RuntimeError("Insufficient memory for this merge. Job rejected.")
22
+
23
+ # download models if needed
24
+ from .model_loader import prepare_model
25
+ path_a = prepare_model(params["model_a_source"], params["model_a_id"], params.get("civitai_key"))
26
+ path_b = prepare_model(params["model_b_source"], params["model_b_id"], params.get("civitai_key"))
27
+
28
+ # output dir
29
+ output_dir = f"/data/output/{job_id}"
30
+ os.makedirs(output_dir, exist_ok=True)
31
+
32
+ # merge
33
+ if params["method"] == "linear":
34
+ alpha = params.get("linear_alpha", 0.5)
35
+ linear_merge(path_a, path_b, alpha, output_dir)
36
+ elif params["method"] == "evolutionary":
37
+ evo_params = params.get("evo_params", {})
38
+ dataset_path = params.get("dataset")
39
+ evolutionary_merge(path_a, path_b, dataset_path, output_dir, evo_params)
40
+ else:
41
+ raise ValueError(f"Unknown method: {params['method']}")
42
+
43
+ # upload
44
+ repo_name = params.get("output_repo_name") or f"merge-{job_id[:8]}"
45
+ upload_to_hf(output_dir, repo_name, params["hf_token"])
46
+
47
+ return {"repo_name": repo_name, "status": "success"}
48
+ except Exception as e:
49
+ traceback.print_exc()
50
+ raise