jboth's picture
Upload app.py with huggingface_hub
7b8ab13 verified
"""SAM 3D Objects – kaolin+pytorch3d stubbed for ZeroGPU (PyTorch 2.10+cu128)."""
import os, sys, subprocess
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("CONDA_PREFIX", "/usr/local")
os.environ["LIDRA_SKIP_INIT"] = "true"
os.environ["ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_ATTN_BACKEND"] = "sdpa"
os.environ["SPARSE_BACKEND"] = "spconv"
# MUST import spaces before torch
import spaces
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download, login
import tempfile
from pathlib import Path
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
# --- Stubs (must be before sam3d imports) ---
STUB_KAOLIN = Path("/home/user/app/kaolin_stub")
STUB_PT3D = Path("/home/user/app/pytorch3d_stub")
STUB_FA = Path("/home/user/app/flash_attn_stub")
for stub in [STUB_KAOLIN, STUB_PT3D, STUB_FA]:
if stub.exists():
sys.path.insert(0, str(stub))
print(f"Stub added: {stub.name}")
# --- Runtime pip installs ---
def _pip(*a):
r = subprocess.run([sys.executable, "-m", "pip", "install", "--no-cache-dir"] + list(a),
capture_output=True, text=True, timeout=1200)
ok = r.returncode == 0
tag = a[-1][:50] if a else "?"
if ok:
print(f" pip OK: {tag}")
else:
print(f" pip FAIL: {tag}")
print(f" {r.stderr[-300:]}")
return ok
print("=== Runtime installs ===")
_pip("open3d>=0.18.0")
_pip("--no-deps", "git+https://github.com/EasternJournalist/utils3d.git") # --no-deps: skip jupyter dependency
_pip("iopath")
_pip("--no-deps", "sam2>=1.1.0")
_pip("--no-deps", "git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b")
# gsplat
for idx in ["https://docs.gsplat.studio/whl/pt210cu128",
"https://docs.gsplat.studio/whl/pt28cu128"]:
if _pip("--no-deps", f"--extra-index-url={idx}", "gsplat"):
break
# spconv (sparse convolution – needed for SAM3D's SLatFlowModel)
# cu124 wheel is forward-compatible with cu128
_pip("spconv-cu124==2.3.8")
# DO NOT import CUDA-dependent packages here!
# --- Clone sam-3d-objects ---
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
print("Cloning sam-3d-objects...")
subprocess.run(["git", "clone", "--depth", "1",
"https://github.com/facebookresearch/sam-3d-objects.git",
str(SAM3D_PATH)], check=True)
subprocess.run([sys.executable, "-m", "pip", "install", "-e", str(SAM3D_PATH), "--no-deps"],
capture_output=True, text=True)
# Hydra patch
patch = SAM3D_PATH / "patching" / "hydra"
if patch.exists():
subprocess.run(["bash", str(patch)], capture_output=True, cwd=str(SAM3D_PATH))
# CRITICAL PATCH: Prevent SAM3D from overriding ATTN_BACKEND to flash_attn
# inference_pipeline.py auto-detects H200/A100 and forces flash_attn,
# but we don't have the real flash_attn package.
ip_file = SAM3D_PATH / "sam3d_objects" / "pipeline" / "inference_pipeline.py"
if ip_file.exists():
ip_src = ip_file.read_text()
# Find and replace the set_attention_backend function
old_marker = 'os.environ["ATTN_BACKEND"] = "flash_attn"'
if old_marker in ip_src:
# Replace the entire if-block that forces flash_attn
ip_src = ip_src.replace(
'if "A100" in gpu_name or "H100" in gpu_name or "H200" in gpu_name:\n'
' # logger.info("Use flash_attn")\n'
' os.environ["ATTN_BACKEND"] = "flash_attn"\n'
' os.environ["SPARSE_ATTN_BACKEND"] = "flash_attn"',
'# PATCHED: Always use sdpa backend (flash_attn not available on ZeroGPU)\n'
' logger.info("Using sdpa backend (patched for ZeroGPU)")\n'
' os.environ.setdefault("ATTN_BACKEND", "sdpa")\n'
' os.environ.setdefault("SPARSE_ATTN_BACKEND", "sdpa")'
)
ip_file.write_text(ip_src)
print("PATCHED: inference_pipeline.py - forced sdpa backend")
else:
print("INFO: inference_pipeline.py already patched or different version")
sys.path.insert(0, str(SAM3D_PATH))
sys.path.insert(0, str(SAM3D_PATH / "notebook"))
# --- Monkey-patch: inject depth_edge into utils3d.numpy ---
# utils3d package lacks depth_edge in newer versions; SAM3D needs it for layout post-optimization
try:
import utils3d.numpy as _u3d_np
if not hasattr(_u3d_np, 'depth_edge'):
def _depth_edge(depth, rtol=0.03, mask=None):
from scipy.ndimage import sobel
import numpy as _np
d = _np.where(mask, depth, 0.0) if mask is not None else depth.copy()
gx = sobel(d, axis=1)
gy = sobel(d, axis=0)
grad = _np.sqrt(gx**2 + gy**2)
denom = _np.abs(d)
denom[denom < 1e-6] = 1e-6
edge = (grad / denom) > rtol
if mask is not None:
edge = edge & mask
return edge
_u3d_np.depth_edge = _depth_edge
def _normals_edge(normals, tol=0.1, mask=None):
"""Detect normal discontinuities."""
import numpy as _np
from scipy.ndimage import sobel
# Compute gradient of each normal component
edges = _np.zeros(normals.shape[:2], dtype=bool)
for c in range(normals.shape[-1]):
ch = normals[..., c]
if mask is not None:
ch = _np.where(mask, ch, 0.0)
gx = sobel(ch, axis=1)
gy = sobel(ch, axis=0)
grad = _np.sqrt(gx**2 + gy**2)
edges |= (grad > tol)
if mask is not None:
edges = edges & mask
return edges
_u3d_np.normals_edge = _normals_edge
# Also inject a catch-all __getattr__ for any future missing functions
_orig_getattr = getattr(_u3d_np, '__getattr__', None)
def _u3d_catchall(name):
if name.startswith('__') and name.endswith('__'):
raise AttributeError(name)
import warnings
warnings.warn(f"utils3d.numpy stub: {name} not implemented, returning dummy")
def _dummy(*a, **kw):
import numpy as _np
return _np.zeros(1)
return _dummy
import types
_u3d_np.__getattr__ = _u3d_catchall
print("Injected depth_edge + normals_edge + catch-all into utils3d.numpy")
except Exception as e:
print(f"depth_edge patch skipped: {e}")
# --- Pre-download checkpoints ---
print("Downloading SAM3D checkpoints...")
CKPT_DIR = snapshot_download(repo_id="facebook/sam-3d-objects",
token=os.environ.get("HF_TOKEN"))
hf_ckpt = Path(CKPT_DIR) / "checkpoints"
local_ckpt = SAM3D_PATH / "checkpoints" / "hf"
if hf_ckpt.exists() and not local_ckpt.exists():
local_ckpt.parent.mkdir(parents=True, exist_ok=True)
local_ckpt.symlink_to(hf_ckpt)
CONFIG_PATH = str(local_ckpt / "pipeline.yaml")
print(f"Config exists: {Path(CONFIG_PATH).exists()}")
print("=== Startup complete ===")
# --- Endpoints ---
@spaces.GPU(duration=60)
def diagnose():
import torch
lines = [f"torch={torch.__version__}", f"cuda={torch.cuda.is_available()}"]
if torch.cuda.is_available():
lines.append(f"gpu={torch.cuda.get_device_name()}")
for mod in ["kaolin", "utils3d", "iopath", "pytorch3d", "open3d", "gsplat", "moge"]:
try:
m = __import__(mod)
lines.append(f"{mod}: OK ({getattr(m, '__version__', '-')})")
except Exception as e:
lines.append(f"{mod}: FAIL - {e}")
try:
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
lines.append("sam2: OK")
except Exception as e:
lines.append(f"sam2: FAIL - {e}")
try:
from inference import Inference
lines.append("SAM3D Inference: importable")
except Exception as e:
lines.append(f"SAM3D Inference: FAIL - {e}")
lines.append(f"config: {Path(CONFIG_PATH).exists()}")
return "\n".join(lines)
@spaces.GPU(duration=300)
def reconstruct_objects(image: np.ndarray):
if image is None:
return None, None, "No image"
try:
import torch, trimesh, time
t0 = time.time()
print(f"GPU: {torch.cuda.get_device_name()}")
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
print(f" Loading SAM2... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
sam2_gen = SAM2AutomaticMaskGenerator.from_pretrained("facebook/sam2-hiera-small")
print(f" SAM2 loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
masks = sam2_gen.generate(image_np)
if not masks:
return None, image_np, "No objects detected"
masks = sorted(masks, key=lambda x: x["area"], reverse=True)
best_mask = masks[0]["segmentation"]
preview = image_np.copy()
preview[best_mask] = (preview[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
print(f" {len(masks)} masks ({time.time()-t0:.0f}s)")
# Free SAM2 to save VRAM for SAM3D
del sam2_gen
torch.cuda.empty_cache()
print(f" SAM2 freed (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
from inference import Inference
print(f" Loading SAM3D... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
sam3d = Inference(CONFIG_PATH, compile=False)
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
print(f" Running reconstruction... (VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
result = sam3d(image=image_np, mask=best_mask, seed=42)
print(f" Reconstructed ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
if result is None:
return None, preview, "Reconstruction returned None"
od = tempfile.mkdtemp()
glb = f"{od}/object.glb"
gs = None
if hasattr(result, "save_ply"):
gs = result
elif isinstance(result, dict):
for k in ("gs", "gaussian", "gaussians", "scene"):
v = result.get(k)
if v is not None:
gs = v[0] if isinstance(v, (list, tuple)) else v
break
if gs is not None and hasattr(gs, "save_ply"):
ply = f"{od}/temp.ply"
gs.save_ply(ply)
import open3d as o3d
pcd = o3d.io.read_point_cloud(ply)
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif gs is not None and hasattr(gs, "_xyz"):
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif isinstance(result, dict) and "mesh" in result:
m = result["mesh"]
if hasattr(m, "export"):
m.export(glb)
else:
keys = list(result.keys()) if isinstance(result, dict) else dir(result)
return None, preview, f"Cannot extract 3D. Keys: {keys}"
n = 0
try:
n = len(trimesh.load(glb, force="mesh").faces)
except Exception:
pass
elapsed = int(time.time() - t0)
return glb, preview, f"OK: {len(masks)} objects, {n:,} faces ({elapsed}s)"
except Exception as e:
import traceback
tb = traceback.format_exc()
print(tb)
return None, None, f"Error:\n{tb[-1500:]}"
@spaces.GPU(duration=60)
def test_sam3d_only(image: np.ndarray):
"""Test SAM3D reconstruction with center-crop mask (no SAM2)."""
if image is None:
return None, None, "No image"
try:
import torch, time, gc
t0 = time.time()
print(f"GPU: {torch.cuda.get_device_name()}, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
image_np = np.array(image) if not isinstance(image, np.ndarray) else image
h, w = image_np.shape[:2]
# Create a center mask (middle 60% of image)
mask = np.zeros((h, w), dtype=bool)
y1, y2 = int(h * 0.2), int(h * 0.8)
x1, x2 = int(w * 0.2), int(w * 0.8)
mask[y1:y2, x1:x2] = True
preview = image_np.copy()
preview[mask] = (preview[mask] * 0.5 + np.array([0, 255, 0]) * 0.5).astype(np.uint8)
print(f" Mask created: {mask.sum()} pixels ({time.time()-t0:.0f}s)")
from inference import Inference
print(f" Loading SAM3D... VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")
sam3d = Inference(CONFIG_PATH, compile=False)
print(f" SAM3D loaded ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
print(f" Running reconstruction...")
result = sam3d(image=image_np, mask=mask, seed=42)
print(f" Done ({time.time()-t0:.0f}s, VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB)")
if result is None:
return None, preview, "Reconstruction returned None"
import tempfile
od = tempfile.mkdtemp()
glb = f"{od}/object.glb"
gs = None
if isinstance(result, dict):
for k in ("gs", "gaussian", "gaussians", "scene"):
v = result.get(k)
if v is not None:
gs = v[0] if isinstance(v, (list, tuple)) else v
break
if gs is not None and hasattr(gs, "save_ply"):
ply = f"{od}/temp.ply"
gs.save_ply(ply)
import open3d as o3d
pcd = o3d.io.read_point_cloud(ply)
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
elif gs is not None and hasattr(gs, "_xyz"):
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(gs._xyz.detach().cpu().numpy())
pcd.estimate_normals()
mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)
o3d.io.write_triangle_mesh(glb, mesh)
else:
keys = list(result.keys()) if isinstance(result, dict) else dir(result)
return None, preview, f"Cannot extract 3D. Keys: {keys}"
import trimesh
n = 0
try:
n = len(trimesh.load(glb, force="mesh").faces)
except: pass
elapsed = int(time.time() - t0)
return glb, preview, f"OK: {n:,} faces ({elapsed}s)"
except Exception as e:
import traceback
tb = traceback.format_exc()
print(tb)
return None, None, f"Error:\n{tb[-1500:]}"
# --- UI ---
with gr.Blocks(title="SAM 3D Objects") as demo:
gr.Markdown("# SAM 3D Objects\nImage → 3D (GLB). SAM2 detection + SAM3D reconstruction.")
with gr.Tab("Reconstruct"):
with gr.Row():
with gr.Column():
inp = gr.Image(label="Input", type="numpy")
btn = gr.Button("Reconstruct", variant="primary", size="lg")
with gr.Column():
prev = gr.Image(label="Detection", type="numpy", interactive=False)
stat = gr.Textbox(label="Status")
with gr.Row():
m3d = gr.Model3D(label="3D Preview")
dl = gr.File(label="Download GLB")
btn.click(reconstruct_objects, inputs=[inp], outputs=[m3d, prev, stat])
m3d.change(lambda x: x, inputs=[m3d], outputs=[dl])
with gr.Tab("Test SAM3D Only"):
with gr.Row():
with gr.Column():
tinp = gr.Image(label="Input", type="numpy")
tbtn = gr.Button("Test SAM3D (no SAM2)", variant="primary")
with gr.Column():
tprev = gr.Image(label="Mask Preview", type="numpy", interactive=False)
tstat = gr.Textbox(label="Status")
with gr.Row():
tm3d = gr.Model3D(label="3D Preview")
tbtn.click(test_sam3d_only, inputs=[tinp], outputs=[tm3d, tprev, tstat])
with gr.Tab("Diagnose"):
dbtn = gr.Button("Diagnose GPU & Modules")
dout = gr.Textbox(lines=15)
dbtn.click(diagnose, outputs=[dout])
demo.launch(mcp_server=True)