# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved. # Last modified: 2024-11-28 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # --------------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation # More information about the method can be found at https://rollingdepth.github.io # --------------------------------------------------------------------------------- import argparse import logging import os from pathlib import Path import numpy as np import torch from tqdm.auto import tqdm import einops from omegaconf import OmegaConf from rollingdepth import ( RollingDepthOutput, RollingDepthPipeline, write_video_from_numpy, get_video_fps, concatenate_videos_horizontally_torch, ) from src.util.colorize import colorize_depth_multi_thread from src.util.config import str2bool if "__main__" == __name__: logging.basicConfig(level=logging.INFO) # -------------------- Arguments -------------------- parser = argparse.ArgumentParser( description="Run video depth estimation using RollingDepth." ) parser.add_argument( "-i", "--input-video", type=str, required=True, help=( "Path to the input video(s) to be processed. Accepts: " "- Single video file path (e.g., 'video.mp4') " "- Text file containing a list of video paths (one per line) " "- Directory path containing video files " "Required argument." ), dest="input_video", ) parser.add_argument( "-o", "--output-dir", type=str, required=True, help=( "Directory path where processed outputs will be saved. " "Will be created if it doesn't exist. " "Required argument." ), dest="output_dir", ) parser.add_argument( "-p", "--preset", type=str, choices=["fast", "fast1024", "full", "paper", "none"], help="Inference preset. TODO: write detailed explanation", ) parser.add_argument( "--start-frame", "--from", type=int, default=0, help=( "Specifies the starting frame index for processing. " "Use 0 to start from the beginning of the video. " "Default: 0" ), dest="start_frame", ) parser.add_argument( "--frame-count", "--frames", type=int, default=0, help=( "Number of frames to process after the starting frame. " "Set to 0 to process until the end of the video. " "Default: 0 (process all frames)" ), dest="frame_count", ) parser.add_argument( "-c", "--checkpoint", type=str, default="prs-eth/rollingdepth-v1-0", help=( "Path to the model checkpoint to use for inference. Can be either: " "- A local path to checkpoint files " "- A Hugging Face model hub identifier (e.g., 'prs-eth/rollingdepth-v1-0') " "Default: 'prs-eth/rollingdepth-v1-0'" ), dest="checkpoint", ) parser.add_argument( "--res", "--processing-resolution", type=int, default=None, help=( "Specifies the maximum resolution (in pixels) at which image processing will be performed. " "If set to None, uses the preset configuration value. " "If set to 0, processes at the original input image resolution. " "Default: None" ), dest="res", ) parser.add_argument( "--max-vae-bs", type=int, default=4, help=( "Maximum batch size for the Variational Autoencoder (VAE) processing. " "Higher values increase memory usage but may improve processing speed. " "Reduce this value if encountering out-of-memory errors. " "Default: 4" ), ) # Output settings parser.add_argument( "--fps", "--output-fps", type=int, default=0, help=( "Frame rate (FPS) for the output video. " "Set to 0 to match the input video's frame rate. " "Default: 0" ), dest="output_fps", ) parser.add_argument( "--restore-resolution", "--restore-res", type=str2bool, nargs="?", default=False, help=( "Whether to restore the output to the original input resolution after processing. " "Only applies when input has been resized during processing. " "Default: False" ), dest="restore_res", ) parser.add_argument( "--save-sbs" "--save-side-by-side", type=str2bool, nargs="?", default=True, help=( "Whether to save RGB and colored depth videos side-by-side. " "If True, the first color map will be used. " "Default: True" ), dest="save_sbs", ) parser.add_argument( "--save-npy", type=str2bool, nargs="?", default=True, help=( "Whether to save depth maps as NumPy (.npy) files. " "Enables further processing and analysis of raw depth data. " "Default: True" ), ) parser.add_argument( "--save-snippets", type=str2bool, nargs="?", default=False, help=( "Whether to save visualization snippets of the depth estimation process. " "Useful for debugging and quality assessment. " "Default: False" ), ) parser.add_argument( "--cmap", "--color-maps", type=str, nargs="+", default=["Spectral_r", "Greys_r"], help=( "One or more matplotlib color maps for depth visualization. " "Multiple maps can be specified for different visualization styles. " "Common options: 'Spectral_r', 'Greys_r', 'viridis', 'magma'. " "Use '' (empty string) to skip colorization. " "Default: ['Spectral_r', 'Greys_r']" ), dest="color_maps", ) # Inference setting parser.add_argument( "-d", "--dilations", type=int, nargs="+", default=None, help=( "Spacing between frames for temporal analysis. " "Set to None to use preset configurations based on video length. " "Custom configurations: " "- [1, 10, 25]: Best accuracy, slower processing " "- [1, 25]: Balanced speed and accuracy " "- [1, 10]: For short videos (<78 frames) " "Default: None (auto-select based on video length)" ), dest="dilations", ) parser.add_argument( "--cap-dilation", type=str2bool, default=None, help=( "Whether to automatically reduce dilation spacing for short videos. " "Set to None to use preset configuration. " "Enabling this prevents temporal windows from extending beyond video length. " "Default: None (automatically determined based on video length)" ), dest="cap_dilation", ) parser.add_argument( "--dtype", "--data-type", type=str, choices=["fp16", "fp32", None], default=None, help=( "Specifies the floating-point precision for inference operations. " "Options: 'fp16' (16-bit), 'fp32' (32-bit), or None. " "If None, uses the preset configuration value. " "Lower precision (fp16) reduces memory usage but may affect accuracy. " "Default: None" ), dest="dtype", ) parser.add_argument( "--snip-len", "--snippet-lengths", type=int, nargs="+", choices=[2, 3, 4], default=None, help=( "Number of consecutive frames to analyze in each temporal window. " "Set to None to use preset value (3). " "Can specify multiple values corresponding to different dilation rates. " "Example: '--dilations 1 25 --snippet-length 2 3' uses " "2 frames for dilation 1 and 3 frames for dilation 25. " "Allowed values: 2, 3, or 4 frames. " "Default: None" ), dest="snippet_lengths", ) parser.add_argument( "--refine-step", type=int, default=None, help=( "Number of refinement iterations to improve depth estimation accuracy. " "Set to None to use preset configuration. " "Set to 0 to disable refinement. " "Higher values may improve accuracy but increase processing time. " "Default: None (uses 0, no refinement)" ), dest="refine_step", ) parser.add_argument( "--refine-snippet-len", type=int, default=None, help=( "Length of text snippets used during the refinement phase. " "Specifies the number of sentences or segments to process at once. " "If not specified (None), system-defined preset values will be used. " "Default: None" ), ) parser.add_argument( "--refine-start-dilation", type=int, default=None, help=( "Initial dilation factor for the coarse-to-fine refinement process. " "Controls the starting granularity of the refinement steps. " "Higher values result in larger initial search windows. " "If not specified (None), uses system default. " "Default: None" ), ) # Other settings parser.add_argument( "--resample-method", type=str, choices=["BILINEAR", "NEAREST_EXACT", "BICUBIC"], default="BILINEAR", help="Resampling method used to resize images.", ) parser.add_argument( "--unload-snippet", type=str2bool, default=False, help=( "Controls memory optimization by moving processed data snippets to CPU. " "When enabled, reduces GPU memory usage at the cost of slower processing. " "Useful for systems with limited GPU memory or large datasets. " "Default: False" ), ) parser.add_argument( "--verbose", action="store_true", help=("Enable detailed progress and information reporting during processing. "), ) parser.add_argument( "--seed", type=int, default=None, help=( "Random number generator seed for reproducibility (up to computational randomness). " "Using the same seed value will produce identical results across runs. " "If not specified (None), a random seed will be used. " "Default: None" ), ) # -------------------- Config preset arguments -------------------- input_args = parser.parse_args() args = OmegaConf.create( { "res": 768, "snippet_lengths": [3], "cap_dilation": True, "dtype": "fp16", "refine_snippet_len": 3, "refine_start_dilation": 6, } ) preset_args_dict = { "fast": OmegaConf.create( { "dilations": [1, 25], "refine_step": 0, } ), "fasthr": OmegaConf.create( { "res": 1024, "dilations": [1, 25], "refine_step": 0, } ), "full": OmegaConf.create( { "res": 1024, "dilations": [1, 10, 25], "refine_step": 10, } ), "paper": OmegaConf.create( { "dilations": [1, 10, 25], "cap_dilation": False, "dtype": "fp32", "refine_step": 10, } ), } if "none" != input_args.preset: logging.info(f"Using preset: {input_args.preset}") args.update(preset_args_dict[input_args.preset]) # Merge or overwrite arguments for key, value in vars(input_args).items(): if key in args.keys(): # overwrite if value is set and different from preset if value is not None and value != args[key]: logging.warning(f"Overwritting argument: {key} = {value}") args[key] = value else: # add argument args[key] = value # sanity check assert value is not None or key in ["seed"], f"Undefined argument: {key}" msg = f"arguments: {args}" if args.verbose: logging.info(msg) else: logging.debug(msg) # Argument check if args.save_sbs: assert ( len(args.color_maps) > 0 ), "No color map is given, can not save side-by-side videos." input_video = Path(args.input_video) output_dir = Path(args.output_dir) os.makedirs(output_dir, exist_ok=True) # -------------------- Device -------------------- if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") logging.warning("CUDA is not available. Running on CPU will be slow.") logging.info(f"device = {device}") # -------------------- Data -------------------- if input_video.is_dir(): input_video_ls = os.listdir(input_video) input_video_ls = [input_video.joinpath(v_name) for v_name in input_video_ls] elif ".txt" == input_video.suffix: with open(input_video, "r") as f: input_video_ls = f.readlines() input_video_ls = [Path(s.strip()) for s in input_video_ls] else: input_video_ls = [Path(input_video)] input_video_ls = sorted(input_video_ls) logging.info(f"Found {len(input_video_ls)} videos.") # -------------------- Model -------------------- if "fp16" == args.dtype: dtype = torch.float16 elif "fp32" == args.dtype: dtype = torch.float32 else: raise ValueError(f"Unsupported dtype: {args.dtype}") pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained( args.checkpoint, torch_dtype=dtype ) # type: ignore try: pipe.enable_xformers_memory_efficient_attention() logging.info("xformers enabled") except ImportError: logging.warning("Run without xformers") pipe = pipe.to(device) # -------------------- Inference and saving -------------------- with torch.no_grad(): if args.verbose: video_iterable = tqdm(input_video_ls, desc="Processing videos", leave=True) else: video_iterable = input_video_ls for video_path in video_iterable: # Random number generator if args.seed is None: generator = None else: generator = torch.Generator(device=device) generator.manual_seed(args.seed) # Predict depth pipe_out: RollingDepthOutput = pipe( # input setting input_video_path=video_path, start_frame=args.start_frame, frame_count=args.frame_count, processing_res=args.res, resample_method=args.resample_method, # infer setting dilations=list(args.dilations), cap_dilation=args.cap_dilation, snippet_lengths=list(args.snippet_lengths), init_infer_steps=[1], strides=[1], coalign_kwargs=None, refine_step=args.refine_step, refine_snippet_len=args.refine_snippet_len, refine_start_dilation=args.refine_start_dilation, # other settings generator=generator, verbose=args.verbose, max_vae_bs=args.max_vae_bs, # output settings restore_res=args.restore_res, unload_snippet=args.unload_snippet, ) depth_pred = pipe_out.depth_pred # [N 1 H W] os.makedirs(output_dir, exist_ok=True) # Save prediction as npy if args.save_npy: save_to = output_dir.joinpath(f"{video_path.stem}_pred.npy") if args.verbose: logging.info(f"Saving predictions to {save_to}") np.save(save_to, depth_pred.numpy().squeeze(1)) # [N H W] # Save intermediate snippets if args.save_snippets and pipe_out.snippet_ls is not None: save_to = output_dir.joinpath(f"{video_path.stem}_snippets.npz") if args.verbose: logging.info(f"Saving snippets to {save_to}") snippet_dict = {} for i_dil, snippets in enumerate(pipe_out.snippet_ls): dilation = args.dilations[i_dil] snippet_dict[f"dilation{dilation}"] = snippets.numpy().squeeze( 2 ) # [n_snip, snippet_len, H W] np.savez_compressed(save_to, **snippet_dict) # Colorize results for i_cmap, cmap in enumerate(args.color_maps): if "" == cmap: continue colored_np = colorize_depth_multi_thread( depth=depth_pred.numpy(), valid_mask=None, chunk_size=4, num_threads=4, color_map=cmap, verbose=args.verbose, ) # [n h w 3], in [0, 255] save_to = output_dir.joinpath(f"{video_path.stem}_{cmap}.mp4") if not args.output_fps > 0: output_fps = int(get_video_fps(video_path)) write_video_from_numpy( frames=colored_np, output_path=save_to, fps=args.output_fps, crf=23, preset="medium", verbose=args.verbose, ) # Save side-by-side videos if args.save_sbs and 0 == i_cmap: rgb = pipe_out.input_rgb * 255 # [N 3 H W] colored_depth = einops.rearrange( torch.from_numpy(colored_np), "n h w c -> n c h w" ) concat_video = ( concatenate_videos_horizontally_torch(rgb, colored_depth, gap=10) .int() .numpy() .astype(np.uint8) ) concat_video = einops.rearrange(concat_video, "n c h w -> n h w c") save_to = output_dir.joinpath(f"{video_path.stem}_rgbd.mp4") write_video_from_numpy( frames=concat_video, output_path=save_to, fps=args.output_fps, crf=23, preset="medium", verbose=args.verbose, ) logging.info( f"Finished. {len(video_iterable)} predictions are saved to {output_dir}" )