LiDAR-Perfect-Depth / code /ppd /lpd /prompt_encoder.py
chenming-wu's picture
code
436b829 verified
"""
Sparse-LiDAR prompt encoder.
Per-pixel sparse depth (B,1,H,W) + binary mask (B,1,H,W) are pooled to multiple
scales via masked average pooling. At each scale we keep both the pooled depth
and the *density* (fraction of observed pixels per cell) — paper §3.1 calls
this the per-token confidence signal that drives the prompt gate.
The output token grid is sized to match the DiT's stage-2 token grid (H/p, W/p),
which is where prompt fusion happens.
"""
from __future__ import annotations
import math
from typing import Iterable
import torch
import torch.nn as nn
import torch.nn.functional as F
def masked_avg_pool(depth: torch.Tensor, mask: torch.Tensor, kernel: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (pooled_depth, density). `mask` is bool/0-1. Both inputs (B,1,H,W)."""
m = mask.float()
summed = F.avg_pool2d(depth * m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
count = F.avg_pool2d(m, kernel_size=kernel, stride=kernel, ceil_mode=False) * (kernel * kernel)
pooled = summed / count.clamp_min(1.0)
density = count / (kernel * kernel)
return pooled, density
def quantile_log_normalize(depth: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Per-sample 2/98 quantile log-depth normalization, matches PPD's GT scheme.
Returns normalized depth in roughly [-0.5, 0.5]. Pixels with mask == 0 are
set to 0 so they look like "no observation" downstream.
"""
out = torch.zeros_like(depth)
B = depth.shape[0]
log_depth = torch.log(depth.clamp_min(0.0) + 1.0)
for i in range(B):
m = mask[i].bool()
if m.sum() == 0:
continue
vals = log_depth[i][m]
d_min = torch.quantile(vals, 0.02)
d_max = torch.quantile(vals, 0.98)
if (d_max - d_min) < 1e-6:
d_max = d_min + 1e-6
norm = (log_depth[i] - d_min) / (d_max - d_min) - 0.5
norm = torch.clamp(norm, -0.5, 1.0)
out[i] = norm * m.float()
return out
class _SmallCNN(nn.Module):
def __init__(self, in_ch: int, hidden: int, out_ch: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv2d(hidden, out_ch, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class SparsePromptEncoder(nn.Module):
"""Multi-scale sparse-prompt encoder.
Args
----
scales : pool kernels (in pixels). Paper §3.1 uses {4, 8, 16, 32} — kernel=4
gives sub-token granularity (4×4 pixels per cell), kernel=32 gives
global context. All scales are bilinearly resampled to the DiT
stage-2 token grid before fusion.
embed_dim : output token embedding dim (matches the DiT's hidden_size).
out_grid_div : the model fuses prompts at the stage-2 grid which is H/p2,
W/p2 with p2 = 8 by default.
"""
def __init__(
self,
scales: Iterable[int] = (4, 8, 16, 32),
embed_dim: int = 1024,
out_grid_div: int = 8,
hidden: int = 128,
):
super().__init__()
self.scales = tuple(scales)
self.embed_dim = embed_dim
self.out_grid_div = out_grid_div
# 2 channels per scale (depth + density) → CNN → embed_dim
self.per_scale = nn.ModuleList(
[_SmallCNN(2, hidden, embed_dim) for _ in self.scales]
)
# final mixer over concatenated multi-scale features
self.fuse = nn.Linear(embed_dim * len(self.scales), embed_dim)
# zero-init the final projection so untrained model behaves like PPD
nn.init.zeros_(self.fuse.weight)
nn.init.zeros_(self.fuse.bias)
def forward(
self, sparse_depth: torch.Tensor, sparse_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (tokens, density_per_token).
tokens: (B, T, embed_dim)
density_per_token: (B, T, 1) — averaged density across scales, used by
the prompt gate as a confidence weight.
"""
# Normalize sparse depth once at the input resolution.
norm_depth = quantile_log_normalize(sparse_depth, sparse_mask)
B, _, H, W = sparse_depth.shape
out_h, out_w = H // self.out_grid_div, W // self.out_grid_div
feats: list[torch.Tensor] = []
densities: list[torch.Tensor] = []
for cnn, k in zip(self.per_scale, self.scales):
pooled, density = masked_avg_pool(norm_depth, sparse_mask, kernel=k)
x = torch.cat([pooled, density], dim=1)
x = cnn(x)
x = F.interpolate(x, size=(out_h, out_w), mode="bilinear", align_corners=False)
d = F.interpolate(density, size=(out_h, out_w), mode="bilinear", align_corners=False)
feats.append(x)
densities.append(d)
x = torch.cat(feats, dim=1) # (B, embed_dim*len(scales), out_h, out_w)
x = x.flatten(2).transpose(1, 2) # (B, T, embed_dim*len(scales))
x = self.fuse(x) # (B, T, embed_dim)
density = torch.stack(densities, dim=0).mean(dim=0) # (B,1,out_h,out_w)
density = density.flatten(2).transpose(1, 2) # (B, T, 1)
return x, density