ESM-2 (t6_8M_UR50D) — JAX weights
JAX (Flax) parameter .npz for facebook/esm2_t6_8M_UR50D, converted
from the canonical PyTorch checkpoint released by Meta AI's FAIR.
Hosted because re-converting the 6-layer ESM-2 from the HF PT
checkpoint requires walking the parameter tree end-to-end and is
non-trivial to reproduce.
Variant: this repo hosts the small 6-layer / 8M-parameter ESM-2 (
facebook/esm2_t6_8M_UR50D). The larger ESM-2 variants (t12_35M, t30_150M, t33_650M, t36_3B, t48_15B) are not converted to JAX here — load them fromfacebook/esm2_*_UR50Ddirectly into PyTorch.
Files
| File | Format | Size |
|---|---|---|
esm2_jax_weights.npz |
numpy .npz (param tree) |
~30 MB |
Loadable as a Flax pytree by xaitalk's protein adapter. Numerically equivalent to the HF PyTorch checkpoint (Pearson r=1.0 on logits and embeddings).
Architecture
| Property | Value |
|---|---|
| Layers | 6 PRE-LN transformer blocks |
| Hidden dim | 320 |
| Heads | 20 |
| Position encoding | RoPE |
| Parameters | 7.4 M |
| Vocab | 33 (20 AAs + 13 special tokens) |
| Input | Tokenized protein sequence |
Reference: Lin et al., Evolutionary-scale prediction of atomic-level protein structure, Science 2023.
Cross-framework verification
These JAX weights are validated against the PyTorch and TF
HuggingFace checkpoints by xaitalk's protein benchmark (5 proteins,
20 methods):
| Methods | Passing at r ≥ 0.95 | Min(min_r) | Verified |
|---|---|---|---|
| 20 | 20/20 | 0.9950 | 2026-05-12 |
Includes the full AttnLRP rule family (ε, γ, α-β, z+, flat, w², SIGN, SIGN-μ, EpsStdXSIGN — Achtibat et al. 2024 routing for the transformer attention / LayerNorm primitives), gradient family, DeepLIFT, smoothgrad family.
Notable: this is the first publicly-released cross-framework ESM-2 JAX port verified with conservation-preserving LRP at r ≥ 0.99 against the PyTorch reference.
Usage
from xaitalk.hub import ensure_model
npz_path = ensure_model('esm2/jax') # → local path to .npz
# Load into a Flax-equivalent forward function
import numpy as np
params = dict(np.load(npz_path, allow_pickle=True))
# xaitalk's protein adapter handles the architecture + forward
from xaitalk.adapters.protein import build_3framework_protein
bundle = build_3framework_protein(
jax_weights_path=npz_path, # short-circuit; otherwise it converts from PT
)
# Run XAI on a protein sequence
import xaitalk
expl = xaitalk.explain(bundle.jax_forward, bundle.jax_params,
x='MQIFVKTLTGKTITLE...', method='lrp_epsilon',
framework='jax')
Training data
ESM-2 was trained on UniRef50 (September 2021 release), a ~65M-sequence cluster of UniProt protein sequences at 50% identity threshold. We provide the converted JAX weights only — the training procedure is the standard masked language modeling objective from the ESM-2 paper.
License
The original ESM-2 weights are released under the MIT license by Meta. This converted JAX format inherits the same.
Citation
ESM-2 paper:
@article{lin2023esm2,
author = {Lin, Zeming and Akin, Halil and Rao, Roshan
and Hie, Brian and Zhu, Zhongkai and Lu, Wenting
and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori
and Shmueli, Yaniv and dos Santos Costa, Allan and
Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Salvatore
and Rives, Alexander},
title = {Evolutionary-scale prediction of atomic-level protein structure},
journal = {Science},
year = {2023},
volume = {379},
number = {6637}
}
AttnLRP (the conservation-preserving LRP for transformer attention used by xaitalk on ESM-2):
@inproceedings{achtibat2024attnlrp,
author = {Achtibat, Reduan and Hatefi, Saeed and Dreyer, Maximilian
and Jain, Aakriti and Wiegand, Thomas and Lapuschkin, Sebastian
and Samek, Wojciech},
title = {{AttnLRP: Attention-Aware Layer-Wise Relevance Propagation
for Transformers}},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2024}
}
xaitalk infrastructure:
@software{paul2026xaitalk,
author = {Paul, Alexander},
title = {xaitalk: Cross-Framework Explainable AI Library},
year = {2026},
url = {https://xaitalk.com}
}
Links
- xaitalk website: https://xaitalk.com
- Framework GitHub: https://github.com/alexanderfpaul/xaitalk-framework
- Protein comparison script:
examples/comparison/run_protein_3framework_comparison.py - Original ESM-2 PyTorch checkpoint: https://huggingface.co/facebook/esm2_t6_8M_UR50D
Model tree for xaitalk/esm2-t6-8m
Base model
facebook/esm2_t6_8M_UR50D