HairStable Bot commited on
Commit
5eb37d7
·
1 Parent(s): 43b5bad

feat: add async job API with concurrency control (/get-hairswap-async, /job/{id})

Browse files
Files changed (1) hide show
  1. Hair_stable_new_fresh/server.py +81 -0
Hair_stable_new_fresh/server.py CHANGED
@@ -8,6 +8,8 @@ from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header
8
  from fastapi.responses import FileResponse, JSONResponse
9
  from pydantic import BaseModel
10
  import torch
 
 
11
 
12
  import numpy as np
13
  from PIL import Image
@@ -254,3 +256,82 @@ def logs(_=Depends(verify_bearer)):
254
  return JSONResponse({"logs": ["service running"], "db": "not_configured"})
255
 
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from fastapi.responses import FileResponse, JSONResponse
9
  from pydantic import BaseModel
10
  import torch
11
+ import asyncio
12
+ import time
13
 
14
  import numpy as np
15
  from PIL import Image
 
256
  return JSONResponse({"logs": ["service running"], "db": "not_configured"})
257
 
258
 
259
+ # -------------------- Async job API --------------------
260
+ MAX_CONCURRENCY = int(os.environ.get("MAX_CONCURRENCY", "1"))
261
+ _sem = asyncio.Semaphore(MAX_CONCURRENCY)
262
+ _jobs = {}
263
+
264
+
265
+ @app.post("/get-hairswap-async")
266
+ async def get_hairswap_async(req: HairSwapRequest, _=Depends(verify_bearer)):
267
+ job_id = str(uuid.uuid4())
268
+ _jobs[job_id] = {"status": "queued", "result": None, "error": None, "started_at": None, "ended_at": None}
269
+
270
+ async def _run_job():
271
+ _jobs[job_id]["status"] = "running"
272
+ _jobs[job_id]["started_at"] = time.time()
273
+ try:
274
+ async with _sem:
275
+ # reuse the same core flow as sync endpoint
276
+ def find_file(image_id: str) -> str:
277
+ for name in os.listdir(UPLOAD_DIR):
278
+ if name.startswith(image_id):
279
+ return os.path.join(UPLOAD_DIR, name)
280
+ raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}")
281
+
282
+ source_path = find_file(req.source_id)
283
+ reference_path = find_file(req.reference_id)
284
+ LOGGER.info(f"[job {job_id}] Found source: {source_path}, reference: {reference_path}")
285
+
286
+ model = get_model()
287
+ LOGGER.info(f"[job {job_id}] Model loaded successfully")
288
+
289
+ LOGGER.info(f"[job {job_id}] Starting hair transfer...")
290
+ try:
291
+ try:
292
+ sched_main = type(model.pipeline.scheduler).__name__ if hasattr(model, "pipeline") else None
293
+ sched_bald = type(model.remove_hair_pipeline.scheduler).__name__ if hasattr(model, "remove_hair_pipeline") else None
294
+ LOGGER.info(f"[job {job_id}] Schedulers -> main: {sched_main}, remove_hair: {sched_bald}")
295
+ except Exception:
296
+ pass
297
+
298
+ id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
299
+ source_image=source_path,
300
+ reference_image=reference_path,
301
+ random_seed=-1,
302
+ step=20,
303
+ guidance_scale=req.guidance_scale,
304
+ scale=req.scale,
305
+ controlnet_conditioning_scale=req.controlnet_conditioning_scale,
306
+ size=448,
307
+ )
308
+ result_id = str(uuid.uuid4())
309
+ out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
310
+ filename = f"{result_id}.png"
311
+ out_path = os.path.join(RESULTS_DIR, filename)
312
+ out_img.save(out_path)
313
+ _jobs[job_id]["result"] = {"filename": filename}
314
+ _jobs[job_id]["status"] = "completed"
315
+ LOGGER.info(f"[job {job_id}] Completed -> {out_path}")
316
+ except Exception as e:
317
+ LOGGER.error(f"[job {job_id}] Hair transfer failed: {str(e)}")
318
+ _jobs[job_id]["error"] = str(e)
319
+ _jobs[job_id]["status"] = "failed"
320
+ except Exception as e:
321
+ _jobs[job_id]["error"] = str(e)
322
+ _jobs[job_id]["status"] = "failed"
323
+ finally:
324
+ _jobs[job_id]["ended_at"] = time.time()
325
+
326
+ asyncio.create_task(_run_job())
327
+ return {"job_id": job_id, "status": "queued"}
328
+
329
+
330
+ @app.get("/job/{job_id}")
331
+ def job_status(job_id: str, _=Depends(verify_bearer)):
332
+ data = _jobs.get(job_id)
333
+ if not data:
334
+ raise HTTPException(status_code=404, detail="job not found")
335
+ return data
336
+
337
+