WJAD / src /wjad /model.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""端到端自动驾驶模型 E2EAVModel。
forward 流程
1. ``DINOv3`` 提取 8 帧 patch 特征。
2. ``OnlineCalibration`` 用原始 ego/intr/extr (symlog) + DINOv3 patch 作 KV,
输出 symlog 空间残差,叠加并 symexp 还原得到 corrected_*。
3. 用 corrected_intr / corrected_extr / corrected_ego 计算
- 每 token 的自车系单位射线(仅用于视觉 token 的 RoPE 第一组头)。
- 8 个 ego token(symlog 后线性投影)。
4. 2×2×2 时空压缩 -> 1536 视觉 token。
5. 拼接 [vision(1536) | ego(8) | det(1024) | ctrl(24) | extra(256)] = 2848 token。
非视觉切片各自加可学习 PE。
6. 18 层主干(仅视觉切片应用 3D RoPE)。
7. 切片送入 ``DetectionTrajHead`` 与 ``ControlHead``。
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import Backbone, BackboneOutput
from .calibration import OnlineCalibration, CalibrationOutput
from .encoders import DINOv3Wrapper
from .heads import (
ControlHead,
ControlOutput,
DetectionTrajHead,
DetectionTrajOutput,
)
from .modules.learned_pe import LearnedTokenPE
from .modules.normalization import symlog
from .modules.pos_encoding import RoPE3D
from .modules.rays import compute_ego_rays
from .modules.temporal_compress import TemporalCompress2x2x2
@dataclass
class E2EOutput:
"""模型完整输出。"""
detection: DetectionTrajOutput
control: ControlOutput
backbone_out: BackboneOutput
calibration: CalibrationOutput
class E2EAVModel(nn.Module):
def __init__(
self,
dinov3_path: str = "./dinov3-vitb16-pretrain-lvd1689m",
backbone_dim: int = 768,
num_heads: int = 12,
num_dense_layers: int = 9,
num_moe_layers: int = 9,
num_routed_experts: int = 7,
num_shared_experts: int = 1,
topk_experts: int = 3,
ffn_mult: int = 4,
# token 数量
num_history_frames: int = 8,
num_detection_tokens: int = 1024,
num_control_tokens: int = 24,
num_ego_tokens: int = 8,
num_extra_tokens: int = 256,
# 输入分辨率
image_h: int = 384,
image_w: int = 1024,
patch_size: int = 16,
# 头超参
num_classes: int = 22,
traj_horizon: int = 24,
det_head_hidden: int = 384,
ctrl_head_hidden: int = 384,
# 校准
calib_dim: int = 256,
calib_num_query: int = 256,
calib_num_blocks: int = 2,
calib_num_self_per_block: int = 2,
calib_num_heads: int = 8,
calib_residual_range: float = 0.1,
calib_intr_dim: int = 11,
# DINOv3
freeze_dinov3: bool = True,
attn_implementation: str = "sdpa",
) -> None:
super().__init__()
self.image_h = image_h
self.image_w = image_w
self.patch_size = patch_size
self.num_history = num_history_frames
self.num_det = num_detection_tokens
self.num_ctrl = num_control_tokens
self.num_ego = num_ego_tokens
self.num_extra = num_extra_tokens
# === 1) DINOv3 ===
self.dinov3 = DINOv3Wrapper(
pretrained_path=dinov3_path,
attn_implementation=attn_implementation,
freeze=freeze_dinov3,
)
dino_dim = self.dinov3.hidden_size
# === 2) 在线校准 ===
self.calib = OnlineCalibration(
dino_dim=dino_dim,
hidden_dim=calib_dim,
num_query_tokens=calib_num_query,
num_blocks=calib_num_blocks,
num_self_attn_per_block=calib_num_self_per_block,
num_heads=calib_num_heads,
residual_range=calib_residual_range,
num_history_frames=num_history_frames,
intr_dim=calib_intr_dim,
)
# === 3) 时空压缩 ===
self.compress = TemporalCompress2x2x2(dim=dino_dim)
# patch 网格大小(必须能被 2 整除)
self.gh = image_h // patch_size
self.gw = image_w // patch_size
# === 4) 各类 token + 可学习 PE ===
self.ego_proj = nn.Linear(6, backbone_dim) # 6D pose -> backbone dim
self.det_tokens = nn.Parameter(torch.empty(num_detection_tokens, backbone_dim))
nn.init.trunc_normal_(self.det_tokens, std=0.02)
self.ctrl_tokens = nn.Parameter(torch.empty(num_control_tokens, backbone_dim))
nn.init.trunc_normal_(self.ctrl_tokens, std=0.02)
self.extra_tokens = nn.Parameter(torch.empty(num_extra_tokens, backbone_dim))
nn.init.trunc_normal_(self.extra_tokens, std=0.02)
self.ego_pe = LearnedTokenPE(num_ego_tokens, backbone_dim)
self.det_pe = LearnedTokenPE(num_detection_tokens, backbone_dim)
self.ctrl_pe = LearnedTokenPE(num_control_tokens, backbone_dim)
self.extra_pe = LearnedTokenPE(num_extra_tokens, backbone_dim)
# === 5) RoPE 3D(仅视觉,4 时间帧 × 12 × 32 网格)===
self.rope = RoPE3D(
num_heads=num_heads,
head_dim=backbone_dim // num_heads,
time_size=num_history_frames // 2,
height_size=self.gh // 2,
width_size=self.gw // 2,
)
# === 6) 主干 18 层 ===
self.backbone = Backbone(
dim=backbone_dim,
num_heads=num_heads,
ffn_mult=ffn_mult,
num_dense_layers=num_dense_layers,
num_moe_layers=num_moe_layers,
num_routed=num_routed_experts,
num_shared=num_shared_experts,
topk=topk_experts,
)
# === 7) 头 ===
self.det_traj_head = DetectionTrajHead(
in_dim=backbone_dim,
hidden_size=det_head_hidden,
num_classes=num_classes,
traj_horizon=traj_horizon,
)
self.ctrl_head = ControlHead(
in_dim=backbone_dim,
hidden_size=ctrl_head_hidden,
num_traj_tokens=12,
num_action_tokens=num_control_tokens - 12,
ego_traj_horizon=traj_horizon,
)
# ---------- 工具 ----------
@property
def num_visual_tokens(self) -> int:
# 2×2×2 压缩后
return (self.num_history // 2) * (self.gh // 2) * (self.gw // 2)
def _build_ego_tokens(self, ego_6d_corrected: torch.Tensor) -> torch.Tensor:
"""``[B, 8, 6]`` -> symlog -> Linear -> ``[B, 8, D]``。"""
return self.ego_proj(symlog(ego_6d_corrected))
def _build_visual_rays(
self,
intr_corrected: torch.Tensor, # [B, calib_intr_dim]
extr_corrected_se3: torch.Tensor, # [B, 4, 4] cam2vehicle
compressed_thw: tuple[int, int, int],
) -> torch.Tensor:
"""计算压缩后视觉 token 网格的射线方向。
在 2×2×2 压缩后,每个视觉 token 对应原 patch 网格的一个 2x2 区域 +
2 个时间帧。这里取所代表区域的中心像素与"中间时间"的射线作近似,
所有时间帧取同一个 (h, w) 上的射线(因为相机 pose 在 8 帧间是
rigid 的相机系;自车运动差异会通过 ego token 传递)。
"""
b = intr_corrected.shape[0]
t_, h_, w_ = compressed_thw
rays_grid = compute_ego_rays(
intr_vec=intr_corrected,
cam2vehicle=extr_corrected_se3,
height=self.image_h,
width=self.image_w,
grid_h=h_,
grid_w=w_,
device=intr_corrected.device,
dtype=intr_corrected.dtype,
) # [B, h_, w_, 3]
# 复制到时间维:[B, T_, h_, w_, 3] -> flatten 为 [B, N_v, 3]
rays = rays_grid.unsqueeze(1).expand(-1, t_, -1, -1, -1).contiguous()
rays = rays.reshape(b, t_ * h_ * w_, 3)
return rays
# ---------- 前向 ----------
def forward(
self,
images: torch.Tensor, # [B, T=8, 3, H, W]
ego_6d_raw: torch.Tensor, # [B, 8, 6]
intr_raw: torch.Tensor, # [B, calib_intr_dim],须与构造时一致
extr_6d_raw: torch.Tensor, # [B, 6]
) -> E2EOutput:
b, t, _, h, w = images.shape
assert t == self.num_history, f"history frames mismatch: {t} vs {self.num_history}"
# 1) DINOv3 patch tokens [B, T, gh, gw, D_dino]
dino_feats = self.dinov3(images)
# 2) 校准(symlog 空间残差 + symexp 还原)
calib_out: CalibrationOutput = self.calib(
dino_feats=dino_feats,
ego_raw=ego_6d_raw,
intr_raw=intr_raw,
extr_raw=extr_6d_raw,
)
corrected_ego = calib_out.corrected_ego
corrected_intr = calib_out.corrected_intr
corrected_extr_6d = calib_out.corrected_extr
# 3) 把 corrected_extr 6D 转成 4x4
from .data.se3 import six_d_to_matrix
cam2veh_corrected = six_d_to_matrix(corrected_extr_6d) # [B, 4, 4]
# 4) 2x2x2 时空压缩
compressed, thw = self.compress(dino_feats) # [B, N_v, D]
n_v = compressed.shape[1]
# 5) 视觉射线(用 corrected_intr / corrected_extr)
rays = self._build_visual_rays(corrected_intr, cam2veh_corrected, thw)
rope_cos, rope_sin = self.rope.compute_freqs(rays)
# 6) 构造非视觉 token
ego_tok = self._build_ego_tokens(corrected_ego) # [B, 8, D]
det_tok = self.det_tokens.unsqueeze(0).expand(b, -1, -1)
ctrl_tok = self.ctrl_tokens.unsqueeze(0).expand(b, -1, -1)
extra_tok = self.extra_tokens.unsqueeze(0).expand(b, -1, -1)
ego_tok = self.ego_pe(ego_tok)
det_tok = self.det_pe(det_tok)
ctrl_tok = self.ctrl_pe(ctrl_tok)
extra_tok = self.extra_pe(extra_tok)
# 7) 拼接序列:[vision | ego | det | ctrl | extra]
seq = torch.cat([compressed, ego_tok, det_tok, ctrl_tok, extra_tok], dim=1)
visual_slice = (0, n_v)
# 8) 主干
bb_out = self.backbone(seq, rope_cos=rope_cos, rope_sin=rope_sin, visual_slice=visual_slice)
# 9) 切片送入头
offset_det = n_v + self.num_ego
offset_ctrl = offset_det + self.num_det
det_feats = bb_out.hidden_states[:, offset_det : offset_det + self.num_det]
ctrl_feats = bb_out.hidden_states[:, offset_ctrl : offset_ctrl + self.num_ctrl]
det_out = self.det_traj_head(det_feats)
ctrl_out = self.ctrl_head(ctrl_feats)
return E2EOutput(
detection=det_out,
control=ctrl_out,
backbone_out=bb_out,
calibration=calib_out,
)