import gc import os import spaces import gradio as gr import random import tempfile import time from easydict import EasyDict import numpy as np import torch from dav.pipelines import DAVPipeline from dav.models import UNetSpatioTemporalRopeConditionModel from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler from dav.utils import img_utils def seed_all(seed: int = 0): """ Set random seeds for reproducibility. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) examples = [ ["demos/wooly_mammoth.mp4", 3, 32, 8, 16, 6, 768], ] def load_models(model_base, device): vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae") scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( model_base, subfolder="scheduler" ) unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( model_base, subfolder="unet" ) unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained( model_base, subfolder="unet_interp" ) pipe = DAVPipeline( vae=vae, unet=unet, unet_interp=unet_interp, scheduler=scheduler, ) pipe = pipe.to(device) return pipe model_base = "hhyangcs/depth-any-video" device_type = "cuda" device = torch.device(device_type) pipe = load_models(model_base, device) @spaces.GPU(duration=140) def infer_depth( file: str, denoise_steps: int = 3, num_frames: int = 32, decode_chunk_size: int = 16, num_interp_frames: int = 16, num_overlap_frames: int = 6, max_resolution: int = 1024, seed: int = 66, output_dir: str = "./outputs", ): seed_all(seed) max_frames = (num_interp_frames + 2 - num_overlap_frames) * (num_frames // 2) image, fps = img_utils.read_video(file, max_frames=max_frames) image = img_utils.imresize_max(image, max_resolution) image = img_utils.imcrop_multi(image) image_tensor = np.ascontiguousarray( [_img.transpose(2, 0, 1) / 255.0 for _img in image] ) image_tensor = torch.from_numpy(image_tensor).to(device) print(f"==> video name: {file}, frames shape: {image_tensor.shape}") with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.float16): pipe_out = pipe( image_tensor, num_frames=num_frames, num_overlap_frames=num_overlap_frames, num_interp_frames=num_interp_frames, decode_chunk_size=decode_chunk_size, num_inference_steps=denoise_steps, ) disparity = pipe_out.disparity disparity_colored = pipe_out.disparity_colored image = pipe_out.image # (N, H, 2 * W, 3) merged = np.concatenate( [ image, disparity_colored, ], axis=2, ) file_name = os.path.splitext(os.path.basename(file))[0] os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"{file_name}_depth.mp4") img_utils.write_video( output_path, merged, fps, ) # clear the cache for the next video gc.collect() torch.cuda.empty_cache() return output_path def construct_demo(): with gr.Blocks(analytics_enabled=False) as depthanyvideo_iface: with gr.Row(equal_height=True): with gr.Column(scale=1): input_video = gr.Video(label="Input Video") with gr.Column(scale=2): with gr.Row(equal_height=True): output_video = gr.Video( label="Ouput Video & Depth", interactive=False, autoplay=True, loop=True, show_share_button=True, scale=2, ) with gr.Row(equal_height=True): with gr.Column(scale=1): with gr.Row(equal_height=False): with gr.Accordion("Advanced Settings", open=False): denoise_steps = gr.Slider( label="Denoise Steps", minimum=1, maximum=10, value=3, step=1, ) num_frames = gr.Slider( label="Number of Key Frames", minimum=16, maximum=32, value=24, step=2, ) decode_chunk_size = gr.Slider( label="Decode Chunk Size", minimum=8, maximum=32, value=8, step=1, ) num_interp_frames = gr.Slider( label="Number of Interpolation Frames", minimum=8, maximum=32, value=16, step=1, ) num_overlap_frames = gr.Slider( label="Number of Overlap Frames", minimum=2, maximum=10, value=6, step=1, ) max_resolution = gr.Slider( label="Maximum Resolution", minimum=512, maximum=2048, value=768, step=32, ) generate_btn = gr.Button("Generate") with gr.Column(scale=2): pass gr.Examples( examples=examples, inputs=[ input_video, denoise_steps, num_frames, decode_chunk_size, num_interp_frames, num_overlap_frames, max_resolution, ], outputs=output_video, fn=infer_depth, cache_examples="lazy", ) generate_btn.click( fn=infer_depth, inputs=[ input_video, denoise_steps, num_frames, decode_chunk_size, num_interp_frames, num_overlap_frames, max_resolution, ], outputs=output_video, ) return depthanyvideo_iface demo = construct_demo() if __name__ == "__main__": demo.queue() demo.launch(share=True)