toshas's picture
initial commit
a45988a
raw
history blame
20.3 kB
# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-11-28
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation
# More information about the method can be found at https://rollingdepth.github.io
# ---------------------------------------------------------------------------------
import argparse
import logging
import os
from pathlib import Path
import numpy as np
import torch
from tqdm.auto import tqdm
import einops
from omegaconf import OmegaConf
from rollingdepth import (
RollingDepthOutput,
RollingDepthPipeline,
write_video_from_numpy,
get_video_fps,
concatenate_videos_horizontally_torch,
)
from src.util.colorize import colorize_depth_multi_thread
from src.util.config import str2bool
if "__main__" == __name__:
logging.basicConfig(level=logging.INFO)
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(
description="Run video depth estimation using RollingDepth."
)
parser.add_argument(
"-i",
"--input-video",
type=str,
required=True,
help=(
"Path to the input video(s) to be processed. Accepts: "
"- Single video file path (e.g., 'video.mp4') "
"- Text file containing a list of video paths (one per line) "
"- Directory path containing video files "
"Required argument."
),
dest="input_video",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
required=True,
help=(
"Directory path where processed outputs will be saved. "
"Will be created if it doesn't exist. "
"Required argument."
),
dest="output_dir",
)
parser.add_argument(
"-p",
"--preset",
type=str,
choices=["fast", "fast1024", "full", "paper", "none"],
help="Inference preset. TODO: write detailed explanation",
)
parser.add_argument(
"--start-frame",
"--from",
type=int,
default=0,
help=(
"Specifies the starting frame index for processing. "
"Use 0 to start from the beginning of the video. "
"Default: 0"
),
dest="start_frame",
)
parser.add_argument(
"--frame-count",
"--frames",
type=int,
default=0,
help=(
"Number of frames to process after the starting frame. "
"Set to 0 to process until the end of the video. "
"Default: 0 (process all frames)"
),
dest="frame_count",
)
parser.add_argument(
"-c",
"--checkpoint",
type=str,
default="prs-eth/rollingdepth-v1-0",
help=(
"Path to the model checkpoint to use for inference. Can be either: "
"- A local path to checkpoint files "
"- A Hugging Face model hub identifier (e.g., 'prs-eth/rollingdepth-v1-0') "
"Default: 'prs-eth/rollingdepth-v1-0'"
),
dest="checkpoint",
)
parser.add_argument(
"--res",
"--processing-resolution",
type=int,
default=None,
help=(
"Specifies the maximum resolution (in pixels) at which image processing will be performed. "
"If set to None, uses the preset configuration value. "
"If set to 0, processes at the original input image resolution. "
"Default: None"
),
dest="res",
)
parser.add_argument(
"--max-vae-bs",
type=int,
default=4,
help=(
"Maximum batch size for the Variational Autoencoder (VAE) processing. "
"Higher values increase memory usage but may improve processing speed. "
"Reduce this value if encountering out-of-memory errors. "
"Default: 4"
),
)
# Output settings
parser.add_argument(
"--fps",
"--output-fps",
type=int,
default=0,
help=(
"Frame rate (FPS) for the output video. "
"Set to 0 to match the input video's frame rate. "
"Default: 0"
),
dest="output_fps",
)
parser.add_argument(
"--restore-resolution",
"--restore-res",
type=str2bool,
nargs="?",
default=False,
help=(
"Whether to restore the output to the original input resolution after processing. "
"Only applies when input has been resized during processing. "
"Default: False"
),
dest="restore_res",
)
parser.add_argument(
"--save-sbs" "--save-side-by-side",
type=str2bool,
nargs="?",
default=True,
help=(
"Whether to save RGB and colored depth videos side-by-side. "
"If True, the first color map will be used. "
"Default: True"
),
dest="save_sbs",
)
parser.add_argument(
"--save-npy",
type=str2bool,
nargs="?",
default=True,
help=(
"Whether to save depth maps as NumPy (.npy) files. "
"Enables further processing and analysis of raw depth data. "
"Default: True"
),
)
parser.add_argument(
"--save-snippets",
type=str2bool,
nargs="?",
default=False,
help=(
"Whether to save visualization snippets of the depth estimation process. "
"Useful for debugging and quality assessment. "
"Default: False"
),
)
parser.add_argument(
"--cmap",
"--color-maps",
type=str,
nargs="+",
default=["Spectral_r", "Greys_r"],
help=(
"One or more matplotlib color maps for depth visualization. "
"Multiple maps can be specified for different visualization styles. "
"Common options: 'Spectral_r', 'Greys_r', 'viridis', 'magma'. "
"Use '' (empty string) to skip colorization. "
"Default: ['Spectral_r', 'Greys_r']"
),
dest="color_maps",
)
# Inference setting
parser.add_argument(
"-d",
"--dilations",
type=int,
nargs="+",
default=None,
help=(
"Spacing between frames for temporal analysis. "
"Set to None to use preset configurations based on video length. "
"Custom configurations: "
"- [1, 10, 25]: Best accuracy, slower processing "
"- [1, 25]: Balanced speed and accuracy "
"- [1, 10]: For short videos (<78 frames) "
"Default: None (auto-select based on video length)"
),
dest="dilations",
)
parser.add_argument(
"--cap-dilation",
type=str2bool,
default=None,
help=(
"Whether to automatically reduce dilation spacing for short videos. "
"Set to None to use preset configuration. "
"Enabling this prevents temporal windows from extending beyond video length. "
"Default: None (automatically determined based on video length)"
),
dest="cap_dilation",
)
parser.add_argument(
"--dtype",
"--data-type",
type=str,
choices=["fp16", "fp32", None],
default=None,
help=(
"Specifies the floating-point precision for inference operations. "
"Options: 'fp16' (16-bit), 'fp32' (32-bit), or None. "
"If None, uses the preset configuration value. "
"Lower precision (fp16) reduces memory usage but may affect accuracy. "
"Default: None"
),
dest="dtype",
)
parser.add_argument(
"--snip-len",
"--snippet-lengths",
type=int,
nargs="+",
choices=[2, 3, 4],
default=None,
help=(
"Number of consecutive frames to analyze in each temporal window. "
"Set to None to use preset value (3). "
"Can specify multiple values corresponding to different dilation rates. "
"Example: '--dilations 1 25 --snippet-length 2 3' uses "
"2 frames for dilation 1 and 3 frames for dilation 25. "
"Allowed values: 2, 3, or 4 frames. "
"Default: None"
),
dest="snippet_lengths",
)
parser.add_argument(
"--refine-step",
type=int,
default=None,
help=(
"Number of refinement iterations to improve depth estimation accuracy. "
"Set to None to use preset configuration. "
"Set to 0 to disable refinement. "
"Higher values may improve accuracy but increase processing time. "
"Default: None (uses 0, no refinement)"
),
dest="refine_step",
)
parser.add_argument(
"--refine-snippet-len",
type=int,
default=None,
help=(
"Length of text snippets used during the refinement phase. "
"Specifies the number of sentences or segments to process at once. "
"If not specified (None), system-defined preset values will be used. "
"Default: None"
),
)
parser.add_argument(
"--refine-start-dilation",
type=int,
default=None,
help=(
"Initial dilation factor for the coarse-to-fine refinement process. "
"Controls the starting granularity of the refinement steps. "
"Higher values result in larger initial search windows. "
"If not specified (None), uses system default. "
"Default: None"
),
)
# Other settings
parser.add_argument(
"--resample-method",
type=str,
choices=["BILINEAR", "NEAREST_EXACT", "BICUBIC"],
default="BILINEAR",
help="Resampling method used to resize images.",
)
parser.add_argument(
"--unload-snippet",
type=str2bool,
default=False,
help=(
"Controls memory optimization by moving processed data snippets to CPU. "
"When enabled, reduces GPU memory usage at the cost of slower processing. "
"Useful for systems with limited GPU memory or large datasets. "
"Default: False"
),
)
parser.add_argument(
"--verbose",
action="store_true",
help=("Enable detailed progress and information reporting during processing. "),
)
parser.add_argument(
"--seed",
type=int,
default=None,
help=(
"Random number generator seed for reproducibility (up to computational randomness). "
"Using the same seed value will produce identical results across runs. "
"If not specified (None), a random seed will be used. "
"Default: None"
),
)
# -------------------- Config preset arguments --------------------
input_args = parser.parse_args()
args = OmegaConf.create(
{
"res": 768,
"snippet_lengths": [3],
"cap_dilation": True,
"dtype": "fp16",
"refine_snippet_len": 3,
"refine_start_dilation": 6,
}
)
preset_args_dict = {
"fast": OmegaConf.create(
{
"dilations": [1, 25],
"refine_step": 0,
}
),
"fasthr": OmegaConf.create(
{
"res": 1024,
"dilations": [1, 25],
"refine_step": 0,
}
),
"full": OmegaConf.create(
{
"res": 1024,
"dilations": [1, 10, 25],
"refine_step": 10,
}
),
"paper": OmegaConf.create(
{
"dilations": [1, 10, 25],
"cap_dilation": False,
"dtype": "fp32",
"refine_step": 10,
}
),
}
if "none" != input_args.preset:
logging.info(f"Using preset: {input_args.preset}")
args.update(preset_args_dict[input_args.preset])
# Merge or overwrite arguments
for key, value in vars(input_args).items():
if key in args.keys():
# overwrite if value is set and different from preset
if value is not None and value != args[key]:
logging.warning(f"Overwritting argument: {key} = {value}")
args[key] = value
else:
# add argument
args[key] = value
# sanity check
assert value is not None or key in ["seed"], f"Undefined argument: {key}"
msg = f"arguments: {args}"
if args.verbose:
logging.info(msg)
else:
logging.debug(msg)
# Argument check
if args.save_sbs:
assert (
len(args.color_maps) > 0
), "No color map is given, can not save side-by-side videos."
input_video = Path(args.input_video)
output_dir = Path(args.output_dir)
os.makedirs(output_dir, exist_ok=True)
# -------------------- Device --------------------
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
logging.warning("CUDA is not available. Running on CPU will be slow.")
logging.info(f"device = {device}")
# -------------------- Data --------------------
if input_video.is_dir():
input_video_ls = os.listdir(input_video)
input_video_ls = [input_video.joinpath(v_name) for v_name in input_video_ls]
elif ".txt" == input_video.suffix:
with open(input_video, "r") as f:
input_video_ls = f.readlines()
input_video_ls = [Path(s.strip()) for s in input_video_ls]
else:
input_video_ls = [Path(input_video)]
input_video_ls = sorted(input_video_ls)
logging.info(f"Found {len(input_video_ls)} videos.")
# -------------------- Model --------------------
if "fp16" == args.dtype:
dtype = torch.float16
elif "fp32" == args.dtype:
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained(
args.checkpoint, torch_dtype=dtype
) # type: ignore
try:
pipe.enable_xformers_memory_efficient_attention()
logging.info("xformers enabled")
except ImportError:
logging.warning("Run without xformers")
pipe = pipe.to(device)
# -------------------- Inference and saving --------------------
with torch.no_grad():
if args.verbose:
video_iterable = tqdm(input_video_ls, desc="Processing videos", leave=True)
else:
video_iterable = input_video_ls
for video_path in video_iterable:
# Random number generator
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=device)
generator.manual_seed(args.seed)
# Predict depth
pipe_out: RollingDepthOutput = pipe(
# input setting
input_video_path=video_path,
start_frame=args.start_frame,
frame_count=args.frame_count,
processing_res=args.res,
resample_method=args.resample_method,
# infer setting
dilations=list(args.dilations),
cap_dilation=args.cap_dilation,
snippet_lengths=list(args.snippet_lengths),
init_infer_steps=[1],
strides=[1],
coalign_kwargs=None,
refine_step=args.refine_step,
refine_snippet_len=args.refine_snippet_len,
refine_start_dilation=args.refine_start_dilation,
# other settings
generator=generator,
verbose=args.verbose,
max_vae_bs=args.max_vae_bs,
# output settings
restore_res=args.restore_res,
unload_snippet=args.unload_snippet,
)
depth_pred = pipe_out.depth_pred # [N 1 H W]
os.makedirs(output_dir, exist_ok=True)
# Save prediction as npy
if args.save_npy:
save_to = output_dir.joinpath(f"{video_path.stem}_pred.npy")
if args.verbose:
logging.info(f"Saving predictions to {save_to}")
np.save(save_to, depth_pred.numpy().squeeze(1)) # [N H W]
# Save intermediate snippets
if args.save_snippets and pipe_out.snippet_ls is not None:
save_to = output_dir.joinpath(f"{video_path.stem}_snippets.npz")
if args.verbose:
logging.info(f"Saving snippets to {save_to}")
snippet_dict = {}
for i_dil, snippets in enumerate(pipe_out.snippet_ls):
dilation = args.dilations[i_dil]
snippet_dict[f"dilation{dilation}"] = snippets.numpy().squeeze(
2
) # [n_snip, snippet_len, H W]
np.savez_compressed(save_to, **snippet_dict)
# Colorize results
for i_cmap, cmap in enumerate(args.color_maps):
if "" == cmap:
continue
colored_np = colorize_depth_multi_thread(
depth=depth_pred.numpy(),
valid_mask=None,
chunk_size=4,
num_threads=4,
color_map=cmap,
verbose=args.verbose,
) # [n h w 3], in [0, 255]
save_to = output_dir.joinpath(f"{video_path.stem}_{cmap}.mp4")
if not args.output_fps > 0:
output_fps = int(get_video_fps(video_path))
write_video_from_numpy(
frames=colored_np,
output_path=save_to,
fps=args.output_fps,
crf=23,
preset="medium",
verbose=args.verbose,
)
# Save side-by-side videos
if args.save_sbs and 0 == i_cmap:
rgb = pipe_out.input_rgb * 255 # [N 3 H W]
colored_depth = einops.rearrange(
torch.from_numpy(colored_np), "n h w c -> n c h w"
)
concat_video = (
concatenate_videos_horizontally_torch(rgb, colored_depth, gap=10)
.int()
.numpy()
.astype(np.uint8)
)
concat_video = einops.rearrange(concat_video, "n c h w -> n h w c")
save_to = output_dir.joinpath(f"{video_path.stem}_rgbd.mp4")
write_video_from_numpy(
frames=concat_video,
output_path=save_to,
fps=args.output_fps,
crf=23,
preset="medium",
verbose=args.verbose,
)
logging.info(
f"Finished. {len(video_iterable)} predictions are saved to {output_dir}"
)