NYU DS-GA-3001 โ World Models (v3, velocity actions)
Trained world-model checkpoints for the lquan9/collect1 drone trajectory dataset.
Eight configurations across two encoders (CNN / DINOv2-ViT-B/14) ร two dynamics
(MLP / RSSM) ร two training splits (filtered _c1 / unfiltered _c1u).
All checkpoints use velocity actions (6-dim: vx, vy, vz, wx, wy, wz) and odom with
per-episode position deltas from the start frame.
Trajectory Evaluation (ADE / FDE, meters)
| experiment | train (min) | ADE@1.6s | FDE@1.6s | ADE@3.2s | FDE@3.2s | ADE@5s | FDE@5s |
|---|---|---|---|---|---|---|---|
cnn_mlp_c1 |
29.9 | 0.978 | 1.190 | 1.333 | 2.188 | 1.874 | 3.473 |
cnn_mlp_c1u |
42.7 | 0.850 | 1.012 | 1.082 | 1.620 | 1.410 | 2.367 |
cnn_rssm_c1 |
30.2 | 1.179 | 1.354 | 1.457 | 2.111 | 1.880 | 3.172 |
cnn_rssm_c1u |
43.2 | 0.860 | 1.028 | 1.105 | 1.682 | 1.457 | 2.464 |
dino_mlp_c1 |
137.3 | 1.051 | 1.250 | 1.337 | 2.002 | 1.761 | 3.038 |
dino_mlp_c1u |
224.4 | 1.010 | 1.153 | 1.216 | 1.703 | 1.521 | 2.420 |
dino_rssm_c1 |
137.6 | 1.056 | 1.235 | 1.345 | 2.027 | 1.765 | 2.998 |
dino_rssm_c1u |
227.4 | 1.054 | 1.185 | 1.240 | 1.684 | 1.522 | 2.348 |
ADE = Average Displacement Error over the rollout horizon. FDE = Final Displacement Error at the last timestep. Both are reported in meters on the held-out validation split.
Files per experiment
| file | purpose |
|---|---|
latest.pt |
final weights + model_config + training args (epoch 99) |
norm_stats.pt |
odom/imu/action mean+std |
eval_metrics.json |
training metadata (train time, etc.) |
trajectory_eval.json |
ADE/FDE per rollout horizon |
Loading
import torch
from huggingface_hub import hf_hub_download
from world_model.architecture import WorldModel
from world_model.data.normalization import NormStats
exp = "dino_rssm_c1"
ckpt_path = hf_hub_download(repo_id="dmody1/nyu-wm-v3", filename=f"{exp}/latest.pt")
ns_path = hf_hub_download(repo_id="dmody1/nyu-wm-v3", filename=f"{exp}/norm_stats.pt")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
ns = NormStats.load(ns_path)
mcfg, targs = ckpt["model_config"], ckpt["args"]
model = WorldModel(
odom_dim=mcfg["odom_dim"], imu_dim=mcfg["imu_dim"], act_dim=mcfg["act_dim"],
state_dim=mcfg["state_dim"],
backbone=targs.get("backbone", "dinov2"),
dynamics_type=targs.get("dynamics_type", "rssm"),
)
model.load_state_dict(ckpt["model"])
model.eval()
See cell 35 of docs/notebooks/world_model_pipeline.ipynb
for MPC inference (CEM planning with ฮ-state cost).
Dataset
Training data: lquan9/collect1.
Train/val split: episodes with eid % 5 == 0 are validation, the rest are training.