Alem RL Baselines

Trained multi-agent RL checkpoints for Alem -- a JAX benchmark for open-ended multi-agent coordination, with paired RL/symbolic and LLM/text interfaces. These are the reference RL baselines from the paper, Benchmarking Open-Ended Multi-Agent Coordination in Language Agents.

What's here

120 checkpoints = 2 training budgets (100m, 1B training env steps) × 4 algorithms × 3 difficulties × 5 seeds.

Axis Values
Budget (env steps) 100M, 1B
Algorithm ippo-rnn (IPPO), mappo-rnn (MAPPO), hypermarl-rnn (HyperMARL-IPPO), pqn-vdn-rnn (PQN-VDN)
Difficulty easy, medium, hard
Seed seed0seed4

Layout

<budget>/<algorithm>/<difficulty>/seed<N>/
    checkpoint/      # Orbax (OCDBT) checkpoint — params, opt_state, step
    config.json      # algorithm, env config, reload_overrides, wandb summary, full training config

For example: 1B/ippo-rnn/hard/seed2/checkpoint.

⚠️ The env config must match training

Checkpoint shapes are fixed at training time, so the env configuration (agent count, communication channels, action masking, network sizes) must match or the Orbax restore fails with a shape mismatch. All of these checkpoints were trained with 4 communication channels (NUM_COMM_CHANNELS=4). The exact overrides needed to reload each policy are stored under reload_overrides in its config.json.

Loading and running a policy

Use the matching trainer in the Alem repo's baselines/. Each trainer accepts a LOAD_CHECKPOINT flag that skips training, restores the policy, and runs the standard evaluation:

# 1. Get the env + baselines
git clone https://github.com/alem-world/alem-env
cd alem-env
uv pip install -e ".[baselines-rl]"

# 2. Download the checkpoints
hf download alem-world/alem-rl-baselines --local-dir alem-rl-baselines

# 3. Reload and evaluate an IPPO policy (note NUM_COMM_CHANNELS=4)
cd baselines
python ippo_rnn.py \
    LOAD_CHECKPOINT=../alem-rl-baselines/1B/ippo-rnn/hard/seed0/checkpoint \
    NUM_COMM_CHANNELS=4 \
    EVAL_DIFFICULTIES=[hard] \
    VISUALIZE=False

To restore the raw weights directly in JAX:

import orbax.checkpoint as ocp
import jax, numpy as np

path = "alem-rl-baselines/1B/ippo-rnn/hard/seed0/checkpoint"
ckptr = ocp.PyTreeCheckpointer()
meta = ckptr.metadata(path)
restore_args = jax.tree.map(lambda _: ocp.RestoreArgs(restore_type=np.ndarray), meta)
state = ckptr.restore(path, restore_args=restore_args)   # {"params": ..., "opt_state": ..., "step": ...}

Citation

@article{tessera2026alem,
  title   = {Benchmarking Open-Ended Multi-Agent Coordination in Language Agents},
  author  = {Tessera, {Kale-ab} Abebe and Szecsenyi, Andras and Barker, Cameron and
             Rutherford, Alexander and Paglieri, Davide and Scannell, Aidan and
             Gouk, Henry and Crowley, Elliot J. and Rockt\"{a}schel, Tim and
             Storkey, Amos},
  year    = {2026},
  url     = {https://arxiv.org/abs/2606.08340}
}

License

Released under the MIT License, matching the Alem codebase.

Downloads last month

-

Downloads are not tracked for this model. How to track
Video Preview
loading

Paper for alem-world/alem-rl-baselines