LiDAR-Perfect-Depth / code /ppd /lpd /sparse_simulator.py
chenming-wu's picture
code
436b829 verified
"""
Simulate sparse-LiDAR observations from dense ground-truth depth.
Patterns: random / scan-line / grid / hybrid. Used during training so the prompt
encoder sees realistic sparsity. Simulation runs on tensors so it can sit
inside the data loader or the training step.
"""
from __future__ import annotations
import math
import torch
def _validate(depth: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert depth.dim() == 4 and mask.dim() == 4, "expect (B,1,H,W)"
assert depth.shape == mask.shape
return depth, mask
def random_pattern(mask: torch.Tensor, density: float, generator: torch.Generator | None = None) -> torch.Tensor:
"""Bernoulli sparse mask with the given fraction of valid pixels kept."""
keep = torch.rand(mask.shape, generator=generator, device=mask.device) < density
return mask & keep
def scan_line_pattern(
mask: torch.Tensor,
n_lines: int,
line_density: float,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Velodyne-like horizontal scan lines.
Picks `n_lines` distinct row indices and within each row keeps a Bernoulli
fraction `line_density` of points.
"""
B, _, H, W = mask.shape
out = torch.zeros_like(mask)
for b in range(B):
n = max(1, min(n_lines, H))
rows = torch.randperm(H, generator=generator, device=mask.device)[:n]
line_mask = torch.zeros((H, W), dtype=torch.bool, device=mask.device)
line_mask[rows] = torch.rand((n, W), generator=generator, device=mask.device) < line_density
out[b, 0] = line_mask
return mask & out
def grid_pattern(mask: torch.Tensor, stride: int) -> torch.Tensor:
"""Regular grid: keep every `stride` pixel along each axis."""
B, _, H, W = mask.shape
out = torch.zeros_like(mask)
out[:, :, ::stride, ::stride] = True
return mask & out
def hybrid_pattern(
mask: torch.Tensor,
density: float,
n_lines: int,
line_density: float,
grid_stride: int,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Union of random + scan-line + grid (some real LiDARs fall in the middle)."""
a = random_pattern(mask, density, generator=generator)
b = scan_line_pattern(mask, n_lines, line_density, generator=generator)
c = grid_pattern(mask, grid_stride)
return a | b | c
def simulate(
depth: torch.Tensor,
mask: torch.Tensor,
pattern: str = "hybrid",
*,
density: float = 0.005,
n_lines: int = 64,
line_density: float = 0.5,
grid_stride: int = 32,
min_points: int = 16,
max_attempts: int = 4,
measurement_noise_std: float = 0.0,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (sparse_depth, sparse_mask) for a batch of dense GT.
`min_points` guarantees at least that many observed pixels per sample
(re-samples with looser params if not). Output sparse_depth is zero where
mask is false.
"""
depth, mask = _validate(depth, mask)
B = depth.shape[0]
out_mask = torch.zeros_like(mask)
for attempt in range(max_attempts):
if pattern == "random":
cur = random_pattern(mask, density, generator=generator)
elif pattern == "scan_line":
cur = scan_line_pattern(mask, n_lines, line_density, generator=generator)
elif pattern == "grid":
cur = grid_pattern(mask, grid_stride)
elif pattern == "hybrid":
cur = hybrid_pattern(
mask, density, n_lines, line_density, grid_stride, generator=generator
)
else:
raise ValueError(f"unknown pattern: {pattern}")
# update samples that already cleared the threshold
per_sample = cur.flatten(1).sum(dim=1)
ok = per_sample >= min_points
out_mask = torch.where(ok[:, None, None, None], cur, out_mask)
if ok.all():
break
# loosen for next attempt: double density, grid step halves
density = min(density * 2.0, 0.5)
n_lines = min(n_lines * 2, depth.shape[-2])
line_density = min(line_density * 1.5, 1.0)
grid_stride = max(grid_stride // 2, 2)
sparse_depth = depth * out_mask.float()
if measurement_noise_std > 0.0:
noise = torch.randn(sparse_depth.shape, generator=generator, device=sparse_depth.device)
sparse_depth = sparse_depth + noise * measurement_noise_std * out_mask.float()
sparse_depth = sparse_depth.clamp_min(0.0)
return sparse_depth, out_mask
def random_pattern_choice(rng: torch.Generator | None = None) -> str:
"""Sample a pattern name uniformly. Used by the trainer to mix patterns per-step."""
options = ["random", "scan_line", "grid", "hybrid"]
if rng is None:
idx = int(torch.randint(0, len(options), (1,)).item())
else:
idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
return options[idx]