Spaces:
Running
Running
| """CVAT webhook listener — triggers finish_review when a job is marked completed.""" | |
| import hmac | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import threading | |
| from hashlib import sha256 | |
| from pathlib import Path | |
| from fastapi import FastAPI, Request, HTTPException | |
| app = FastAPI() | |
| def _require_env(name: str) -> str: | |
| val = os.environ.get(name, "").strip() | |
| if not val: | |
| raise RuntimeError(f"Missing required env var: {name}") | |
| return val | |
| DATASET = _require_env("HF_DATASET") | |
| CVAT_TOKEN = _require_env("CVAT_TOKEN") | |
| CVAT_WEBHOOK_SECRET = _require_env("CVAT_WEBHOOK_SECRET") | |
| GITHUB_PAT = _require_env("GITHUB_PAT") | |
| REPO_URL = _require_env("REPO_URL") | |
| REPO_REF = _require_env("REPO_REF") | |
| CVAT_URL = os.environ.get("CVAT_URL", "https://app.cvat.ai").strip() | |
| def _clone_repo(workdir: Path) -> Path: | |
| repo_url = REPO_URL | |
| if GITHUB_PAT and "github.com" in repo_url: | |
| repo_url = repo_url.replace("https://", f"https://{GITHUB_PAT}@") | |
| repo_dir = workdir / "repo" | |
| print(f"Cloning {REPO_REF}...") | |
| result = subprocess.run( | |
| ["git", "clone", "--depth", "1", "-b", REPO_REF, repo_url, str(repo_dir)], | |
| capture_output=True, text=True, | |
| timeout=60, | |
| env={**os.environ, "GIT_TERMINAL_PROMPT": "0"}, | |
| ) | |
| if result.returncode != 0: | |
| print(f"Clone failed: {result.stderr}") | |
| raise RuntimeError(f"git clone failed: {result.stderr[:200]}") | |
| print("Clone done") | |
| return repo_dir | |
| def _run_finish_review(repo_dir: Path, task_id: int) -> None: | |
| proc = subprocess.Popen( | |
| [ | |
| sys.executable, "-u", str(repo_dir / "scripts" / "finish_review.py"), | |
| "--task-id", str(task_id), | |
| "--dataset", DATASET, | |
| "--experiment", f"cvat_review_{task_id}", | |
| "--labelmap", str(repo_dir / "labelmap.txt"), | |
| "--cvat-url", CVAT_URL, | |
| "--cvat-token", CVAT_TOKEN, | |
| ], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| cwd=str(repo_dir), | |
| ) | |
| for line in proc.stdout: | |
| print(line, end="", flush=True) | |
| proc.wait() | |
| if proc.returncode != 0: | |
| raise RuntimeError(f"finish_review exited with code {proc.returncode}") | |
| def _verify_signature(body: bytes, signature: str) -> bool: | |
| if not CVAT_WEBHOOK_SECRET: | |
| return True | |
| expected = "sha256=" + hmac.new( | |
| CVAT_WEBHOOK_SECRET.encode("utf-8"), body, digestmod=sha256 | |
| ).hexdigest() | |
| return hmac.compare_digest(signature, expected) | |
| async def cvat_webhook(request: Request): | |
| raw_body = await request.body() | |
| signature = request.headers.get("X-Signature-256", "") | |
| if not _verify_signature(raw_body, signature): | |
| raise HTTPException(status_code=403, detail="Invalid signature") | |
| body = json.loads(raw_body) | |
| event = body.get("event", "") | |
| if event != "update:job": | |
| return {"status": "ignored", "event": event} | |
| job = body.get("job", {}) | |
| state = job.get("state", "") | |
| before = body.get("before_update", {}) | |
| prev_state = before.get("state", "") | |
| if state != "completed" or prev_state == "completed": | |
| return {"status": "ignored", "reason": f"state={state}, prev={prev_state}"} | |
| task_id = job.get("task_id") | |
| if not task_id: | |
| raise HTTPException(status_code=400, detail="No task_id in payload") | |
| print(f"Job completed — task_id={task_id}, running finish_review in background...") | |
| def _run_in_background(tid: int): | |
| try: | |
| with tempfile.TemporaryDirectory() as workdir: | |
| repo_dir = _clone_repo(Path(workdir)) | |
| print(f"Running finish_review for task {tid}...", flush=True) | |
| _run_finish_review(repo_dir, tid) | |
| print(f"finish_review completed for task {tid}", flush=True) | |
| except subprocess.TimeoutExpired: | |
| print(f"finish_review timed out for task {tid}", flush=True) | |
| except Exception as exc: | |
| print(f"finish_review failed for task {tid}: {exc}", flush=True) | |
| threading.Thread(target=_run_in_background, args=(task_id,), daemon=True).start() | |
| return {"status": "accepted", "task_id": task_id} | |
| async def health(): | |
| return {"status": "ok", "dataset": DATASET} | |