Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Gradio demo for rgbd-depth on Hugging Face Spaces.""" | |
| import logging | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from rgbddepth import RGBDDepth | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global model cache | |
| MODELS = {} | |
| # Model mappings from HuggingFace (all are vitl encoder) | |
| # Format: "camera_model": ("repo_id", "checkpoint_filename") | |
| HF_MODELS = { | |
| "d435": ("depth-anything/camera-depth-model-d435", "cdm_d435.ckpt"), | |
| "d405": ("depth-anything/camera-depth-model-d405", "cdm_d405.ckpt"), | |
| "l515": ("depth-anything/camera-depth-model-l515", "cdm_l515.ckpt"), | |
| "zed2i": ("depth-anything/camera-depth-model-zed2i", "cdm_zed2i.ckpt"), | |
| } | |
| # Default model | |
| DEFAULT_MODEL = "d435" | |
| def download_model(camera_model: str = DEFAULT_MODEL): | |
| """Download model from HuggingFace Hub.""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| repo_id, filename = HF_MODELS.get(camera_model, HF_MODELS[DEFAULT_MODEL]) | |
| logger.info(f"Downloading {camera_model} model from {repo_id}/{filename}...") | |
| # Download the checkpoint | |
| checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=".cache") | |
| logger.info(f"Downloaded to {checkpoint_path}") | |
| return checkpoint_path | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| return None | |
| def load_model(camera_model: str = DEFAULT_MODEL, use_xformers: bool = False): | |
| """Load model with automatic download from HuggingFace.""" | |
| cache_key = f"{camera_model}_{use_xformers}" | |
| if cache_key not in MODELS: | |
| # All HF models use vitl encoder | |
| config = { | |
| "encoder": "vitl", | |
| "features": 256, | |
| "out_channels": [256, 512, 1024, 1024], | |
| "use_xformers": use_xformers, | |
| } | |
| model = RGBDDepth(**config) | |
| # Try to load weights | |
| checkpoint_path = None | |
| # 1. Try local checkpoints/ directory first | |
| local_path = Path(f"checkpoints/{camera_model}.pt") | |
| if local_path.exists(): | |
| checkpoint_path = str(local_path) | |
| logger.info(f"Using local checkpoint: {checkpoint_path}") | |
| else: | |
| # 2. Download from HuggingFace | |
| checkpoint_path = download_model(camera_model) | |
| # Load checkpoint if available | |
| if checkpoint_path: | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
| if "model" in checkpoint: | |
| states = {k[7:]: v for k, v in checkpoint["model"].items()} | |
| elif "state_dict" in checkpoint: | |
| states = {k[9:]: v for k, v in checkpoint["state_dict"].items()} | |
| else: | |
| states = checkpoint | |
| model.load_state_dict(states, strict=False) | |
| logger.info(f"Loaded checkpoint for {camera_model}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load checkpoint: {e}, using random weights") | |
| else: | |
| logger.warning( | |
| f"No checkpoint available for {camera_model}, using random weights (demo only)" | |
| ) | |
| # Move to GPU if available (CUDA or MPS for macOS) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| model = model.to(device).eval() | |
| MODELS[cache_key] = model | |
| return MODELS[cache_key] | |
| def process_depth( | |
| rgb_image: np.ndarray, | |
| depth_image: np.ndarray, | |
| camera_model: str = DEFAULT_MODEL, | |
| input_size: int = 518, | |
| depth_scale: float = 1000.0, | |
| max_depth: float = 25.0, | |
| use_xformers: bool = False, | |
| precision: str = "fp32", | |
| colormap: str = "Spectral", | |
| ) -> tuple[Image.Image, str]: | |
| """Process RGB-D depth refinement. | |
| Args: | |
| rgb_image: RGB image as numpy array [H, W, 3] | |
| depth_image: Depth image as numpy array [H, W] or [H, W, 3] | |
| camera_model: Camera model to use (d435, d405, l515, zed2i) | |
| input_size: Input size for inference | |
| depth_scale: Scale factor for depth values | |
| max_depth: Maximum valid depth value | |
| use_xformers: Whether to use xFormers (CUDA only) | |
| precision: Precision mode (fp32/fp16/bf16) | |
| colormap: Matplotlib colormap for visualization | |
| Returns: | |
| Tuple of (refined depth image, info message) | |
| """ | |
| try: | |
| # Validate inputs | |
| if rgb_image is None: | |
| return None, "β Please upload an RGB image" | |
| if depth_image is None: | |
| return None, "β Please upload a depth image" | |
| # Convert depth to single channel if needed | |
| if depth_image.ndim == 3: | |
| depth_image = depth_image[:, :, 0] | |
| # Normalize depth | |
| depth_normalized = depth_image.astype(np.float32) / depth_scale | |
| depth_normalized[depth_normalized > max_depth] = 0.0 | |
| # Create inverse depth (similarity depth) | |
| simi_depth = np.zeros_like(depth_normalized) | |
| valid_mask = depth_normalized > 0 | |
| simi_depth[valid_mask] = 1.0 / depth_normalized[valid_mask] | |
| # Load model | |
| model = load_model(camera_model, use_xformers and torch.cuda.is_available()) | |
| device = next(model.parameters()).device | |
| # Determine precision | |
| if precision == "fp16" and device.type in ["cuda", "mps"]: | |
| dtype = torch.float16 | |
| elif precision == "bf16" and device.type == "cuda": | |
| dtype = torch.bfloat16 | |
| else: | |
| dtype = None # FP32 | |
| # Log input statistics | |
| logger.debug(f"depth_image raw: min={depth_image.min():.1f}, max={depth_image.max():.1f}") | |
| logger.debug( | |
| f"depth_normalized: min={depth_normalized[depth_normalized>0].min():.4f}, max={depth_normalized.max():.4f}" | |
| ) | |
| logger.debug( | |
| f"simi_depth: min={simi_depth[simi_depth>0].min():.4f}, max={simi_depth.max():.4f}" | |
| ) | |
| # Run inference | |
| if dtype is not None: | |
| device_type = "cuda" if device.type == "cuda" else "cpu" | |
| with torch.amp.autocast(device_type=device_type, dtype=dtype): | |
| pred = model.infer_image(rgb_image, simi_depth, input_size=input_size) | |
| else: | |
| pred = model.infer_image(rgb_image, simi_depth, input_size=input_size) | |
| # Log prediction statistics | |
| logger.debug(f"pred (inverse depth): min={pred[pred>0].min():.4f}, max={pred.max():.4f}") | |
| # Convert from inverse depth to depth | |
| pred = np.where(pred > 1e-8, 1.0 / pred, 0.0) | |
| # Log final depth statistics | |
| logger.debug(f"pred (depth): min={pred[pred>0].min():.4f}, max={pred.max():.4f}") | |
| # Colorize for visualization | |
| try: | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| # Normalize to [0, 1] | |
| pred_min, pred_max = pred.min(), pred.max() | |
| if pred_max - pred_min > 1e-8: | |
| pred_norm = (pred - pred_min) / (pred_max - pred_min) | |
| else: | |
| pred_norm = np.zeros_like(pred) | |
| # Apply colormap | |
| cm_func = matplotlib.colormaps[colormap] | |
| pred_colored = cm_func(pred_norm, bytes=True)[:, :, :3] # RGB only | |
| # Create PIL Image | |
| output_image = Image.fromarray(pred_colored) | |
| except ImportError: | |
| # Fallback to grayscale if matplotlib not available | |
| pred_norm = ((pred - pred.min()) / (pred.max() - pred.min() + 1e-8) * 255).astype( | |
| np.uint8 | |
| ) | |
| output_image = Image.fromarray(pred_norm, mode="L").convert("RGB") | |
| # Create info message | |
| info = f""" | |
| β **Refinement complete!** | |
| **Camera Model:** {camera_model.upper()} | |
| **Precision:** {precision.upper()} | |
| **Device:** {device.type.upper()} | |
| **Input size:** {input_size}px | |
| **Depth range:** {pred_min:.3f}m - {pred_max:.3f}m | |
| **xFormers:** {'β Enabled' if use_xformers and torch.cuda.is_available() else 'β Disabled'} | |
| """ | |
| return output_image, info.strip() | |
| except Exception as e: | |
| return None, f"β Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="rgbd-depth Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # π¨ rgbd-depth: RGB-D Depth Refinement | |
| High-quality depth map refinement using Vision Transformers. Based on [ByteDance's camera-depth-models](https://manipulation-as-in-simulation.github.io/). | |
| π₯ **Models are automatically downloaded from Hugging Face on first use!** | |
| Choose your camera model (D435, D405, L515, or ZED 2i) and the trained weights will be downloaded automatically. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π₯ Inputs") | |
| rgb_input = gr.Image( | |
| label="RGB Image", | |
| type="numpy", | |
| height=300, | |
| ) | |
| depth_input = gr.Image( | |
| label="Input Depth Map", | |
| type="numpy", | |
| height=300, | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| camera_choice = gr.Dropdown( | |
| choices=["d435", "d405", "l515", "zed2i"], | |
| value=DEFAULT_MODEL, | |
| label="Camera Model", | |
| info="Choose the camera model for trained weights (auto-downloads from HF)", | |
| ) | |
| input_size = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=518, | |
| step=2, | |
| label="Input Size", | |
| info="Resolution for processing (higher = better but slower)", | |
| ) | |
| depth_scale = gr.Number( | |
| value=1000.0, | |
| label="Depth Scale", | |
| info="Scale factor to convert depth values to meters", | |
| ) | |
| max_depth = gr.Number( | |
| value=25.0, | |
| label="Max Depth (m)", | |
| info="Maximum valid depth value", | |
| ) | |
| precision_choice = gr.Radio( | |
| choices=["fp32", "fp16", "bf16"], | |
| value="fp32", | |
| label="Precision", | |
| info="fp16/bf16 = faster but slightly less accurate (CUDA only)", | |
| ) | |
| use_xformers = gr.Checkbox( | |
| value=False, # Set to True to test xFormers by default | |
| label="Use xFormers (CUDA only)", | |
| info="~8% faster on CUDA with xFormers installed", | |
| ) | |
| colormap_choice = gr.Dropdown( | |
| choices=["Spectral", "viridis", "plasma", "inferno", "magma", "turbo"], | |
| value="Spectral", | |
| label="Colormap", | |
| info="Visualization colormap", | |
| ) | |
| process_btn = gr.Button("π Refine Depth", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.Markdown("### π€ Output") | |
| output_image = gr.Image( | |
| label="Refined Depth Map", | |
| type="pil", | |
| height=600, | |
| ) | |
| output_info = gr.Markdown() | |
| # Example inputs | |
| gr.Markdown("### πΈ Examples") | |
| gr.Examples( | |
| examples=[ | |
| ["example_data/color_12.png", "example_data/depth_12.png"], | |
| ], | |
| inputs=[rgb_input, depth_input], | |
| label="Try with example images", | |
| ) | |
| # Process button click | |
| process_btn.click( | |
| fn=process_depth, | |
| inputs=[ | |
| rgb_input, | |
| depth_input, | |
| camera_choice, | |
| input_size, | |
| depth_scale, | |
| max_depth, | |
| use_xformers, | |
| precision_choice, | |
| colormap_choice, | |
| ], | |
| outputs=[output_image, output_info], | |
| ) | |
| # Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π Links | |
| - **GitHub:** [Aedelon/camera-depth-models](https://github.com/Aedelon/camera-depth-models) | |
| - **PyPI:** [rgbd-depth](https://pypi.org/project/rgbd-depth/) | |
| - **Paper:** [Manipulation-as-in-Simulation](https://manipulation-as-in-simulation.github.io/) | |
| ### π¦ Install | |
| ```bash | |
| pip install rgbd-depth | |
| ``` | |
| ### π» CLI Usage | |
| ```bash | |
| rgbd-depth \\ | |
| --model-path model.pt \\ | |
| --rgb-image input.jpg \\ | |
| --depth-image depth.png \\ | |
| --output refined.png | |
| ``` | |
| --- | |
| Built with β€οΈ by [Aedelon](https://github.com/Aedelon) | Powered by [Gradio](https://gradio.app) | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", share=True) | |