dev-bjoern's picture
Add SAM3 for auto-segmentation, GLB export
af69327
raw
history blame
6.38 kB
"""
SAM 3D Objects MCP Server
Image + Click β†’ 3D Object (GLB)
"""
import os
import sys
import subprocess
import tempfile
import uuid
from pathlib import Path
import gradio as gr
import numpy as np
import spaces
from huggingface_hub import snapshot_download, login
from PIL import Image
# Login with HF_TOKEN if available
if os.environ.get("HF_TOKEN"):
login(token=os.environ.get("HF_TOKEN"))
# Clone sam-3d-objects repo if not exists
SAM3D_PATH = Path("/home/user/app/sam-3d-objects")
if not SAM3D_PATH.exists():
print("Cloning sam-3d-objects repository...")
subprocess.run([
"git", "clone",
"https://github.com/facebookresearch/sam-3d-objects.git",
str(SAM3D_PATH)
], check=True)
sys.path.insert(0, str(SAM3D_PATH))
# Add to path
sys.path.insert(0, str(SAM3D_PATH))
# Global models
SAM3D_MODEL = None
SAM_PREDICTOR = None
def load_sam_model():
"""Load SAM3 model for segmentation"""
global SAM_PREDICTOR
if SAM_PREDICTOR is not None:
return SAM_PREDICTOR
import torch
from sam3 import SAM3ImagePredictor
print("Loading SAM3 model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
SAM_PREDICTOR = SAM3ImagePredictor.from_pretrained(
"facebook/sam3-hiera-large",
device=device,
token=os.environ.get("HF_TOKEN")
)
print("βœ“ SAM3 model loaded")
return SAM_PREDICTOR
def load_sam3d_model():
"""Load SAM 3D Objects model"""
global SAM3D_MODEL
if SAM3D_MODEL is not None:
return SAM3D_MODEL
import torch
print("Loading SAM 3D Objects model...")
# Download checkpoint
checkpoint_dir = snapshot_download(
repo_id="facebook/sam-3d-objects",
token=os.environ.get("HF_TOKEN")
)
from sam_3d_objects import Sam3dObjects
device = "cuda" if torch.cuda.is_available() else "cpu"
SAM3D_MODEL = Sam3dObjects.from_pretrained(checkpoint_dir, device=device)
print("βœ“ SAM 3D Objects model loaded")
return SAM3D_MODEL
@spaces.GPU(duration=60)
def segment_object(image: np.ndarray, evt: gr.SelectData) -> np.ndarray:
"""
Segment object at clicked point using SAM2.
Args:
image: Input RGB image
evt: Click event with coordinates
Returns:
Image with mask overlay
"""
if image is None:
return None
try:
predictor = load_sam_model()
# Get click coordinates
point = np.array([[evt.index[0], evt.index[1]]])
label = np.array([1]) # 1 = foreground
# Set image
predictor.set_image(image)
# Predict mask
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=label,
multimask_output=True
)
# Use best mask
best_mask = masks[np.argmax(scores)]
# Create overlay
overlay = image.copy()
overlay[best_mask] = overlay[best_mask] * 0.5 + np.array([0, 255, 0]) * 0.5
return overlay, best_mask.astype(np.uint8) * 255
except Exception as e:
import traceback
traceback.print_exc()
return image, None
@spaces.GPU(duration=120)
def reconstruct_object(image: np.ndarray, mask: np.ndarray) -> tuple:
"""
Reconstruct 3D object from image and mask.
Args:
image: Input RGB image
mask: Binary mask indicating object region
Returns:
tuple: (glb_path, status)
"""
if image is None:
return None, "❌ No image provided"
if mask is None:
return None, "❌ No mask provided - click on object first"
try:
import torch
import trimesh
model = load_sam3d_model()
# Process image
if isinstance(image, Image.Image):
image = np.array(image)
# Process mask
if isinstance(mask, Image.Image):
mask = np.array(mask)
# Convert mask to binary if needed
if len(mask.shape) == 3:
mask = mask[:, :, 0]
mask = (mask > 127).astype(np.uint8)
# Run inference
outputs = model.predict(image, mask)
if outputs is None:
return None, "⚠️ Reconstruction failed"
# Export as GLB via trimesh
output_dir = tempfile.mkdtemp()
glb_path = f"{output_dir}/object_{uuid.uuid4().hex[:8]}.glb"
# Get vertices and faces from gaussian splat
# Convert to mesh and export as GLB
vertices = outputs.get_xyz().cpu().numpy()
# Create point cloud mesh (gaussian splats don't have faces directly)
# We'll export as a point cloud GLB
cloud = trimesh.PointCloud(vertices)
cloud.export(glb_path, file_type='glb')
return glb_path, f"βœ“ Object reconstructed ({len(vertices)} points)"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ Error: {e}"
# Gradio Interface
with gr.Blocks(title="SAM 3D Objects MCP") as demo:
gr.Markdown("# πŸ“¦ SAM 3D Objects MCP Server\n**Click on object β†’ 3D Reconstruction (GLB)**")
# State for mask
mask_state = gr.State(None)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image (click on object)", type="numpy")
gr.Markdown("*Click on the object you want to reconstruct*")
with gr.Column():
preview_image = gr.Image(label="Segmentation Preview", type="numpy", interactive=False)
with gr.Row():
btn = gr.Button("🎯 Reconstruct 3D", variant="primary", size="lg")
with gr.Row():
with gr.Column():
output_model = gr.Model3D(label="3D Object")
output_file = gr.File(label="Download GLB")
with gr.Column():
status = gr.Textbox(label="Status")
# Click to segment
input_image.select(
segment_object,
inputs=[input_image],
outputs=[preview_image, mask_state]
)
# Reconstruct
btn.click(
reconstruct_object,
inputs=[input_image, mask_state],
outputs=[output_file, status]
)
gr.Markdown("""
---
### MCP Server
```json
{"mcpServers": {"sam3d-objects": {"command": "npx", "args": ["mcp-remote", "URL/gradio_api/mcp/sse"]}}}
```
""")
if __name__ == "__main__":
demo.launch(mcp_server=True)