Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved. | |
# Last modified: 2024-11-28 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# --------------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation | |
# More information about the method can be found at https://rollingdepth.github.io | |
# --------------------------------------------------------------------------------- | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from tqdm.auto import tqdm | |
import einops | |
from omegaconf import OmegaConf | |
from rollingdepth import ( | |
RollingDepthOutput, | |
RollingDepthPipeline, | |
write_video_from_numpy, | |
get_video_fps, | |
concatenate_videos_horizontally_torch, | |
) | |
from src.util.colorize import colorize_depth_multi_thread | |
from src.util.config import str2bool | |
if "__main__" == __name__: | |
logging.basicConfig(level=logging.INFO) | |
# -------------------- Arguments -------------------- | |
parser = argparse.ArgumentParser( | |
description="Run video depth estimation using RollingDepth." | |
) | |
parser.add_argument( | |
"-i", | |
"--input-video", | |
type=str, | |
required=True, | |
help=( | |
"Path to the input video(s) to be processed. Accepts: " | |
"- Single video file path (e.g., 'video.mp4') " | |
"- Text file containing a list of video paths (one per line) " | |
"- Directory path containing video files " | |
"Required argument." | |
), | |
dest="input_video", | |
) | |
parser.add_argument( | |
"-o", | |
"--output-dir", | |
type=str, | |
required=True, | |
help=( | |
"Directory path where processed outputs will be saved. " | |
"Will be created if it doesn't exist. " | |
"Required argument." | |
), | |
dest="output_dir", | |
) | |
parser.add_argument( | |
"-p", | |
"--preset", | |
type=str, | |
choices=["fast", "fast1024", "full", "paper", "none"], | |
help="Inference preset. TODO: write detailed explanation", | |
) | |
parser.add_argument( | |
"--start-frame", | |
"--from", | |
type=int, | |
default=0, | |
help=( | |
"Specifies the starting frame index for processing. " | |
"Use 0 to start from the beginning of the video. " | |
"Default: 0" | |
), | |
dest="start_frame", | |
) | |
parser.add_argument( | |
"--frame-count", | |
"--frames", | |
type=int, | |
default=0, | |
help=( | |
"Number of frames to process after the starting frame. " | |
"Set to 0 to process until the end of the video. " | |
"Default: 0 (process all frames)" | |
), | |
dest="frame_count", | |
) | |
parser.add_argument( | |
"-c", | |
"--checkpoint", | |
type=str, | |
default="prs-eth/rollingdepth-v1-0", | |
help=( | |
"Path to the model checkpoint to use for inference. Can be either: " | |
"- A local path to checkpoint files " | |
"- A Hugging Face model hub identifier (e.g., 'prs-eth/rollingdepth-v1-0') " | |
"Default: 'prs-eth/rollingdepth-v1-0'" | |
), | |
dest="checkpoint", | |
) | |
parser.add_argument( | |
"--res", | |
"--processing-resolution", | |
type=int, | |
default=None, | |
help=( | |
"Specifies the maximum resolution (in pixels) at which image processing will be performed. " | |
"If set to None, uses the preset configuration value. " | |
"If set to 0, processes at the original input image resolution. " | |
"Default: None" | |
), | |
dest="res", | |
) | |
parser.add_argument( | |
"--max-vae-bs", | |
type=int, | |
default=4, | |
help=( | |
"Maximum batch size for the Variational Autoencoder (VAE) processing. " | |
"Higher values increase memory usage but may improve processing speed. " | |
"Reduce this value if encountering out-of-memory errors. " | |
"Default: 4" | |
), | |
) | |
# Output settings | |
parser.add_argument( | |
"--fps", | |
"--output-fps", | |
type=int, | |
default=0, | |
help=( | |
"Frame rate (FPS) for the output video. " | |
"Set to 0 to match the input video's frame rate. " | |
"Default: 0" | |
), | |
dest="output_fps", | |
) | |
parser.add_argument( | |
"--restore-resolution", | |
"--restore-res", | |
type=str2bool, | |
nargs="?", | |
default=False, | |
help=( | |
"Whether to restore the output to the original input resolution after processing. " | |
"Only applies when input has been resized during processing. " | |
"Default: False" | |
), | |
dest="restore_res", | |
) | |
parser.add_argument( | |
"--save-sbs" "--save-side-by-side", | |
type=str2bool, | |
nargs="?", | |
default=True, | |
help=( | |
"Whether to save RGB and colored depth videos side-by-side. " | |
"If True, the first color map will be used. " | |
"Default: True" | |
), | |
dest="save_sbs", | |
) | |
parser.add_argument( | |
"--save-npy", | |
type=str2bool, | |
nargs="?", | |
default=True, | |
help=( | |
"Whether to save depth maps as NumPy (.npy) files. " | |
"Enables further processing and analysis of raw depth data. " | |
"Default: True" | |
), | |
) | |
parser.add_argument( | |
"--save-snippets", | |
type=str2bool, | |
nargs="?", | |
default=False, | |
help=( | |
"Whether to save visualization snippets of the depth estimation process. " | |
"Useful for debugging and quality assessment. " | |
"Default: False" | |
), | |
) | |
parser.add_argument( | |
"--cmap", | |
"--color-maps", | |
type=str, | |
nargs="+", | |
default=["Spectral_r", "Greys_r"], | |
help=( | |
"One or more matplotlib color maps for depth visualization. " | |
"Multiple maps can be specified for different visualization styles. " | |
"Common options: 'Spectral_r', 'Greys_r', 'viridis', 'magma'. " | |
"Use '' (empty string) to skip colorization. " | |
"Default: ['Spectral_r', 'Greys_r']" | |
), | |
dest="color_maps", | |
) | |
# Inference setting | |
parser.add_argument( | |
"-d", | |
"--dilations", | |
type=int, | |
nargs="+", | |
default=None, | |
help=( | |
"Spacing between frames for temporal analysis. " | |
"Set to None to use preset configurations based on video length. " | |
"Custom configurations: " | |
"- [1, 10, 25]: Best accuracy, slower processing " | |
"- [1, 25]: Balanced speed and accuracy " | |
"- [1, 10]: For short videos (<78 frames) " | |
"Default: None (auto-select based on video length)" | |
), | |
dest="dilations", | |
) | |
parser.add_argument( | |
"--cap-dilation", | |
type=str2bool, | |
default=None, | |
help=( | |
"Whether to automatically reduce dilation spacing for short videos. " | |
"Set to None to use preset configuration. " | |
"Enabling this prevents temporal windows from extending beyond video length. " | |
"Default: None (automatically determined based on video length)" | |
), | |
dest="cap_dilation", | |
) | |
parser.add_argument( | |
"--dtype", | |
"--data-type", | |
type=str, | |
choices=["fp16", "fp32", None], | |
default=None, | |
help=( | |
"Specifies the floating-point precision for inference operations. " | |
"Options: 'fp16' (16-bit), 'fp32' (32-bit), or None. " | |
"If None, uses the preset configuration value. " | |
"Lower precision (fp16) reduces memory usage but may affect accuracy. " | |
"Default: None" | |
), | |
dest="dtype", | |
) | |
parser.add_argument( | |
"--snip-len", | |
"--snippet-lengths", | |
type=int, | |
nargs="+", | |
choices=[2, 3, 4], | |
default=None, | |
help=( | |
"Number of consecutive frames to analyze in each temporal window. " | |
"Set to None to use preset value (3). " | |
"Can specify multiple values corresponding to different dilation rates. " | |
"Example: '--dilations 1 25 --snippet-length 2 3' uses " | |
"2 frames for dilation 1 and 3 frames for dilation 25. " | |
"Allowed values: 2, 3, or 4 frames. " | |
"Default: None" | |
), | |
dest="snippet_lengths", | |
) | |
parser.add_argument( | |
"--refine-step", | |
type=int, | |
default=None, | |
help=( | |
"Number of refinement iterations to improve depth estimation accuracy. " | |
"Set to None to use preset configuration. " | |
"Set to 0 to disable refinement. " | |
"Higher values may improve accuracy but increase processing time. " | |
"Default: None (uses 0, no refinement)" | |
), | |
dest="refine_step", | |
) | |
parser.add_argument( | |
"--refine-snippet-len", | |
type=int, | |
default=None, | |
help=( | |
"Length of text snippets used during the refinement phase. " | |
"Specifies the number of sentences or segments to process at once. " | |
"If not specified (None), system-defined preset values will be used. " | |
"Default: None" | |
), | |
) | |
parser.add_argument( | |
"--refine-start-dilation", | |
type=int, | |
default=None, | |
help=( | |
"Initial dilation factor for the coarse-to-fine refinement process. " | |
"Controls the starting granularity of the refinement steps. " | |
"Higher values result in larger initial search windows. " | |
"If not specified (None), uses system default. " | |
"Default: None" | |
), | |
) | |
# Other settings | |
parser.add_argument( | |
"--resample-method", | |
type=str, | |
choices=["BILINEAR", "NEAREST_EXACT", "BICUBIC"], | |
default="BILINEAR", | |
help="Resampling method used to resize images.", | |
) | |
parser.add_argument( | |
"--unload-snippet", | |
type=str2bool, | |
default=False, | |
help=( | |
"Controls memory optimization by moving processed data snippets to CPU. " | |
"When enabled, reduces GPU memory usage at the cost of slower processing. " | |
"Useful for systems with limited GPU memory or large datasets. " | |
"Default: False" | |
), | |
) | |
parser.add_argument( | |
"--verbose", | |
action="store_true", | |
help=("Enable detailed progress and information reporting during processing. "), | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=None, | |
help=( | |
"Random number generator seed for reproducibility (up to computational randomness). " | |
"Using the same seed value will produce identical results across runs. " | |
"If not specified (None), a random seed will be used. " | |
"Default: None" | |
), | |
) | |
# -------------------- Config preset arguments -------------------- | |
input_args = parser.parse_args() | |
args = OmegaConf.create( | |
{ | |
"res": 768, | |
"snippet_lengths": [3], | |
"cap_dilation": True, | |
"dtype": "fp16", | |
"refine_snippet_len": 3, | |
"refine_start_dilation": 6, | |
} | |
) | |
preset_args_dict = { | |
"fast": OmegaConf.create( | |
{ | |
"dilations": [1, 25], | |
"refine_step": 0, | |
} | |
), | |
"fasthr": OmegaConf.create( | |
{ | |
"res": 1024, | |
"dilations": [1, 25], | |
"refine_step": 0, | |
} | |
), | |
"full": OmegaConf.create( | |
{ | |
"res": 1024, | |
"dilations": [1, 10, 25], | |
"refine_step": 10, | |
} | |
), | |
"paper": OmegaConf.create( | |
{ | |
"dilations": [1, 10, 25], | |
"cap_dilation": False, | |
"dtype": "fp32", | |
"refine_step": 10, | |
} | |
), | |
} | |
if "none" != input_args.preset: | |
logging.info(f"Using preset: {input_args.preset}") | |
args.update(preset_args_dict[input_args.preset]) | |
# Merge or overwrite arguments | |
for key, value in vars(input_args).items(): | |
if key in args.keys(): | |
# overwrite if value is set and different from preset | |
if value is not None and value != args[key]: | |
logging.warning(f"Overwritting argument: {key} = {value}") | |
args[key] = value | |
else: | |
# add argument | |
args[key] = value | |
# sanity check | |
assert value is not None or key in ["seed"], f"Undefined argument: {key}" | |
msg = f"arguments: {args}" | |
if args.verbose: | |
logging.info(msg) | |
else: | |
logging.debug(msg) | |
# Argument check | |
if args.save_sbs: | |
assert ( | |
len(args.color_maps) > 0 | |
), "No color map is given, can not save side-by-side videos." | |
input_video = Path(args.input_video) | |
output_dir = Path(args.output_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
# -------------------- Device -------------------- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
logging.warning("CUDA is not available. Running on CPU will be slow.") | |
logging.info(f"device = {device}") | |
# -------------------- Data -------------------- | |
if input_video.is_dir(): | |
input_video_ls = os.listdir(input_video) | |
input_video_ls = [input_video.joinpath(v_name) for v_name in input_video_ls] | |
elif ".txt" == input_video.suffix: | |
with open(input_video, "r") as f: | |
input_video_ls = f.readlines() | |
input_video_ls = [Path(s.strip()) for s in input_video_ls] | |
else: | |
input_video_ls = [Path(input_video)] | |
input_video_ls = sorted(input_video_ls) | |
logging.info(f"Found {len(input_video_ls)} videos.") | |
# -------------------- Model -------------------- | |
if "fp16" == args.dtype: | |
dtype = torch.float16 | |
elif "fp32" == args.dtype: | |
dtype = torch.float32 | |
else: | |
raise ValueError(f"Unsupported dtype: {args.dtype}") | |
pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained( | |
args.checkpoint, torch_dtype=dtype | |
) # type: ignore | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
logging.info("xformers enabled") | |
except ImportError: | |
logging.warning("Run without xformers") | |
pipe = pipe.to(device) | |
# -------------------- Inference and saving -------------------- | |
with torch.no_grad(): | |
if args.verbose: | |
video_iterable = tqdm(input_video_ls, desc="Processing videos", leave=True) | |
else: | |
video_iterable = input_video_ls | |
for video_path in video_iterable: | |
# Random number generator | |
if args.seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=device) | |
generator.manual_seed(args.seed) | |
# Predict depth | |
pipe_out: RollingDepthOutput = pipe( | |
# input setting | |
input_video_path=video_path, | |
start_frame=args.start_frame, | |
frame_count=args.frame_count, | |
processing_res=args.res, | |
resample_method=args.resample_method, | |
# infer setting | |
dilations=list(args.dilations), | |
cap_dilation=args.cap_dilation, | |
snippet_lengths=list(args.snippet_lengths), | |
init_infer_steps=[1], | |
strides=[1], | |
coalign_kwargs=None, | |
refine_step=args.refine_step, | |
refine_snippet_len=args.refine_snippet_len, | |
refine_start_dilation=args.refine_start_dilation, | |
# other settings | |
generator=generator, | |
verbose=args.verbose, | |
max_vae_bs=args.max_vae_bs, | |
# output settings | |
restore_res=args.restore_res, | |
unload_snippet=args.unload_snippet, | |
) | |
depth_pred = pipe_out.depth_pred # [N 1 H W] | |
os.makedirs(output_dir, exist_ok=True) | |
# Save prediction as npy | |
if args.save_npy: | |
save_to = output_dir.joinpath(f"{video_path.stem}_pred.npy") | |
if args.verbose: | |
logging.info(f"Saving predictions to {save_to}") | |
np.save(save_to, depth_pred.numpy().squeeze(1)) # [N H W] | |
# Save intermediate snippets | |
if args.save_snippets and pipe_out.snippet_ls is not None: | |
save_to = output_dir.joinpath(f"{video_path.stem}_snippets.npz") | |
if args.verbose: | |
logging.info(f"Saving snippets to {save_to}") | |
snippet_dict = {} | |
for i_dil, snippets in enumerate(pipe_out.snippet_ls): | |
dilation = args.dilations[i_dil] | |
snippet_dict[f"dilation{dilation}"] = snippets.numpy().squeeze( | |
2 | |
) # [n_snip, snippet_len, H W] | |
np.savez_compressed(save_to, **snippet_dict) | |
# Colorize results | |
for i_cmap, cmap in enumerate(args.color_maps): | |
if "" == cmap: | |
continue | |
colored_np = colorize_depth_multi_thread( | |
depth=depth_pred.numpy(), | |
valid_mask=None, | |
chunk_size=4, | |
num_threads=4, | |
color_map=cmap, | |
verbose=args.verbose, | |
) # [n h w 3], in [0, 255] | |
save_to = output_dir.joinpath(f"{video_path.stem}_{cmap}.mp4") | |
if not args.output_fps > 0: | |
output_fps = int(get_video_fps(video_path)) | |
write_video_from_numpy( | |
frames=colored_np, | |
output_path=save_to, | |
fps=args.output_fps, | |
crf=23, | |
preset="medium", | |
verbose=args.verbose, | |
) | |
# Save side-by-side videos | |
if args.save_sbs and 0 == i_cmap: | |
rgb = pipe_out.input_rgb * 255 # [N 3 H W] | |
colored_depth = einops.rearrange( | |
torch.from_numpy(colored_np), "n h w c -> n c h w" | |
) | |
concat_video = ( | |
concatenate_videos_horizontally_torch(rgb, colored_depth, gap=10) | |
.int() | |
.numpy() | |
.astype(np.uint8) | |
) | |
concat_video = einops.rearrange(concat_video, "n c h w -> n h w c") | |
save_to = output_dir.joinpath(f"{video_path.stem}_rgbd.mp4") | |
write_video_from_numpy( | |
frames=concat_video, | |
output_path=save_to, | |
fps=args.output_fps, | |
crf=23, | |
preset="medium", | |
verbose=args.verbose, | |
) | |
logging.info( | |
f"Finished. {len(video_iterable)} predictions are saved to {output_dir}" | |
) | |