DepthAnyPanorama-coreml / export_and_validate_coreml.py
Kyle Pearson
intermediate seam fix algorithm
ac6a5ca
#!/usr/bin/env python3
"""
Export DAP (Depth Any Panoramas) to CoreML for iOS/macOS.
Produces a single-output CoreML model (depth map only) with ImageType input,
compatible with Vision framework and the included DepthPredictor.swift.
Usage:
python export_and_validate_coreml.py
python export_and_validate_coreml.py --height 768 --width 1536
"""
import os
import sys
import time
import numpy as np
import torch
import coremltools as ct
from PIL import Image
from argparse import ArgumentParser
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from networks.dap import make_model
# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------
class DAPSingleOutputWrapper(torch.nn.Module):
"""Returns only depth map (clamped to [0, 10]). Uses torch.where for CoreML compat."""
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
def __init__(self, dap_model):
super().__init__()
self.dap = dap_model
self.register_buffer("mean", self.IMAGENET_MEAN)
self.register_buffer("std", self.IMAGENET_STD)
def forward(self, x):
# x is float32 [0,1] from CoreML ImageType (scale=1/255 applied internally)
x = (x - self.mean) / self.std
out = self.dap(x)
depth = out["pred_depth"] # [B, 1, H, W]
# CoreML-friendly post-processing (no torch.clamp)
depth = torch.where(depth < 0.0, torch.zeros_like(depth), depth)
depth = torch.where(depth > 10.0, torch.full_like(depth, 10.0), depth)
depth = torch.where(torch.isnan(depth), torch.zeros_like(depth), depth)
return depth
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_weights(model, weight_path):
state_dict = torch.load(weight_path, map_location="cpu", weights_only=False)
cleaned = {}
for k, v in state_dict.items():
if k.startswith("module."):
k = k[len("module."):]
if not isinstance(v, torch.Tensor):
continue
cleaned[k] = v
model_keys = set(model.state_dict().keys())
matched = {k: v for k, v in cleaned.items() if k in model_keys}
unmatched = set(cleaned.keys()) - model_keys
if unmatched:
print(f" [info] Skipping {len(unmatched)} unmatched weight keys")
if not matched:
print(" [error] No weights matched!")
sys.exit(1)
model.load_state_dict(matched, strict=False)
print(f" Loaded {len(matched)} weight tensors from {weight_path}")
def prepare_image(image_path, height, width):
"""Load and resize image, return PIL Image + float32 tensor."""
img = Image.open(image_path).convert("RGB")
img_resized = img.resize((width, height), Image.LANCZOS)
img_np = np.array(img_resized).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
return img_tensor, img_resized
def validate_trace(original, traced, example):
original.eval()
traced.eval()
with torch.no_grad():
out_orig = original(example)
out_traced = traced(example)
max_diff = (out_orig - out_traced).abs().max().item()
has_nan = torch.isnan(out_traced).any().item()
print(f" Trace validation: max_diff={max_diff:.2e}, has_nan={has_nan}")
if has_nan:
print(" [error] Traced output contains NaN!")
sys.exit(1)
return max_diff
def run_pytorch_inference(model, img_tensor):
"""Run PyTorch inference (no normalization wrapper — raw model)."""
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
model.eval()
with torch.no_grad():
x = (img_tensor - IMAGENET_MEAN) / IMAGENET_STD
output = model(x)
depth = output["pred_depth"].squeeze().cpu().numpy()
# Same post-processing as CoreML wrapper
depth = np.where(depth < 0.0, 0.0, depth)
depth = np.where(depth > 10.0, 10.0, depth)
depth = np.nan_to_num(depth, nan=0.0, posinf=10.0, neginf=0.0)
return depth
def run_coreml_inference(mlpackage_path, pil_image):
"""Run CoreML inference with PIL Image (ImageType input)."""
model = ct.models.MLModel(mlpackage_path)
start = time.time()
output = model.predict({"image": pil_image})
elapsed = time.time() - start
output_tensor = list(output.values())[0] # [1, 1, H, W]
depth = output_tensor[0, 0, :, :]
return depth, elapsed
def compute_metrics(depth_pt, depth_cl):
if depth_pt.shape != depth_cl.shape:
print(f" [warn] Shape mismatch: PyTorch {depth_pt.shape} vs CoreML {depth_cl.shape}")
from scipy.ndimage import zoom
depth_cl = zoom(depth_cl, (depth_pt.shape[0] / depth_cl.shape[0], depth_pt.shape[1] / depth_cl.shape[1]))
diff = np.abs(depth_pt - depth_cl)
max_diff = diff.max()
mean_diff = diff.mean()
rmse = np.sqrt(np.mean((depth_pt - depth_cl) ** 2))
mask = depth_pt > 1e-6
rel_error = (diff[mask] / depth_pt[mask]).mean() if mask.sum() > 0 else float("nan")
correlation = np.corrcoef(depth_pt.flatten(), depth_cl.flatten())[0, 1]
return {
"max_abs_diff": max_diff,
"mean_abs_diff": mean_diff,
"rmse": rmse,
"mean_rel_error": rel_error,
"correlation": correlation,
}
def save_comparison_viz(depth_pt, depth_cl, metrics, output_dir):
os.makedirs(output_dir, exist_ok=True)
fig, axes = plt.subplots(3, 1, figsize=(6, 12))
vmax = max(depth_pt.max(), depth_cl.max())
im0 = axes[0].imshow(depth_pt, cmap="Spectral", vmin=0, vmax=vmax)
axes[0].set_title(f"PyTorch Depth\n[{depth_pt.min():.4f}, {depth_pt.max():.4f}]")
axes[0].axis("off")
im1 = axes[1].imshow(depth_cl, cmap="Spectral", vmin=0, vmax=vmax)
axes[1].set_title(f"CoreML Depth\n[{depth_cl.min():.4f}, {depth_cl.max():.4f}]")
axes[1].axis("off")
diff = np.abs(depth_pt - depth_cl)
im2 = axes[2].imshow(diff, cmap="hot")
axes[2].set_title(f"Abs Diff\nmax={diff.max():.6f}, mean={diff.mean():.6f}")
axes[2].axis("off")
plt.colorbar(im0, ax=axes[0], fraction=0.046)
plt.colorbar(im1, ax=axes[1], fraction=0.046)
plt.colorbar(im2, ax=axes[2], fraction=0.046)
plt.tight_layout()
viz_path = os.path.join(output_dir, "comparison.png")
plt.savefig(viz_path, dpi=150)
plt.close()
print(f" Saved visualization to {viz_path}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = ArgumentParser(description="Export DAP to CoreML and validate against PyTorch")
parser.add_argument("--image", default=os.path.join(os.path.dirname(__file__), "test", "test.png"))
parser.add_argument("--height", type=int, default=768)
parser.add_argument("--width", type=int, default=1536)
parser.add_argument("--model_type", choices=["vits", "vitb", "vitl", "vitg"], default="vitl")
parser.add_argument("--weights", default=os.path.join(os.path.dirname(__file__), "model.pth"))
parser.add_argument("--output", default=os.path.join(os.path.dirname(__file__), "DAPModel.mlpackage"))
parser.add_argument("--results", default=os.path.join(os.path.dirname(__file__), "test_output"))
parser.add_argument("--threshold", type=float, default=0.05)
parser.add_argument("--skip_export", action="store_true", help="Skip CoreML export, only validate")
args = parser.parse_args()
patch_size = 16
if args.height % patch_size != 0 or args.width % patch_size != 0:
print(f" [error] Dimensions must be multiples of {patch_size}. Got {args.height}x{args.width}")
sys.exit(1)
# =========================================================================
# 1. Load image
# =========================================================================
print("=" * 60)
print("DAP CoreML Export + Validation (single output, ImageType)")
print("=" * 60)
print(f"\n[1/5] Loading test image: {args.image}")
if not os.path.exists(args.image):
print(f" [error] Image not found: {args.image}")
sys.exit(1)
img_tensor, pil_image = prepare_image(args.image, args.height, args.width)
print(f" Image resized to {args.width}x{args.height}")
# =========================================================================
# 2. Build model + PyTorch inference
# =========================================================================
print(f"\n[2/5] Building DAP model ({args.model_type}) + PyTorch inference ...")
model = make_model(midas_model_type=args.model_type)
load_weights(model, args.weights)
model.eval()
pt_start = time.time()
depth_pt = run_pytorch_inference(model, img_tensor)
pt_time = time.time() - pt_start
print(f" PyTorch time: {pt_time*1000:.1f}ms")
print(f" Depth: {depth_pt.shape}, range=[{depth_pt.min():.4f}, {depth_pt.max():.4f}]")
# Save ground truth
os.makedirs(args.results, exist_ok=True)
np.save(os.path.join(args.results, "pytorch_depth.npy"), depth_pt)
print(f" Ground truth saved to {args.results}/pytorch_depth.npy")
# =========================================================================
# 3. Export CoreML (single output, ImageType)
# =========================================================================
if not args.skip_export:
print(f"\n[3/5] Exporting CoreML model (ImageType input) ...")
wrapped = DAPSingleOutputWrapper(model)
wrapped.eval()
# Example input for tracing: float32 tensor in [0,1]
example_input = torch.rand(1, 3, args.height, args.width)
traced = torch.jit.trace(wrapped, example_input)
validate_trace(wrapped, traced, example_input)
print(" Converting to CoreML (this may take a few minutes) ...")
image_input = ct.ImageType(
name="image",
shape=(1, 3, args.height, args.width),
scale=1 / 255.0,
bias=[0.0, 0.0, 0.0],
color_layout=ct.colorlayout.RGB,
channel_first=True,
)
mlmodel = ct.convert(
traced,
inputs=[image_input],
outputs=[ct.TensorType(name="depth", dtype=np.float32)],
minimum_deployment_target=ct.target.iOS18,
compute_precision=ct.precision.FLOAT32,
compute_units=ct.ComputeUnit.ALL,
)
mlmodel.save(args.output)
total_size = sum(
os.path.getsize(os.path.join(dp, f))
for dp, _, fnames in os.walk(args.output)
for f in fnames
)
print(f" Saved to {args.output} ({total_size / (1024**2):.0f} MB)")
else:
print("\n[3/5] Skipping CoreML export (--skip_export)")
# =========================================================================
# 4. Validate CoreML
# =========================================================================
print(f"\n[4/5] Validating CoreML model ...")
if os.path.exists(args.output):
depth_cl, cl_time = run_coreml_inference(args.output, pil_image)
print(f" CoreML time: {cl_time*1000:.1f}ms")
print(f" Depth: {depth_cl.shape}, range=[{depth_cl.min():.4f}, {depth_cl.max():.4f}]")
metrics = compute_metrics(depth_pt, depth_cl)
passed = metrics["max_abs_diff"] < args.threshold
print(f" Max Abs Diff: {metrics['max_abs_diff']:.2e}")
print(f" Mean Abs Diff: {metrics['mean_abs_diff']:.2e}")
print(f" RMSE: {metrics['rmse']:.2e}")
print(f" Rel Error: {metrics['mean_rel_error']:.2e}")
print(f" Correlation: {metrics['correlation']:.6f}")
print(f" {'PASS' if passed else 'FAIL'} (threshold: {args.threshold})")
save_comparison_viz(depth_pt, depth_cl, metrics, args.results)
else:
print(f" [warn] Model not found: {args.output}")
metrics = None
passed = False
# =========================================================================
# 5. Summary
# =========================================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f" Model: DAP {args.model_type}")
print(f" Input: {args.width}x{args.height} (ImageType)")
print(f" Image: {args.image}")
print(f" Threshold: {args.threshold}")
print("-" * 60)
if metrics:
status = "PASS" if passed else "FAIL"
print(f" Validation: {status} (max_diff={metrics['max_abs_diff']:.2e})")
print("-" * 60)
print(f" PyTorch: {pt_time*1000:.1f}ms")
if 'cl_time' in dir():
print(f" CoreML: {cl_time*1000:.1f}ms")
print(f" Results: {args.results}/")
print("=" * 60)
# Exit non-zero if validation failed
if metrics and not passed:
sys.exit(1)
if __name__ == "__main__":
main()