SWITCH — Phase 3 (Switch-GRPO) LoRA Adapter for Qwen3-8B

This is the paper-final checkpoint of SWITCH, a switchable latent chain-of-thought framework that combines Coconut-style hidden-state recurrence with on-policy reinforcement learning through a single primitive: a pair of learned boundary tokens <swi> / </swi>.

Headline result. 79.3 % MATH-500 / 89.2 % GSM8K, +25.7 points above the strongest Coconut-style baseline at the same scale.

📄 Companion paper: "Demystifying Hidden-State Recurrence: Switchable Latent Reasoning with On-Policy Reinforcement Learning"arXiv:2606.13106. 💻 Code: github.com/LARK-AI-Lab/SWITCH 📊 Training data: LARK-Lab/SWITCH-Math-Train

What this checkpoint does

The model emits <swi> to enter latent mode and </swi> to exit. Inside the latent block it performs Coconut-style hidden-state recurrence (each step's last-layer hidden state becomes the input embedding of the next <latent> position). Outside the block it decodes ordinary text. The boundary tokens are ordinary discrete vocabulary items, so on-policy GRPO is well-defined at every text position; latent positions contribute no policy-gradient term.

This adapter was trained in three phases on Qwen3-8B:

  1. Phase 1 (SFT). Wrap high-entropy CoT spans in <swi>/</swi>.
  2. Phase 2 (Curriculum). Replace text inside <swi> blocks with <latent> placeholders progressively (parallel schedule).
  3. Phase 3 (Switch-GRPO). On-policy RL on the answer reward, with rollouts that perform real hidden-state injection at <latent> positions.

This release is the Phase 3 endpoint, the version reported in the paper.

Quick start

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

BASE    = "Qwen/Qwen3-8B"
ADAPTER = "LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B"

tokenizer = AutoTokenizer.from_pretrained(ADAPTER)   # contains <swi>, </swi>, <latent>
model = AutoModelForCausalLM.from_pretrained(
    BASE, torch_dtype=torch.bfloat16, device_map="auto"
)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, ADAPTER)
model.eval()

⚠️ Important: A naïve model.generate(...) will treat <latent> as just another token and will not perform the hidden-state recurrence inside <swi>...</swi> blocks. To run inference exactly as in the paper, use the SWITCH inference loop in src/model/coconut_swi_model.py, which feeds the previous latent step's last-layer hidden state back as the next input embedding and enforces the K_min minimum-dwell constraint inside the latent block.

Headline results

Benchmark SWITCH (this checkpoint) Strongest Coconut-style baseline Gap
MATH-500 79.3 % 53.6 % +25.7
GSM8K 89.2 % 78.5 % +10.7

All numbers under matched data, decoding, and Qwen3-8B base-model settings.

Training details

Base model Qwen/Qwen3-8B
Phase 1 LoRA (r=32, α=64) on {q,k,v,o,gate,up,down}_proj + resized embeddings + LM head, bf16
Phase 2 LoRA continued from Phase 1; parallel curriculum schedule, c=2, K_max=8, per-sample latent cap 48
Phase 3 Switch-GRPO; group size G=5, clip ε=0.2, KL β=1e-3, lr=1e-6; reward = correctness + format + latent-usage
Training data LARK-Lab/SWITCH-Math-Train
Hardware 8 × NVIDIA H20 (95 GB)
K_min (inference) 4

See the paper §3 for the full method and §4–§5 for setup / results.

Special tokens

Token Purpose
<swi> Enter latent reasoning
</swi> Exit latent reasoning
<latent> Latent placeholder; no token sampled, hidden-state injection happens here

Mechanistic findings (verifiable on this checkpoint)

The boundary tokens make latent computation directly inspectable. Three takeaways from the paper, all reproducible with scripts/interpret_swi.py:

  1. <swi> is a learned switching policy, not a stylistic tag. Sharply localised (rank ≤ 2 at boundaries vs 10³ at random positions), forms a clean one-token spike, linearly decodable from late hidden states (91.9 %).
  2. The latent step performs causally important computation. Zeroing the injected hidden states reduces accuracy by roughly two-thirds on the diagnostic subset; same-norm random replacements cost only a few points.
  3. The work is concentrated at a single hidden-state transition on entry; subsequent steps are near-deterministic exits with p(</swi>) ≈ 1. The K_min constraint is what protects this single computational step.

Intended use

  • Math-reasoning research on hidden-state-recurrence latent CoT.
  • Reproducing the SWITCH paper's main table and mechanistic analysis.
  • A starting point for further on-policy RL on latent-recurrent models.

Limitations

  • Only English mathematical reasoning is in the training distribution.
  • Visible token count (1,700 / problem on MATH-500) is much higher than the pure-Coconut baselines (10 visible tokens / problem) because we keep visible CoT outside the <swi> block; this is a deliberate design choice for verifiability, not a token-level efficiency claim.
  • Naïve model.generate(...) does not activate the hidden-state recurrence; you must use the SWITCH inference loop to reproduce the paper numbers.

License

MIT.

Citation

@misc{yang2026demystifyinghiddenstaterecurrenceswitchable,
  title         = {Demystifying Hidden-State Recurrence: Switchable Latent Reasoning with On-Policy Reinforcement Learning},
  author        = {Jiayu Yang and Chao Chen and Shengen Wu and Yinhong Liu and Yuxuan Fan and Lujundong Li and Songning Lai and Chengwei Qin and Zhijiang Guo},
  year          = {2026},
  eprint        = {2606.13106},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG},
  url           = {https://arxiv.org/abs/2606.13106}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B

Finetuned
Qwen/Qwen3-8B
Adapter
(1471)
this model

Dataset used to train LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B

Paper for LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B