ESM-2 (t6_8M_UR50D) — JAX weights

JAX (Flax) parameter .npz for facebook/esm2_t6_8M_UR50D, converted from the canonical PyTorch checkpoint released by Meta AI's FAIR. Hosted because re-converting the 6-layer ESM-2 from the HF PT checkpoint requires walking the parameter tree end-to-end and is non-trivial to reproduce.

Variant: this repo hosts the small 6-layer / 8M-parameter ESM-2 (facebook/esm2_t6_8M_UR50D). The larger ESM-2 variants (t12_35M, t30_150M, t33_650M, t36_3B, t48_15B) are not converted to JAX here — load them from facebook/esm2_*_UR50D directly into PyTorch.

Files

File Format Size
esm2_jax_weights.npz numpy .npz (param tree) ~30 MB

Loadable as a Flax pytree by xaitalk's protein adapter. Numerically equivalent to the HF PyTorch checkpoint (Pearson r=1.0 on logits and embeddings).

Architecture

Property Value
Layers 6 PRE-LN transformer blocks
Hidden dim 320
Heads 20
Position encoding RoPE
Parameters 7.4 M
Vocab 33 (20 AAs + 13 special tokens)
Input Tokenized protein sequence

Reference: Lin et al., Evolutionary-scale prediction of atomic-level protein structure, Science 2023.

Cross-framework verification

These JAX weights are validated against the PyTorch and TF HuggingFace checkpoints by xaitalk's protein benchmark (5 proteins, 20 methods):

Methods Passing at r ≥ 0.95 Min(min_r) Verified
20 20/20 0.9950 2026-05-12

Includes the full AttnLRP rule family (ε, γ, α-β, z+, flat, w², SIGN, SIGN-μ, EpsStdXSIGN — Achtibat et al. 2024 routing for the transformer attention / LayerNorm primitives), gradient family, DeepLIFT, smoothgrad family.

Notable: this is the first publicly-released cross-framework ESM-2 JAX port verified with conservation-preserving LRP at r ≥ 0.99 against the PyTorch reference.

Usage

from xaitalk.hub import ensure_model

npz_path = ensure_model('esm2/jax')  # → local path to .npz

# Load into a Flax-equivalent forward function
import numpy as np
params = dict(np.load(npz_path, allow_pickle=True))

# xaitalk's protein adapter handles the architecture + forward
from xaitalk.adapters.protein import build_3framework_protein
bundle = build_3framework_protein(
    jax_weights_path=npz_path,  # short-circuit; otherwise it converts from PT
)

# Run XAI on a protein sequence
import xaitalk
expl = xaitalk.explain(bundle.jax_forward, bundle.jax_params,
                       x='MQIFVKTLTGKTITLE...', method='lrp_epsilon',
                       framework='jax')

Training data

ESM-2 was trained on UniRef50 (September 2021 release), a ~65M-sequence cluster of UniProt protein sequences at 50% identity threshold. We provide the converted JAX weights only — the training procedure is the standard masked language modeling objective from the ESM-2 paper.

License

The original ESM-2 weights are released under the MIT license by Meta. This converted JAX format inherits the same.

Citation

ESM-2 paper:

@article{lin2023esm2,
  author  = {Lin, Zeming and Akin, Halil and Rao, Roshan
             and Hie, Brian and Zhu, Zhongkai and Lu, Wenting
             and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori
             and Shmueli, Yaniv and dos Santos Costa, Allan and
             Fazel-Zarandi, Maryam and Sercu, Tom and Candido, Salvatore
             and Rives, Alexander},
  title   = {Evolutionary-scale prediction of atomic-level protein structure},
  journal = {Science},
  year    = {2023},
  volume  = {379},
  number  = {6637}
}

AttnLRP (the conservation-preserving LRP for transformer attention used by xaitalk on ESM-2):

@inproceedings{achtibat2024attnlrp,
  author    = {Achtibat, Reduan and Hatefi, Saeed and Dreyer, Maximilian
               and Jain, Aakriti and Wiegand, Thomas and Lapuschkin, Sebastian
               and Samek, Wojciech},
  title     = {{AttnLRP: Attention-Aware Layer-Wise Relevance Propagation
                for Transformers}},
  booktitle = {International Conference on Machine Learning (ICML)},
  year      = {2024}
}

xaitalk infrastructure:

@software{paul2026xaitalk,
  author = {Paul, Alexander},
  title  = {xaitalk: Cross-Framework Explainable AI Library},
  year   = {2026},
  url    = {https://xaitalk.com}
}

Links

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for xaitalk/esm2-t6-8m

Finetuned
(45)
this model