Alem RL Baselines
Trained multi-agent RL checkpoints for Alem -- a JAX benchmark for open-ended multi-agent coordination, with paired RL/symbolic and LLM/text interfaces. These are the reference RL baselines from the paper, Benchmarking Open-Ended Multi-Agent Coordination in Language Agents.
What's here
120 checkpoints = 2 training budgets (100m, 1B training env steps) × 4 algorithms × 3 difficulties × 5 seeds.
| Axis | Values |
|---|---|
| Budget (env steps) | 100M, 1B |
| Algorithm | ippo-rnn (IPPO), mappo-rnn (MAPPO), hypermarl-rnn (HyperMARL-IPPO), pqn-vdn-rnn (PQN-VDN) |
| Difficulty | easy, medium, hard |
| Seed | seed0 … seed4 |
Layout
<budget>/<algorithm>/<difficulty>/seed<N>/
checkpoint/ # Orbax (OCDBT) checkpoint — params, opt_state, step
config.json # algorithm, env config, reload_overrides, wandb summary, full training config
For example: 1B/ippo-rnn/hard/seed2/checkpoint.
⚠️ The env config must match training
Checkpoint shapes are fixed at training time, so the env configuration (agent count, communication channels, action masking, network sizes) must match or the Orbax
restore fails with a shape mismatch. All of these checkpoints were trained with 4 communication channels (NUM_COMM_CHANNELS=4). The exact overrides needed to reload
each policy are stored under reload_overrides in its config.json.
Loading and running a policy
Use the matching trainer in the Alem repo's baselines/. Each trainer accepts a LOAD_CHECKPOINT flag that skips training, restores the policy, and runs the standard evaluation:
# 1. Get the env + baselines
git clone https://github.com/alem-world/alem-env
cd alem-env
uv pip install -e ".[baselines-rl]"
# 2. Download the checkpoints
hf download alem-world/alem-rl-baselines --local-dir alem-rl-baselines
# 3. Reload and evaluate an IPPO policy (note NUM_COMM_CHANNELS=4)
cd baselines
python ippo_rnn.py \
LOAD_CHECKPOINT=../alem-rl-baselines/1B/ippo-rnn/hard/seed0/checkpoint \
NUM_COMM_CHANNELS=4 \
EVAL_DIFFICULTIES=[hard] \
VISUALIZE=False
To restore the raw weights directly in JAX:
import orbax.checkpoint as ocp
import jax, numpy as np
path = "alem-rl-baselines/1B/ippo-rnn/hard/seed0/checkpoint"
ckptr = ocp.PyTreeCheckpointer()
meta = ckptr.metadata(path)
restore_args = jax.tree.map(lambda _: ocp.RestoreArgs(restore_type=np.ndarray), meta)
state = ckptr.restore(path, restore_args=restore_args) # {"params": ..., "opt_state": ..., "step": ...}
Citation
@article{tessera2026alem,
title = {Benchmarking Open-Ended Multi-Agent Coordination in Language Agents},
author = {Tessera, {Kale-ab} Abebe and Szecsenyi, Andras and Barker, Cameron and
Rutherford, Alexander and Paglieri, Davide and Scannell, Aidan and
Gouk, Henry and Crowley, Elliot J. and Rockt\"{a}schel, Tim and
Storkey, Amos},
year = {2026},
url = {https://arxiv.org/abs/2606.08340}
}
License
Released under the MIT License, matching the Alem codebase.