# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- import functools import os import sys import tempfile import av import numpy as np import spaces import gradio as gr import torch as torch import einops from huggingface_hub import login from gradio_patches.examples import Examples from colorize import colorize_depth_multi_thread from video_io import get_video_fps, write_video_from_numpy VERBOSE = False MAX_FRAMES = 100 def process(pipe, device, path_input): print(f"Processing {path_input}") path_output_dir = tempfile.mkdtemp() os.makedirs(path_output_dir, exist_ok=True) name_base = os.path.splitext(os.path.basename(path_input))[0] path_out_in = os.path.join(path_output_dir, f"{name_base}_depth_input.mp4") path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") output_fps = int(get_video_fps(path_input)) container = av.open(path_input) stream = container.streams.video[0] fps = float(stream.average_rate) duration_sec = float(stream.duration * stream.time_base) if stream.duration else 0 total_frames = int(duration_sec * fps) if total_frames > MAX_FRAMES: gr.Warning( f"Only the first {MAX_FRAMES} frames (~{MAX_FRAMES / fps:.1f} sec.) will be processed for demonstration; " f"use the code from GitHub for full processing" ) generator = torch.Generator(device=device) generator.manual_seed(2024) pipe_out: RollingDepthOutput = pipe( # input setting input_video_path=path_input, start_frame=0, frame_count=min(MAX_FRAMES, total_frames), # 0 = all processing_res=768, # infer setting dilations=[1, 25], cap_dilation=True, snippet_lengths=[3], init_infer_steps=[1], strides=[1], coalign_kwargs=None, refine_step=0, # 0 = off max_vae_bs=8, # batch size for encoder/decoder # other settings generator=generator, verbose=VERBOSE, # output settings restore_res=False, unload_snippet=False, ) depth_pred = pipe_out.depth_pred # [N 1 H W] # Colorize results cmap = "Spectral_r" colored_np = colorize_depth_multi_thread( depth=depth_pred.numpy(), valid_mask=None, chunk_size=4, num_threads=4, color_map=cmap, verbose=VERBOSE, ) # [n h w 3], in [0, 255] write_video_from_numpy( frames=colored_np, output_path=path_out_vis, fps=output_fps, crf=23, preset="medium", verbose=VERBOSE, ) # Save rgb rgb = (pipe_out.input_rgb.numpy() * 255).astype(np.uint8) # [N 3 H W] rgb = einops.rearrange(rgb, "n c h w -> n h w c") write_video_from_numpy( frames=rgb, output_path=path_out_in, fps=output_fps, crf=23, preset="medium", verbose=VERBOSE, ) return path_out_in, path_out_vis def run_demo_server(pipe, device): process_pipe = spaces.GPU(functools.partial(process, pipe, device)) os.environ["GRADIO_ALLOW_FLAGGING"] = "never" with gr.Blocks( analytics_enabled=False, title="RollingDepth", css=""" h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } """, ) as demo: gr.HTML( """

🛹 RollingDepth: Video Depth without Video Models

Website Badge arXiv Badge GitHub Stars Badge social

RollingDepth is the state-of-the-art depth estimator for videos in the wild. Upload your video into the left pane, or click any of the examples below. The result preview will be computed and appear in the right panes. For full functionality, use the code on GitHub. TIP: When running out of GPU time, fork the demo.

""" ) with gr.Row(equal_height=True): with gr.Column(scale=1): input_video = gr.Video(label="Input Video") with gr.Column(scale=2): with gr.Row(equal_height=True): output_video_1 = gr.Video( label="Preprocessed video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5, ) output_video_2 = gr.Video( label="Generated Depth Video", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=5, ) with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Row(equal_height=False): generate_btn = gr.Button("Generate") with gr.Column(scale=2): pass Examples( examples=[ ["files/gokart.mp4"], ["files/horse.mp4"], ["files/walking.mp4"], ], inputs=[input_video], outputs=[output_video_1, output_video_2], fn=process_pipe, cache_examples=True, directory_name="examples_video", ) generate_btn.click( fn=process_pipe, inputs=[input_video], outputs=[output_video_1, output_video_2], ) demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): os.system("pip freeze") os.system("pip uninstall -y diffusers") os.system("pip install rollingdepth_src/diffusers") os.system("pip freeze") if "HF_TOKEN_LOGIN" in os.environ: login(token=os.environ["HF_TOKEN_LOGIN"]) if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") sys.path.append(os.path.join(os.path.dirname(__file__), "rollingdepth_src")) from rollingdepth import RollingDepthOutput, RollingDepthPipeline pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained( "prs-eth/rollingdepth-v1-0", torch_dtype=torch.float16, ) pipe.set_progress_bar_config(disable=True) try: import xformers pipe.enable_xformers_memory_efficient_attention() except: pass # run without xformers pipe = pipe.to(device) run_demo_server(pipe, device) if __name__ == "__main__": main()