import os import cv2 import torch import tempfile import numpy as np import matplotlib import gradio as gr from PIL import Image import spaces from gradio_imageslider import ImageSlider from huggingface_hub import hf_hub_download from bridge.dpt import Bridge # ====== Gradio CSS 样式 ====== css = """ #img-display-container { max-height: 100vh; } #img-display-input { max-height: 80vh; } #img-display-output { max-height: 80vh; } #download { height: 62px; } """ # ====== device ======  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # ====== model load ====== model = Bridge() filepath = hf_hub_download(repo_id=f"Dingning/BRIDGE", filename=f"bridge.pth", repo_type="model") state_dict = torch.load(filepath, map_location="cpu") model.load_state_dict(state_dict) model = model.to(DEVICE).eval() # ====== description ====== title = "# Bridge Simplified Demo" description = """ Official demo for Bridge using Gradio. [project page](https://dingning-liu.github.io/bridge.github.io/), [github](https://github.com/lnbxldn/BRIDGE). """ cmap = matplotlib.colormaps.get_cmap("Spectral_r") # ====== inference ====== @spaces.GPU def predict_depth(image: np.ndarray) -> np.ndarray: """Run depth inference on an RGB image (numpy).""" return model.infer_image(image[:, :, ::-1]) # BGR→RGB def on_submit(image: np.ndarray): original_image = image.copy() depth = predict_depth(image) # 16-bit depth map raw_depth = Image.fromarray(depth.astype("uint16")) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False) raw_depth.save(tmp_raw_depth.name) # normalization and colorize depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth_uint8 = depth_norm.astype(np.uint8) colored_depth = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8) # save depth map gray_depth = Image.fromarray(depth_uint8) tmp_gray_depth = tempfile.NamedTemporaryFile(suffix=".png", delete=False) gray_depth.save(tmp_gray_depth.name) return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name] # ====== Gradio UI====== with gr.Blocks(css=css) as demo: gr.Markdown(title) gr.Markdown(description) gr.Markdown("### Depth Prediction Demo") with gr.Row(): input_image = gr.Image( label="Input Image", type="numpy", elem_id="img-display-input" ) depth_image_slider = ImageSlider( label="Depth Map with Slider View", elem_id="img-display-output", position=0.5 ) submit = gr.Button(value="Compute Depth") gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") raw_file = gr.File(label="16-bit raw output", elem_id="download") submit.click( on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file] ) # examples if os.path.exists("assets/examples"): example_files = sorted(os.listdir("assets/examples")) example_files = [os.path.join("assets/examples", f) for f in example_files] gr.Examples( examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit ) if __name__ == "__main__": demo.queue().launch(share=True)