| """
|
| Batch VDPM Inference (CLI version of Gradio Demo)
|
|
|
| This script replicates the exact logic of gradio_demo.py but for command-line usage.
|
| It supports processing a folder of video files (treated as synchronized multi-view input)
|
| or a single video file.
|
|
|
| Usage:
|
| python vdpm/infer.py --input path/to/videos_folder --output output/
|
| python vdpm/infer.py --input path/to/video.mp4 --output output/
|
| """
|
|
|
| import os
|
| import sys
|
| import glob
|
| import json
|
| import argparse
|
| import time
|
| import shutil
|
| import gc
|
| from pathlib import Path
|
| from datetime import datetime
|
|
|
| import cv2
|
| import numpy as np
|
| import torch
|
| from hydra import compose, initialize
|
| from hydra.core.global_hydra import GlobalHydra
|
|
|
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
| sys.path.insert(0, str(Path(__file__).parent))
|
|
|
| from dpm.model import VDPM
|
| from vggt.utils.load_fn import load_and_preprocess_images
|
| from util.depth import write_depth_to_png
|
|
|
|
|
|
|
|
|
|
|
| VIDEO_SAMPLE_HZ = 1.0
|
| USE_HALF_PRECISION = True
|
| USE_QUANTIZATION = False
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| MAX_FRAMES = 5
|
| if device == "cuda":
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
| vram_bytes = torch.cuda.get_device_properties(0).total_memory
|
| vram_gb = vram_bytes / (1024**3)
|
|
|
| print(f"✓ GPU Detected: {torch.cuda.get_device_name(0)} ({vram_gb:.1f} GB VRAM)")
|
|
|
| if vram_gb >= 22:
|
| MAX_FRAMES = 80
|
| print(f" -> High VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| elif vram_gb >= 14:
|
| MAX_FRAMES = 16
|
| print(f" -> Medium VRAM detected! Increased MAX_FRAMES to {MAX_FRAMES}")
|
| elif vram_gb >= 7.5:
|
| MAX_FRAMES = 8
|
| print(f" -> 8GB VRAM detected. Set MAX_FRAMES to {MAX_FRAMES}")
|
| else:
|
| MAX_FRAMES = 5
|
| print(f" -> Low VRAM (<8GB). Keeping MAX_FRAMES at {MAX_FRAMES} to prevent OOM")
|
|
|
| def require_cuda():
|
| if device != "cuda":
|
| raise ValueError("CUDA is not available. Check your environment.")
|
|
|
|
|
| def decode_poses(pose_enc: np.ndarray, image_hw: tuple) -> tuple:
|
| """Decode VGGT pose encodings to camera matrices."""
|
| try:
|
| from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
|
|
| pose_enc_t = torch.from_numpy(pose_enc).float()
|
| extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc_t, image_hw)
|
|
|
| extrinsic = extrinsic[0].numpy()
|
| intrinsic = intrinsic[0].numpy()
|
|
|
| N = extrinsic.shape[0]
|
| bottom = np.array([0, 0, 0, 1], dtype=np.float32).reshape(1, 1, 4)
|
| bottom = np.tile(bottom, (N, 1, 1))
|
| extrinsics_4x4 = np.concatenate([extrinsic, bottom], axis=1)
|
|
|
| return extrinsics_4x4, intrinsic
|
|
|
| except ImportError:
|
| print("Warning: vggt not available. Using identity poses.")
|
| N = pose_enc.shape[1]
|
| extrinsics = np.tile(np.eye(4, dtype=np.float32), (N, 1, 1))
|
|
|
| H, W = image_hw
|
| fx = fy = max(H, W)
|
| cx, cy = W / 2, H / 2
|
| intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| intrinsics = np.tile(intrinsic, (N, 1, 1))
|
|
|
| return extrinsics, intrinsics
|
|
|
|
|
| def compute_depths(world_points: np.ndarray, extrinsics: np.ndarray, num_views: int) -> np.ndarray:
|
| """
|
| Compute depth maps from world points and camera extrinsics.
|
|
|
| Args:
|
| world_points: (T, V, H, W, 3) world-space 3D points
|
| extrinsics: (N, 4, 4) camera extrinsics (world-to-camera)
|
| num_views: Number of camera views
|
|
|
| Returns:
|
| depths: (T, V, H, W) depth maps (Z in camera coordinates)
|
| """
|
| T, V, H, W, _ = world_points.shape
|
| depths = np.zeros((T, V, H, W), dtype=np.float32)
|
|
|
| for t in range(T):
|
| for v in range(V):
|
|
|
|
|
| img_idx = t * num_views + v
|
| if img_idx >= len(extrinsics):
|
| img_idx = v
|
|
|
| w2c = extrinsics[img_idx]
|
| R = w2c[:3, :3]
|
| t_vec = w2c[:3, 3]
|
|
|
|
|
| pts_world = world_points[t, v].reshape(-1, 3)
|
| pts_cam = (R @ pts_world.T).T + t_vec
|
|
|
|
|
| depth = pts_cam[:, 2].reshape(H, W)
|
| depths[t, v] = depth
|
|
|
| return depths
|
|
|
| def load_cfg_from_cli() -> "omegaconf.DictConfig":
|
| if GlobalHydra.instance().is_initialized():
|
| GlobalHydra.instance().clear()
|
|
|
| with initialize(config_path="configs"):
|
| return compose(config_name="visualise")
|
|
|
| def load_model(cfg) -> VDPM:
|
| model = VDPM(cfg).to(device)
|
|
|
|
|
| cache_dir = os.path.expanduser("~/.cache/vdpm")
|
| os.makedirs(cache_dir, exist_ok=True)
|
| model_path = os.path.join(cache_dir, "vdpm_model.pt")
|
|
|
| _URL = "https://huggingface.co/edgarsucar/vdpm/resolve/main/model.pt"
|
|
|
|
|
| if not os.path.exists(model_path):
|
| print(f"Downloading model to {model_path}...")
|
| sd = torch.hub.load_state_dict_from_url(
|
| _URL,
|
| file_name="vdpm_model.pt",
|
| progress=True,
|
| map_location=device
|
| )
|
| torch.save(sd, model_path)
|
| print(f"✓ Model cached at {model_path}")
|
| else:
|
| print(f"✓ Loading cached model from {model_path}")
|
| sd = torch.load(model_path, map_location=device)
|
|
|
| print(model.load_state_dict(sd, strict=True))
|
|
|
| model.eval()
|
|
|
| if USE_HALF_PRECISION and not USE_QUANTIZATION:
|
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| print("Converting model to BF16 precision...")
|
| model = model.to(torch.bfloat16)
|
| else:
|
| print("Converting model to FP16 precision...")
|
| model = model.half()
|
|
|
| if USE_QUANTIZATION:
|
| try:
|
| print("Applying INT8 dynamic quantization...")
|
| model = model.cpu()
|
| model = torch.quantization.quantize_dynamic(
|
| model,
|
| {torch.nn.Linear, torch.nn.Conv2d},
|
| dtype=torch.qint8
|
| )
|
| model = model.to(device)
|
| except Exception as e:
|
| print(f"⚠️ Quantization failed: {e}")
|
| model = model.to(device)
|
|
|
| if not USE_QUANTIZATION:
|
| try:
|
| print("Compiling model with torch.compile...")
|
| model = torch.compile(model, mode="reduce-overhead")
|
| except Exception as e:
|
| print(f"Warning: torch.compile failed: {e}")
|
|
|
| return model
|
|
|
|
|
|
|
|
|
|
|
| def process_videos_interleaved(input_video_list, target_dir_images):
|
| """
|
| Extract frames from multiple videos in a synchronized/interleaved manner.
|
| Matches handle_uploads logic from gradio_demo.py.
|
| """
|
| frame_num = 0
|
| image_paths = []
|
|
|
|
|
| captures = []
|
| capture_meta = []
|
|
|
| for idx, video_path in enumerate(input_video_list):
|
| print(f"Preparing video {idx+1}/{len(input_video_list)}: {video_path}")
|
|
|
| vs = cv2.VideoCapture(video_path)
|
| fps = float(vs.get(cv2.CAP_PROP_FPS) or 0.0)
|
| if fps <= 0: fps = 30.0
|
|
|
| frame_interval = max(int(fps / max(VIDEO_SAMPLE_HZ, 1e-6)), 1)
|
| captures.append(vs)
|
| capture_meta.append({"interval": frame_interval, "name": video_path})
|
|
|
|
|
| print("Processing videos in interleaved mode...")
|
| step_count = 0
|
| active_videos = True
|
|
|
| while active_videos:
|
| active_videos = False
|
| for i, vs in enumerate(captures):
|
| if not vs.isOpened():
|
| continue
|
|
|
| gotit, frame = vs.read()
|
| if gotit:
|
| active_videos = True
|
|
|
| if step_count % capture_meta[i]["interval"] == 0:
|
| out_path = os.path.join(target_dir_images, f"{frame_num:06}.png")
|
| cv2.imwrite(out_path, frame)
|
| image_paths.append(out_path)
|
| frame_num += 1
|
| else:
|
| vs.release()
|
|
|
| step_count += 1
|
|
|
| return image_paths
|
|
|
|
|
| def run_model(target_dir: str, model: VDPM, frame_id_arg=0) -> dict:
|
| require_cuda()
|
|
|
| image_names = sorted(glob.glob(os.path.join(target_dir, "images", "*")))
|
| if not image_names:
|
| raise ValueError("No images found in target_dir.")
|
|
|
|
|
| meta_path = os.path.join(target_dir, "meta.json")
|
| num_views = 1
|
| if os.path.exists(meta_path):
|
| try:
|
| with open(meta_path, 'r') as f:
|
| num_views = json.load(f).get("num_views", 1)
|
| except:
|
| pass
|
|
|
|
|
| if len(image_names) > MAX_FRAMES:
|
| limit = (MAX_FRAMES // num_views) * num_views
|
| if limit == 0:
|
| limit = num_views
|
| print(f"⚠️ Warning: MAX_FRAMES={MAX_FRAMES} is smaller than num_views={num_views}. Processing 1 full timestep anyway.")
|
|
|
| print(f"⚠️ Limiting to {limit} frames ({limit // num_views} timesteps * {num_views} views) to fit in GPU memory")
|
| image_names = image_names[:limit]
|
|
|
| print(f"Loading {len(image_names)} images...")
|
| images = load_and_preprocess_images(image_names).to(device)
|
|
|
| if device == "cuda":
|
| print(f"GPU memory before inference: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
|
|
| print(f"Running inference on {len(image_names)} images ({num_views} synchronized views)...")
|
|
|
|
|
| views = []
|
| for i in range(len(image_names)):
|
| t_idx = i // num_views
|
| cam_idx = i % num_views
|
| views.append({
|
| "img": images[i].unsqueeze(0),
|
| "view_idxs": torch.tensor([[cam_idx, t_idx]], device=device, dtype=torch.long)
|
| })
|
|
|
| inference_start = time.time()
|
|
|
| with torch.no_grad():
|
| with torch.amp.autocast('cuda'):
|
| predictions = model.inference(views=views)
|
|
|
| inference_time = time.time() - inference_start
|
| print(f"✓ Inference completed in {inference_time:.2f}s ({inference_time/len(image_names):.2f}s per frame)")
|
|
|
| pts_list = [pm["pts3d"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
| conf_list = [pm["conf"].detach().cpu().numpy() for pm in predictions["pointmaps"]]
|
|
|
|
|
| pose_enc = None
|
| if "pose_enc" in predictions:
|
| pose_enc = predictions["pose_enc"].detach().cpu().numpy()
|
|
|
| del predictions
|
| if device == "cuda":
|
| torch.cuda.empty_cache()
|
|
|
| world_points_raw = np.concatenate(pts_list, axis=0)
|
| world_points_conf_raw = np.concatenate(conf_list, axis=0)
|
|
|
| T = world_points_raw.shape[0]
|
| S = world_points_raw.shape[1]
|
| num_timesteps = T
|
|
|
| if num_views > 1 and S == num_views * T:
|
|
|
| print(f"DEBUG: Multi-view mode - extracting ALL {num_views} views")
|
| world_points_list = []
|
| world_points_conf_list = []
|
| for t in range(T):
|
| start_idx = t * num_views
|
| end_idx = start_idx + num_views
|
| world_points_list.append(world_points_raw[t, start_idx:end_idx])
|
| world_points_conf_list.append(world_points_conf_raw[t, start_idx:end_idx])
|
|
|
| world_points_mv = np.stack(world_points_list, axis=0)
|
| world_points_conf_mv = np.stack(world_points_conf_list, axis=0)
|
|
|
| world_points_full = world_points_mv
|
| world_points_conf_full = world_points_conf_mv
|
| else:
|
|
|
| if world_points_raw.ndim == 5 and world_points_raw.shape[0] == 1:
|
| world_points = world_points_raw[0]
|
| world_points_conf = world_points_conf_raw[0]
|
| elif world_points_raw.ndim == 5:
|
| world_points_list = []
|
| world_points_conf_list = []
|
| for t in range(min(T, S)):
|
| world_points_list.append(world_points_raw[t, t])
|
| world_points_conf_list.append(world_points_conf_raw[t, t])
|
| world_points = np.stack(world_points_list, axis=0)
|
| world_points_conf = np.stack(world_points_conf_list, axis=0)
|
| else:
|
| world_points = world_points_raw
|
| world_points_conf = world_points_conf_raw
|
|
|
| world_points_full = world_points
|
| world_points_conf_full = world_points_conf
|
|
|
|
|
| tracks_path = os.path.join(target_dir, "tracks.npz")
|
| print(f"Saving tracks (clean) to {tracks_path}")
|
| np.savez_compressed(
|
| tracks_path,
|
| world_points=world_points_full,
|
| world_points_conf=world_points_conf_full,
|
| num_views=num_views,
|
| num_timesteps=num_timesteps
|
| )
|
|
|
| if pose_enc is not None:
|
| poses_path = os.path.join(target_dir, "poses.npz")
|
| print(f"Saving poses to {poses_path}")
|
| np.savez_compressed(poses_path, pose_enc=pose_enc)
|
|
|
|
|
|
|
|
|
| depths = None
|
| if pose_enc is not None:
|
| print("Computing depth maps from world points and camera poses...")
|
|
|
|
|
| if world_points_full.ndim == 5:
|
| _, _, H, W, _ = world_points_full.shape
|
| elif world_points_full.ndim == 4:
|
|
|
| _, H, W, _ = world_points_full.shape
|
| world_points_full = world_points_full[:, np.newaxis, :, :, :]
|
| else:
|
| H, W = 518, 518
|
| print(f"Warning: Unexpected world_points shape {world_points_full.shape}")
|
|
|
| extrinsics, intrinsics = decode_poses(pose_enc, (H, W))
|
| depths = compute_depths(world_points_full, extrinsics, num_views)
|
|
|
|
|
| depths_path = os.path.join(target_dir, "depths.npz")
|
| print(f"Saving depths to {depths_path}")
|
| np.savez_compressed(
|
| depths_path,
|
| depths=depths,
|
| num_views=num_views,
|
| num_timesteps=num_timesteps
|
| )
|
|
|
|
|
| depths_dir = os.path.join(target_dir, "depths")
|
| os.makedirs(depths_dir, exist_ok=True)
|
| print(f"Saving depth images to {depths_dir}/")
|
|
|
| T_depth = depths.shape[0]
|
| V_depth = depths.shape[1]
|
| for t in range(T_depth):
|
| for v in range(V_depth):
|
| depth_map = depths[t, v]
|
| png_path = os.path.join(depths_dir, f"depth_t{t:04d}_v{v:02d}.png")
|
| write_depth_to_png(png_path, depth_map)
|
|
|
| print(f"✓ Saved {T_depth * V_depth} depth images")
|
| else:
|
| print("⚠ No pose encodings available - skipping depth computation")
|
|
|
|
|
| output_path = os.path.join(target_dir, "output_4d.npz")
|
| save_dict = {
|
| "world_points": world_points_full,
|
| "world_points_conf": world_points_conf_full,
|
| "timestamps": np.arange(num_timesteps),
|
| "num_views": num_views,
|
| "num_timesteps": num_timesteps
|
| }
|
| if depths is not None:
|
| save_dict["depths"] = depths
|
| np.savez_compressed(output_path, **save_dict)
|
|
|
| return {
|
| "tracks_path": tracks_path,
|
| "output_path": output_path,
|
| "depths_path": os.path.join(target_dir, "depths.npz") if depths is not None else None
|
| }
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Run VDPM Inference (CLI)")
|
| parser.add_argument("--input", required=True, help="Input video file or folder containing videos")
|
| parser.add_argument("--output", required=True, help="Output directory")
|
| parser.add_argument("--name", help="Optional name for the reconstruction folder")
|
|
|
| args = parser.parse_args()
|
|
|
| input_path = Path(args.input)
|
| output_root = Path(args.output)
|
|
|
|
|
| videos = []
|
| if input_path.is_file():
|
| videos = [str(input_path)]
|
| elif input_path.is_dir():
|
|
|
| found_videos = set()
|
| for ext in ['*.mp4', '*.mov', '*.avi', '*.mkv']:
|
|
|
| matches = glob.glob(str(input_path / ext)) + glob.glob(str(input_path / ext.upper()))
|
| for m in matches:
|
| found_videos.add(os.path.abspath(m))
|
|
|
|
|
| videos = sorted(list(found_videos))
|
|
|
| if not videos:
|
| print(f"No videos found in {input_path}")
|
| return
|
|
|
| print(f"Found {len(videos)} videos in {input_path}")
|
| else:
|
| print(f"Input {input_path} not found")
|
| return
|
|
|
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| folder_name = args.name if args.name else f"reconstruction_{timestamp}"
|
| target_dir = output_root / folder_name
|
| target_dir_images = target_dir / "images"
|
|
|
| if target_dir.exists():
|
| print(f"Cleaning existing output dir: {target_dir}")
|
| shutil.rmtree(target_dir)
|
| target_dir_images.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| process_videos_interleaved(videos, str(target_dir_images))
|
|
|
|
|
| num_views = len(videos)
|
| with open(target_dir / "meta.json", "w") as f:
|
| json.dump({"num_views": num_views}, f)
|
|
|
| print(f"Metadata saved: {num_views} view(s)")
|
|
|
|
|
| print("Loading model...")
|
| cfg = load_cfg_from_cli()
|
| model = load_model(cfg)
|
|
|
| print("Running inference...")
|
| run_model(str(target_dir), model)
|
|
|
| print(f"\n{'='*60}")
|
| print(f"Success! Output saved to:\n{target_dir}")
|
| print(f"Next step: Train Gaussian Splats using:")
|
| print(f"python gs/train_vdpm.py --input {target_dir} --output output/splats")
|
| print(f"{'='*60}")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|