WJAD / scripts /smoke_train.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
"""端到端训练循环烟囱测试:构造随机 batch,跑 1-2 步 trainer。
不依赖磁盘上的数据集,仅验证 forward/backward/loss/PCGrad/GradNorm 链路。
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "src"))
import logging
import numpy as np
import torch
from wjad.model import E2EAVModel
from wjad.train.trainer import Trainer, TrainerConfig
def _make_dummy_batch(
B: int = 1,
T: int = 8,
H: int = 64,
W: int = 128,
num_classes: int = 22,
num_objects: int = 3,
) -> dict:
"""构造极小分辨率的随机 batch(CPU 烟囱测试用)。"""
images = torch.randn(B, T, 3, H, W)
ego_6d = torch.zeros(B, T, 6)
intr_vec = torch.tensor([[
W / 2, H / 2, W, H,
0.0, 0.5, 0.0, 0.0, 0.0, 0.0,
1.0,
]] * B)
extr_6d = torch.zeros(B, 6)
ego_future = torch.zeros(B, 24, 3)
ego_future_valid = torch.ones(B, 24, dtype=torch.bool)
targets = []
for _ in range(B):
boxes = torch.zeros(num_objects, 7)
boxes[:, 3:6] = 2.0
targets.append({
"labels": torch.randint(1, num_classes, (num_objects,)),
"boxes": boxes,
"is_dynamic": torch.ones(num_objects, dtype=torch.long),
"future_traj": torch.zeros(num_objects, 24, 3),
"future_valid": torch.ones(num_objects, 24, dtype=torch.bool),
})
return {
"images": images,
"ego_6d": ego_6d,
"intr_vec": intr_vec,
"extr_6d": extr_6d,
"ego_future": ego_future,
"ego_future_valid": ego_future_valid,
"targets": targets,
"meta": [{}] * B,
}
def main() -> None:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
torch.manual_seed(0)
has_cuda = torch.cuda.is_available()
device = "cuda" if has_cuda else "cpu"
if has_cuda:
# GPU 上跑接近真实规模:full 384x1024 + 完整 18 层
# a10g-small (~22 GB) 上 BS=4 OOM,启用 gradient_checkpointing 后 BS=2 稳定
H, W = 384, 1024
B = 2
num_dense, num_moe = 9, 9
num_det = 1024
num_extra = 256
amp = "bf16"
n_steps = 4
use_grad_ckpt = True
else:
# CPU 上跑极小规模仅做 sanity
H, W = 64, 128
B = 1
num_dense, num_moe = 2, 2
num_det = 32
num_extra = 16
amp = "fp32"
n_steps = 4
use_grad_ckpt = False
print(f"[smoke_train] device={device}, H={H} W={W} B={B} amp={amp} grad_ckpt={use_grad_ckpt}")
model = E2EAVModel(
dinov3_path=str(ROOT / "dinov3-vitb16-pretrain-lvd1689m"),
num_dense_layers=num_dense,
num_moe_layers=num_moe,
num_detection_tokens=num_det,
num_control_tokens=24,
num_ego_tokens=8,
num_extra_tokens=num_extra,
num_classes=22,
image_h=H,
image_w=W,
patch_size=16,
)
if use_grad_ckpt:
model.backbone.set_gradient_checkpointing(True)
# sandbox a10g-small 不做 DINOv3 finetune(显存预算 22GB 不够),冻结即可
# 验证两阶段路径切换。完整训练交给 H100 Jobs。
model.dinov3.freeze()
cfg = TrainerConfig(
total_steps=n_steps,
warmup_steps=1,
base_lr=1e-4,
log_interval=1,
stage1_steps=2, # 跑到 stage2 验证切换路径
stage1_perturb_start=1,
enable_gradnorm=True,
enable_pcgrad=True, # 全程启用 PCGrad
mixed_precision=amp,
unfreeze_dinov3_at_stage2=False, # sandbox 显存有限,验证路径即可
)
trainer = Trainer(model, cfg, num_classes=22, device=device)
rng = np.random.default_rng(0)
if has_cuda:
torch.cuda.reset_peak_memory_stats()
for step in range(n_steps):
batch = _make_dummy_batch(B=B, H=H, W=W)
info = trainer.train_step(batch, rng)
print(
f"step={info['step']} stage={info['stage']} total={info['total_loss']:.4f} "
f"cls={info['L_cls']:.4f} box={info['L_box']:.4f} traj_obj={info['L_traj_obj']:.4f} "
f"weights={[f'{w:.2f}' for w in info['weights']]}"
)
if has_cuda:
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"[smoke_train] CUDA peak memory = {peak_gb:.2f} GB")
print("[smoke_train] OK")
if __name__ == "__main__":
main()