MergeDNA (reimplementation)
Faithful reimplementation of MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging (Li et al., 2025, arXiv:2511.14806). The model is a hierarchical autoencoder for DNA that learns variable-length tokenization via differentiable token merging β the tokenizer and context model are trained jointly under three self-supervised objectives.
This release is the best-performing checkpoint from a single-GPU reimplementation. It is undertrained relative to the paper (β0.24% of the paper's training tokens) and is intended for reproducibility studies and as a starting point for further pre-training or fine-tuning, not as a state-of-the-art model.
Model details
| Property | Value |
|---|---|
| Parameters | ~380 M |
d_model |
1024 |
| Attention heads | 16 |
| Layers | 4 Local Enc / 20 Latent Enc / 4 Latent Dec / 2 Local Dec |
| Local window size | 16 |
| Max sequence length | 2048 (paper uses 4096; RoPE extrapolates but untested here) |
| Vocab | 4 (A, T, C, G; N collapses to A) |
| Building blocks | RMSNorm, RoPE, SwiGLU FFN, pre-norm (LLaMA-style) |
| Precision (training) | fp32 (this checkpoint) β a faster bf16 + Triton variant exists, see source repo |
The novel architectural pieces are:
- Differentiable token merging inside local-window attention. Within each window, even/odd bipartite matching plus a DTEM-style decoupled grouping projection produces similarity scores; adjacent tokens with highest similarity merge by averaging. A source matrix tracks merge lineage so the unmerge step is exact.
- Three pre-training losses trained jointly:
MTR: full-autoencoder reconstruction (collapses to ~0 early via a trivial shortcut β expected behavior).Ξ» Β· MTR(ΞΈ\{Ο}): second-stage reconstruction with the tokenizer detached and a tighterK=L/2latent bottleneck, which forces the latent stack to do real work.Ξ» = 0.25.AMTM: adaptive masked-token modeling biased toward small-merge-group ("important") tokens.
Intended use
- Feature extraction for DNA classification tasks (frozen-backbone embeddings via
forward_classify). - LoRA fine-tuning on task-specific genomic data (transcription factor binding, splice sites, promoters, etc.) β protocol below.
- Continued pre-training as a starting point for domain adaptation to a specific organism, tissue, or task family.
Not intended for clinical decision-making, variant interpretation, or any use where errors carry health consequences.
How to use
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# from your local copy of the source repo
from MergeDNA import MergeDNA, MergeDNAConfig, encode_dna_sequence
REPO = "Sharmistha-NLP/mergedna-400M"
weights_path = hf_hub_download(REPO, "model.safetensors")
cfg = MergeDNAConfig(
d_model=1024, n_heads=16,
local_enc_layers=4, latent_enc_layers=20,
latent_dec_layers=4, local_dec_layers=2,
window_size=16, max_seq_len=2048,
)
model = MergeDNA(cfg)
model.load_state_dict(load_file(weights_path), strict=True)
model.eval()
# Frozen-backbone feature extraction
seq = "ACGT" * 256 # 1024 nt example
ids = encode_dna_sequence(seq).unsqueeze(0) # [1, 1024]
with torch.no_grad():
features = model.forward_classify(ids) # mean-pooled latent embedding
For resumption of pre-training, download ckpt_best.pt instead β it carries full optimizer + scheduler + RNG state.
Sandboxed sanity test
A self-contained script that downloads the weights and runs them on a tiny in-memory random-DNA dataset β no external data needed. Useful for verifying the upload round-tripped before plugging the model into a real pipeline. Save as HF/HF-test/test_hf_model.py inside a clone of the source repo and run:
uv run python HF/HF-test/test_hf_model.py --repo-id Sharmistha-NLP/mergedna-400M
The val loss will be high β random ACGT is out-of-distribution. The point is to confirm weights load, the forward pass produces finite values, and forward_classify returns the expected (B, D) tensor.
Full script (click to expand)
"""Download the uploaded MergeDNA checkpoint from Hugging Face and sanity-check
it on a tiny, sandboxed dataset that lives entirely inside this folder.
No dependency on code/train.py and no validation data on disk β sequences are
generated in-memory with a fixed seed so the test runs anywhere the package
(and MergeDNA.py on sys.path) is available.
"""
import argparse
import sys
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from torch.utils.data import DataLoader, IterableDataset
def _find_project_root(start: Path) -> Path:
for p in [start, *start.parents]:
if (p / "pyproject.toml").exists():
return p
raise RuntimeError(f"pyproject.toml not found upward from {start}")
ROOT = _find_project_root(Path(__file__).resolve())
sys.path.insert(0, str(ROOT / "code"))
from MergeDNA import MergeDNA, MergeDNAConfig, encode_dna_sequence # noqa: E402
class RandomDNAIterable(IterableDataset):
"""Yields `n_samples` seeded random ACGT token-id sequences of length `max_len`."""
def __init__(self, n_samples: int, max_len: int, seed: int = 0):
super().__init__()
self.n_samples = n_samples
self.max_len = max_len
self.seed = seed
def __iter__(self):
g = torch.Generator().manual_seed(self.seed)
for _ in range(self.n_samples):
yield torch.randint(0, 4, (self.max_len,), generator=g, dtype=torch.long)
def build_val_loader(args) -> DataLoader:
n_samples = args.val_batches * args.batch_size
ds = RandomDNAIterable(n_samples=n_samples, max_len=args.max_len, seed=args.seed)
return DataLoader(ds, batch_size=args.batch_size, num_workers=0)
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--repo-id", default="Sharmistha-NLP/mergedna-400M")
p.add_argument("--weights-file", default="model.safetensors")
p.add_argument("--val-batches", type=int, default=4)
p.add_argument("--batch-size", type=int, default=2)
p.add_argument("--max-len", type=int, default=2048,
help="Must match the max_seq_len the checkpoint was trained with.")
p.add_argument("--d-model", type=int, default=1024)
p.add_argument("--n-heads", type=int, default=16)
p.add_argument("--window", type=int, default=16)
p.add_argument("--latent-enc-layers", type=int, default=20)
p.add_argument("--latent-dec-layers", type=int, default=4)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = p.parse_args()
print(f"[hf] downloading {args.weights_file} from {args.repo_id}")
weights_path = hf_hub_download(args.repo_id, args.weights_file)
cfg = MergeDNAConfig(
d_model=args.d_model,
n_heads=args.n_heads,
local_enc_layers=4,
latent_enc_layers=args.latent_enc_layers,
latent_dec_layers=args.latent_dec_layers,
local_dec_layers=2,
window_size=args.window,
max_seq_len=args.max_len,
)
print(f"[config] {cfg}")
model = MergeDNA(cfg).to(args.device)
state = load_file(weights_path)
model.load_state_dict(state, strict=True)
model.eval()
print(f"[model] loaded {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M params")
val_loader = build_val_loader(args)
sums = {"total": 0.0, "mtr": 0.0, "latent_mtr": 0.0, "amtm": 0.0, "compression": 0.0}
n = 0
with torch.no_grad():
for batch in val_loader:
if n >= args.val_batches:
break
ids = batch.to(args.device)
losses = model.forward_pretrain(ids)
sums["total"] += losses["total_loss"].item()
sums["mtr"] += losses["loss_mtr"]
sums["latent_mtr"] += losses["loss_latent_mtr"]
sums["amtm"] += losses["loss_amtm"]
sums["compression"] += losses["compression_ratio"]
n += 1
if n == 0:
print("[error] no val batches produced β increase --val-batches.")
sys.exit(1)
print(f"\n[val] over {n} batches (batch_size={args.batch_size}, max_len={args.max_len}):")
for k, v in sums.items():
print(f" {k:>12s}: {v / n:.4f}")
seq = "ACGT" * 256
ids = encode_dna_sequence(seq).unsqueeze(0).to(args.device)
with torch.no_grad():
features = model.forward_classify(ids)
print(f"\n[classify] features shape={tuple(features.shape)} "
f"finite={torch.isfinite(features).all().item()} "
f"mean={features.mean().item():+.4f} std={features.std().item():.4f}")
if __name__ == "__main__":
main()
Training
Data
- Corpus: Multi-Species Genomes (NCBI RefSeq), 849 species, ~174 B nucleotides total.
- Splits: species-disjoint β 749 species train / 50 validation / 50 test. Train is bacteria-heavy; val/test are weighted toward larger eukaryote genomes.
- Chunking: 6,200 bp chunks (6 kbp config) with 100 bp overlap on each side, randomly cropped at training time so each epoch sees a different window.
- Tokens seen during this training run: ~256 M (vs the paper's ~105 B β this is a major caveat).
Hyperparameters
- Optimizer: AdamW, Ξ² = (0.9, 0.95), weight decay 0.1
- LR: base 1e-4, cosine schedule with 10K warmup
- Batch size: 4 sequences Γ 2048 tokens = 8,192 tokens/step
- Steps: 32,000 (run was terminated early due to a dataset URL 404; the next 18K of the planned 50K were not executed)
- Hardware: single RTX PRO 6000 Blackwell, 96 GB
- Wall-clock: ~30 hours
Training trajectory
- Steps 0β1K:
MTRcollapses to ~0.01 (trivial autoencoder shortcut β diagnostic of the loss, not a failure). - Steps 1Kβ10K: Latent MTR and AMTM begin moving (1.39 β ~1.2).
- Steps 10Kβ30K: Latent MTR drops to 0.28 at step 31,999. Representation learning is real (verified by the linear probe and LoRA results below).
Evaluation
Verified on the GUE Mouse TF-3 task (DNABERT-2 benchmark, smallest task: 1,904 train / 239 test).
| Setup | val MCC | test MCC |
|---|---|---|
| Random baseline | 0 | 0 |
| Linear probe (frozen, sklearn LogisticRegression on mean-pooled features) | ~22 | β |
| LoRA fine-tune (rank 8, Ξ± 16, 10 epochs, lr 1e-4) | 69.46 (epoch 9) | 56.34 |
| Paper Mouse TF-3 (full 100K-step backbone) | β | 73.46 |
The LoRA val MCC is within ~4 percentage points of the paper despite this checkpoint seeing ~400Γ less data. The test/val gap is dominated by the 239-sample test split's variance and mild overfitting by epoch 10.
Full GUE benchmark (36 tasks): not run. Estimated 6 GPU-days at this checkpoint's throughput. See REPORT.md in the source repo for the tradeoff discussion.
Reproducing the LoRA result
uv run --extra data --extra eval python code/lora_finetune.py \
--checkpoint ckpt_best.pt \
--task-dir data/GUE/mouse/3 \
--epochs 10 --max-len 256 --batch-size 32 \
--rank 8 --lora-alpha 16 --lr 1e-4 \
--head-hidden 256
Limitations and known issues
- Undertrained. ~256 M tokens vs paper's ~105 B (β0.24%). Expect a gap on most GUE tasks vs the published numbers.
- Max sequence length 2048 at training; RoPE allows longer at inference, but this has not been validated.
- N nucleotides become A. The byte-LUT tokenizer maps anything not A/T/C/G to 0 (= A). Assembly gaps and ambiguous bases are silently lost; downstream interpretation on N-rich regions should be done carefully.
- Bacteria-heavy training distribution. 667 of 749 training species are bacteria. Performance on plant/viral genomes (excluded from training) is unknown; on large eukaryotic genomes it is plausible but undertrained.
- No safety or bias evaluation. This is a pre-training base; downstream classifiers must do their own validation.
- Single-GPU reimplementation. Architecture follows the paper but optimization details, data ordering, and tokenization quirks may differ in subtle ways from the original release.
Files in this repo
| File | Size (approx) | Purpose |
|---|---|---|
model.safetensors |
~1.5 GB | Weights only. Use for inference and fine-tuning. |
ckpt_best.pt |
~4.5 GB | Full training state (model + optimizer + scheduler + RNG + step + cfg). Use for resuming pre-training. |
best.json |
<1 KB | Per-checkpoint validation loss history. |
README.md |
β | This card. |
Citation
If you use this checkpoint, please cite the original paper:
@article{li2025mergedna,
title = {MergeDNA: Context-aware Genome Modeling with Dynamic Tokenization through Token Merging},
author = {Li, et al.},
journal= {arXiv preprint arXiv:2511.14806},
year = {2025}
}
And acknowledge this reimplementation:
@misc{jat2026mergedna_reimpl,
title = {MergeDNA reimplementation (single-GPU)},
author = {Jat, Sharmistha},
year = {2026},
howpublished = {Hugging Face model repository},
url = {https://huggingface.co/Sharmistha-NLP/mergedna-400M}
}
License
Apache 2.0. Data is sourced from NCBI RefSeq (public domain) via the InstaDeepAI multi-species genomes dataset.