Instructions to use LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("/root/.cache/modelscope/hub/models/Qwen/Qwen3-8B") model = PeftModel.from_pretrained(base_model, "LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B") - Notebooks
- Google Colab
- Kaggle
SWITCH — Phase 3 (Switch-GRPO) LoRA Adapter for Qwen3-8B
This is the paper-final checkpoint of SWITCH, a switchable latent
chain-of-thought framework that combines Coconut-style hidden-state recurrence
with on-policy reinforcement learning through a single primitive: a pair of
learned boundary tokens <swi> / </swi>.
Headline result. 79.3 % MATH-500 / 89.2 % GSM8K, +25.7 points above the strongest Coconut-style baseline at the same scale.
📄 Companion paper: "Demystifying Hidden-State Recurrence: Switchable Latent Reasoning with On-Policy Reinforcement Learning" — arXiv:2606.13106. 💻 Code: github.com/LARK-AI-Lab/SWITCH 📊 Training data: LARK-Lab/SWITCH-Math-Train
What this checkpoint does
The model emits <swi> to enter latent mode and </swi> to exit. Inside the
latent block it performs Coconut-style hidden-state recurrence (each step's
last-layer hidden state becomes the input embedding of the next <latent>
position). Outside the block it decodes ordinary text. The boundary tokens
are ordinary discrete vocabulary items, so on-policy GRPO is well-defined at
every text position; latent positions contribute no policy-gradient term.
This adapter was trained in three phases on Qwen3-8B:
- Phase 1 (SFT). Wrap high-entropy CoT spans in
<swi>/</swi>. - Phase 2 (Curriculum). Replace text inside
<swi>blocks with<latent>placeholders progressively (parallel schedule). - Phase 3 (Switch-GRPO). On-policy RL on the answer reward, with
rollouts that perform real hidden-state injection at
<latent>positions.
This release is the Phase 3 endpoint, the version reported in the paper.
Quick start
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE = "Qwen/Qwen3-8B"
ADAPTER = "LARK-Lab/SWITCH-Phase3-GRPO-LoRA-Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(ADAPTER) # contains <swi>, </swi>, <latent>
model = AutoModelForCausalLM.from_pretrained(
BASE, torch_dtype=torch.bfloat16, device_map="auto"
)
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, ADAPTER)
model.eval()
⚠️ Important: A naïve
model.generate(...)will treat<latent>as just another token and will not perform the hidden-state recurrence inside<swi>...</swi>blocks. To run inference exactly as in the paper, use the SWITCH inference loop insrc/model/coconut_swi_model.py, which feeds the previous latent step's last-layer hidden state back as the next input embedding and enforces theK_minminimum-dwell constraint inside the latent block.
Headline results
| Benchmark | SWITCH (this checkpoint) | Strongest Coconut-style baseline | Gap |
|---|---|---|---|
| MATH-500 | 79.3 % | 53.6 % | +25.7 |
| GSM8K | 89.2 % | 78.5 % | +10.7 |
All numbers under matched data, decoding, and Qwen3-8B base-model settings.
Training details
| Base model | Qwen/Qwen3-8B |
| Phase 1 | LoRA (r=32, α=64) on {q,k,v,o,gate,up,down}_proj + resized embeddings + LM head, bf16 |
| Phase 2 | LoRA continued from Phase 1; parallel curriculum schedule, c=2, K_max=8, per-sample latent cap 48 |
| Phase 3 | Switch-GRPO; group size G=5, clip ε=0.2, KL β=1e-3, lr=1e-6; reward = correctness + format + latent-usage |
| Training data | LARK-Lab/SWITCH-Math-Train |
| Hardware | 8 × NVIDIA H20 (95 GB) |
| K_min (inference) | 4 |
See the paper §3 for the full method and §4–§5 for setup / results.
Special tokens
| Token | Purpose |
|---|---|
<swi> |
Enter latent reasoning |
</swi> |
Exit latent reasoning |
<latent> |
Latent placeholder; no token sampled, hidden-state injection happens here |
Mechanistic findings (verifiable on this checkpoint)
The boundary tokens make latent computation directly inspectable. Three
takeaways from the paper, all reproducible with
scripts/interpret_swi.py:
<swi>is a learned switching policy, not a stylistic tag. Sharply localised (rank ≤ 2 at boundaries vs10³ at random positions), forms a clean one-token spike, linearly decodable from late hidden states (91.9 %).- The latent step performs causally important computation. Zeroing the injected hidden states reduces accuracy by roughly two-thirds on the diagnostic subset; same-norm random replacements cost only a few points.
- The work is concentrated at a single hidden-state transition on entry;
subsequent steps are near-deterministic exits with
p(</swi>) ≈ 1. TheK_minconstraint is what protects this single computational step.
Intended use
- Math-reasoning research on hidden-state-recurrence latent CoT.
- Reproducing the SWITCH paper's main table and mechanistic analysis.
- A starting point for further on-policy RL on latent-recurrent models.
Limitations
- Only English mathematical reasoning is in the training distribution.
- Visible token count (
1,700 / problem on MATH-500) is much higher than the pure-Coconut baselines (10 visible tokens / problem) because we keep visible CoT outside the<swi>block; this is a deliberate design choice for verifiability, not a token-level efficiency claim. - Naïve
model.generate(...)does not activate the hidden-state recurrence; you must use the SWITCH inference loop to reproduce the paper numbers.
License
MIT.
Citation
@misc{yang2026demystifyinghiddenstaterecurrenceswitchable,
title = {Demystifying Hidden-State Recurrence: Switchable Latent Reasoning with On-Policy Reinforcement Learning},
author = {Jiayu Yang and Chao Chen and Shengen Wu and Yinhong Liu and Yuxuan Fan and Lujundong Li and Songning Lai and Chengwei Qin and Zhijiang Guo},
year = {2026},
eprint = {2606.13106},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2606.13106}
}
- Downloads last month
- -