sapiens2-cpu / app.py
Nekochu's picture
Hard-evict caches when loading 1B variants
86547e5
"""Sapiens2 multi-task CPU: seg / normal / pointmap / pose at 0.4b/0.8b/1b plus 5B INT8 ONNX.
5B (seg, normal, pointmap) runs via INT8 ONNX from WeReCooking/sapiens2-onnx; pose-5b not shipped.
Pose top-down: DETR finds people, sapiens2 estimates 308 keypoints per crop.
Lazy-load with LRU cache (keeps 2 dense models + 1 pose model resident).
Per-task API endpoint via Gradio's auto-API (curl-able with Bearer token).
Also exposes a standalone ONNX CLI mode that does not need PyTorch or sapiens2:
python app.py onnx seg 0.4b photo.jpg --output seg.png
python app.py onnx pointmap 5b photo.jpg --output depth.png
"""
# Block mmpretrain: mmdet's reid modules try to import it via try/except ImportError,
# but mmpretrain raises TypeError on import (transformers API drift) which escapes
# the except and kills the process.
import sys
sys.modules["mmpretrain"] = None
# --- ONNX CLI (standalone, no PyTorch/sapiens2 import) ----------------------
def _onnx_cli():
"""Run a published sapiens2 ONNX model on a local image. Only needs numpy,
onnxruntime, huggingface_hub, opencv-python-headless."""
import argparse
import os
import time
from pathlib import Path
import numpy as np
import cv2
import onnxruntime as ort
from huggingface_hub import hf_hub_download
DEFAULT_REPO = "WeReCooking/sapiens2-onnx"
PRECISIONS = {("seg", "0.4b"): "fp16"} # only seg-0.4b is fp16; rest fp32 or int8 for 5B
INPUT_HW = (1024, 768)
parser = argparse.ArgumentParser(prog="app.py onnx")
parser.add_argument("task", choices=["seg", "normal", "pointmap", "pose"])
parser.add_argument("size", choices=["0.4b", "0.8b", "1b", "5b"])
parser.add_argument("image", help="Local image path")
parser.add_argument("--cache-dir", default="./onnx_cache")
parser.add_argument("--token", default=os.environ.get("HF_TOKEN"))
parser.add_argument("--output", default=None, help="Save the visualization here")
parser.add_argument("--repo", default=DEFAULT_REPO)
args = parser.parse_args(sys.argv[2:])
precision = PRECISIONS.get((args.task, args.size), "int8" if args.size == "5b" else "fp32")
filename = f"{args.task}/{args.task}_{args.size}_{precision}.onnx"
print(f"[1/3] downloading {filename} from {args.repo}", flush=True)
t0 = time.time()
onnx_path = hf_hub_download(repo_id=args.repo, filename=filename, local_dir=args.cache_dir, token=args.token)
hf_hub_download(repo_id=args.repo, filename=f"{filename}.data", local_dir=args.cache_dir, token=args.token)
print(f" ready in {time.time()-t0:.1f}s", flush=True)
img = cv2.imread(args.image, cv2.IMREAD_COLOR)
if img is None:
raise FileNotFoundError(args.image)
H, W = INPUT_HW
h0, w0 = img.shape[:2]
scale = min(W / w0, H / h0)
new_w, new_h = int(round(w0 * scale)), int(round(h0 * scale))
resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
canvas = np.zeros((H, W, 3), dtype=np.uint8)
top = (H - new_h) // 2
left = (W - new_w) // 2
canvas[top:top + new_h, left:left + new_w] = resized
mean = (123.675, 116.28, 103.53)
std = (58.395, 57.12, 57.375)
x = canvas.astype(np.float32)
for c in range(3):
x[:, :, c] = (x[:, :, c] - mean[c]) / std[c]
x = x.transpose(2, 0, 1)[None]
print(f"[2/3] ORT forward (input {x.shape} {x.dtype})", flush=True)
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
t0 = time.time()
out = sess.run(None, {sess.get_inputs()[0].name: x})
print(f" forward {time.time()-t0:.1f}s, outputs={[o.shape for o in out]}", flush=True)
print(f"[3/3] postprocess + preview", flush=True)
if args.task == "pose":
heatmaps = out[0][0]
K, hH, hW = heatmaps.shape
flat = heatmaps.reshape(K, -1)
peak = flat.argmax(axis=1)
ys, xs = np.unravel_index(peak, (hH, hW))
scores = flat.max(axis=1)
inp_y = ys * (INPUT_HW[0] / hH)
inp_x = xs * (INPUT_HW[1] / hW)
scale_y = h0 / new_h
scale_x = w0 / new_w
img_y = (inp_y - top) * scale_y
img_x = (inp_x - left) * scale_x
n_visible = int((scores > 0.3).sum())
print(f" {n_visible}/{K} keypoints above 0.3 confidence (range {scores.min():.3f} to {scores.max():.3f})")
if args.output:
for i in range(K):
if scores[i] < 0.3:
continue
cv2.circle(img, (int(img_x[i]), int(img_y[i])), 4, (0, 255, 0), -1)
cv2.imwrite(args.output, img)
print(f" saved {args.output}")
return
if args.task == "seg":
logits = out[0][0]
class_map = logits.argmax(axis=0).astype(np.int32)
class_map_crop = class_map[top:top + new_h, left:left + new_w]
class_map_full = cv2.resize(class_map_crop, (w0, h0), interpolation=cv2.INTER_NEAREST)
classes = np.unique(class_map_full).tolist()
print(f" classes detected: {classes[:15]}")
if args.output:
palette = (np.random.RandomState(42).rand(29, 3) * 255).astype(np.uint8)
cv2.imwrite(args.output, palette[class_map_full])
print(f" saved {args.output}")
return
if args.task == "normal":
normal_raw = out[0][0].transpose(1, 2, 0)
norm = np.linalg.norm(normal_raw, axis=2, keepdims=True)
normal_unit = normal_raw / np.maximum(norm, 1e-8)
normal_crop = normal_unit[top:top + new_h, left:left + new_w]
normal_full = cv2.resize(normal_crop, (w0, h0), interpolation=cv2.INTER_LINEAR)
if args.output:
rgb = (((normal_full + 1.0) / 2.0) * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite(args.output, rgb)
print(f" saved {args.output}")
return
# pointmap
pointmap_rel = out[0][0].transpose(1, 2, 0)
s = out[1][0, 0] if len(out) > 1 else 1.0
pointmap_metric = pointmap_rel / max(float(s), 1e-8)
z = pointmap_metric[..., 2]
z_crop = z[top:top + new_h, left:left + new_w]
z_full = cv2.resize(z_crop, (w0, h0), interpolation=cv2.INTER_LINEAR)
zmin, zmax = float(z_full.min()), float(z_full.max())
print(f" Z range: [{zmin:.2f}, {zmax:.2f}] meters")
if args.output:
z_norm = ((z_full - zmin) / max(zmax - zmin, 1e-8) * 255).astype(np.uint8)
cv2.imwrite(args.output, z_norm)
print(f" saved {args.output}")
if len(sys.argv) > 1 and sys.argv[1] == "onnx":
_onnx_cli()
sys.exit(0)
# --- Gradio path -----------------------------------------------------------
import glob
import os
import time
import traceback
from pathlib import Path
import gradio as gr
import numpy as np
from PIL import Image
# --- Catalog ----------------------------------------------------------------
VARIANTS = {
("seg", "0.4b"): {"repo": "facebook/sapiens2-seg-0.4b", "filename": "sapiens2_0.4b_seg.safetensors", "config_glob": "**/sapiens2_0.4b_seg*shutterstock*1024x768*.py", "kind": "seg"},
("seg", "0.8b"): {"repo": "facebook/sapiens2-seg-0.8b", "filename": "sapiens2_0.8b_seg.safetensors", "config_glob": "**/sapiens2_0.8b_seg*shutterstock*1024x768*.py", "kind": "seg"},
("seg", "1b"): {"repo": "facebook/sapiens2-seg-1b", "filename": "sapiens2_1b_seg.safetensors", "config_glob": "**/sapiens2_1b_seg*shutterstock*1024x768*.py", "kind": "seg"},
("normal", "0.4b"): {"repo": "facebook/sapiens2-normal-0.4b", "filename": "sapiens2_0.4b_normal.safetensors", "config_glob": "**/sapiens2_0.4b_normal*metasim*1024x768*.py", "kind": "normal"},
("normal", "0.8b"): {"repo": "facebook/sapiens2-normal-0.8b", "filename": "sapiens2_0.8b_normal.safetensors", "config_glob": "**/sapiens2_0.8b_normal*metasim*1024x768*.py", "kind": "normal"},
("normal", "1b"): {"repo": "facebook/sapiens2-normal-1b", "filename": "sapiens2_1b_normal.safetensors", "config_glob": "**/sapiens2_1b_normal*metasim*1024x768*.py", "kind": "normal"},
("pointmap", "0.4b"): {"repo": "facebook/sapiens2-pointmap-0.4b", "filename": "sapiens2_0.4b_pointmap.safetensors", "config_glob": "**/sapiens2_0.4b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
("pointmap", "0.8b"): {"repo": "facebook/sapiens2-pointmap-0.8b", "filename": "sapiens2_0.8b_pointmap.safetensors", "config_glob": "**/sapiens2_0.8b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
("pointmap", "1b"): {"repo": "facebook/sapiens2-pointmap-1b", "filename": "sapiens2_1b_pointmap.safetensors", "config_glob": "**/sapiens2_1b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
("pose", "0.4b"): {"repo": "facebook/sapiens2-pose-0.4b", "filename": "sapiens2_0.4b_pose.safetensors", "config_glob": "**/sapiens2_0.4b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
("pose", "0.8b"): {"repo": "facebook/sapiens2-pose-0.8b", "filename": "sapiens2_0.8b_pose.safetensors", "config_glob": "**/sapiens2_0.8b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
("pose", "1b"): {"repo": "facebook/sapiens2-pose-1b", "filename": "sapiens2_1b_pose.safetensors", "config_glob": "**/sapiens2_1b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
# 5B variants run via prebuilt INT8 ONNX from WeReCooking/sapiens2-onnx.
# fp32 5B PyTorch (~20 GB) won't fit in the free CPU Space's 16 GB; INT8 ONNX is ~5-6 GB.
# pose-5b is intentionally absent — INT8 wasn't successfully built for it.
("seg", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "seg/seg_5b_int8.onnx", "kind": "seg"},
("normal", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "normal/normal_5b_int8.onnx", "kind": "normal"},
("pointmap", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "pointmap/pointmap_5b_int8.onnx", "kind": "pointmap"},
}
DENSE_KINDS = {"seg", "normal", "pointmap"}
_MODELS: dict = {} # (task, size) -> dense model (LRU)
_POSE_MODELS: dict = {} # (task, size) -> pose model (separate cache so DETR survives)
_DETECTOR = None # tuple(processor, model) — lazily loaded once
_POSE_METAINFO = None
_ORT_SESSIONS: dict = {} # (task, "5b") -> onnxruntime InferenceSession
_MAX_CACHED = 2
_DOME_CLASSES_29 = None
_SAPIENS_PKG_ROOT = None
def _sapiens_root() -> Path:
"""Return the directory containing the installed sapiens package."""
global _SAPIENS_PKG_ROOT
if _SAPIENS_PKG_ROOT is None:
import sapiens # imported lazily because it has side effects (mmdet etc.)
_SAPIENS_PKG_ROOT = Path(sapiens.__file__).resolve().parent
return _SAPIENS_PKG_ROOT
def _find_config(pattern: str) -> str:
# cfg_glob comes in as "**/sapiens2_..._1024x768*.py"; rglob applies the leading ** implicitly
leaf = pattern.split("/")[-1]
root = _sapiens_root()
matches = list(root.rglob(leaf))
if not matches:
raise FileNotFoundError(f"No config matching {leaf} under {root}")
return str(matches[0])
def _get_dense_model(task: str, size: str):
"""Lazy-load + LRU-cache for seg/normal/pointmap."""
key = (task, size)
if key in _MODELS:
_MODELS[key] = _MODELS.pop(key)
return _MODELS[key]
spec = VARIANTS[key]
from sapiens.dense.models import init_model
if spec["kind"] == "normal":
from sapiens.dense.models import NormalEstimator # noqa: F401
elif spec["kind"] == "pointmap":
from sapiens.dense.models import PointmapEstimator # noqa: F401
config = _find_config(spec["config_glob"])
from huggingface_hub import hf_hub_download
local_dir = f"/tmp/sapiens_models/{task}-{size}"
os.makedirs(local_dir, exist_ok=True)
ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir)
# cpu-basic has 16 GB. Loading a 1B dense (~6 GB fp32) on top of cached 0.8b/0.4b dense (~5 GB each) + a 1B pose + DETR OOMs.
# So before init_model allocates a 1B's weights, evict ALL caches it would race with.
import gc
if size == "1b":
_MODELS.clear()
_POSE_MODELS.clear()
_ORT_SESSIONS.clear()
gc.collect()
else:
while len(_MODELS) >= _MAX_CACHED:
oldest = next(iter(_MODELS))
del _MODELS[oldest]
gc.collect()
model = init_model(config, ckpt, device="cpu")
_MODELS[key] = model
return model
def _get_pose_metainfo():
global _POSE_METAINFO
if _POSE_METAINFO is None:
from sapiens.pose.datasets import parse_pose_metainfo
meta_cfg = _find_config("**/pose/configs/**/keypoints308.py")
import importlib.util
spec_obj = importlib.util.spec_from_file_location("keypoints308_meta", meta_cfg)
mod = importlib.util.module_from_spec(spec_obj)
spec_obj.loader.exec_module(mod)
ds_info = getattr(mod, "dataset_info", None)
if ds_info is None:
raise RuntimeError(f"No dataset_info in {meta_cfg}")
_POSE_METAINFO = parse_pose_metainfo(ds_info)
return _POSE_METAINFO
def _get_pose_model(size: str):
key = ("pose", size)
if key in _POSE_MODELS:
return _POSE_MODELS[key]
spec = VARIANTS[key]
from sapiens.pose.models import init_model
from sapiens.pose.datasets import UDPHeatmap
config = _find_config(spec["config_glob"])
from huggingface_hub import hf_hub_download
local_dir = f"/tmp/sapiens_models/pose-{size}"
os.makedirs(local_dir, exist_ok=True)
ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir)
# Same hard eviction as the dense 1B path: clear every other resident model before init_model allocates.
import gc
if size == "1b":
_MODELS.clear()
_POSE_MODELS.clear()
_ORT_SESSIONS.clear()
else:
_POSE_MODELS.clear() # cap=1
gc.collect()
model = init_model(config, ckpt, device="cpu")
codec_cfg = dict(model.cfg.codec)
assert codec_cfg.pop("type") == "UDPHeatmap"
model.codec = UDPHeatmap(**codec_cfg)
model.pose_metainfo = _get_pose_metainfo()
_POSE_MODELS[key] = model
return model
def _get_detector():
global _DETECTOR
if _DETECTOR is None:
import torch # noqa: F401
from transformers import DetrImageProcessor, DetrForObjectDetection
proc = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
det = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval()
_DETECTOR = (proc, det)
return _DETECTOR
def _load_dome_classes():
global _DOME_CLASSES_29
if _DOME_CLASSES_29 is None:
from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29
_DOME_CLASSES_29 = DOME_CLASSES_29
return _DOME_CLASSES_29
def _get_padding(data_samples):
ds = data_samples[0] if isinstance(data_samples, list) and data_samples else data_samples
if hasattr(ds, "padding_size"):
return tuple(ds.padding_size)
if hasattr(ds, "metainfo") and isinstance(ds.metainfo, dict):
if "padding_size" in ds.metainfo:
return tuple(ds.metainfo["padding_size"])
if "pad_shape" in ds.metainfo and "img_shape" in ds.metainfo:
ph, pw = ds.metainfo["pad_shape"][:2]
ih, iw = ds.metainfo["img_shape"][:2]
return (0, pw - iw, 0, ph - ih)
if isinstance(ds, dict):
meta = ds.get("meta") or ds
if "padding_size" in meta:
return tuple(meta["padding_size"])
return (0, 0, 0, 0)
# --- Per-task inference -----------------------------------------------------
def _infer_seg(image_bgr, model):
import torch
import torch.nn.functional as F
import cv2
h0, w0 = image_bgr.shape[:2]
data = model.pipeline(dict(img=image_bgr))
data = model.data_preprocessor(data)
with torch.no_grad():
logits = model(data["inputs"])
logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False)
label_map = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32)
classes = _load_dome_classes()
palette = np.zeros((256, 3), dtype=np.uint8)
for cid, meta in classes.items():
palette[cid] = meta["color"][::-1]
color_mask = palette[label_map]
overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0)
overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
uniq = sorted(int(c) for c in np.unique(label_map))
labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes]
return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}"
def _infer_normal(image_bgr, model):
import torch
data = model.pipeline(dict(img=image_bgr))
data = model.data_preprocessor(data)
inputs, data_samples = data["inputs"], data["data_samples"]
if inputs.ndim == 3:
inputs = inputs.unsqueeze(0)
with torch.no_grad():
normal = model(inputs)
normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
pl, pr, pt, pb = _get_padding(data_samples)
normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
normal_hwc = normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
return Image.fromarray(rgb), f"normal map {rgb.shape}"
def _infer_pointmap(image_bgr, model):
import torch
data = model.pipeline(dict(img=image_bgr))
data = model.data_preprocessor(data)
inputs, data_samples = data["inputs"], data["data_samples"]
if inputs.ndim == 3:
inputs = inputs.unsqueeze(0)
with torch.no_grad():
out = model(inputs)
if isinstance(out, tuple) and len(out) == 2:
pointmap, scale = out
pointmap = pointmap / scale.clamp_min(1e-8)
else:
pointmap = out
pl, pr, pt, pb = _get_padding(data_samples)
pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
pmap_hwc = pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
z = pmap_hwc[..., 2]
z_min, z_max = float(z.min()), float(z.max())
z_norm = (z - z_min) / max(z_max - z_min, 1e-8)
z_rgb = (z_norm * 255).astype(np.uint8)
rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1)
return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]"
# --- 5B INT8 ONNX path -------------------------------------------------------
def _get_ort_session(task: str):
"""Lazy-load + cache an ORT session for {task}_5b_int8.onnx.
Each 5B session is 5-6 GB RAM. cpu-basic has 16 GB total, so keep at most one
5B session live and evict cached dense/pose PyTorch models that would push us OOM."""
key = (task, "5b")
sess = _ORT_SESSIONS.get(key)
if sess is not None:
return sess
import onnxruntime as ort
from huggingface_hub import hf_hub_download
spec = VARIANTS[key]
cache_dir = os.environ.get("ONNX_5B_CACHE", "/app/onnx_5b")
os.makedirs(cache_dir, exist_ok=True)
fn = spec["onnx_filename"]
onnx_path = hf_hub_download(repo_id=spec["onnx_repo"], filename=fn, local_dir=cache_dir)
hf_hub_download(repo_id=spec["onnx_repo"], filename=fn + ".data", local_dir=cache_dir)
# Evict any prior 5B ORT session and any 1b dense models — they together exceed 16 GB.
import gc
if _ORT_SESSIONS:
_ORT_SESSIONS.clear()
gc.collect()
for k in list(_MODELS.keys()):
if k[1] in ("1b", "0.8b"):
del _MODELS[k]
for k in list(_POSE_MODELS.keys()):
if k[1] in ("1b", "0.8b"):
del _POSE_MODELS[k]
gc.collect()
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
_ORT_SESSIONS[key] = sess
return sess
def _infer_dense_5b(image_bgr, task: str):
"""5B inference: preprocess via the 0.4b PyTorch pipeline (cached), forward via ORT INT8."""
import torch
import torch.nn.functional as F
import cv2
# Use the 0.4b model's pipeline+preprocessor for image prep — it's already in cache for warm calls.
proxy = _get_dense_model(task, "0.4b")
data = proxy.pipeline(dict(img=image_bgr))
data = proxy.data_preprocessor(data)
inputs, data_samples = data["inputs"], data["data_samples"]
if inputs.ndim == 3:
inputs = inputs.unsqueeze(0)
sess = _get_ort_session(task)
out = sess.run(None, {sess.get_inputs()[0].name: inputs.float().cpu().numpy()})
if task == "seg":
logits = torch.from_numpy(out[0])
h0, w0 = image_bgr.shape[:2]
logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False)
label_map = logits.argmax(dim=1).squeeze(0).numpy().astype(np.int32)
classes = _load_dome_classes()
palette = np.zeros((256, 3), dtype=np.uint8)
for cid, meta in classes.items():
palette[cid] = meta["color"][::-1]
color_mask = palette[label_map]
overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0)
overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
uniq = sorted(int(c) for c in np.unique(label_map))
labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes]
return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}"
if task == "normal":
normal = torch.from_numpy(out[0])
normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
pl, pr, pt, pb = _get_padding(data_samples)
normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
normal_hwc = normal.squeeze(0).numpy().transpose(1, 2, 0)
rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
return Image.fromarray(rgb), f"normal map {rgb.shape}"
# pointmap — ONNX produces (pointmap [1,3,H,W], scale [1,1]); divide to recover metric depths.
pointmap = torch.from_numpy(out[0])
if len(out) > 1:
scale = torch.from_numpy(out[1])
pointmap = pointmap / scale.clamp_min(1e-8)
pl, pr, pt, pb = _get_padding(data_samples)
pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
pmap_hwc = pointmap.squeeze(0).numpy().transpose(1, 2, 0)
z = pmap_hwc[..., 2]
z_min, z_max = float(z.min()), float(z.max())
z_norm = (z - z_min) / max(z_max - z_min, 1e-8)
z_rgb = (z_norm * 255).astype(np.uint8)
rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1)
return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]"
# Inlined from the upstream Meta sample (pose keypoint render).
# Draws skeleton links + colored keypoints; thickness/radius are picked by the caller.
def visualize_keypoints(
image: np.ndarray,
keypoints,
keypoints_visible,
keypoint_scores,
*,
radius: int = 4,
thickness: int = -1,
color=(255, 0, 0),
kpt_thr: float = 0.3,
skeleton: list | None = None,
kpt_color=None,
link_color=None,
show_kpt_idx: bool = False,
) -> np.ndarray:
import cv2
img = image.copy()
H, W = img.shape[:2]
if skeleton is None:
skeleton = []
if kpt_color is None:
kpt_color = color
if link_color is None:
link_color = (0, 255, 0)
def _as_color_list(c, n):
if hasattr(c, "detach"):
c = c.detach().cpu().numpy()
if isinstance(c, np.ndarray):
if c.ndim == 2 and c.shape[1] == 3:
return [tuple(int(v) for v in row) for row in c.tolist()]
if c.size == 3:
return [tuple(int(v) for v in c.tolist())] * max(1, n)
if isinstance(c, (list, tuple)):
if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)):
out = []
for cc in c:
cc = np.asarray(cc).reshape(-1)
out.append(tuple(int(v) for v in cc.tolist()))
return out
c_arr = np.asarray(c).reshape(-1)
if c_arr.size == 3:
return [tuple(int(v) for v in c_arr.tolist())] * max(1, n)
return [(255, 0, 0)] * max(1, n)
J = keypoints[0].shape[0] if keypoints else 0
kpt_colors = _as_color_list(kpt_color, J)
link_colors = _as_color_list(link_color, len(skeleton))
def in_bounds(x, y):
return 0 <= x < W and 0 <= y < H
for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores):
kpts = np.asarray(kpts, float)
vis = np.asarray(vis).reshape(-1).astype(bool)
score = np.asarray(score).reshape(-1)
for lk, (i, j) in enumerate(skeleton):
if i >= len(kpts) or j >= len(kpts):
continue
if not (vis[i] and vis[j]):
continue
if score[i] < kpt_thr or score[j] < kpt_thr:
continue
x1, y1 = map(int, np.round(kpts[i]))
x2, y2 = map(int, np.round(kpts[j]))
if not (in_bounds(x1, y1) and in_bounds(x2, y2)):
continue
cv2.line(img, (x1, y1), (x2, y2), link_colors[lk % len(link_colors)],
thickness=max(1, thickness), lineType=cv2.LINE_AA)
for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)):
if not v or s < kpt_thr:
continue
x, y = map(int, np.round(xy))
if not in_bounds(x, y):
continue
c = kpt_colors[min(j_idx, len(kpt_colors) - 1)]
cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA)
if show_kpt_idx:
cv2.putText(img, str(j_idx), (x + radius, y - radius),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA)
return img
def _detect_persons(image_rgb: np.ndarray, threshold: float = 0.5):
import torch
proc, det = _get_detector()
pil_img = Image.fromarray(image_rgb)
inputs = proc(images=pil_img, return_tensors="pt")
with torch.no_grad():
outputs = det(**inputs)
target_sizes = torch.tensor([image_rgb.shape[:2]])
results = proc.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=threshold
)[0]
person_mask = results["labels"] == 1 # COCO class 1 = person
boxes = results["boxes"][person_mask].cpu().numpy()
scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1)
if len(boxes) == 0:
h, w = image_rgb.shape[:2]
return np.array([[0, 0, w - 1, h - 1, 1.0]], dtype=np.float32)
return np.concatenate([boxes, scores], axis=1).astype(np.float32)
def _infer_pose(image_bgr, model, kpt_thr: float = 0.3):
import torch
import cv2
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
bboxes = _detect_persons(image_rgb)
inputs_list, samples_list = [], []
for bbox in bboxes:
data_info = dict(img=image_bgr, bbox=bbox[None, :4], bbox_score=np.ones(1, dtype=np.float32))
data = model.pipeline(data_info)
data = model.data_preprocessor(data)
inputs_list.append(data["inputs"])
samples_list.append(data["data_samples"])
inputs = torch.cat(inputs_list, dim=0)
with torch.no_grad():
pred = model(inputs).cpu().numpy()
keypoints, scores = [], []
for i, sample in enumerate(samples_list):
kpts_i, scr_i = model.codec.decode(pred[i])
meta = sample["meta"] if isinstance(sample, dict) else sample.metainfo
kpts_i = kpts_i / np.array(meta["input_size"]) * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"]
keypoints.append(kpts_i[0])
scores.append(scr_i[0])
pmeta = model.pose_metainfo
vis_rgb = image_rgb.copy()
# Scale render thickness so 308-keypoint dense pose stays visible on high-res input
short_side = min(vis_rgb.shape[:2])
radius_px = max(3, short_side // 200)
thick_px = max(2, short_side // 250)
box_thick = max(2, short_side // 300)
for bbox, kpts, scr in zip(bboxes, keypoints, scores):
x1, y1, x2, y2 = map(int, bbox[:4])
cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), box_thick)
vis_rgb = visualize_keypoints(
image=vis_rgb,
keypoints=[kpts],
keypoints_visible=[np.ones(len(scr), dtype=bool)],
keypoint_scores=[scr],
radius=radius_px, thickness=thick_px, kpt_thr=kpt_thr,
skeleton=pmeta["skeleton_links"],
kpt_color=pmeta["keypoint_colors"],
link_color=pmeta["skeleton_link_colors"],
)
return Image.fromarray(vis_rgb), f"persons={len(bboxes)} | kpts/person={len(keypoints[0]) if keypoints else 0}"
# --- Predict entry point ----------------------------------------------------
def predict(image: Image.Image, task: str, size: str):
if image is None:
return None, "No image provided"
key = (task, size)
if key not in VARIANTS:
return None, f"Unknown variant {task}-{size}. Allowed: {sorted(VARIANTS.keys())}"
t0 = time.time()
try:
import cv2
image_pil = image.convert("RGB")
in_w, in_h = image_pil.size
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
kind = VARIANTS[key]["kind"]
if size == "5b":
out_img, info = _infer_dense_5b(image_bgr, task)
elif kind == "pose":
model = _get_pose_model(size)
out_img, info = _infer_pose(image_bgr, model)
else:
model = _get_dense_model(task, size)
if kind == "seg":
out_img, info = _infer_seg(image_bgr, model)
elif kind == "normal":
out_img, info = _infer_normal(image_bgr, model)
elif kind == "pointmap":
out_img, info = _infer_pointmap(image_bgr, model)
else:
return None, f"Unhandled kind: {kind}"
elapsed = time.time() - t0
out_w, out_h = out_img.size
return out_img, f"{task}-{size}: done in {elapsed:.1f}s | {in_w}×{in_h} → 1024×768 → {out_w}×{out_h} | {info}"
except Exception as e:
return None, f"{type(e).__name__}: {e}\n\n{traceback.format_exc()[:1500]}"
def health():
return (
f"Service up | dense cache: {list(_MODELS.keys())} | pose cache: {list(_POSE_MODELS.keys())} | "
f"detector_loaded={_DETECTOR is not None} | variants={len(VARIANTS)} "
f"({sorted(set(t for t, _ in VARIANTS))} × {sorted(set(s for _, s in VARIANTS))})"
)
DEMO_IMAGES = sorted(str(p) for p in Path("/app/assets/images").glob("*.jpg"))
with gr.Blocks(title="Sapiens2 CPU", css="""
#img-in,#img-out{max-height:220px}
#status-box textarea{max-height:60px!important;min-height:60px!important}
#status-box{flex-grow:0!important}
""") as demo:
with gr.Row(equal_height=False):
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="Input", height=200, elem_id="img-in")
with gr.Row():
task_in = gr.Dropdown(choices=["seg", "normal", "pointmap", "pose"], value="seg", label="Task", scale=1)
size_in = gr.Dropdown(choices=["0.4b", "0.8b", "1b", "5b"], value="0.4b", label="Size", scale=1)
run_btn = gr.Button("Predict - 1024×768 native", variant="primary")
gr.Examples(
examples=[[u] for u in DEMO_IMAGES],
inputs=[img_in],
examples_per_page=6,
cache_examples=False,
label="Meta demo images",
)
with gr.Column(scale=1):
img_out = gr.Image(type="pil", label="Output", height=200, elem_id="img-out")
status = gr.Textbox(show_label=False, lines=2, max_lines=2, interactive=False, container=False, placeholder="Status will show here after Predict", elem_id="status-box")
run_btn.click(
fn=predict, inputs=[img_in, task_in, size_in], outputs=[img_out, status], api_name="predict"
)
# Keep health endpoint accessible via API (no UI button — useless in browser)
gr.Button(visible=False).click(fn=health, outputs=[gr.Textbox(visible=False)], api_name="health")
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch()