ml-sharp / model_utils.py
Robin L. M. Cheung, MBA
feat: Add local CUDA support, MCP server, Spaces GPU selection, and stacking roadmap
01504c4
"""SHARP inference + optional CUDA video rendering utilities.
Design goals:
- Reuse SHARP's own predict/render pipeline (no subprocess calls).
- Be robust on Hugging Face Spaces + ZeroGPU.
- Cache model weights and predictor construction across requests.
Public API (used by the Gradio app):
- TrajectoryType
- predict_and_maybe_render_gpu(...)
"""
from __future__ import annotations
import os
import threading
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Final, Literal
import torch
# Optional Spaces GPU support (for HuggingFace Spaces deployment)
try:
import spaces
_SPACES_AVAILABLE = True
except ImportError:
spaces = None # type: ignore[assignment]
_SPACES_AVAILABLE = False
try:
# Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`).
from huggingface_hub import hf_hub_download, try_to_load_from_cache
except Exception: # pragma: no cover
hf_hub_download = None # type: ignore[assignment]
try_to_load_from_cache = None # type: ignore[assignment]
from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image
from sharp.cli.render import render_gaussians as sharp_render_gaussians
from sharp.models import PredictorParams, create_predictor
from sharp.utils import camera, io
from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
from sharp.utils.gsplat import GSplatRenderer
TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------
def _now_ms() -> int:
return int(time.time() * 1000)
def _ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def _make_even(x: int) -> int:
return x if x % 2 == 0 else x + 1
def _select_device(preference: str = "auto") -> torch.device:
"""Select the best available device for inference (CPU/CUDA/MPS)."""
if preference not in {"auto", "cpu", "cuda", "mps"}:
raise ValueError("device preference must be one of: auto|cpu|cuda|mps")
if preference == "cpu":
return torch.device("cpu")
if preference == "cuda":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
if preference == "mps":
return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# auto
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# -----------------------------------------------------------------------------
# Prediction outputs
# -----------------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class PredictionOutputs:
"""Outputs of SHARP inference (plus derived metadata for rendering)."""
ply_path: Path
gaussians: Gaussians3D
metadata_for_render: SceneMetaData
input_resolution_hw: tuple[int, int]
focal_length_px: float
# -----------------------------------------------------------------------------
# Patch SHARP VideoWriter to properly close the optional depth writer
# -----------------------------------------------------------------------------
class _PatchedVideoWriter(io.VideoWriter):
"""Ensure depth writer is closed so files can be safely cleaned up."""
def __init__(
self, output_path: Path, fps: float = 30.0, render_depth: bool = True
) -> None:
super().__init__(output_path, fps=fps, render_depth=render_depth)
# Ensure attribute exists for downstream code paths.
if not hasattr(self, "depth_writer"):
self.depth_writer = None # type: ignore[attribute-defined-outside-init]
def close(self):
super().close()
depth_writer = getattr(self, "depth_writer", None)
try:
if depth_writer is not None:
depth_writer.close()
except Exception:
pass
@contextmanager
def _patched_sharp_videowriter():
"""Temporarily patch `sharp.utils.io.VideoWriter` used by `sharp.cli.render`."""
original = io.VideoWriter
io.VideoWriter = _PatchedVideoWriter # type: ignore[assignment]
try:
yield
finally:
io.VideoWriter = original # type: ignore[assignment]
# -----------------------------------------------------------------------------
# Model wrapper
# -----------------------------------------------------------------------------
class ModelWrapper:
"""Cached SHARP model wrapper for Gradio/Spaces."""
def __init__(
self,
*,
outputs_dir: str | Path = "outputs",
checkpoint_url: str = DEFAULT_MODEL_URL,
checkpoint_path: str | Path | None = None,
device_preference: str = "auto",
keep_model_on_device: bool | None = None,
hf_repo_id: str | None = None,
hf_filename: str | None = None,
hf_revision: str | None = None,
) -> None:
self.outputs_dir = _ensure_dir(Path(outputs_dir))
self.checkpoint_url = checkpoint_url
env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT")
if checkpoint_path:
self.checkpoint_path = Path(checkpoint_path)
elif env_ckpt:
self.checkpoint_path = Path(env_ckpt)
else:
self.checkpoint_path = None
# Optional Hugging Face Hub fallback (useful when direct CDN download fails).
self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp")
self.hf_filename = hf_filename or os.getenv(
"SHARP_HF_FILENAME", "sharp_2572gikvuh.pt"
)
self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None
self.device_preference = device_preference
# Local CUDA: keep model on device by default for better performance
if keep_model_on_device is None:
keep_env = os.getenv("SHARP_KEEP_MODEL_ON_DEVICE", "1")
self.keep_model_on_device = keep_env != "0"
else:
self.keep_model_on_device = keep_model_on_device
# Support CUDA device selection via env var
cuda_device = os.getenv("CUDA_VISIBLE_DEVICES")
if cuda_device and device_preference == "auto":
# Let PyTorch handle device mapping via CUDA_VISIBLE_DEVICES
pass
self._lock = threading.RLock()
self._predictor: torch.nn.Module | None = None
self._predictor_device: torch.device | None = None
self._state_dict: dict | None = None
def has_cuda(self) -> bool:
return torch.cuda.is_available()
def _load_state_dict(self) -> dict:
with self._lock:
if self._state_dict is not None:
return self._state_dict
# 1) Explicit local checkpoint path
if self.checkpoint_path is not None:
try:
self._state_dict = torch.load(
self.checkpoint_path,
weights_only=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
raise RuntimeError(
"Failed to load SHARP checkpoint from local path.\n\n"
f"Path:\n {self.checkpoint_path}\n\n"
f"Original error:\n {type(e).__name__}: {e}"
) from e
# 2) HF cache (no-network): best match for Spaces `preload_from_hub`.
hf_cache_error: Exception | None = None
if try_to_load_from_cache is not None:
try:
cached = try_to_load_from_cache(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
repo_type="model",
)
except TypeError:
cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc]
try:
if isinstance(cached, str) and Path(cached).exists():
self._state_dict = torch.load(
cached, weights_only=True, map_location="cpu"
)
return self._state_dict
except Exception as e:
hf_cache_error = e
# 3) HF Hub download (reuse cache when available; may download otherwise).
hf_error: Exception | None = None
if hf_hub_download is not None:
# Attempt "local only" mode if supported (avoids network).
try:
import inspect
if "local_files_only" in inspect.signature(hf_hub_download).parameters:
ckpt_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
local_files_only=True,
)
if Path(ckpt_path).exists():
self._state_dict = torch.load(
ckpt_path, weights_only=True, map_location="cpu"
)
return self._state_dict
except Exception:
pass
try:
ckpt_path = hf_hub_download(
repo_id=self.hf_repo_id,
filename=self.hf_filename,
revision=self.hf_revision,
)
self._state_dict = torch.load(
ckpt_path,
weights_only=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
hf_error = e
# 4) Default upstream CDN (torch hub cache). Last resort.
url_error: Exception | None = None
try:
self._state_dict = torch.hub.load_state_dict_from_url(
self.checkpoint_url,
progress=True,
map_location="cpu",
)
return self._state_dict
except Exception as e:
url_error = e
# If we got here: all options failed.
hint_lines = [
"Failed to load SHARP checkpoint.",
"",
"Tried (in order):",
f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
f" 3) URL (torch hub): {self.checkpoint_url}",
"",
"If network access is restricted, set a local checkpoint path:",
" - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt",
"",
"Original errors:",
]
if try_to_load_from_cache is None:
hint_lines.append(" HF cache: huggingface_hub not installed")
elif hf_cache_error is not None:
hint_lines.append(
f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}"
)
else:
hint_lines.append(" HF cache: (not found in cache)")
if hf_hub_download is None:
hint_lines.append(" HF download: huggingface_hub not installed")
else:
hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}")
hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}")
raise RuntimeError("\n".join(hint_lines))
def _get_predictor(self, device: torch.device) -> torch.nn.Module:
with self._lock:
if self._predictor is None:
state_dict = self._load_state_dict()
predictor = create_predictor(PredictorParams())
predictor.load_state_dict(state_dict)
predictor.eval()
self._predictor = predictor
self._predictor_device = torch.device("cpu")
assert self._predictor is not None
assert self._predictor_device is not None
if self._predictor_device != device:
self._predictor.to(device)
self._predictor_device = device
return self._predictor
def _maybe_move_model_back_to_cpu(self) -> None:
if self.keep_model_on_device:
return
with self._lock:
if self._predictor is not None and self._predictor_device is not None:
if self._predictor_device.type != "cpu":
self._predictor.to("cpu")
self._predictor_device = torch.device("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _make_output_stem(self, input_path: Path) -> str:
return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}"
def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs:
"""Run SHARP inference and export a .ply file."""
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"Image does not exist: {image_path}")
device = _select_device(self.device_preference)
predictor = self._get_predictor(device)
image_np, _, f_px = io.load_rgb(image_path)
height, width = image_np.shape[:2]
with torch.no_grad():
gaussians = predict_image(predictor, image_np, f_px, device)
stem = self._make_output_stem(image_path)
ply_path = self.outputs_dir / f"{stem}.ply"
# save_ply expects (height, width).
save_ply(gaussians, f_px, (height, width), ply_path)
# SceneMetaData expects (width, height) for resolution.
metadata_for_render = SceneMetaData(
focal_length_px=float(f_px),
resolution_px=(int(width), int(height)),
color_space="linearRGB",
)
self._maybe_move_model_back_to_cpu()
return PredictionOutputs(
ply_path=ply_path,
gaussians=gaussians,
metadata_for_render=metadata_for_render,
input_resolution_hw=(int(height), int(width)),
focal_length_px=float(f_px),
)
def _render_video_impl(
self,
*,
gaussians: Gaussians3D,
metadata: SceneMetaData,
output_path: Path,
trajectory_type: TrajectoryType,
num_frames: int,
fps: int,
output_long_side: int | None,
) -> Path:
if not torch.cuda.is_available():
raise RuntimeError("Rendering requires CUDA (gsplat).")
if num_frames < 2:
raise ValueError("num_frames must be >= 2")
if fps < 1:
raise ValueError("fps must be >= 1")
# Keep aligned with upstream CLI pipeline where possible.
if output_long_side is None and int(fps) == 30:
params = camera.TrajectoryParams(
type=trajectory_type,
num_steps=int(num_frames),
num_repeats=1,
)
with _patched_sharp_videowriter():
sharp_render_gaussians(
gaussians=gaussians,
metadata=metadata,
params=params,
output_path=output_path,
)
depth_path = output_path.with_suffix(".depth.mp4")
try:
if depth_path.exists():
depth_path.unlink()
except Exception:
pass
return output_path
# Adapted pipeline for custom output resolution / FPS.
src_w, src_h = metadata.resolution_px
src_f = float(metadata.focal_length_px)
if output_long_side is None:
out_w, out_h, out_f = src_w, src_h, src_f
else:
long_side = max(src_w, src_h)
scale = float(output_long_side) / float(long_side)
out_w = _make_even(max(2, int(round(src_w * scale))))
out_h = _make_even(max(2, int(round(src_h * scale))))
out_f = src_f * scale
traj_params = camera.TrajectoryParams(
type=trajectory_type,
num_steps=int(num_frames),
num_repeats=1,
)
device = torch.device("cuda")
gaussians_cuda = gaussians.to(device)
intrinsics = torch.tensor(
[
[out_f, 0.0, (out_w - 1) / 2.0, 0.0],
[0.0, out_f, (out_h - 1) / 2.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
device=device,
dtype=torch.float32,
)
cam_model = camera.create_camera_model(
gaussians_cuda,
intrinsics,
resolution_px=(out_w, out_h),
lookat_mode=traj_params.lookat_mode,
)
trajectory = camera.create_eye_trajectory(
gaussians_cuda,
traj_params,
resolution_px=(out_w, out_h),
f_px=out_f,
)
renderer = GSplatRenderer(color_space=metadata.color_space)
# IMPORTANT: Keep render_depth=True (avoids upstream AttributeError).
video_writer = _PatchedVideoWriter(output_path, fps=float(fps), render_depth=True)
for eye_position in trajectory:
cam_info = cam_model.compute(eye_position)
rendering = renderer(
gaussians_cuda,
extrinsics=cam_info.extrinsics[None].to(device),
intrinsics=cam_info.intrinsics[None].to(device),
image_width=cam_info.width,
image_height=cam_info.height,
)
color = (rendering.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8)
depth = rendering.depth[0]
video_writer.add_frame(color, depth)
video_writer.close()
depth_path = output_path.with_suffix(".depth.mp4")
try:
if depth_path.exists():
depth_path.unlink()
except Exception:
pass
return output_path
def render_video(
self,
*,
gaussians: Gaussians3D,
metadata: SceneMetaData,
output_stem: str,
trajectory_type: TrajectoryType = "rotate_forward",
num_frames: int = 60,
fps: int = 30,
output_long_side: int | None = None,
) -> Path:
"""Render a camera trajectory as an MP4 (CUDA-only)."""
output_path = self.outputs_dir / f"{output_stem}.mp4"
return self._render_video_impl(
gaussians=gaussians,
metadata=metadata,
output_path=output_path,
trajectory_type=trajectory_type,
num_frames=num_frames,
fps=fps,
output_long_side=output_long_side,
)
def predict_and_maybe_render(
self,
image_path: str | Path,
*,
trajectory_type: TrajectoryType,
num_frames: int,
fps: int,
output_long_side: int | None,
render_video: bool = True,
) -> tuple[Path | None, Path]:
"""One-shot helper for the UI: returns (video_path, ply_path)."""
pred = self.predict_to_ply(image_path)
if not render_video:
return None, pred.ply_path
if not torch.cuda.is_available():
return None, pred.ply_path
output_stem = pred.ply_path.with_suffix("").name
video_path = self.render_video(
gaussians=pred.gaussians,
metadata=pred.metadata_for_render,
output_stem=output_stem,
trajectory_type=trajectory_type,
num_frames=num_frames,
fps=fps,
output_long_side=output_long_side,
)
return video_path, pred.ply_path
# -----------------------------------------------------------------------------
# Module-level entrypoints
# -----------------------------------------------------------------------------
DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs")
_GLOBAL_MODEL: ModelWrapper | None = None
_GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock()
def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper:
global _GLOBAL_MODEL
with _GLOBAL_MODEL_INIT_LOCK:
if _GLOBAL_MODEL is None:
_GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir)
return _GLOBAL_MODEL
def predict_and_maybe_render(
image_path: str | Path,
*,
trajectory_type: TrajectoryType,
num_frames: int,
fps: int,
output_long_side: int | None,
render_video: bool = True,
) -> tuple[Path | None, Path]:
model = get_global_model()
return model.predict_and_maybe_render(
image_path,
trajectory_type=trajectory_type,
num_frames=num_frames,
fps=fps,
output_long_side=output_long_side,
render_video=render_video,
)
# -----------------------------------------------------------------------------
# GPU-wrapped entrypoint (Spaces or local)
# -----------------------------------------------------------------------------
def _create_spaces_gpu_wrapper(duration: int = 180):
"""Create a Spaces GPU-wrapped version of predict_and_maybe_render.
This is called dynamically based on hardware configuration.
"""
if spaces is not None and _SPACES_AVAILABLE:
return spaces.GPU(duration=duration)(predict_and_maybe_render)
return predict_and_maybe_render
# Default export: use local CUDA unless explicitly configured for Spaces
# The actual wrapper is created dynamically based on hardware_config
predict_and_maybe_render_gpu = predict_and_maybe_render
def configure_gpu_mode(use_spaces: bool = False, duration: int = 180) -> None:
"""Configure the GPU mode at runtime.
Args:
use_spaces: If True and spaces module available, use @spaces.GPU decorator
duration: Duration for @spaces.GPU decorator (seconds)
"""
global predict_and_maybe_render_gpu
if use_spaces and _SPACES_AVAILABLE and spaces is not None:
predict_and_maybe_render_gpu = spaces.GPU(duration=duration)(predict_and_maybe_render)
else:
predict_and_maybe_render_gpu = predict_and_maybe_render
def get_gpu_status() -> dict:
"""Get current GPU status information."""
import torch
status = {
"cuda_available": torch.cuda.is_available(),
"spaces_available": _SPACES_AVAILABLE,
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"devices": [],
}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
status["devices"].append({
"index": i,
"name": props.name,
"total_memory_gb": round(props.total_memory / (1024**3), 2),
"compute_capability": f"{props.major}.{props.minor}",
})
return status