import rerun as rr import rerun.blueprint as rrb import depth_pro import subprocess import torch import cv2 import os from pathlib import Path import gradio as gr from gradio_rerun import Rerun import spaces # Run the script to get pretrained models if not os.path.exists("checkpoints/depth_pro.pt"): print("downloading pretrained model") subprocess.run(["bash", "get_pretrained_models.sh"]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load model and preprocessing transform print("loading model...") model, transform = depth_pro.create_model_and_transforms() model = model.to(device) model.eval() @rr.thread_local_stream("rerun_example_ml_depth_pro") @spaces.GPU(duration=20) def run_ml_depth_pro(frame): stream = rr.binary_stream() assert model is not None, "Model is None" assert transform is not None, "Transform is None" assert frames is not None, "Frames is None" blueprint = rrb.Blueprint( rrb.Spatial3DView(origin="/"), rrb.Horizontal( rrb.Spatial2DView( origin="/world/camera/depth", ), rrb.Spatial2DView(origin="/world/camera/image"), ), collapse_panels=True, ) rr.send_blueprint(blueprint) # for i, frame in enumerate(frames): rr.set_time_sequence("frame", 0) rr.log("world/camera/image", rr.Image(frame)) image = transform(frame) prediction = model.infer(image) depth = prediction["depth"].squeeze().detach().cpu().numpy() rr.log( "world/camera", rr.Pinhole( width=frame.shape[1], height=frame.shape[0], focal_length=prediction["focallength_px"].item(), principal_point=(frame.shape[1] / 2, frame.shape[0] / 2), image_plane_distance=depth.max(), ), ) rr.log( "world/camera/depth", # need 0.19 stable for this # rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())), rr.DepthImage(depth, meter=1), ) yield stream.read() video_path = Path("hd-cat.mp4") # Load video frames = [] video = cv2.VideoCapture("hd-cat2.mp4") while True: read, frame = video.read() if not read: break frame = cv2.resize(frame, (320, 240)) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(frame) with gr.Blocks() as demo: with gr.Tab("Streaming"): with gr.Row(): img = gr.Image(interactive=True, label="Image") with gr.Column(): stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro") with gr.Row(): viewer = Rerun( streaming=True, panel_states={ "time": "collapsed", "blueprint": "hidden", "selection": "hidden", }, ) stream_ml_depth_pro.click(run_ml_depth_pro, inputs=[img], outputs=[viewer]) if __name__ == "__main__": demo.launch()