oxkitsune's picture
spaces
95377ef
raw
history blame
3.01 kB
import rerun as rr
import rerun.blueprint as rrb
import depth_pro
import subprocess
import torch
import cv2
import os
from pathlib import Path
import gradio as gr
from gradio_rerun import Rerun
import spaces
# Run the script to get pretrained models
if not os.path.exists("checkpoints/depth_pro.pt"):
print("downloading pretrained model")
subprocess.run(["bash", "get_pretrained_models.sh"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load model and preprocessing transform
print("loading model...")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
@rr.thread_local_stream("rerun_example_ml_depth_pro")
@spaces.GPU(duration=20)
def run_ml_depth_pro(frame):
stream = rr.binary_stream()
assert model is not None, "Model is None"
assert transform is not None, "Transform is None"
assert frames is not None, "Frames is None"
blueprint = rrb.Blueprint(
rrb.Spatial3DView(origin="/"),
rrb.Horizontal(
rrb.Spatial2DView(
origin="/world/camera/depth",
),
rrb.Spatial2DView(origin="/world/camera/image"),
),
collapse_panels=True,
)
rr.send_blueprint(blueprint)
# for i, frame in enumerate(frames):
rr.set_time_sequence("frame", 0)
rr.log("world/camera/image", rr.Image(frame))
image = transform(frame)
prediction = model.infer(image)
depth = prediction["depth"].squeeze().detach().cpu().numpy()
rr.log(
"world/camera",
rr.Pinhole(
width=frame.shape[1],
height=frame.shape[0],
focal_length=prediction["focallength_px"].item(),
principal_point=(frame.shape[1] / 2, frame.shape[0] / 2),
image_plane_distance=depth.max(),
),
)
rr.log(
"world/camera/depth",
# need 0.19 stable for this
# rr.DepthImage(depth, meter=1, depth_range=(depth.min(), depth.max())),
rr.DepthImage(depth, meter=1),
)
yield stream.read()
video_path = Path("hd-cat.mp4")
# Load video
frames = []
video = cv2.VideoCapture("hd-cat2.mp4")
while True:
read, frame = video.read()
if not read:
break
frame = cv2.resize(frame, (320, 240))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
with gr.Blocks() as demo:
with gr.Tab("Streaming"):
with gr.Row():
img = gr.Image(interactive=True, label="Image")
with gr.Column():
stream_ml_depth_pro = gr.Button("Stream Ml Depth Pro")
with gr.Row():
viewer = Rerun(
streaming=True,
panel_states={
"time": "collapsed",
"blueprint": "hidden",
"selection": "hidden",
},
)
stream_ml_depth_pro.click(run_ml_depth_pro, inputs=[img], outputs=[viewer])
if __name__ == "__main__":
demo.launch()