Justin Wood commited on
Commit
c401d3e
·
0 Parent(s):

Initial backend

Browse files
Files changed (6) hide show
  1. README.md +20 -0
  2. app.py +161 -0
  3. depth.py +52 -0
  4. reconstruction.py +123 -0
  5. requirements.txt +13 -0
  6. segmentation.py +91 -0
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Jiggle Physics Backend
3
+ emoji: 🎯
4
+ colorFrom: purple
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: "4.44.0"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ REST API for the Jiggle Physics Simulator. Endpoints mounted on Gradio's FastAPI instance for ZeroGPU compatibility.
14
+
15
+ | Endpoint | Method | What it does |
16
+ |---|---|---|
17
+ | `/health` | GET | Liveness check |
18
+ | `/segment` | POST | SAM2 body region masks |
19
+ | `/depth` | POST | Apple Depth Pro metric depth |
20
+ | `/reconstruct` | POST | TripoSR → GLB mesh |
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Jiggle Physics Simulator — HuggingFace Space backend
3
+ ZeroGPU pattern: routes are added to Gradio's underlying FastAPI instance.
4
+ Deploy with sdk: gradio and ZeroGPU hardware selected in Space settings.
5
+ """
6
+ import base64
7
+ import io
8
+ import json
9
+ from typing import Optional
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import numpy as np
14
+ from fastapi import File, Form, UploadFile, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import Response
17
+ from PIL import Image
18
+
19
+ # ── Minimal Gradio UI (required for ZeroGPU Spaces) ──────────────────────────
20
+ with gr.Blocks(title="Jiggle Physics API") as demo:
21
+ gr.Markdown("## Jiggle Physics ML API\nREST endpoints: `/segment` `/depth` `/reconstruct`")
22
+
23
+ # Grab Gradio's underlying FastAPI app and add CORS
24
+ app = demo.app
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_methods=["POST", "GET"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+
33
+ def _load_image(upload: UploadFile) -> Image.Image:
34
+ data = upload.file.read()
35
+ img = Image.open(io.BytesIO(data)).convert("RGB")
36
+ max_dim = 1024
37
+ if max(img.size) > max_dim:
38
+ ratio = max_dim / max(img.size)
39
+ img = img.resize(
40
+ (int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS
41
+ )
42
+ return img
43
+
44
+
45
+ @app.get("/health")
46
+ def health():
47
+ return {"status": "ok"}
48
+
49
+
50
+ @app.post("/segment")
51
+ @spaces.GPU
52
+ async def segment(
53
+ image: UploadFile = File(...),
54
+ regions: str = Form("breast_left,breast_right,buttocks"),
55
+ click_points: Optional[str] = Form(None),
56
+ ):
57
+ """SAM2 body region segmentation. Returns RLE-encoded masks + bounding boxes."""
58
+ from segmentation import segment_regions
59
+
60
+ img = _load_image(image)
61
+ region_list = [r.strip() for r in regions.split(",") if r.strip()]
62
+ clicks = json.loads(click_points) if click_points else None
63
+
64
+ try:
65
+ result = segment_regions(img, region_list, clicks)
66
+ except Exception as e:
67
+ raise HTTPException(status_code=500, detail=str(e))
68
+
69
+ encoded = {}
70
+ for region, data in result.items():
71
+ mask_arr = np.array(data["mask"], dtype=bool)
72
+ flat = mask_arr.flatten()
73
+ rle: list[int] = []
74
+ current = bool(flat[0])
75
+ count = 0
76
+ for val in flat:
77
+ if bool(val) == current:
78
+ count += 1
79
+ else:
80
+ rle.append(count)
81
+ count = 1
82
+ current = bool(val)
83
+ rle.append(count)
84
+ encoded[region] = {
85
+ "rle": rle,
86
+ "rle_start": bool(flat[0]),
87
+ "shape": list(mask_arr.shape),
88
+ "bbox": data["bbox"],
89
+ }
90
+
91
+ return {"regions": encoded, "image_size": [img.width, img.height]}
92
+
93
+
94
+ @app.post("/depth")
95
+ @spaces.GPU
96
+ async def depth(image: UploadFile = File(...)):
97
+ """Apple Depth Pro depth estimation. Returns base64 float32 depth map."""
98
+ from depth import estimate_depth
99
+
100
+ img = _load_image(image)
101
+ try:
102
+ result = estimate_depth(img)
103
+ except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))
105
+
106
+ arr = np.array(result["depth"], dtype=np.float32)
107
+ b64 = base64.b64encode(arr.tobytes()).decode("ascii")
108
+
109
+ return {
110
+ "depth_b64": b64,
111
+ "width": result["width"],
112
+ "height": result["height"],
113
+ "min": result["min"],
114
+ "max": result["max"],
115
+ "dtype": "float32",
116
+ }
117
+
118
+
119
+ @app.post("/reconstruct")
120
+ @spaces.GPU
121
+ async def reconstruct(
122
+ image: UploadFile = File(...),
123
+ mask_rle: str = Form(...),
124
+ mask_shape: str = Form(...),
125
+ mask_rle_start: str = Form("false"),
126
+ bbox: str = Form(...),
127
+ use_triposr: str = Form("true"),
128
+ ):
129
+ """TripoSR single-image 3D reconstruction. Returns GLB binary."""
130
+ from reconstruction import reconstruct_region, depth_to_mesh
131
+ from depth import estimate_depth
132
+
133
+ img = _load_image(image)
134
+
135
+ rle = json.loads(mask_rle)
136
+ shape = json.loads(mask_shape)
137
+ start_val = mask_rle_start.lower() == "true"
138
+ flat: list[bool] = []
139
+ current = start_val
140
+ for run in rle:
141
+ flat.extend([current] * run)
142
+ current = not current
143
+ mask = np.array(flat, dtype=bool).reshape(shape)
144
+
145
+ bbox_list = json.loads(bbox)
146
+
147
+ try:
148
+ if use_triposr.lower() == "true":
149
+ glb_bytes = reconstruct_region(img, mask.tolist(), bbox_list)
150
+ else:
151
+ depth_result = estimate_depth(img)
152
+ glb_bytes = depth_to_mesh(depth_result["depth"], mask.tolist(), img)
153
+ except Exception as e:
154
+ raise HTTPException(status_code=500, detail=str(e))
155
+
156
+ return Response(content=glb_bytes, media_type="application/octet-stream")
157
+
158
+
159
+ # ── Entry point ───────────────────────────────────────────────────────────────
160
+ if __name__ == "__main__":
161
+ demo.launch(server_name="0.0.0.0", server_port=7860)
depth.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+
5
+
6
+ _depth_cache = None
7
+
8
+
9
+ def get_depth_model():
10
+ global _depth_cache
11
+ if _depth_cache is None:
12
+ from transformers import pipeline as hf_pipeline
13
+ # Depth Pro from Apple — best metric depth, ~600MB
14
+ _depth_cache = hf_pipeline(
15
+ "depth-estimation",
16
+ model="apple/DepthPro-hf",
17
+ device=0 if torch.cuda.is_available() else -1,
18
+ )
19
+ return _depth_cache
20
+
21
+
22
+ def estimate_depth(image: Image.Image) -> dict:
23
+ """
24
+ Returns {"depth": [[float]], "width": int, "height": int, "min": float, "max": float}
25
+ Depth values are metric (meters) when Depth Pro is used.
26
+ """
27
+ pipe = get_depth_model()
28
+ result = pipe(image)
29
+ depth_map = result["depth"] # PIL image or numpy array
30
+
31
+ if isinstance(depth_map, Image.Image):
32
+ arr = np.array(depth_map).astype(np.float32)
33
+ else:
34
+ arr = np.array(depth_map, dtype=np.float32)
35
+
36
+ # Resize to match source image if needed
37
+ if arr.shape[:2] != (image.height, image.width):
38
+ depth_pil = Image.fromarray(arr).resize(
39
+ (image.width, image.height), Image.BILINEAR
40
+ )
41
+ arr = np.array(depth_pil)
42
+
43
+ dmin = float(arr.min())
44
+ dmax = float(arr.max())
45
+
46
+ return {
47
+ "depth": arr.tolist(),
48
+ "width": image.width,
49
+ "height": image.height,
50
+ "min": dmin,
51
+ "max": dmax,
52
+ }
reconstruction.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import trimesh
4
+ import io
5
+ from PIL import Image
6
+
7
+
8
+ _triposr_cache = None
9
+
10
+
11
+ def get_triposr():
12
+ global _triposr_cache
13
+ if _triposr_cache is None:
14
+ from transformers import TripoSRForImageTo3D, TripoSRImageProcessor
15
+ processor = TripoSRImageProcessor.from_pretrained("stabilityai/TripoSR")
16
+ model = TripoSRForImageTo3D.from_pretrained("stabilityai/TripoSR")
17
+ model.eval()
18
+ _triposr_cache = (model, processor)
19
+ return _triposr_cache
20
+
21
+
22
+ def reconstruct_region(image: Image.Image, mask: list[list[bool]], bbox: list[int]) -> bytes:
23
+ """
24
+ Crop the masked region from the image, run TripoSR, return GLB bytes.
25
+ """
26
+ model, processor = get_triposr()
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model = model.to(device)
29
+
30
+ # Crop to bounding box with 20% padding
31
+ x, y, w, h = bbox
32
+ pad_x = int(w * 0.20)
33
+ pad_y = int(h * 0.20)
34
+ W, H = image.size
35
+ x0 = max(0, x - pad_x)
36
+ y0 = max(0, y - pad_y)
37
+ x1 = min(W, x + w + pad_x)
38
+ y1 = min(H, y + h + pad_y)
39
+ cropped = image.crop((x0, y0, x1, y1)).resize((512, 512), Image.LANCZOS)
40
+
41
+ # Apply mask as alpha channel so TripoSR focuses on the region
42
+ mask_arr = np.array(mask, dtype=np.uint8)[y0:y1, x0:x1]
43
+ mask_resized = np.array(
44
+ Image.fromarray(mask_arr * 255).resize((512, 512), Image.NEAREST)
45
+ )
46
+ rgba = np.array(cropped.convert("RGBA"))
47
+ rgba[:, :, 3] = mask_resized
48
+ input_img = Image.fromarray(rgba)
49
+
50
+ inputs = processor(images=input_img, return_tensors="pt").to(device)
51
+
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+
55
+ # Export as GLB via trimesh
56
+ mesh_data = outputs.mesh # TripoSR returns a trimesh-compatible object
57
+ if hasattr(mesh_data, "export"):
58
+ glb_bytes = mesh_data.export(file_type="glb")
59
+ else:
60
+ # Fallback: build trimesh from vertices/faces tensors
61
+ verts = mesh_data.verts_list()[0].cpu().numpy()
62
+ faces = mesh_data.faces_list()[0].cpu().numpy()
63
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
64
+ buf = io.BytesIO()
65
+ mesh.export(buf, file_type="glb")
66
+ glb_bytes = buf.getvalue()
67
+
68
+ return glb_bytes
69
+
70
+
71
+ def depth_to_mesh(depth: list[list[float]], mask: list[list[bool]], image: Image.Image) -> bytes:
72
+ """
73
+ Fallback when TripoSR isn't available: lift depth map into a 3D mesh
74
+ constrained to the masked region, textured with the source image.
75
+ """
76
+ depth_arr = np.array(depth, dtype=np.float32)
77
+ mask_arr = np.array(mask, dtype=bool)
78
+ H, W = depth_arr.shape
79
+
80
+ # Normalize depth to [0, 1] then scale to reasonable Z range
81
+ dmin, dmax = depth_arr.min(), depth_arr.max()
82
+ if dmax > dmin:
83
+ depth_norm = (depth_arr - dmin) / (dmax - dmin)
84
+ else:
85
+ depth_norm = np.zeros_like(depth_arr)
86
+ depth_scaled = depth_norm * 0.5 # 0.5 units of Z range
87
+
88
+ # Build vertex grid only for masked pixels
89
+ ys, xs = np.where(mask_arr)
90
+ if len(xs) == 0:
91
+ # Empty mask — return a flat quad
92
+ verts = np.array([[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], dtype=np.float32)
93
+ faces = np.array([[0, 1, 2], [0, 2, 3]])
94
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
95
+ buf = io.BytesIO()
96
+ mesh.export(buf, file_type="glb")
97
+ return buf.getvalue()
98
+
99
+ # Normalize to [-0.5, 0.5] XY space
100
+ x_norm = (xs / W) - 0.5
101
+ y_norm = 0.5 - (ys / H)
102
+ z_vals = depth_scaled[ys, xs]
103
+ vertices = np.stack([x_norm, y_norm, z_vals], axis=1).astype(np.float32)
104
+
105
+ # UV = source pixel position
106
+ uvs = np.stack([xs / W, 1.0 - ys / H], axis=1).astype(np.float32)
107
+
108
+ # Triangulate the masked grid using Delaunay
109
+ from scipy.spatial import Delaunay
110
+ points_2d = np.stack([x_norm, y_norm], axis=1)
111
+ tri = Delaunay(points_2d)
112
+ faces = tri.simplices.astype(np.int32)
113
+
114
+ # Build mesh with texture
115
+ img_arr = np.array(image.convert("RGB"))
116
+ texture = trimesh.visual.texture.TextureVisuals(
117
+ uv=uvs,
118
+ image=Image.fromarray(img_arr),
119
+ )
120
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces, visual=texture, process=False)
121
+ buf = io.BytesIO()
122
+ mesh.export(buf, file_type="glb")
123
+ return buf.getvalue()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ fastapi
3
+ uvicorn
4
+ python-multipart
5
+ torch
6
+ torchvision
7
+ transformers
8
+ Pillow
9
+ numpy
10
+ trimesh
11
+ pygltflib
12
+ opencv-python-headless
13
+ spaces
segmentation.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+
5
+
6
+ def load_sam2():
7
+ from transformers import Sam2Model, Sam2Processor
8
+ processor = Sam2Processor.from_pretrained("facebook/sam2-hiera-large")
9
+ model = Sam2Model.from_pretrained("facebook/sam2-hiera-large")
10
+ model.eval()
11
+ return model, processor
12
+
13
+
14
+ _sam2_cache = None
15
+
16
+
17
+ def get_sam2():
18
+ global _sam2_cache
19
+ if _sam2_cache is None:
20
+ _sam2_cache = load_sam2()
21
+ return _sam2_cache
22
+
23
+
24
+ # Region labels understood by the frontend
25
+ REGION_LABELS = ["breast_left", "breast_right", "buttocks", "ponytail", "hair"]
26
+
27
+ # Approximate point prompts on a normalized [0,1] image for each region
28
+ # (x, y) from top-left. Used when user hasn't provided click points.
29
+ DEFAULT_PROMPTS = {
30
+ "breast_left": (0.38, 0.38),
31
+ "breast_right": (0.62, 0.38),
32
+ "buttocks": (0.50, 0.72),
33
+ "ponytail": (0.50, 0.05),
34
+ "hair": (0.50, 0.08),
35
+ }
36
+
37
+
38
+ def segment_regions(image: Image.Image, requested: list[str], click_points: dict | None = None) -> dict:
39
+ """
40
+ Returns a dict of {region_label: {"mask": [[bool]], "bbox": [x,y,w,h]}}
41
+ """
42
+ model, processor = get_sam2()
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ model = model.to(device)
45
+
46
+ W, H = image.size
47
+ results = {}
48
+
49
+ for region in requested:
50
+ if region not in DEFAULT_PROMPTS:
51
+ continue
52
+
53
+ # Use user-supplied click or fall back to default
54
+ if click_points and region in click_points:
55
+ px, py = click_points[region]
56
+ else:
57
+ nx, ny = DEFAULT_PROMPTS[region]
58
+ px, py = nx * W, ny * H
59
+
60
+ inputs = processor(
61
+ images=image,
62
+ input_points=[[[px, py]]],
63
+ return_tensors="pt",
64
+ ).to(device)
65
+
66
+ with torch.no_grad():
67
+ outputs = model(**inputs)
68
+
69
+ masks = processor.post_process_masks(
70
+ outputs.pred_masks.cpu(),
71
+ inputs["original_sizes"].cpu(),
72
+ inputs["reshaped_input_sizes"].cpu(),
73
+ )[0] # shape: [1, num_masks, H, W]
74
+
75
+ # Pick highest-score mask
76
+ scores = outputs.iou_scores[0].cpu().numpy()
77
+ best = int(np.argmax(scores))
78
+ mask = masks[0, best].numpy().astype(bool) # [H, W]
79
+
80
+ # Compute bounding box
81
+ rows = np.any(mask, axis=1)
82
+ cols = np.any(mask, axis=0)
83
+ rmin, rmax = np.where(rows)[0][[0, -1]]
84
+ cmin, cmax = np.where(cols)[0][[0, -1]]
85
+
86
+ results[region] = {
87
+ "mask": mask.tolist(),
88
+ "bbox": [int(cmin), int(rmin), int(cmax - cmin), int(rmax - rmin)],
89
+ }
90
+
91
+ return results