| | """ |
| | Pre-Processing Pipeline: Compute BA and oracle uncertainty offline. |
| | |
| | This module handles the offline preprocessing phase that runs OUTSIDE the training |
| | loop to pre-compute expensive operations: |
| | - BA validation (CPU, expensive, slow) |
| | - Oracle uncertainty propagation (CPU, moderate) |
| | - Oracle target selection (BA vs ARKit) |
| | |
| | Results are cached to disk and loaded during training for fast iteration. |
| | |
| | Key Design: |
| | The training pipeline is split into two phases: |
| | 1. **Pre-Processing Phase** (offline, expensive): Compute BA and oracle uncertainty |
| | 2. **Training Phase** (online, fast): Load pre-computed results and train |
| | |
| | This separation allows: |
| | - BA computation outside training loop (can be parallelized) |
| | - Reuse of expensive computations across training runs |
| | - Continuous confidence weighting (not binary rejection) |
| | - Efficient training iteration (100-1000x faster) |
| | |
| | See `docs/TRAINING_PIPELINE_ARCHITECTURE.md` for detailed architecture. |
| | """ |
| |
|
| | import json |
| | import logging |
| | from pathlib import Path |
| | from typing import Dict, Optional |
| | import numpy as np |
| |
|
| | from ..utils.oracle_uncertainty import OracleUncertaintyPropagator |
| | from .arkit_processor import ARKitProcessor |
| | from .ba_validator import BAValidator |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def preprocess_arkit_sequence( |
| | arkit_dir: Path, |
| | output_cache_dir: Path, |
| | model, |
| | ba_validator: BAValidator, |
| | oracle_propagator: OracleUncertaintyPropagator, |
| | device: str = "cuda", |
| | prefer_arkit_poses: bool = True, |
| | min_arkit_quality: float = 0.8, |
| | use_lidar: bool = True, |
| | use_ba_depth: bool = False, |
| | ) -> Dict: |
| | """ |
| | Pre-process a single ARKit sequence: compute BA and oracle uncertainty. |
| | |
| | This runs OUTSIDE the training loop and can be parallelized across sequences. |
| | The preprocessing phase computes expensive operations once and caches results |
| | for fast training iteration. |
| | |
| | Processing Steps: |
| | 1. Extract ARKit data (poses, LiDAR depth) - FREE, fast |
| | 2. Run DA3 inference (GPU, batchable) - Moderate cost |
| | 3. Run BA validation (CPU, expensive) - Only if ARKit quality is poor |
| | 4. Compute oracle uncertainty propagation - Moderate cost |
| | 5. Save to cache - Fast disk I/O |
| | |
| | Oracle Target Selection: |
| | - If ARKit tracking quality >= min_arkit_quality: Use ARKit poses directly |
| | (fast, no BA needed) |
| | - Otherwise: Run BA validation to refine poses (expensive but necessary) |
| | |
| | Args: |
| | arkit_dir: Directory containing ARKit sequence with: |
| | - videos/*.MOV: Video file |
| | - metadata.json: ARKit metadata (poses, LiDAR, intrinsics) |
| | output_cache_dir: Directory to save pre-processed results. Each sequence |
| | will be saved as a subdirectory with: |
| | - oracle_targets.npz: BA/ARKit poses and depth |
| | - uncertainty_results.npz: Confidence and uncertainty maps |
| | - metadata.json: Sequence metadata |
| | model: DA3 model for initial inference. Used to generate initial predictions |
| | that are then validated/refined by BA. |
| | ba_validator: BAValidator instance for pose refinement via Bundle Adjustment. |
| | Only used if ARKit tracking quality is below threshold. |
| | oracle_propagator: OracleUncertaintyPropagator for computing uncertainty |
| | and confidence maps from multiple oracle sources (ARKit, BA, LiDAR). |
| | device: Device for DA3 inference ('cuda' or 'cpu'). Default 'cuda'. |
| | prefer_arkit_poses: If True, use ARKit poses when tracking quality is good. |
| | This avoids expensive BA computation. Default True. |
| | min_arkit_quality: Minimum ARKit tracking quality (0-1) to use ARKit poses |
| | directly. Below this threshold, BA validation is run. Default 0.8. |
| | use_lidar: Include ARKit LiDAR depth in oracle uncertainty computation. |
| | Default True. |
| | use_ba_depth: Include BA depth maps in oracle uncertainty computation. |
| | BA depth is optional and may not always be available. Default False. |
| | |
| | Returns: |
| | Dictionary with preprocessing results: |
| | { |
| | 'status': str, # 'success', 'skipped', 'error' |
| | 'reason': str, # Reason if skipped/error |
| | 'sequence_id': str, # Sequence identifier |
| | 'cache_path': Path, # Path to cached results |
| | 'num_frames': int, # Number of frames processed |
| | 'pose_source': str, # 'arkit' or 'ba' |
| | 'tracking_quality': float, # ARKit tracking quality (0-1) |
| | } |
| | |
| | Example: |
| | >>> from ylff.services.preprocessing import preprocess_arkit_sequence |
| | >>> from ylff.services.ba_validator import BAValidator |
| | >>> from ylff.utils.oracle_uncertainty import OracleUncertaintyPropagator |
| | >>> |
| | >>> result = preprocess_arkit_sequence( |
| | ... arkit_dir=Path("data/arkit_sequences/seq001"), |
| | ... output_cache_dir=Path("cache/preprocessed"), |
| | ... model=da3_model, |
| | ... ba_validator=ba_validator, |
| | ... oracle_propagator=oracle_propagator, |
| | ... prefer_arkit_poses=True, |
| | ... min_arkit_quality=0.8, |
| | ... ) |
| | |
| | Note: |
| | This function is designed to be called in parallel across multiple sequences. |
| | Each sequence is processed independently and results are cached separately. |
| | See `ylff preprocess arkit` CLI command for batch processing. |
| | """ |
| | sequence_id = arkit_dir.name |
| | sequence_cache_dir = output_cache_dir / sequence_id |
| | sequence_cache_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | try: |
| | |
| | logger.info(f"Extracting ARKit data for {sequence_id}...") |
| | processor = ARKitProcessor(arkit_dir=arkit_dir) |
| | images = processor.extract_frames( |
| | output_dir=None, max_frames=None, frame_interval=1, return_images=True |
| | ) |
| |
|
| | if len(images) < 2: |
| | return {"status": "skipped", "reason": "insufficient_frames"} |
| |
|
| | |
| | good_indices = processor.filter_good_frames() |
| | good_tracking_ratio = len(good_indices) / len(images) if images else 0.0 |
| |
|
| | |
| | is_video_only = good_tracking_ratio < 0.5 |
| | if is_video_only: |
| | logger.info( |
| | f"ARKit tracking missing or poor for {sequence_id} ({good_tracking_ratio:.1%}). " |
| | "Falling back to Video-only (BA-driven) mode." |
| | ) |
| |
|
| | |
| | arkit_poses_c2w, intrinsics = processor.get_arkit_poses() |
| | arkit_poses_w2c = processor.convert_arkit_to_w2c(arkit_poses_c2w) |
| |
|
| | |
| | |
| | if arkit_poses_c2w is not None and len(arkit_poses_c2w) > 0: |
| | min_len = min(len(images), len(arkit_poses_c2w)) |
| | if len(images) != len(arkit_poses_c2w): |
| | logger.warning( |
| | f"Syncing {sequence_id}: video has {len(images)} frames, " |
| | f"metadata has {len(arkit_poses_c2w)}. Slicing to {min_len}." |
| | ) |
| | images = images[:min_len] |
| | arkit_poses_c2w = arkit_poses_c2w[:min_len] |
| | arkit_poses_w2c = arkit_poses_w2c[:min_len] |
| | if intrinsics is not None and len(intrinsics) > 0: |
| | intrinsics = intrinsics[:min_len] |
| |
|
| | |
| | if arkit_poses_c2w is not None and arkit_poses_c2w.size == 0: |
| | arkit_poses_c2w = None |
| | if arkit_poses_w2c is not None and arkit_poses_w2c.size == 0: |
| | arkit_poses_w2c = None |
| | if intrinsics is not None and intrinsics.size == 0: |
| | intrinsics = None |
| |
|
| | |
| | lidar_depth = None |
| | if use_lidar: |
| | lidar_depth = processor.get_lidar_depths() |
| |
|
| | |
| | logger.info(f"Running DA3 inference for {sequence_id} (length: {len(images)})...") |
| | import torch |
| |
|
| | |
| | |
| | batch_size = 8 |
| | overlap = 1 |
| | |
| | all_depths = [] |
| | all_poses = [] |
| | all_intrinsics = [] |
| | |
| | last_pose = None |
| | |
| | for i in range(0, len(images), batch_size - overlap): |
| | end_idx = min(i + batch_size, len(images)) |
| | chunk_images = images[i:end_idx] |
| | |
| | |
| | if len(chunk_images) < 2 and i > 0: |
| | break |
| | |
| | chunk_arkit = arkit_poses_c2w[i:end_idx] if arkit_poses_c2w is not None else None |
| | chunk_ix = intrinsics[i:end_idx] if intrinsics is not None else None |
| | |
| | with torch.no_grad(): |
| | chunk_output = model.inference( |
| | chunk_images, |
| | extrinsics=chunk_arkit, |
| | intrinsics=chunk_ix |
| | ) |
| | |
| | |
| | c_depth = chunk_output.depth |
| | c_poses = chunk_output.extrinsics |
| | c_ix = getattr(chunk_output, "intrinsics", None) |
| | |
| | |
| | if is_video_only and last_pose is not None: |
| | |
| | |
| | |
| | |
| | |
| | p_prev = np.eye(4) |
| | p_prev[:3, :] = last_pose |
| | p_curr_start = np.eye(4) |
| | p_curr_start[:3, :] = c_poses[0] |
| | |
| | |
| | |
| | stitch_trans = p_prev @ np.linalg.inv(p_curr_start) |
| | |
| | |
| | for j in range(len(c_poses)): |
| | p_j = np.eye(4) |
| | p_j[:3, :] = c_poses[j] |
| | c_poses[j] = (stitch_trans @ p_j)[:3, :] |
| | |
| | |
| | skip = overlap if i > 0 else 0 |
| | all_depths.append(c_depth[skip:]) |
| | all_poses.append(c_poses[skip:]) |
| | if c_ix is not None: |
| | all_intrinsics.append(c_ix[skip:]) |
| | |
| | |
| | last_pose = c_poses[-1] |
| | |
| | if end_idx == len(images): |
| | break |
| |
|
| | |
| | da3_depth = np.concatenate(all_depths, axis=0) |
| | da3_poses = np.concatenate(all_poses, axis=0) |
| | da3_intrinsics = ( |
| | np.concatenate(all_intrinsics, axis=0) |
| | if all_intrinsics else (intrinsics if intrinsics is not None else None) |
| | ) |
| |
|
| | da3_output_summary = { |
| | "extrinsics": da3_poses, |
| | "depth": da3_depth, |
| | "intrinsics": da3_intrinsics |
| | } |
| |
|
| | |
| | use_arkit_poses = ( |
| | prefer_arkit_poses and |
| | good_tracking_ratio >= min_arkit_quality and |
| | not is_video_only |
| | ) |
| |
|
| | if use_arkit_poses: |
| | |
| | logger.info( |
| | f"Using ARKit poses for {sequence_id} " |
| | f"(tracking quality: {good_tracking_ratio:.1%})" |
| | ) |
| | oracle_poses = arkit_poses_w2c |
| | pose_source = "arkit" |
| | ba_poses = None |
| | ba_depths = None |
| | else: |
| | |
| | if is_video_only: |
| | logger.info(f"Running video-only BA reconstruction for {sequence_id}...") |
| | else: |
| | logger.info( |
| | f"Running BA validation for {sequence_id} " |
| | f"(ARKit tracking quality: {good_tracking_ratio:.1%} < {min_arkit_quality:.1%})" |
| | ) |
| | ba_result = ba_validator.validate( |
| | images=images, |
| | poses_model=da3_poses, |
| | intrinsics=da3_intrinsics, |
| | ) |
| |
|
| | |
| | ba_poses_extracted = ba_result.get("poses_ba") |
| | |
| | if ba_poses_extracted is None: |
| | if is_video_only: |
| | logger.warning(f"BA reconstruction failed for video-only sequence {sequence_id}") |
| | return {"status": "skipped", "reason": "ba_failed"} |
| | |
| | |
| | logger.warning(f"BA failed for {sequence_id}, falling back to ARKit poses") |
| | oracle_poses = arkit_poses_w2c |
| | pose_source = "arkit_fallback" |
| | ba_poses = None |
| | ba_depths = None |
| | else: |
| | oracle_poses = ba_poses_extracted |
| | pose_source = "ba" |
| | ba_poses = ba_poses_extracted |
| | ba_depths = ba_result.get("ba_depths") if use_ba_depth else None |
| |
|
| | |
| | logger.info(f"Computing oracle uncertainty for {sequence_id}...") |
| | uncertainty_results = oracle_propagator.propagate_uncertainty( |
| | da3_poses=da3_poses, |
| | da3_depth=da3_depth, |
| | intrinsics=intrinsics, |
| | arkit_poses=arkit_poses_c2w, |
| | ba_poses=ba_poses, |
| | lidar_depth=lidar_depth if use_lidar else None, |
| | ) |
| |
|
| | |
| | |
| | oracle_depth = None |
| | if use_lidar and lidar_depth is not None: |
| | oracle_depth = lidar_depth |
| | depth_source = "lidar" |
| | elif use_ba_depth and ba_depths is not None: |
| | oracle_depth = ba_depths |
| | depth_source = "ba" |
| | else: |
| | depth_source = "none" |
| |
|
| | |
| | logger.info(f"Saving pre-processed results for {sequence_id}...") |
| |
|
| | |
| | np.savez_compressed( |
| | sequence_cache_dir / "oracle_targets.npz", |
| | poses=oracle_poses, |
| | depth=oracle_depth if oracle_depth is not None else np.zeros((1, 1, 1)), |
| | ) |
| |
|
| | |
| | np.savez_compressed( |
| | sequence_cache_dir / "uncertainty_results.npz", |
| | pose_confidence=uncertainty_results["pose_confidence"], |
| | depth_confidence=uncertainty_results["depth_confidence"], |
| | collective_confidence=uncertainty_results["collective_confidence"], |
| | pose_uncertainty=uncertainty_results.get( |
| | "pose_uncertainty", |
| | np.zeros((len(images), 6)), |
| | ), |
| | depth_uncertainty=uncertainty_results.get( |
| | "depth_uncertainty", np.zeros_like(da3_depth) |
| | ), |
| | ) |
| |
|
| | |
| | np.savez_compressed( |
| | sequence_cache_dir / "arkit_data.npz", |
| | poses=arkit_poses_c2w, |
| | lidar_depth=lidar_depth if lidar_depth is not None else np.zeros((1, 1, 1)), |
| | ) |
| |
|
| | |
| | metadata = { |
| | "sequence_id": sequence_id, |
| | "num_frames": len(images), |
| | "tracking_quality": float(good_tracking_ratio), |
| | "pose_source": pose_source, |
| | "depth_source": depth_source, |
| | "has_lidar": lidar_depth is not None, |
| | "has_ba_depth": ba_depths is not None, |
| | "mean_pose_confidence": float(uncertainty_results["pose_confidence"].mean()), |
| | "mean_depth_confidence": float(uncertainty_results["depth_confidence"].mean()), |
| | } |
| |
|
| | with open(sequence_cache_dir / "metadata.json", "w") as f: |
| | json.dump(metadata, f, indent=2) |
| |
|
| | |
| | image_paths_file = sequence_cache_dir / "image_paths.txt" |
| | |
| | with open(image_paths_file, "w") as f: |
| | f.write(f"{arkit_dir}\n") |
| |
|
| | logger.info(f"Pre-processing complete for {sequence_id}") |
| |
|
| | return { |
| | "status": "success", |
| | "sequence_id": sequence_id, |
| | "num_frames": len(images), |
| | "pose_source": pose_source, |
| | "depth_source": depth_source, |
| | "mean_confidence": float(uncertainty_results["collective_confidence"].mean()), |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Pre-processing failed for {sequence_id}: {e}", exc_info=True) |
| | return {"status": "failed", "sequence_id": sequence_id, "error": str(e)} |
| |
|
| |
|
| | def load_preprocessed_sample(cache_dir: Path, sequence_id: str) -> Optional[Dict]: |
| | """ |
| | Load pre-processed sample from cache. |
| | |
| | Args: |
| | cache_dir: Cache directory |
| | sequence_id: Sequence identifier |
| | |
| | Returns: |
| | Dict with pre-processed data or None if not found |
| | """ |
| | sequence_cache_dir = cache_dir / sequence_id |
| |
|
| | if not sequence_cache_dir.exists(): |
| | return None |
| |
|
| | try: |
| | |
| | oracle_targets_data = np.load(sequence_cache_dir / "oracle_targets.npz") |
| | oracle_targets = { |
| | "poses": oracle_targets_data["poses"], |
| | "depth": ( |
| | oracle_targets_data["depth"] |
| | if oracle_targets_data["depth"].shape != (1, 1, 1) |
| | else None |
| | ), |
| | } |
| |
|
| | |
| | uncertainty_data = np.load(sequence_cache_dir / "uncertainty_results.npz") |
| | uncertainty_results = { |
| | "pose_confidence": uncertainty_data["pose_confidence"], |
| | "depth_confidence": uncertainty_data["depth_confidence"], |
| | "collective_confidence": uncertainty_data["collective_confidence"], |
| | "pose_uncertainty": uncertainty_data.get("pose_uncertainty"), |
| | "depth_uncertainty": uncertainty_data.get("depth_uncertainty"), |
| | } |
| |
|
| | |
| | arkit_data_file = sequence_cache_dir / "arkit_data.npz" |
| | arkit_data = None |
| | if arkit_data_file.exists(): |
| | arkit_data_npz = np.load(arkit_data_file) |
| | arkit_data = { |
| | "poses": arkit_data_npz["poses"], |
| | "lidar_depth": ( |
| | arkit_data_npz["lidar_depth"] |
| | if arkit_data_npz["lidar_depth"].shape != (1, 1, 1) |
| | else None |
| | ), |
| | } |
| |
|
| | |
| | metadata_file = sequence_cache_dir / "metadata.json" |
| | metadata = {} |
| | if metadata_file.exists(): |
| | with open(metadata_file) as f: |
| | metadata = json.load(f) |
| |
|
| | return { |
| | "oracle_targets": oracle_targets, |
| | "uncertainty_results": uncertainty_results, |
| | "arkit_data": arkit_data, |
| | "metadata": metadata, |
| | "sequence_id": sequence_id, |
| | } |
| |
|
| | except Exception as e: |
| | logger.error(f"Failed to load pre-processed sample {sequence_id}: {e}") |
| | return None |
| |
|