DVD / app.py
haodongli's picture
Improve ZeroGPU compatibility (#2)
13ad79c
import spaces # must be first!
import os
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent
GRADIO_TMP = REPO_ROOT / ".gradio_cache"
GRADIO_TMP.mkdir(parents=True, exist_ok=True)
os.environ["GRADIO_TEMP_DIR"] = str(GRADIO_TMP)
print(f"Gradio temp/cache dir: {GRADIO_TMP}")
import torch
from argparse import Namespace
import subprocess
from test_script.test_single_video import *
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
yaml_args = OmegaConf.load(f"{REPO_ROOT}/ckpt/model_config.yaml")
if not os.path.exists(f"{REPO_ROOT}/ckpt/model.safetensors"):
subprocess.run(["bash", f"{REPO_ROOT}/infer_bash/download_ckpt.sh"], check=True)
pipeline = load_model(f"{REPO_ROOT}/ckpt", yaml_args)
MAX_FRAMES = 300
def read_video_limited(video_path, max_frames):
"""Read up to max_frames from a video without loading the entire file."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise gr.Error(f"Cannot open video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames > max_frames:
gr.Warning(
f"Video has {total_frames} frames, processing only the first {max_frames}."
)
frames = []
while len(frames) < max_frames:
ret, frame = cap.read()
if not ret:
break
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
cap.release()
video_np = np.stack(frames)
video_tensor = (
torch.from_numpy(video_np).permute(0, 3, 1, 2).float() / 255.0
)
return video_tensor.unsqueeze(0), fps
@spaces.GPU(duration=90)
def fn(input_video):
input_tensor, origin_fps = read_video_limited(input_video, MAX_FRAMES)
input_tensor, orig_size = resize_for_training_scale(
input_tensor, 480, 640
)
depth = predict_depth(pipeline, input_tensor, orig_size, Namespace(
window_size=81,
overlap=21
))
output_video = save_results(depth, origin_fps, Namespace(
input_video=input_video,
output_dir=GRADIO_TMP,
grayscale=False
))
return output_video
if __name__ == "__main__":
inputs = [
gr.Video(label="Input Video", autoplay=True),
]
outputs = [
gr.Video(label="Output Video", autoplay=True),
]
demo = gr.Interface(
fn=fn,
title="DVD: Deterministic Video Depth Estimation with Generative Priors",
description="""
<strong>Please consider starring <span style="color: orange">&#9733;</span> our <a href="https://github.com/EnVision-Research/DVD" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this demo useful!</strong>
""",
inputs=inputs,
outputs=outputs,
examples=[
[f"{REPO_ROOT}/demo/drone.mp4"],
[f"{REPO_ROOT}/demo/robot_navi.mp4"]
]
)
demo.queue(default_concurrency_limit=1)
demo.launch(
# server_name="0.0.0.0",
# server_port=1324,
)