PokeDreamer Checkpoints & Demos (v1)

This repository hosts the trained model checkpoints and demonstration videos for PokeDreamer v1, a model-based reinforcement learning project navigating PokΓ©mon Red on the Game Boy.

Model Checkpoints (v1)

The models in v1 are trained on low-resolution $40 \times 36 \times 3$ observations:

1. πŸ–ΌοΈ Variational Autoencoder (VAE)

  • Location: vae/
  • Architecture: CNN bottlenecking observations to a 32-dimensional continuous latent space $z$.
  • Loss Criteria: Reconstruction (BCE) + KL Divergence.

2. πŸ”€ Recurrent Latent Dynamics Model

  • Location: dynamics/
  • Architecture: Autoregressive GRU cell taking $(z_t, a_t)$ and predicting $z_{t+1}$.
  • Training Regime: Scheduled sampling (teacher forcing probability decayed over epochs) to prevent compounding drift in long-horizon rollouts.

3. 🎯 Coordinate RAM Probe

  • Location: probe/
  • Architecture: Linear probing layers mapping continuous latent states $z_t$ to overworld coordinates $(x_t, y_t)$ and map_id values extracted directly from the WRAM.
  • Accuracy: Evaluates spatial alignment of the latents.

πŸŽ₯ Demos

  • side_by_side_demo.mp4: Side-by-side comparison of actual emulator frames (left) and dynamics model reconstructions (right) rolled out autoregressively for 30 steps.
  • planner_navigation_demo.mp4: A 100-step trajectory showing an MPC lookahead search planner using coordinate probes to navigate the overworld.

Usage

Checkpoints can be loaded in PyTorch:

import torch
# Make sure to import Encoder/Decoder from the archived v1/src
from v1.src.vae import VAE

device = "cuda" if torch.cuda.is_available() else "cpu"
vae = VAE(latent_dim=32).to(device)

checkpoint = torch.load("vae/best_vae.pt", map_location=device)
vae.load_state_dict(checkpoint["vae_state_dict"])
vae.eval()
Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading