NYU DS-GA-3001 โ€” World Models (v3, velocity actions)

Trained world-model checkpoints for the lquan9/collect1 drone trajectory dataset. Eight configurations across two encoders (CNN / DINOv2-ViT-B/14) ร— two dynamics (MLP / RSSM) ร— two training splits (filtered _c1 / unfiltered _c1u).

All checkpoints use velocity actions (6-dim: vx, vy, vz, wx, wy, wz) and odom with per-episode position deltas from the start frame.

Trajectory Evaluation (ADE / FDE, meters)

experiment train (min) ADE@1.6s FDE@1.6s ADE@3.2s FDE@3.2s ADE@5s FDE@5s
cnn_mlp_c1 29.9 0.978 1.190 1.333 2.188 1.874 3.473
cnn_mlp_c1u 42.7 0.850 1.012 1.082 1.620 1.410 2.367
cnn_rssm_c1 30.2 1.179 1.354 1.457 2.111 1.880 3.172
cnn_rssm_c1u 43.2 0.860 1.028 1.105 1.682 1.457 2.464
dino_mlp_c1 137.3 1.051 1.250 1.337 2.002 1.761 3.038
dino_mlp_c1u 224.4 1.010 1.153 1.216 1.703 1.521 2.420
dino_rssm_c1 137.6 1.056 1.235 1.345 2.027 1.765 2.998
dino_rssm_c1u 227.4 1.054 1.185 1.240 1.684 1.522 2.348

ADE = Average Displacement Error over the rollout horizon. FDE = Final Displacement Error at the last timestep. Both are reported in meters on the held-out validation split.

Files per experiment

file purpose
latest.pt final weights + model_config + training args (epoch 99)
norm_stats.pt odom/imu/action mean+std
eval_metrics.json training metadata (train time, etc.)
trajectory_eval.json ADE/FDE per rollout horizon

Loading

import torch
from huggingface_hub import hf_hub_download
from world_model.architecture import WorldModel
from world_model.data.normalization import NormStats

exp = "dino_rssm_c1"
ckpt_path = hf_hub_download(repo_id="dmody1/nyu-wm-v3", filename=f"{exp}/latest.pt")
ns_path   = hf_hub_download(repo_id="dmody1/nyu-wm-v3", filename=f"{exp}/norm_stats.pt")

ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
ns   = NormStats.load(ns_path)
mcfg, targs = ckpt["model_config"], ckpt["args"]

model = WorldModel(
    odom_dim=mcfg["odom_dim"], imu_dim=mcfg["imu_dim"], act_dim=mcfg["act_dim"],
    state_dim=mcfg["state_dim"],
    backbone=targs.get("backbone", "dinov2"),
    dynamics_type=targs.get("dynamics_type", "rssm"),
)
model.load_state_dict(ckpt["model"])
model.eval()

See cell 35 of docs/notebooks/world_model_pipeline.ipynb for MPC inference (CEM planning with ฮ”-state cost).

Dataset

Training data: lquan9/collect1. Train/val split: episodes with eid % 5 == 0 are validation, the rest are training.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading