Spaces:
Sleeping
Sleeping
Justin Wood commited on
Commit ·
c401d3e
0
Parent(s):
Initial backend
Browse files- README.md +20 -0
- app.py +161 -0
- depth.py +52 -0
- reconstruction.py +123 -0
- requirements.txt +13 -0
- segmentation.py +91 -0
README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Jiggle Physics Backend
|
| 3 |
+
emoji: 🎯
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "4.44.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
REST API for the Jiggle Physics Simulator. Endpoints mounted on Gradio's FastAPI instance for ZeroGPU compatibility.
|
| 14 |
+
|
| 15 |
+
| Endpoint | Method | What it does |
|
| 16 |
+
|---|---|---|
|
| 17 |
+
| `/health` | GET | Liveness check |
|
| 18 |
+
| `/segment` | POST | SAM2 body region masks |
|
| 19 |
+
| `/depth` | POST | Apple Depth Pro metric depth |
|
| 20 |
+
| `/reconstruct` | POST | TripoSR → GLB mesh |
|
app.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Jiggle Physics Simulator — HuggingFace Space backend
|
| 3 |
+
ZeroGPU pattern: routes are added to Gradio's underlying FastAPI instance.
|
| 4 |
+
Deploy with sdk: gradio and ZeroGPU hardware selected in Space settings.
|
| 5 |
+
"""
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
import json
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
import spaces
|
| 13 |
+
import numpy as np
|
| 14 |
+
from fastapi import File, Form, UploadFile, HTTPException
|
| 15 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 16 |
+
from fastapi.responses import Response
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
# ── Minimal Gradio UI (required for ZeroGPU Spaces) ──────────────────────────
|
| 20 |
+
with gr.Blocks(title="Jiggle Physics API") as demo:
|
| 21 |
+
gr.Markdown("## Jiggle Physics ML API\nREST endpoints: `/segment` `/depth` `/reconstruct`")
|
| 22 |
+
|
| 23 |
+
# Grab Gradio's underlying FastAPI app and add CORS
|
| 24 |
+
app = demo.app
|
| 25 |
+
app.add_middleware(
|
| 26 |
+
CORSMiddleware,
|
| 27 |
+
allow_origins=["*"],
|
| 28 |
+
allow_methods=["POST", "GET"],
|
| 29 |
+
allow_headers=["*"],
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_image(upload: UploadFile) -> Image.Image:
|
| 34 |
+
data = upload.file.read()
|
| 35 |
+
img = Image.open(io.BytesIO(data)).convert("RGB")
|
| 36 |
+
max_dim = 1024
|
| 37 |
+
if max(img.size) > max_dim:
|
| 38 |
+
ratio = max_dim / max(img.size)
|
| 39 |
+
img = img.resize(
|
| 40 |
+
(int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS
|
| 41 |
+
)
|
| 42 |
+
return img
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.get("/health")
|
| 46 |
+
def health():
|
| 47 |
+
return {"status": "ok"}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.post("/segment")
|
| 51 |
+
@spaces.GPU
|
| 52 |
+
async def segment(
|
| 53 |
+
image: UploadFile = File(...),
|
| 54 |
+
regions: str = Form("breast_left,breast_right,buttocks"),
|
| 55 |
+
click_points: Optional[str] = Form(None),
|
| 56 |
+
):
|
| 57 |
+
"""SAM2 body region segmentation. Returns RLE-encoded masks + bounding boxes."""
|
| 58 |
+
from segmentation import segment_regions
|
| 59 |
+
|
| 60 |
+
img = _load_image(image)
|
| 61 |
+
region_list = [r.strip() for r in regions.split(",") if r.strip()]
|
| 62 |
+
clicks = json.loads(click_points) if click_points else None
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
result = segment_regions(img, region_list, clicks)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 68 |
+
|
| 69 |
+
encoded = {}
|
| 70 |
+
for region, data in result.items():
|
| 71 |
+
mask_arr = np.array(data["mask"], dtype=bool)
|
| 72 |
+
flat = mask_arr.flatten()
|
| 73 |
+
rle: list[int] = []
|
| 74 |
+
current = bool(flat[0])
|
| 75 |
+
count = 0
|
| 76 |
+
for val in flat:
|
| 77 |
+
if bool(val) == current:
|
| 78 |
+
count += 1
|
| 79 |
+
else:
|
| 80 |
+
rle.append(count)
|
| 81 |
+
count = 1
|
| 82 |
+
current = bool(val)
|
| 83 |
+
rle.append(count)
|
| 84 |
+
encoded[region] = {
|
| 85 |
+
"rle": rle,
|
| 86 |
+
"rle_start": bool(flat[0]),
|
| 87 |
+
"shape": list(mask_arr.shape),
|
| 88 |
+
"bbox": data["bbox"],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return {"regions": encoded, "image_size": [img.width, img.height]}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.post("/depth")
|
| 95 |
+
@spaces.GPU
|
| 96 |
+
async def depth(image: UploadFile = File(...)):
|
| 97 |
+
"""Apple Depth Pro depth estimation. Returns base64 float32 depth map."""
|
| 98 |
+
from depth import estimate_depth
|
| 99 |
+
|
| 100 |
+
img = _load_image(image)
|
| 101 |
+
try:
|
| 102 |
+
result = estimate_depth(img)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 105 |
+
|
| 106 |
+
arr = np.array(result["depth"], dtype=np.float32)
|
| 107 |
+
b64 = base64.b64encode(arr.tobytes()).decode("ascii")
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"depth_b64": b64,
|
| 111 |
+
"width": result["width"],
|
| 112 |
+
"height": result["height"],
|
| 113 |
+
"min": result["min"],
|
| 114 |
+
"max": result["max"],
|
| 115 |
+
"dtype": "float32",
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@app.post("/reconstruct")
|
| 120 |
+
@spaces.GPU
|
| 121 |
+
async def reconstruct(
|
| 122 |
+
image: UploadFile = File(...),
|
| 123 |
+
mask_rle: str = Form(...),
|
| 124 |
+
mask_shape: str = Form(...),
|
| 125 |
+
mask_rle_start: str = Form("false"),
|
| 126 |
+
bbox: str = Form(...),
|
| 127 |
+
use_triposr: str = Form("true"),
|
| 128 |
+
):
|
| 129 |
+
"""TripoSR single-image 3D reconstruction. Returns GLB binary."""
|
| 130 |
+
from reconstruction import reconstruct_region, depth_to_mesh
|
| 131 |
+
from depth import estimate_depth
|
| 132 |
+
|
| 133 |
+
img = _load_image(image)
|
| 134 |
+
|
| 135 |
+
rle = json.loads(mask_rle)
|
| 136 |
+
shape = json.loads(mask_shape)
|
| 137 |
+
start_val = mask_rle_start.lower() == "true"
|
| 138 |
+
flat: list[bool] = []
|
| 139 |
+
current = start_val
|
| 140 |
+
for run in rle:
|
| 141 |
+
flat.extend([current] * run)
|
| 142 |
+
current = not current
|
| 143 |
+
mask = np.array(flat, dtype=bool).reshape(shape)
|
| 144 |
+
|
| 145 |
+
bbox_list = json.loads(bbox)
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
if use_triposr.lower() == "true":
|
| 149 |
+
glb_bytes = reconstruct_region(img, mask.tolist(), bbox_list)
|
| 150 |
+
else:
|
| 151 |
+
depth_result = estimate_depth(img)
|
| 152 |
+
glb_bytes = depth_to_mesh(depth_result["depth"], mask.tolist(), img)
|
| 153 |
+
except Exception as e:
|
| 154 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 155 |
+
|
| 156 |
+
return Response(content=glb_bytes, media_type="application/octet-stream")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
depth.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
_depth_cache = None
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_depth_model():
|
| 10 |
+
global _depth_cache
|
| 11 |
+
if _depth_cache is None:
|
| 12 |
+
from transformers import pipeline as hf_pipeline
|
| 13 |
+
# Depth Pro from Apple — best metric depth, ~600MB
|
| 14 |
+
_depth_cache = hf_pipeline(
|
| 15 |
+
"depth-estimation",
|
| 16 |
+
model="apple/DepthPro-hf",
|
| 17 |
+
device=0 if torch.cuda.is_available() else -1,
|
| 18 |
+
)
|
| 19 |
+
return _depth_cache
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def estimate_depth(image: Image.Image) -> dict:
|
| 23 |
+
"""
|
| 24 |
+
Returns {"depth": [[float]], "width": int, "height": int, "min": float, "max": float}
|
| 25 |
+
Depth values are metric (meters) when Depth Pro is used.
|
| 26 |
+
"""
|
| 27 |
+
pipe = get_depth_model()
|
| 28 |
+
result = pipe(image)
|
| 29 |
+
depth_map = result["depth"] # PIL image or numpy array
|
| 30 |
+
|
| 31 |
+
if isinstance(depth_map, Image.Image):
|
| 32 |
+
arr = np.array(depth_map).astype(np.float32)
|
| 33 |
+
else:
|
| 34 |
+
arr = np.array(depth_map, dtype=np.float32)
|
| 35 |
+
|
| 36 |
+
# Resize to match source image if needed
|
| 37 |
+
if arr.shape[:2] != (image.height, image.width):
|
| 38 |
+
depth_pil = Image.fromarray(arr).resize(
|
| 39 |
+
(image.width, image.height), Image.BILINEAR
|
| 40 |
+
)
|
| 41 |
+
arr = np.array(depth_pil)
|
| 42 |
+
|
| 43 |
+
dmin = float(arr.min())
|
| 44 |
+
dmax = float(arr.max())
|
| 45 |
+
|
| 46 |
+
return {
|
| 47 |
+
"depth": arr.tolist(),
|
| 48 |
+
"width": image.width,
|
| 49 |
+
"height": image.height,
|
| 50 |
+
"min": dmin,
|
| 51 |
+
"max": dmax,
|
| 52 |
+
}
|
reconstruction.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import trimesh
|
| 4 |
+
import io
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_triposr_cache = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_triposr():
|
| 12 |
+
global _triposr_cache
|
| 13 |
+
if _triposr_cache is None:
|
| 14 |
+
from transformers import TripoSRForImageTo3D, TripoSRImageProcessor
|
| 15 |
+
processor = TripoSRImageProcessor.from_pretrained("stabilityai/TripoSR")
|
| 16 |
+
model = TripoSRForImageTo3D.from_pretrained("stabilityai/TripoSR")
|
| 17 |
+
model.eval()
|
| 18 |
+
_triposr_cache = (model, processor)
|
| 19 |
+
return _triposr_cache
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def reconstruct_region(image: Image.Image, mask: list[list[bool]], bbox: list[int]) -> bytes:
|
| 23 |
+
"""
|
| 24 |
+
Crop the masked region from the image, run TripoSR, return GLB bytes.
|
| 25 |
+
"""
|
| 26 |
+
model, processor = get_triposr()
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
model = model.to(device)
|
| 29 |
+
|
| 30 |
+
# Crop to bounding box with 20% padding
|
| 31 |
+
x, y, w, h = bbox
|
| 32 |
+
pad_x = int(w * 0.20)
|
| 33 |
+
pad_y = int(h * 0.20)
|
| 34 |
+
W, H = image.size
|
| 35 |
+
x0 = max(0, x - pad_x)
|
| 36 |
+
y0 = max(0, y - pad_y)
|
| 37 |
+
x1 = min(W, x + w + pad_x)
|
| 38 |
+
y1 = min(H, y + h + pad_y)
|
| 39 |
+
cropped = image.crop((x0, y0, x1, y1)).resize((512, 512), Image.LANCZOS)
|
| 40 |
+
|
| 41 |
+
# Apply mask as alpha channel so TripoSR focuses on the region
|
| 42 |
+
mask_arr = np.array(mask, dtype=np.uint8)[y0:y1, x0:x1]
|
| 43 |
+
mask_resized = np.array(
|
| 44 |
+
Image.fromarray(mask_arr * 255).resize((512, 512), Image.NEAREST)
|
| 45 |
+
)
|
| 46 |
+
rgba = np.array(cropped.convert("RGBA"))
|
| 47 |
+
rgba[:, :, 3] = mask_resized
|
| 48 |
+
input_img = Image.fromarray(rgba)
|
| 49 |
+
|
| 50 |
+
inputs = processor(images=input_img, return_tensors="pt").to(device)
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
outputs = model(**inputs)
|
| 54 |
+
|
| 55 |
+
# Export as GLB via trimesh
|
| 56 |
+
mesh_data = outputs.mesh # TripoSR returns a trimesh-compatible object
|
| 57 |
+
if hasattr(mesh_data, "export"):
|
| 58 |
+
glb_bytes = mesh_data.export(file_type="glb")
|
| 59 |
+
else:
|
| 60 |
+
# Fallback: build trimesh from vertices/faces tensors
|
| 61 |
+
verts = mesh_data.verts_list()[0].cpu().numpy()
|
| 62 |
+
faces = mesh_data.faces_list()[0].cpu().numpy()
|
| 63 |
+
mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
|
| 64 |
+
buf = io.BytesIO()
|
| 65 |
+
mesh.export(buf, file_type="glb")
|
| 66 |
+
glb_bytes = buf.getvalue()
|
| 67 |
+
|
| 68 |
+
return glb_bytes
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def depth_to_mesh(depth: list[list[float]], mask: list[list[bool]], image: Image.Image) -> bytes:
|
| 72 |
+
"""
|
| 73 |
+
Fallback when TripoSR isn't available: lift depth map into a 3D mesh
|
| 74 |
+
constrained to the masked region, textured with the source image.
|
| 75 |
+
"""
|
| 76 |
+
depth_arr = np.array(depth, dtype=np.float32)
|
| 77 |
+
mask_arr = np.array(mask, dtype=bool)
|
| 78 |
+
H, W = depth_arr.shape
|
| 79 |
+
|
| 80 |
+
# Normalize depth to [0, 1] then scale to reasonable Z range
|
| 81 |
+
dmin, dmax = depth_arr.min(), depth_arr.max()
|
| 82 |
+
if dmax > dmin:
|
| 83 |
+
depth_norm = (depth_arr - dmin) / (dmax - dmin)
|
| 84 |
+
else:
|
| 85 |
+
depth_norm = np.zeros_like(depth_arr)
|
| 86 |
+
depth_scaled = depth_norm * 0.5 # 0.5 units of Z range
|
| 87 |
+
|
| 88 |
+
# Build vertex grid only for masked pixels
|
| 89 |
+
ys, xs = np.where(mask_arr)
|
| 90 |
+
if len(xs) == 0:
|
| 91 |
+
# Empty mask — return a flat quad
|
| 92 |
+
verts = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=np.float32)
|
| 93 |
+
faces = np.array([[0, 1, 2], [0, 2, 3]])
|
| 94 |
+
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
|
| 95 |
+
buf = io.BytesIO()
|
| 96 |
+
mesh.export(buf, file_type="glb")
|
| 97 |
+
return buf.getvalue()
|
| 98 |
+
|
| 99 |
+
# Normalize to [-0.5, 0.5] XY space
|
| 100 |
+
x_norm = (xs / W) - 0.5
|
| 101 |
+
y_norm = 0.5 - (ys / H)
|
| 102 |
+
z_vals = depth_scaled[ys, xs]
|
| 103 |
+
vertices = np.stack([x_norm, y_norm, z_vals], axis=1).astype(np.float32)
|
| 104 |
+
|
| 105 |
+
# UV = source pixel position
|
| 106 |
+
uvs = np.stack([xs / W, 1.0 - ys / H], axis=1).astype(np.float32)
|
| 107 |
+
|
| 108 |
+
# Triangulate the masked grid using Delaunay
|
| 109 |
+
from scipy.spatial import Delaunay
|
| 110 |
+
points_2d = np.stack([x_norm, y_norm], axis=1)
|
| 111 |
+
tri = Delaunay(points_2d)
|
| 112 |
+
faces = tri.simplices.astype(np.int32)
|
| 113 |
+
|
| 114 |
+
# Build mesh with texture
|
| 115 |
+
img_arr = np.array(image.convert("RGB"))
|
| 116 |
+
texture = trimesh.visual.texture.TextureVisuals(
|
| 117 |
+
uv=uvs,
|
| 118 |
+
image=Image.fromarray(img_arr),
|
| 119 |
+
)
|
| 120 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, visual=texture, process=False)
|
| 121 |
+
buf = io.BytesIO()
|
| 122 |
+
mesh.export(buf, file_type="glb")
|
| 123 |
+
return buf.getvalue()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
fastapi
|
| 3 |
+
uvicorn
|
| 4 |
+
python-multipart
|
| 5 |
+
torch
|
| 6 |
+
torchvision
|
| 7 |
+
transformers
|
| 8 |
+
Pillow
|
| 9 |
+
numpy
|
| 10 |
+
trimesh
|
| 11 |
+
pygltflib
|
| 12 |
+
opencv-python-headless
|
| 13 |
+
spaces
|
segmentation.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_sam2():
|
| 7 |
+
from transformers import Sam2Model, Sam2Processor
|
| 8 |
+
processor = Sam2Processor.from_pretrained("facebook/sam2-hiera-large")
|
| 9 |
+
model = Sam2Model.from_pretrained("facebook/sam2-hiera-large")
|
| 10 |
+
model.eval()
|
| 11 |
+
return model, processor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_sam2_cache = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_sam2():
|
| 18 |
+
global _sam2_cache
|
| 19 |
+
if _sam2_cache is None:
|
| 20 |
+
_sam2_cache = load_sam2()
|
| 21 |
+
return _sam2_cache
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Region labels understood by the frontend
|
| 25 |
+
REGION_LABELS = ["breast_left", "breast_right", "buttocks", "ponytail", "hair"]
|
| 26 |
+
|
| 27 |
+
# Approximate point prompts on a normalized [0,1] image for each region
|
| 28 |
+
# (x, y) from top-left. Used when user hasn't provided click points.
|
| 29 |
+
DEFAULT_PROMPTS = {
|
| 30 |
+
"breast_left": (0.38, 0.38),
|
| 31 |
+
"breast_right": (0.62, 0.38),
|
| 32 |
+
"buttocks": (0.50, 0.72),
|
| 33 |
+
"ponytail": (0.50, 0.05),
|
| 34 |
+
"hair": (0.50, 0.08),
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict:
|
| 39 |
+
"""
|
| 40 |
+
Returns a dict of {region_label: {"mask": [[bool]], "bbox": [x,y,w,h]}}
|
| 41 |
+
"""
|
| 42 |
+
model, processor = get_sam2()
|
| 43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
+
model = model.to(device)
|
| 45 |
+
|
| 46 |
+
W, H = image.size
|
| 47 |
+
results = {}
|
| 48 |
+
|
| 49 |
+
for region in requested:
|
| 50 |
+
if region not in DEFAULT_PROMPTS:
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
# Use user-supplied click or fall back to default
|
| 54 |
+
if click_points and region in click_points:
|
| 55 |
+
px, py = click_points[region]
|
| 56 |
+
else:
|
| 57 |
+
nx, ny = DEFAULT_PROMPTS[region]
|
| 58 |
+
px, py = nx * W, ny * H
|
| 59 |
+
|
| 60 |
+
inputs = processor(
|
| 61 |
+
images=image,
|
| 62 |
+
input_points=[[[px, py]]],
|
| 63 |
+
return_tensors="pt",
|
| 64 |
+
).to(device)
|
| 65 |
+
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
outputs = model(**inputs)
|
| 68 |
+
|
| 69 |
+
masks = processor.post_process_masks(
|
| 70 |
+
outputs.pred_masks.cpu(),
|
| 71 |
+
inputs["original_sizes"].cpu(),
|
| 72 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 73 |
+
)[0] # shape: [1, num_masks, H, W]
|
| 74 |
+
|
| 75 |
+
# Pick highest-score mask
|
| 76 |
+
scores = outputs.iou_scores[0].cpu().numpy()
|
| 77 |
+
best = int(np.argmax(scores))
|
| 78 |
+
mask = masks[0, best].numpy().astype(bool) # [H, W]
|
| 79 |
+
|
| 80 |
+
# Compute bounding box
|
| 81 |
+
rows = np.any(mask, axis=1)
|
| 82 |
+
cols = np.any(mask, axis=0)
|
| 83 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
| 84 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
| 85 |
+
|
| 86 |
+
results[region] = {
|
| 87 |
+
"mask": mask.tolist(),
|
| 88 |
+
"bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)],
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return results
|