Spaces:
Running
on
Zero
Running
on
Zero
import rerun as rr | |
import rerun.blueprint as rrb | |
import depth_pro | |
import subprocess | |
import torch | |
import cv2 | |
import os | |
from pathlib import Path | |
import gradio as gr | |
from gradio_rerun import Rerun | |
import spaces | |
# Run the script to get pretrained models | |
if not os.path.exists("checkpoints/depth_pro.pt"): | |
print("downloading pretrained model") | |
subprocess.run(["bash", "get_pretrained_models.sh"]) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load model and preprocessing transform | |
print("loading model...") | |
model, transform = depth_pro.create_model_and_transforms() | |
model = model.to(device) | |
model.eval() | |
def run_ml_depth_pro(frame): | |
stream = rr.binary_stream() | |
assert model is not None, "Model is None" | |
assert transform is not None, "Transform is None" | |
assert frames is not None, "Frames is None" | |
blueprint = rrb.Blueprint( | |
rrb.Spatial3DView(origin="/"), | |
rrb.Horizontal( | |
rrb.Spatial2DView( | |
origin="/world/camera/depth", | |
), | |
rrb.Spatial2DView(origin="/world/camera/image"), | |
), | |
collapse_panels=True, | |
) | |
rr.send_blueprint(blueprint) | |
# for i, frame in enumerate(frames): | |
rr.set_time_sequence("frame", 0) | |
rr.log("world/camera/image", rr.Image(frame)) | |
image = transform(frame) | |
prediction = model.infer(image) | |
depth = prediction["depth"].squeeze().detach().cpu().numpy() | |
rr.log( | |
"world/camera", | |
rr.Pinhole( | |
width=frame.shape[1], | |
height=frame.shape[0], | |
focal_length=prediction["focallength_px"].item(), | |
principal_point=(frame.shape[1] / 2, frame.shape[0] / 2), | |
image_plane_distance=depth.max(), | |
), | |
) | |
rr.log( | |
"world/camera/depth", | |
# need 0.19 stable for this | |
# rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())), | |
rr.DepthImage(depth, meter=1), | |
) | |
yield stream.read() | |
video_path = Path("hd-cat.mp4") | |
# Load video | |
frames = [] | |
video = cv2.VideoCapture("hd-cat2.mp4") | |
while True: | |
read, frame = video.read() | |
if not read: | |
break | |
frame = cv2.resize(frame, (320, 240)) | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame) | |
with gr.Blocks() as demo: | |
with gr.Tab("Streaming"): | |
with gr.Row(): | |
img = gr.Image(interactive=True, label="Image") | |
with gr.Column(): | |
stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro") | |
with gr.Row(): | |
viewer = Rerun( | |
streaming=True, | |
panel_states={ | |
"time": "collapsed", | |
"blueprint": "hidden", | |
"selection": "hidden", | |
}, | |
) | |
stream_ml_depth_pro.click(run_ml_depth_pro, inputs=[img], outputs=[viewer]) | |
if __name__ == "__main__": | |
demo.launch() | |