| """ |
| Simulate sparse-LiDAR observations from dense ground-truth depth. |
| |
| Patterns: random / scan-line / grid / hybrid. Used during training so the prompt |
| encoder sees realistic sparsity. Simulation runs on tensors so it can sit |
| inside the data loader or the training step. |
| """ |
| from __future__ import annotations |
|
|
| import math |
| import torch |
|
|
|
|
| def _validate(depth: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| assert depth.dim() == 4 and mask.dim() == 4, "expect (B,1,H,W)" |
| assert depth.shape == mask.shape |
| return depth, mask |
|
|
|
|
| def random_pattern(mask: torch.Tensor, density: float, generator: torch.Generator | None = None) -> torch.Tensor: |
| """Bernoulli sparse mask with the given fraction of valid pixels kept.""" |
| keep = torch.rand(mask.shape, generator=generator, device=mask.device) < density |
| return mask & keep |
|
|
|
|
| def scan_line_pattern( |
| mask: torch.Tensor, |
| n_lines: int, |
| line_density: float, |
| generator: torch.Generator | None = None, |
| ) -> torch.Tensor: |
| """Velodyne-like horizontal scan lines. |
| |
| Picks `n_lines` distinct row indices and within each row keeps a Bernoulli |
| fraction `line_density` of points. |
| """ |
| B, _, H, W = mask.shape |
| out = torch.zeros_like(mask) |
| for b in range(B): |
| n = max(1, min(n_lines, H)) |
| rows = torch.randperm(H, generator=generator, device=mask.device)[:n] |
| line_mask = torch.zeros((H, W), dtype=torch.bool, device=mask.device) |
| line_mask[rows] = torch.rand((n, W), generator=generator, device=mask.device) < line_density |
| out[b, 0] = line_mask |
| return mask & out |
|
|
|
|
| def grid_pattern(mask: torch.Tensor, stride: int) -> torch.Tensor: |
| """Regular grid: keep every `stride` pixel along each axis.""" |
| B, _, H, W = mask.shape |
| out = torch.zeros_like(mask) |
| out[:, :, ::stride, ::stride] = True |
| return mask & out |
|
|
|
|
| def hybrid_pattern( |
| mask: torch.Tensor, |
| density: float, |
| n_lines: int, |
| line_density: float, |
| grid_stride: int, |
| generator: torch.Generator | None = None, |
| ) -> torch.Tensor: |
| """Union of random + scan-line + grid (some real LiDARs fall in the middle).""" |
| a = random_pattern(mask, density, generator=generator) |
| b = scan_line_pattern(mask, n_lines, line_density, generator=generator) |
| c = grid_pattern(mask, grid_stride) |
| return a | b | c |
|
|
|
|
| def simulate( |
| depth: torch.Tensor, |
| mask: torch.Tensor, |
| pattern: str = "hybrid", |
| *, |
| density: float = 0.005, |
| n_lines: int = 64, |
| line_density: float = 0.5, |
| grid_stride: int = 32, |
| min_points: int = 16, |
| max_attempts: int = 4, |
| measurement_noise_std: float = 0.0, |
| generator: torch.Generator | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return (sparse_depth, sparse_mask) for a batch of dense GT. |
| |
| `min_points` guarantees at least that many observed pixels per sample |
| (re-samples with looser params if not). Output sparse_depth is zero where |
| mask is false. |
| """ |
| depth, mask = _validate(depth, mask) |
| B = depth.shape[0] |
|
|
| out_mask = torch.zeros_like(mask) |
| for attempt in range(max_attempts): |
| if pattern == "random": |
| cur = random_pattern(mask, density, generator=generator) |
| elif pattern == "scan_line": |
| cur = scan_line_pattern(mask, n_lines, line_density, generator=generator) |
| elif pattern == "grid": |
| cur = grid_pattern(mask, grid_stride) |
| elif pattern == "hybrid": |
| cur = hybrid_pattern( |
| mask, density, n_lines, line_density, grid_stride, generator=generator |
| ) |
| else: |
| raise ValueError(f"unknown pattern: {pattern}") |
| |
| per_sample = cur.flatten(1).sum(dim=1) |
| ok = per_sample >= min_points |
| out_mask = torch.where(ok[:, None, None, None], cur, out_mask) |
| if ok.all(): |
| break |
| |
| density = min(density * 2.0, 0.5) |
| n_lines = min(n_lines * 2, depth.shape[-2]) |
| line_density = min(line_density * 1.5, 1.0) |
| grid_stride = max(grid_stride // 2, 2) |
|
|
| sparse_depth = depth * out_mask.float() |
| if measurement_noise_std > 0.0: |
| noise = torch.randn(sparse_depth.shape, generator=generator, device=sparse_depth.device) |
| sparse_depth = sparse_depth + noise * measurement_noise_std * out_mask.float() |
| sparse_depth = sparse_depth.clamp_min(0.0) |
| return sparse_depth, out_mask |
|
|
|
|
| def random_pattern_choice(rng: torch.Generator | None = None) -> str: |
| """Sample a pattern name uniformly. Used by the trainer to mix patterns per-step.""" |
| options = ["random", "scan_line", "grid", "hybrid"] |
| if rng is None: |
| idx = int(torch.randint(0, len(options), (1,)).item()) |
| else: |
| idx = int(torch.randint(0, len(options), (1,), generator=rng).item()) |
| return options[idx] |
|
|