nsgf-plusplus / TODO.md
rogermt's picture
Add TODO.md β€” next steps for NSGF++ reproduction
91fd7ed verified

TODO.md β€” Next Steps for NSGF++ Reproduction

Current Status

Experiment Pool Building Phase 1 (NSGF) Phase 2 (NSF) Phase 3 (Predictor) Inference Eval
2D 8gaussians βœ… βœ… β€” β€” βœ… βœ… W2=2.04 (small run)
MNIST βœ… πŸ”Ά runs, loss converging (~0.03), interrupted at 9.5K/100K untested on GPU untested on GPU untested untested
CIFAR-10 πŸ”Ά OOM fixed (batch 128β†’32), untested on GPU untested untested untested untested untested

βœ… = verified working πŸ”Ά = partially done ❌ = blocked


Immediate β€” Run Full Experiments

1. MNIST full run on T4

The most important next step. All code bugs are fixed. Need a clean Kaggle run.

cd /kaggle/working/ && rm -rf nsgf-plusplus
git clone https://huggingface.co/rogermt/nsgf-plusplus
cd nsgf-plusplus && pip install -r requirements.txt

# Phase 1: pool (~7 min) + NSGF training (100K steps, ~2.5 hrs)
python main.py --experiment mnist

# If session runs out, next session:
python main.py --experiment mnist --resume-phase 2

# If Phase 2 done:
python main.py --experiment mnist --resume-phase 3

Expected runtimes on T4:

  • Pool building (1500 batches): ~7 min
  • Phase 1 NSGF (100K steps): ~2.5 hours
  • Phase 2 NSF (100K steps): ~3-4 hours (each step does NSGF inference + NSF forward/backward)
  • Phase 3 Predictor (40K steps): ~1.5 hours
  • Total: ~7-8 hours β€” tight for one 9-hour Kaggle session

Alternative: use --train-iters 50000 for Phase 1+2 to fit in one session, accept lower quality.

Paper target: FID β‰ˆ 3.8 at NFE=60


2. CIFAR-10 first test on T4

After MNIST works, test CIFAR with reduced Sinkhorn batch.

# Smoke test first (should run ~2 min)
python main.py --experiment cifar10 --pool-batches 10 --train-iters 50

# If smoke test passes, real Phase 1:
python main.py --experiment cifar10 --train-iters 50000

# Subsequent sessions:
python main.py --experiment cifar10 --resume-phase 2 --train-iters 50000
python main.py --experiment cifar10 --resume-phase 3

If still OOMs: try --sinkhorn-batch 16 --pool-batches 20000

Paper target: FID β‰ˆ 5.55, IS β‰ˆ 8.86 at NFE=59


3. 2D full-scale run

Quick win to validate against paper numbers. Should take ~20 min on T4.

python main.py --experiment 2d --dataset 8gaussians --steps 10

Paper target: W2 β‰ˆ 0.285 for 8gaussians

Current small-run W2=2.04 is expected β€” only used 10 pool batches + 1000 iters. Full run (200 batches, 20K iters) should drop dramatically.

Also run other 2D datasets:

python main.py --experiment 2d --dataset moons --steps 10
python main.py --experiment 2d --dataset scurve --steps 10
python main.py --experiment 2d --dataset checkerboard --steps 10

Medium-term β€” Code Improvements

4. Step-level resume within phases

Current --resume-phase skips completed phases but restarts the current phase from step 0. For 100K-step phases, mid-phase interruption still loses progress. Need:

  • Load nsgf_checkpoint.pt / nsf_checkpoint.pt / predictor_checkpoint.pt
  • Resume optimizer state + step counter
  • Continue from last checkpoint step

5. EMA (Exponential Moving Average) for image models

Paper uses EMA for MNIST and CIFAR-10 (standard in diffusion/flow models). Current code doesn't implement EMA. This likely affects FID significantly.

6. Learning rate scheduler

Paper may use cosine decay or warmup. Currently using constant lr. Check if this matters for convergence.

7. FID evaluation correctness

Verify that evaluation.py's FID computation matches the standard protocol:

  • InceptionV3 features from pool3 layer (2048-dim)
  • 10K generated vs 10K test samples
  • Proper image preprocessing (resize to 299Γ—299 for Inception)
  • Compare against pytorch-fid or clean-fid for sanity check

8. Inception Score evaluation

Implement properly for CIFAR-10 if not already correct. Paper reports IS=8.86.


Longer-term β€” Towards Paper Numbers

9. Full paper hyperparameters

Once code is stable, run with exact paper configs (no iteration reduction):

  • MNIST: 100K + 100K + 40K iterations
  • CIFAR-10: 200K + 200K + 40K iterations
  • This requires A100 or multiple Kaggle sessions with checkpointing

10. Ablation: NSGF vs NSGF++

Run NSGF-only (Phase 1 only, no straight flow) and compare FID/W2 against NSGF++ to verify the two-phase approach actually helps. Paper shows clear improvement.

11. NFE sweep

Paper reports results at various NFE (number of function evaluations). Test:

  • MNIST: NFE = 10, 20, 40, 60
  • CIFAR: NFE = 10, 20, 40, 59
  • Compare FID vs NFE curve against paper's Figure 3

12. pykeops for faster Sinkhorn

Install pykeops to enable geomloss online backend. This avoids materializing the full NΓ—N cost matrix and should be much faster + lower VRAM for image experiments. Could enable using paper's original batch_size=128 on T4.

pip install pykeops
# Then in config or code:
# backend: "online"  instead of "tensorized"

Known Limitations

  • Single-GPU only β€” no DDP, T4Γ—2 wastes one GPU
  • No EMA β€” standard in flow/diffusion, likely hurts FID
  • No mixed precision β€” fp32 only, could halve VRAM with fp16/bf16
  • No gradient accumulation β€” batch size is hard-limited by VRAM
  • Kaggle checkpoint persistence β€” checkpoints lost between sessions unless manually saved