Spaces:
Running
Running
| """Sapiens2 multi-task CPU: seg / normal / pointmap / pose at 0.4b/0.8b/1b plus 5B INT8 ONNX. | |
| 5B (seg, normal, pointmap) runs via INT8 ONNX from WeReCooking/sapiens2-onnx; pose-5b not shipped. | |
| Pose top-down: DETR finds people, sapiens2 estimates 308 keypoints per crop. | |
| Lazy-load with LRU cache (keeps 2 dense models + 1 pose model resident). | |
| Per-task API endpoint via Gradio's auto-API (curl-able with Bearer token). | |
| Also exposes a standalone ONNX CLI mode that does not need PyTorch or sapiens2: | |
| python app.py onnx seg 0.4b photo.jpg --output seg.png | |
| python app.py onnx pointmap 5b photo.jpg --output depth.png | |
| """ | |
| # Block mmpretrain: mmdet's reid modules try to import it via try/except ImportError, | |
| # but mmpretrain raises TypeError on import (transformers API drift) which escapes | |
| # the except and kills the process. | |
| import sys | |
| sys.modules["mmpretrain"] = None | |
| # --- ONNX CLI (standalone, no PyTorch/sapiens2 import) ---------------------- | |
| def _onnx_cli(): | |
| """Run a published sapiens2 ONNX model on a local image. Only needs numpy, | |
| onnxruntime, huggingface_hub, opencv-python-headless.""" | |
| import argparse | |
| import os | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import cv2 | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| DEFAULT_REPO = "WeReCooking/sapiens2-onnx" | |
| PRECISIONS = {("seg", "0.4b"): "fp16"} # only seg-0.4b is fp16; rest fp32 or int8 for 5B | |
| INPUT_HW = (1024, 768) | |
| parser = argparse.ArgumentParser(prog="app.py onnx") | |
| parser.add_argument("task", choices=["seg", "normal", "pointmap", "pose"]) | |
| parser.add_argument("size", choices=["0.4b", "0.8b", "1b", "5b"]) | |
| parser.add_argument("image", help="Local image path") | |
| parser.add_argument("--cache-dir", default="./onnx_cache") | |
| parser.add_argument("--token", default=os.environ.get("HF_TOKEN")) | |
| parser.add_argument("--output", default=None, help="Save the visualization here") | |
| parser.add_argument("--repo", default=DEFAULT_REPO) | |
| args = parser.parse_args(sys.argv[2:]) | |
| precision = PRECISIONS.get((args.task, args.size), "int8" if args.size == "5b" else "fp32") | |
| filename = f"{args.task}/{args.task}_{args.size}_{precision}.onnx" | |
| print(f"[1/3] downloading {filename} from {args.repo}", flush=True) | |
| t0 = time.time() | |
| onnx_path = hf_hub_download(repo_id=args.repo, filename=filename, local_dir=args.cache_dir, token=args.token) | |
| hf_hub_download(repo_id=args.repo, filename=f"{filename}.data", local_dir=args.cache_dir, token=args.token) | |
| print(f" ready in {time.time()-t0:.1f}s", flush=True) | |
| img = cv2.imread(args.image, cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise FileNotFoundError(args.image) | |
| H, W = INPUT_HW | |
| h0, w0 = img.shape[:2] | |
| scale = min(W / w0, H / h0) | |
| new_w, new_h = int(round(w0 * scale)), int(round(h0 * scale)) | |
| resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
| canvas = np.zeros((H, W, 3), dtype=np.uint8) | |
| top = (H - new_h) // 2 | |
| left = (W - new_w) // 2 | |
| canvas[top:top + new_h, left:left + new_w] = resized | |
| mean = (123.675, 116.28, 103.53) | |
| std = (58.395, 57.12, 57.375) | |
| x = canvas.astype(np.float32) | |
| for c in range(3): | |
| x[:, :, c] = (x[:, :, c] - mean[c]) / std[c] | |
| x = x.transpose(2, 0, 1)[None] | |
| print(f"[2/3] ORT forward (input {x.shape} {x.dtype})", flush=True) | |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) | |
| t0 = time.time() | |
| out = sess.run(None, {sess.get_inputs()[0].name: x}) | |
| print(f" forward {time.time()-t0:.1f}s, outputs={[o.shape for o in out]}", flush=True) | |
| print(f"[3/3] postprocess + preview", flush=True) | |
| if args.task == "pose": | |
| heatmaps = out[0][0] | |
| K, hH, hW = heatmaps.shape | |
| flat = heatmaps.reshape(K, -1) | |
| peak = flat.argmax(axis=1) | |
| ys, xs = np.unravel_index(peak, (hH, hW)) | |
| scores = flat.max(axis=1) | |
| inp_y = ys * (INPUT_HW[0] / hH) | |
| inp_x = xs * (INPUT_HW[1] / hW) | |
| scale_y = h0 / new_h | |
| scale_x = w0 / new_w | |
| img_y = (inp_y - top) * scale_y | |
| img_x = (inp_x - left) * scale_x | |
| n_visible = int((scores > 0.3).sum()) | |
| print(f" {n_visible}/{K} keypoints above 0.3 confidence (range {scores.min():.3f} to {scores.max():.3f})") | |
| if args.output: | |
| for i in range(K): | |
| if scores[i] < 0.3: | |
| continue | |
| cv2.circle(img, (int(img_x[i]), int(img_y[i])), 4, (0, 255, 0), -1) | |
| cv2.imwrite(args.output, img) | |
| print(f" saved {args.output}") | |
| return | |
| if args.task == "seg": | |
| logits = out[0][0] | |
| class_map = logits.argmax(axis=0).astype(np.int32) | |
| class_map_crop = class_map[top:top + new_h, left:left + new_w] | |
| class_map_full = cv2.resize(class_map_crop, (w0, h0), interpolation=cv2.INTER_NEAREST) | |
| classes = np.unique(class_map_full).tolist() | |
| print(f" classes detected: {classes[:15]}") | |
| if args.output: | |
| palette = (np.random.RandomState(42).rand(29, 3) * 255).astype(np.uint8) | |
| cv2.imwrite(args.output, palette[class_map_full]) | |
| print(f" saved {args.output}") | |
| return | |
| if args.task == "normal": | |
| normal_raw = out[0][0].transpose(1, 2, 0) | |
| norm = np.linalg.norm(normal_raw, axis=2, keepdims=True) | |
| normal_unit = normal_raw / np.maximum(norm, 1e-8) | |
| normal_crop = normal_unit[top:top + new_h, left:left + new_w] | |
| normal_full = cv2.resize(normal_crop, (w0, h0), interpolation=cv2.INTER_LINEAR) | |
| if args.output: | |
| rgb = (((normal_full + 1.0) / 2.0) * 255).clip(0, 255).astype(np.uint8) | |
| cv2.imwrite(args.output, rgb) | |
| print(f" saved {args.output}") | |
| return | |
| # pointmap | |
| pointmap_rel = out[0][0].transpose(1, 2, 0) | |
| s = out[1][0, 0] if len(out) > 1 else 1.0 | |
| pointmap_metric = pointmap_rel / max(float(s), 1e-8) | |
| z = pointmap_metric[..., 2] | |
| z_crop = z[top:top + new_h, left:left + new_w] | |
| z_full = cv2.resize(z_crop, (w0, h0), interpolation=cv2.INTER_LINEAR) | |
| zmin, zmax = float(z_full.min()), float(z_full.max()) | |
| print(f" Z range: [{zmin:.2f}, {zmax:.2f}] meters") | |
| if args.output: | |
| z_norm = ((z_full - zmin) / max(zmax - zmin, 1e-8) * 255).astype(np.uint8) | |
| cv2.imwrite(args.output, z_norm) | |
| print(f" saved {args.output}") | |
| if len(sys.argv) > 1 and sys.argv[1] == "onnx": | |
| _onnx_cli() | |
| sys.exit(0) | |
| # --- Gradio path ----------------------------------------------------------- | |
| import glob | |
| import os | |
| import time | |
| import traceback | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| # --- Catalog ---------------------------------------------------------------- | |
| VARIANTS = { | |
| ("seg", "0.4b"): {"repo": "facebook/sapiens2-seg-0.4b", "filename": "sapiens2_0.4b_seg.safetensors", "config_glob": "**/sapiens2_0.4b_seg*shutterstock*1024x768*.py", "kind": "seg"}, | |
| ("seg", "0.8b"): {"repo": "facebook/sapiens2-seg-0.8b", "filename": "sapiens2_0.8b_seg.safetensors", "config_glob": "**/sapiens2_0.8b_seg*shutterstock*1024x768*.py", "kind": "seg"}, | |
| ("seg", "1b"): {"repo": "facebook/sapiens2-seg-1b", "filename": "sapiens2_1b_seg.safetensors", "config_glob": "**/sapiens2_1b_seg*shutterstock*1024x768*.py", "kind": "seg"}, | |
| ("normal", "0.4b"): {"repo": "facebook/sapiens2-normal-0.4b", "filename": "sapiens2_0.4b_normal.safetensors", "config_glob": "**/sapiens2_0.4b_normal*metasim*1024x768*.py", "kind": "normal"}, | |
| ("normal", "0.8b"): {"repo": "facebook/sapiens2-normal-0.8b", "filename": "sapiens2_0.8b_normal.safetensors", "config_glob": "**/sapiens2_0.8b_normal*metasim*1024x768*.py", "kind": "normal"}, | |
| ("normal", "1b"): {"repo": "facebook/sapiens2-normal-1b", "filename": "sapiens2_1b_normal.safetensors", "config_glob": "**/sapiens2_1b_normal*metasim*1024x768*.py", "kind": "normal"}, | |
| ("pointmap", "0.4b"): {"repo": "facebook/sapiens2-pointmap-0.4b", "filename": "sapiens2_0.4b_pointmap.safetensors", "config_glob": "**/sapiens2_0.4b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, | |
| ("pointmap", "0.8b"): {"repo": "facebook/sapiens2-pointmap-0.8b", "filename": "sapiens2_0.8b_pointmap.safetensors", "config_glob": "**/sapiens2_0.8b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, | |
| ("pointmap", "1b"): {"repo": "facebook/sapiens2-pointmap-1b", "filename": "sapiens2_1b_pointmap.safetensors", "config_glob": "**/sapiens2_1b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, | |
| ("pose", "0.4b"): {"repo": "facebook/sapiens2-pose-0.4b", "filename": "sapiens2_0.4b_pose.safetensors", "config_glob": "**/sapiens2_0.4b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, | |
| ("pose", "0.8b"): {"repo": "facebook/sapiens2-pose-0.8b", "filename": "sapiens2_0.8b_pose.safetensors", "config_glob": "**/sapiens2_0.8b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, | |
| ("pose", "1b"): {"repo": "facebook/sapiens2-pose-1b", "filename": "sapiens2_1b_pose.safetensors", "config_glob": "**/sapiens2_1b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, | |
| # 5B variants run via prebuilt INT8 ONNX from WeReCooking/sapiens2-onnx. | |
| # fp32 5B PyTorch (~20 GB) won't fit in the free CPU Space's 16 GB; INT8 ONNX is ~5-6 GB. | |
| # pose-5b is intentionally absent — INT8 wasn't successfully built for it. | |
| ("seg", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "seg/seg_5b_int8.onnx", "kind": "seg"}, | |
| ("normal", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "normal/normal_5b_int8.onnx", "kind": "normal"}, | |
| ("pointmap", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "pointmap/pointmap_5b_int8.onnx", "kind": "pointmap"}, | |
| } | |
| DENSE_KINDS = {"seg", "normal", "pointmap"} | |
| _MODELS: dict = {} # (task, size) -> dense model (LRU) | |
| _POSE_MODELS: dict = {} # (task, size) -> pose model (separate cache so DETR survives) | |
| _DETECTOR = None # tuple(processor, model) — lazily loaded once | |
| _POSE_METAINFO = None | |
| _ORT_SESSIONS: dict = {} # (task, "5b") -> onnxruntime InferenceSession | |
| _MAX_CACHED = 2 | |
| _DOME_CLASSES_29 = None | |
| _SAPIENS_PKG_ROOT = None | |
| def _sapiens_root() -> Path: | |
| """Return the directory containing the installed sapiens package.""" | |
| global _SAPIENS_PKG_ROOT | |
| if _SAPIENS_PKG_ROOT is None: | |
| import sapiens # imported lazily because it has side effects (mmdet etc.) | |
| _SAPIENS_PKG_ROOT = Path(sapiens.__file__).resolve().parent | |
| return _SAPIENS_PKG_ROOT | |
| def _find_config(pattern: str) -> str: | |
| # cfg_glob comes in as "**/sapiens2_..._1024x768*.py"; rglob applies the leading ** implicitly | |
| leaf = pattern.split("/")[-1] | |
| root = _sapiens_root() | |
| matches = list(root.rglob(leaf)) | |
| if not matches: | |
| raise FileNotFoundError(f"No config matching {leaf} under {root}") | |
| return str(matches[0]) | |
| def _get_dense_model(task: str, size: str): | |
| """Lazy-load + LRU-cache for seg/normal/pointmap.""" | |
| key = (task, size) | |
| if key in _MODELS: | |
| _MODELS[key] = _MODELS.pop(key) | |
| return _MODELS[key] | |
| spec = VARIANTS[key] | |
| from sapiens.dense.models import init_model | |
| if spec["kind"] == "normal": | |
| from sapiens.dense.models import NormalEstimator # noqa: F401 | |
| elif spec["kind"] == "pointmap": | |
| from sapiens.dense.models import PointmapEstimator # noqa: F401 | |
| config = _find_config(spec["config_glob"]) | |
| from huggingface_hub import hf_hub_download | |
| local_dir = f"/tmp/sapiens_models/{task}-{size}" | |
| os.makedirs(local_dir, exist_ok=True) | |
| ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir) | |
| # cpu-basic has 16 GB. Loading a 1B dense (~6 GB fp32) on top of cached 0.8b/0.4b dense (~5 GB each) + a 1B pose + DETR OOMs. | |
| # So before init_model allocates a 1B's weights, evict ALL caches it would race with. | |
| import gc | |
| if size == "1b": | |
| _MODELS.clear() | |
| _POSE_MODELS.clear() | |
| _ORT_SESSIONS.clear() | |
| gc.collect() | |
| else: | |
| while len(_MODELS) >= _MAX_CACHED: | |
| oldest = next(iter(_MODELS)) | |
| del _MODELS[oldest] | |
| gc.collect() | |
| model = init_model(config, ckpt, device="cpu") | |
| _MODELS[key] = model | |
| return model | |
| def _get_pose_metainfo(): | |
| global _POSE_METAINFO | |
| if _POSE_METAINFO is None: | |
| from sapiens.pose.datasets import parse_pose_metainfo | |
| meta_cfg = _find_config("**/pose/configs/**/keypoints308.py") | |
| import importlib.util | |
| spec_obj = importlib.util.spec_from_file_location("keypoints308_meta", meta_cfg) | |
| mod = importlib.util.module_from_spec(spec_obj) | |
| spec_obj.loader.exec_module(mod) | |
| ds_info = getattr(mod, "dataset_info", None) | |
| if ds_info is None: | |
| raise RuntimeError(f"No dataset_info in {meta_cfg}") | |
| _POSE_METAINFO = parse_pose_metainfo(ds_info) | |
| return _POSE_METAINFO | |
| def _get_pose_model(size: str): | |
| key = ("pose", size) | |
| if key in _POSE_MODELS: | |
| return _POSE_MODELS[key] | |
| spec = VARIANTS[key] | |
| from sapiens.pose.models import init_model | |
| from sapiens.pose.datasets import UDPHeatmap | |
| config = _find_config(spec["config_glob"]) | |
| from huggingface_hub import hf_hub_download | |
| local_dir = f"/tmp/sapiens_models/pose-{size}" | |
| os.makedirs(local_dir, exist_ok=True) | |
| ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir) | |
| # Same hard eviction as the dense 1B path: clear every other resident model before init_model allocates. | |
| import gc | |
| if size == "1b": | |
| _MODELS.clear() | |
| _POSE_MODELS.clear() | |
| _ORT_SESSIONS.clear() | |
| else: | |
| _POSE_MODELS.clear() # cap=1 | |
| gc.collect() | |
| model = init_model(config, ckpt, device="cpu") | |
| codec_cfg = dict(model.cfg.codec) | |
| assert codec_cfg.pop("type") == "UDPHeatmap" | |
| model.codec = UDPHeatmap(**codec_cfg) | |
| model.pose_metainfo = _get_pose_metainfo() | |
| _POSE_MODELS[key] = model | |
| return model | |
| def _get_detector(): | |
| global _DETECTOR | |
| if _DETECTOR is None: | |
| import torch # noqa: F401 | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| proc = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| det = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval() | |
| _DETECTOR = (proc, det) | |
| return _DETECTOR | |
| def _load_dome_classes(): | |
| global _DOME_CLASSES_29 | |
| if _DOME_CLASSES_29 is None: | |
| from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29 | |
| _DOME_CLASSES_29 = DOME_CLASSES_29 | |
| return _DOME_CLASSES_29 | |
| def _get_padding(data_samples): | |
| ds = data_samples[0] if isinstance(data_samples, list) and data_samples else data_samples | |
| if hasattr(ds, "padding_size"): | |
| return tuple(ds.padding_size) | |
| if hasattr(ds, "metainfo") and isinstance(ds.metainfo, dict): | |
| if "padding_size" in ds.metainfo: | |
| return tuple(ds.metainfo["padding_size"]) | |
| if "pad_shape" in ds.metainfo and "img_shape" in ds.metainfo: | |
| ph, pw = ds.metainfo["pad_shape"][:2] | |
| ih, iw = ds.metainfo["img_shape"][:2] | |
| return (0, pw - iw, 0, ph - ih) | |
| if isinstance(ds, dict): | |
| meta = ds.get("meta") or ds | |
| if "padding_size" in meta: | |
| return tuple(meta["padding_size"]) | |
| return (0, 0, 0, 0) | |
| # --- Per-task inference ----------------------------------------------------- | |
| def _infer_seg(image_bgr, model): | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| h0, w0 = image_bgr.shape[:2] | |
| data = model.pipeline(dict(img=image_bgr)) | |
| data = model.data_preprocessor(data) | |
| with torch.no_grad(): | |
| logits = model(data["inputs"]) | |
| logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) | |
| label_map = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32) | |
| classes = _load_dome_classes() | |
| palette = np.zeros((256, 3), dtype=np.uint8) | |
| for cid, meta in classes.items(): | |
| palette[cid] = meta["color"][::-1] | |
| color_mask = palette[label_map] | |
| overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0) | |
| overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) | |
| uniq = sorted(int(c) for c in np.unique(label_map)) | |
| labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes] | |
| return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}" | |
| def _infer_normal(image_bgr, model): | |
| import torch | |
| data = model.pipeline(dict(img=image_bgr)) | |
| data = model.data_preprocessor(data) | |
| inputs, data_samples = data["inputs"], data["data_samples"] | |
| if inputs.ndim == 3: | |
| inputs = inputs.unsqueeze(0) | |
| with torch.no_grad(): | |
| normal = model(inputs) | |
| normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8) | |
| pl, pr, pt, pb = _get_padding(data_samples) | |
| normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] | |
| normal_hwc = normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) | |
| rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(rgb), f"normal map {rgb.shape}" | |
| def _infer_pointmap(image_bgr, model): | |
| import torch | |
| data = model.pipeline(dict(img=image_bgr)) | |
| data = model.data_preprocessor(data) | |
| inputs, data_samples = data["inputs"], data["data_samples"] | |
| if inputs.ndim == 3: | |
| inputs = inputs.unsqueeze(0) | |
| with torch.no_grad(): | |
| out = model(inputs) | |
| if isinstance(out, tuple) and len(out) == 2: | |
| pointmap, scale = out | |
| pointmap = pointmap / scale.clamp_min(1e-8) | |
| else: | |
| pointmap = out | |
| pl, pr, pt, pb = _get_padding(data_samples) | |
| pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] | |
| pmap_hwc = pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) | |
| z = pmap_hwc[..., 2] | |
| z_min, z_max = float(z.min()), float(z.max()) | |
| z_norm = (z - z_min) / max(z_max - z_min, 1e-8) | |
| z_rgb = (z_norm * 255).astype(np.uint8) | |
| rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1) | |
| return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]" | |
| # --- 5B INT8 ONNX path ------------------------------------------------------- | |
| def _get_ort_session(task: str): | |
| """Lazy-load + cache an ORT session for {task}_5b_int8.onnx. | |
| Each 5B session is 5-6 GB RAM. cpu-basic has 16 GB total, so keep at most one | |
| 5B session live and evict cached dense/pose PyTorch models that would push us OOM.""" | |
| key = (task, "5b") | |
| sess = _ORT_SESSIONS.get(key) | |
| if sess is not None: | |
| return sess | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| spec = VARIANTS[key] | |
| cache_dir = os.environ.get("ONNX_5B_CACHE", "/app/onnx_5b") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| fn = spec["onnx_filename"] | |
| onnx_path = hf_hub_download(repo_id=spec["onnx_repo"], filename=fn, local_dir=cache_dir) | |
| hf_hub_download(repo_id=spec["onnx_repo"], filename=fn + ".data", local_dir=cache_dir) | |
| # Evict any prior 5B ORT session and any 1b dense models — they together exceed 16 GB. | |
| import gc | |
| if _ORT_SESSIONS: | |
| _ORT_SESSIONS.clear() | |
| gc.collect() | |
| for k in list(_MODELS.keys()): | |
| if k[1] in ("1b", "0.8b"): | |
| del _MODELS[k] | |
| for k in list(_POSE_MODELS.keys()): | |
| if k[1] in ("1b", "0.8b"): | |
| del _POSE_MODELS[k] | |
| gc.collect() | |
| sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) | |
| _ORT_SESSIONS[key] = sess | |
| return sess | |
| def _infer_dense_5b(image_bgr, task: str): | |
| """5B inference: preprocess via the 0.4b PyTorch pipeline (cached), forward via ORT INT8.""" | |
| import torch | |
| import torch.nn.functional as F | |
| import cv2 | |
| # Use the 0.4b model's pipeline+preprocessor for image prep — it's already in cache for warm calls. | |
| proxy = _get_dense_model(task, "0.4b") | |
| data = proxy.pipeline(dict(img=image_bgr)) | |
| data = proxy.data_preprocessor(data) | |
| inputs, data_samples = data["inputs"], data["data_samples"] | |
| if inputs.ndim == 3: | |
| inputs = inputs.unsqueeze(0) | |
| sess = _get_ort_session(task) | |
| out = sess.run(None, {sess.get_inputs()[0].name: inputs.float().cpu().numpy()}) | |
| if task == "seg": | |
| logits = torch.from_numpy(out[0]) | |
| h0, w0 = image_bgr.shape[:2] | |
| logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) | |
| label_map = logits.argmax(dim=1).squeeze(0).numpy().astype(np.int32) | |
| classes = _load_dome_classes() | |
| palette = np.zeros((256, 3), dtype=np.uint8) | |
| for cid, meta in classes.items(): | |
| palette[cid] = meta["color"][::-1] | |
| color_mask = palette[label_map] | |
| overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0) | |
| overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) | |
| uniq = sorted(int(c) for c in np.unique(label_map)) | |
| labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes] | |
| return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}" | |
| if task == "normal": | |
| normal = torch.from_numpy(out[0]) | |
| normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8) | |
| pl, pr, pt, pb = _get_padding(data_samples) | |
| normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] | |
| normal_hwc = normal.squeeze(0).numpy().transpose(1, 2, 0) | |
| rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) | |
| return Image.fromarray(rgb), f"normal map {rgb.shape}" | |
| # pointmap — ONNX produces (pointmap [1,3,H,W], scale [1,1]); divide to recover metric depths. | |
| pointmap = torch.from_numpy(out[0]) | |
| if len(out) > 1: | |
| scale = torch.from_numpy(out[1]) | |
| pointmap = pointmap / scale.clamp_min(1e-8) | |
| pl, pr, pt, pb = _get_padding(data_samples) | |
| pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] | |
| pmap_hwc = pointmap.squeeze(0).numpy().transpose(1, 2, 0) | |
| z = pmap_hwc[..., 2] | |
| z_min, z_max = float(z.min()), float(z.max()) | |
| z_norm = (z - z_min) / max(z_max - z_min, 1e-8) | |
| z_rgb = (z_norm * 255).astype(np.uint8) | |
| rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1) | |
| return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]" | |
| # Inlined from the upstream Meta sample (pose keypoint render). | |
| # Draws skeleton links + colored keypoints; thickness/radius are picked by the caller. | |
| def visualize_keypoints( | |
| image: np.ndarray, | |
| keypoints, | |
| keypoints_visible, | |
| keypoint_scores, | |
| *, | |
| radius: int = 4, | |
| thickness: int = -1, | |
| color=(255, 0, 0), | |
| kpt_thr: float = 0.3, | |
| skeleton: list | None = None, | |
| kpt_color=None, | |
| link_color=None, | |
| show_kpt_idx: bool = False, | |
| ) -> np.ndarray: | |
| import cv2 | |
| img = image.copy() | |
| H, W = img.shape[:2] | |
| if skeleton is None: | |
| skeleton = [] | |
| if kpt_color is None: | |
| kpt_color = color | |
| if link_color is None: | |
| link_color = (0, 255, 0) | |
| def _as_color_list(c, n): | |
| if hasattr(c, "detach"): | |
| c = c.detach().cpu().numpy() | |
| if isinstance(c, np.ndarray): | |
| if c.ndim == 2 and c.shape[1] == 3: | |
| return [tuple(int(v) for v in row) for row in c.tolist()] | |
| if c.size == 3: | |
| return [tuple(int(v) for v in c.tolist())] * max(1, n) | |
| if isinstance(c, (list, tuple)): | |
| if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): | |
| out = [] | |
| for cc in c: | |
| cc = np.asarray(cc).reshape(-1) | |
| out.append(tuple(int(v) for v in cc.tolist())) | |
| return out | |
| c_arr = np.asarray(c).reshape(-1) | |
| if c_arr.size == 3: | |
| return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) | |
| return [(255, 0, 0)] * max(1, n) | |
| J = keypoints[0].shape[0] if keypoints else 0 | |
| kpt_colors = _as_color_list(kpt_color, J) | |
| link_colors = _as_color_list(link_color, len(skeleton)) | |
| def in_bounds(x, y): | |
| return 0 <= x < W and 0 <= y < H | |
| for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): | |
| kpts = np.asarray(kpts, float) | |
| vis = np.asarray(vis).reshape(-1).astype(bool) | |
| score = np.asarray(score).reshape(-1) | |
| for lk, (i, j) in enumerate(skeleton): | |
| if i >= len(kpts) or j >= len(kpts): | |
| continue | |
| if not (vis[i] and vis[j]): | |
| continue | |
| if score[i] < kpt_thr or score[j] < kpt_thr: | |
| continue | |
| x1, y1 = map(int, np.round(kpts[i])) | |
| x2, y2 = map(int, np.round(kpts[j])) | |
| if not (in_bounds(x1, y1) and in_bounds(x2, y2)): | |
| continue | |
| cv2.line(img, (x1, y1), (x2, y2), link_colors[lk % len(link_colors)], | |
| thickness=max(1, thickness), lineType=cv2.LINE_AA) | |
| for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): | |
| if not v or s < kpt_thr: | |
| continue | |
| x, y = map(int, np.round(xy)) | |
| if not in_bounds(x, y): | |
| continue | |
| c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] | |
| cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA) | |
| if show_kpt_idx: | |
| cv2.putText(img, str(j_idx), (x + radius, y - radius), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA) | |
| return img | |
| def _detect_persons(image_rgb: np.ndarray, threshold: float = 0.5): | |
| import torch | |
| proc, det = _get_detector() | |
| pil_img = Image.fromarray(image_rgb) | |
| inputs = proc(images=pil_img, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = det(**inputs) | |
| target_sizes = torch.tensor([image_rgb.shape[:2]]) | |
| results = proc.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=threshold | |
| )[0] | |
| person_mask = results["labels"] == 1 # COCO class 1 = person | |
| boxes = results["boxes"][person_mask].cpu().numpy() | |
| scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1) | |
| if len(boxes) == 0: | |
| h, w = image_rgb.shape[:2] | |
| return np.array([[0, 0, w - 1, h - 1, 1.0]], dtype=np.float32) | |
| return np.concatenate([boxes, scores], axis=1).astype(np.float32) | |
| def _infer_pose(image_bgr, model, kpt_thr: float = 0.3): | |
| import torch | |
| import cv2 | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| bboxes = _detect_persons(image_rgb) | |
| inputs_list, samples_list = [], [] | |
| for bbox in bboxes: | |
| data_info = dict(img=image_bgr, bbox=bbox[None, :4], bbox_score=np.ones(1, dtype=np.float32)) | |
| data = model.pipeline(data_info) | |
| data = model.data_preprocessor(data) | |
| inputs_list.append(data["inputs"]) | |
| samples_list.append(data["data_samples"]) | |
| inputs = torch.cat(inputs_list, dim=0) | |
| with torch.no_grad(): | |
| pred = model(inputs).cpu().numpy() | |
| keypoints, scores = [], [] | |
| for i, sample in enumerate(samples_list): | |
| kpts_i, scr_i = model.codec.decode(pred[i]) | |
| meta = sample["meta"] if isinstance(sample, dict) else sample.metainfo | |
| kpts_i = kpts_i / np.array(meta["input_size"]) * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"] | |
| keypoints.append(kpts_i[0]) | |
| scores.append(scr_i[0]) | |
| pmeta = model.pose_metainfo | |
| vis_rgb = image_rgb.copy() | |
| # Scale render thickness so 308-keypoint dense pose stays visible on high-res input | |
| short_side = min(vis_rgb.shape[:2]) | |
| radius_px = max(3, short_side // 200) | |
| thick_px = max(2, short_side // 250) | |
| box_thick = max(2, short_side // 300) | |
| for bbox, kpts, scr in zip(bboxes, keypoints, scores): | |
| x1, y1, x2, y2 = map(int, bbox[:4]) | |
| cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), box_thick) | |
| vis_rgb = visualize_keypoints( | |
| image=vis_rgb, | |
| keypoints=[kpts], | |
| keypoints_visible=[np.ones(len(scr), dtype=bool)], | |
| keypoint_scores=[scr], | |
| radius=radius_px, thickness=thick_px, kpt_thr=kpt_thr, | |
| skeleton=pmeta["skeleton_links"], | |
| kpt_color=pmeta["keypoint_colors"], | |
| link_color=pmeta["skeleton_link_colors"], | |
| ) | |
| return Image.fromarray(vis_rgb), f"persons={len(bboxes)} | kpts/person={len(keypoints[0]) if keypoints else 0}" | |
| # --- Predict entry point ---------------------------------------------------- | |
| def predict(image: Image.Image, task: str, size: str): | |
| if image is None: | |
| return None, "No image provided" | |
| key = (task, size) | |
| if key not in VARIANTS: | |
| return None, f"Unknown variant {task}-{size}. Allowed: {sorted(VARIANTS.keys())}" | |
| t0 = time.time() | |
| try: | |
| import cv2 | |
| image_pil = image.convert("RGB") | |
| in_w, in_h = image_pil.size | |
| image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| kind = VARIANTS[key]["kind"] | |
| if size == "5b": | |
| out_img, info = _infer_dense_5b(image_bgr, task) | |
| elif kind == "pose": | |
| model = _get_pose_model(size) | |
| out_img, info = _infer_pose(image_bgr, model) | |
| else: | |
| model = _get_dense_model(task, size) | |
| if kind == "seg": | |
| out_img, info = _infer_seg(image_bgr, model) | |
| elif kind == "normal": | |
| out_img, info = _infer_normal(image_bgr, model) | |
| elif kind == "pointmap": | |
| out_img, info = _infer_pointmap(image_bgr, model) | |
| else: | |
| return None, f"Unhandled kind: {kind}" | |
| elapsed = time.time() - t0 | |
| out_w, out_h = out_img.size | |
| return out_img, f"{task}-{size}: done in {elapsed:.1f}s | {in_w}×{in_h} → 1024×768 → {out_w}×{out_h} | {info}" | |
| except Exception as e: | |
| return None, f"{type(e).__name__}: {e}\n\n{traceback.format_exc()[:1500]}" | |
| def health(): | |
| return ( | |
| f"Service up | dense cache: {list(_MODELS.keys())} | pose cache: {list(_POSE_MODELS.keys())} | " | |
| f"detector_loaded={_DETECTOR is not None} | variants={len(VARIANTS)} " | |
| f"({sorted(set(t for t, _ in VARIANTS))} × {sorted(set(s for _, s in VARIANTS))})" | |
| ) | |
| DEMO_IMAGES = sorted(str(p) for p in Path("/app/assets/images").glob("*.jpg")) | |
| with gr.Blocks(title="Sapiens2 CPU", css=""" | |
| #img-in,#img-out{max-height:220px} | |
| #status-box textarea{max-height:60px!important;min-height:60px!important} | |
| #status-box{flex-grow:0!important} | |
| """) as demo: | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="Input", height=200, elem_id="img-in") | |
| with gr.Row(): | |
| task_in = gr.Dropdown(choices=["seg", "normal", "pointmap", "pose"], value="seg", label="Task", scale=1) | |
| size_in = gr.Dropdown(choices=["0.4b", "0.8b", "1b", "5b"], value="0.4b", label="Size", scale=1) | |
| run_btn = gr.Button("Predict - 1024×768 native", variant="primary") | |
| gr.Examples( | |
| examples=[[u] for u in DEMO_IMAGES], | |
| inputs=[img_in], | |
| examples_per_page=6, | |
| cache_examples=False, | |
| label="Meta demo images", | |
| ) | |
| with gr.Column(scale=1): | |
| img_out = gr.Image(type="pil", label="Output", height=200, elem_id="img-out") | |
| status = gr.Textbox(show_label=False, lines=2, max_lines=2, interactive=False, container=False, placeholder="Status will show here after Predict", elem_id="status-box") | |
| run_btn.click( | |
| fn=predict, inputs=[img_in, task_in, size_in], outputs=[img_out, status], api_name="predict" | |
| ) | |
| # Keep health endpoint accessible via API (no UI button — useless in browser) | |
| gr.Button(visible=False).click(fn=health, outputs=[gr.Textbox(visible=False)], api_name="health") | |
| demo.queue(default_concurrency_limit=1) | |
| if __name__ == "__main__": | |
| demo.launch() | |