Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldmonodepth.github.io | |
# -------------------------------------------------------------------------- | |
import functools | |
import os | |
import sys | |
import tempfile | |
import av | |
import numpy as np | |
import spaces | |
import gradio as gr | |
import torch as torch | |
import einops | |
from huggingface_hub import login | |
from gradio_patches.examples import Examples | |
from colorize import colorize_depth_multi_thread | |
from video_io import get_video_fps, write_video_from_numpy | |
VERBOSE = False | |
MAX_FRAMES = 100 | |
def process(pipe, device, path_input): | |
print(f"Processing {path_input}") | |
path_output_dir = tempfile.mkdtemp() | |
os.makedirs(path_output_dir, exist_ok=True) | |
name_base = os.path.splitext(os.path.basename(path_input))[0] | |
path_out_in = os.path.join(path_output_dir, f"{name_base}_depth_input.mp4") | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
output_fps = int(get_video_fps(path_input)) | |
container = av.open(path_input) | |
stream = container.streams.video[0] | |
fps = float(stream.average_rate) | |
duration_sec = float(stream.duration * stream.time_base) if stream.duration else 0 | |
total_frames = int(duration_sec * fps) | |
if total_frames > MAX_FRAMES: | |
gr.Warning( | |
f"Only the first {MAX_FRAMES} frames (~{MAX_FRAMES / fps:.1f} sec.) will be processed for demonstration; " | |
f"use the code from GitHub for full processing" | |
) | |
generator = torch.Generator(device=device) | |
generator.manual_seed(2024) | |
pipe_out: RollingDepthOutput = pipe( | |
# input setting | |
input_video_path=path_input, | |
start_frame=0, | |
frame_count=min(MAX_FRAMES, total_frames), # 0 = all | |
processing_res=768, | |
# infer setting | |
dilations=[1, 25], | |
cap_dilation=True, | |
snippet_lengths=[3], | |
init_infer_steps=[1], | |
strides=[1], | |
coalign_kwargs=None, | |
refine_step=0, # 0 = off | |
max_vae_bs=8, # batch size for encoder/decoder | |
# other settings | |
generator=generator, | |
verbose=VERBOSE, | |
# output settings | |
restore_res=False, | |
unload_snippet=False, | |
) | |
depth_pred = pipe_out.depth_pred # [N 1 H W] | |
# Colorize results | |
cmap = "Spectral_r" | |
colored_np = colorize_depth_multi_thread( | |
depth=depth_pred.numpy(), | |
valid_mask=None, | |
chunk_size=4, | |
num_threads=4, | |
color_map=cmap, | |
verbose=VERBOSE, | |
) # [n h w 3], in [0, 255] | |
write_video_from_numpy( | |
frames=colored_np, | |
output_path=path_out_vis, | |
fps=output_fps, | |
crf=23, | |
preset="medium", | |
verbose=VERBOSE, | |
) | |
# Save rgb | |
rgb = (pipe_out.input_rgb.numpy() * 255).astype(np.uint8) # [N 3 H W] | |
rgb = einops.rearrange(rgb, "n c h w -> n h w c") | |
write_video_from_numpy( | |
frames=rgb, | |
output_path=path_out_in, | |
fps=output_fps, | |
crf=23, | |
preset="medium", | |
verbose=VERBOSE, | |
) | |
return path_out_in, path_out_vis | |
def run_demo_server(pipe, device): | |
process_pipe = spaces.GPU(functools.partial(process, pipe, device)) | |
os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
with gr.Blocks( | |
analytics_enabled=False, | |
title="RollingDepth", | |
css=""" | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
""", | |
) as demo: | |
gr.HTML( | |
""" | |
<h1>🛹 RollingDepth: Video Depth without Video Models</h1> | |
<div style="text-align: center; margin-top: 20px;"> | |
<a title="Website" href="https://rollingdepth.github.io" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg" alt="Website Badge"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2411.xxxxx" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg" alt="arXiv Badge"> | |
</a> | |
<a title="GitHub" href="https://github.com/prs-eth/rollingdepth" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
<img src="https://img.shields.io/github/stars/prs-eth/rollingdepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="GitHub Stars Badge"> | |
</a> | |
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block; margin-right: 4px;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</div> | |
<p style="margin-top: 20px; text-align: justify;"> | |
RollingDepth is the state-of-the-art depth estimator for videos in the wild. Upload your video into the | |
<b>left</b> pane, or click any of the <b>examples</b> below. The result preview will be computed and | |
appear in the <b>right</b> panes. For full functionality, use the code on GitHub. | |
<b>TIP:</b> When running out of GPU time, fork the demo. | |
</p> | |
""" | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
input_video = gr.Video(label="Input Video") | |
with gr.Column(scale=2): | |
with gr.Row(equal_height=True): | |
output_video_1 = gr.Video( | |
label="Preprocessed video", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
scale=5, | |
) | |
output_video_2 = gr.Video( | |
label="Generated Depth Video", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
scale=5, | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Row(equal_height=False): | |
generate_btn = gr.Button("Generate") | |
with gr.Column(scale=2): | |
pass | |
Examples( | |
examples=[ | |
["files/gokart.mp4"], | |
["files/horse.mp4"], | |
["files/walking.mp4"], | |
], | |
inputs=[input_video], | |
outputs=[output_video_1, output_video_2], | |
fn=process_pipe, | |
cache_examples=True, | |
directory_name="examples_video", | |
) | |
generate_btn.click( | |
fn=process_pipe, | |
inputs=[input_video], | |
outputs=[output_video_1, output_video_2], | |
) | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
os.system("pip freeze") | |
os.system("pip uninstall -y diffusers") | |
os.system("pip install rollingdepth_src/diffusers") | |
os.system("pip freeze") | |
if "HF_TOKEN_LOGIN" in os.environ: | |
login(token=os.environ["HF_TOKEN_LOGIN"]) | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
sys.path.append(os.path.join(os.path.dirname(__file__), "rollingdepth_src")) | |
from rollingdepth import RollingDepthOutput, RollingDepthPipeline | |
pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained( | |
"prs-eth/rollingdepth-v1-0", | |
torch_dtype=torch.float16, | |
) | |
pipe.set_progress_bar_config(disable=True) | |
try: | |
import xformers | |
pipe.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
pipe = pipe.to(device) | |
run_demo_server(pipe, device) | |
if __name__ == "__main__": | |
main() | |