CODI-Qwen3-8B (answer-only)

A CODI latent-reasoning model trained on top of Qwen/Qwen3-8B. The model compresses chain-of-thought reasoning into a small number of continuous latent vectors (6 "thought" tokens) rather than emitting it as text, then reads the answer off those latents. It is a Qwen3-8B-scale reproduction of the experiment in the LessWrong post "Can we interpret latent reasoning using current mechanistic interpretability?" (the post used Llama-3.2-1B).

⚠️ This is not a standard AutoModelForCausalLM checkpoint. CODI wraps the base model in a custom nn.Module (LoRA adapters + a projection MLP + 2 special tokens, with the answer produced after iterating latent embeddings). There is intentionally no config.json; load it with the CODI code (see Usage below).

📓 How this model was made (the full development log, including the dead-ends — an 8e-4 run that scored 1%, the harness control that diagnosed it, and the fix): see REPRODUCTION_NOTES.md.


Antecedents (lineage / provenance)

This model is a derivative work; here is the full chain it descends from.

  1. Base model — Qwen/Qwen3-8B (Apache-2.0). All transformer weights are inherited from Qwen3-8B; only LoRA adapters (rank 128) + a projection module + 2 resized embedding rows are trained. Hidden size 4096, 36 layers.

  2. Method — CODI ("Compressing Chain-of-Thought into Continuous Space via Self-Distillation"), Shen et al., arXiv:2502.21074. A single model is co-trained on three heads: (a) direct answer (SFT), (b) verbalized chain-of-thought (SFT — the "teacher" path), and (c) latent CoT, where the latent hidden states are distilled to match the teacher's worded-CoT hidden states. At inference any of the three modes can be selected.

  3. Reproduction target — LessWrong post "Can we interpret latent reasoning using current mechanistic interpretability?" by B. Cywiński & B. Wójcik (https://www.lesswrong.com/posts/YGAimivLxycZcqRFR), and its companion model bcywinski/codi_llama1b-answer_only.

  4. Training code — github.com/cywinski/codi (the post authors' fork of the original CODI repo). Trained config: configs/llama1b_gsm8k-aug-nl.yaml, re-pointed to Qwen3-8B.

  5. Training data — GSM8k-Aug-NL (zen-E/GSM8k-Aug-NL, grade-school math with step-by-step solutions) + CommonsenseQA (zen-E/CommonsenseQA-GPT4omini), ≈388k examples total, in answer_only formatting ("Output only the answer and nothing else.").


Training

Base Qwen/Qwen3-8B (bf16, full-precision)
Adapter LoRA r=128, α=32, on q,k,v,o,gate,up,down_proj
Latent tokens num_latent = 6
Projection 2-layer MLP, prj_dim = 4096 (= hidden size), GELU + LayerNorm
Distillation MSE on hidden states, distill_loss_factor = 20, std-normalized
Data GSM8k-Aug-NL + CommonsenseQA, answer_only=True
Optimizer AdamW, lr 2e-4, cosine, warmup 0.03, weight-decay 0.1, grad-clip 2.0
Batch per-device 8 × grad-accum 4 × 8 GPU = 256 effective
Schedule 5 epochs (~7,580 steps), seed 11
Hardware 8× H100 (bf16), NCCL_NVLS_ENABLE=0, expandable_segments

Note on learning rate. This model was trained from the base model (no warm-start). At 8B scale, the post's Llama-1B learning rate of 8e-4 fails catastrophically (answer head never converges; GSM8k ≈ 1%). 2e-4 is required — that is the single most important hyperparameter for reproducing this at 8B. (The original published Llama-1B model additionally warm-started from a prior CODI checkpoint; this one did not.)


Results (GSM8k test, 1319 problems, 3-sample self-consistency)

Accuracy of the three co-trained inference modes, per training epoch:

epoch Latent CoT Verbalized CoT Direct answer latent − verbal
1 44.6 52.8 36.2 −8.2
2 43.6 47.5 39.8 −3.9
3 47.7 47.9 43.8 −0.3
4 49.0 48.8 45.2 +0.2
5 (final) 49.3 49.9 45.9 −0.6

Finding: the latent-CoT mechanism matches verbalized CoT from epoch 3 onward (latent ≈ verbal > direct), reproducing the post's central result. For reference, the post's Llama-1B numbers were 41.6 / 42.1 / 36.7 — the same near-tie, shifted up ~8 points by the stronger 8B base. A control eval of the published Llama-1B checkpoint through the same harness scored 41.5%, confirming the evaluation is faithful.

The earlier checkpoints (epochs 1–4) are available in the training run's history but this repo ships the final epoch-5 weights.


Examples (real GSM8k test outputs)

All outputs below are verbatim from the epoch-5 evaluation (GSM8k test, sampling T=1.0). In latent mode the model produces no visible reasoning — it iterates 6 continuous latent vectors internally and emits only the final answer. Overall latent-mode accuracy is ~49.8% (657/1319), so it is right roughly half the time; these are correct cases that show the mechanism working across difficulty levels.

GSM8k question Latent-CoT answer
Charmaine will be 16 years old in 12 years. How old will she be 4 years from now? 8
Matteo traveled at 55 mph for 4 hours. Shandy traveled at 45 mph for 10 hours. How many miles farther did Shandy drive than Matteo? 230
Mom bought a set of pots for $19 and a sack of garden soil for $26, and used a $7-off coupon. How much did she spend? 38
For every pound of beeswax Charlie makes 10 tapered candles; one pound of beeswax + wicks costs $10.00. If he sells each candle for $2.00, what is his net profit on 20 candles? 20
Richard's driveway is 24 ft wide; he puts a soda bottle every 3 ft. Starting at the first bottle, it takes 5 s to go from one bottle to the next. How many seconds total to set off all fountains? 35

Are the latent thoughts load-bearing? (ablations)

Four causal interventions on the latent trace, run on the full GSM8k test set (n=1319, greedy decoding — hence the baseline here is 52.0% rather than the sampled 49.3% above). Method: record each example's clean 6-latent trace, then re-run answer generation on a modified trace (downstream latents are not regenerated, except in truncation where the prefix is unchanged anyway).

1. Truncation — "how many thoughts does it need?" (cf. the post's experiment 1)

latents kept k 0 1 2 3 4 5 6 (full)
accuracy % 47.0 44.5 46.6 47.6 50.1 51.1 52.0

Truncation curve

Deleting the entire latent CoT costs −5.0 pp (47.0% ≈ the no-reasoning floor; the separately-trained direct head scores 45.9% sampled). Accuracy recovers mostly through latents 3–5; the 6th adds little — consistent with the post's finding on Llama-1B that the model barely uses its final latent.

2. Single-latent dropout — "which thought matters?" (delete one, keep 5 in order)

dropped latent 1 2 3 4 5 6
accuracy % 49.3 52.2 48.4 48.6 50.5 51.1
Δ vs full −2.7 +0.2 −3.6 −3.4 −1.5 −0.9

Dropout and shuffle deltas

The middle latents (3–4) are the most load-bearing single thoughts — echoing the post's logit-lens finding that intermediate calculation values live in the middle of the trace (theirs: latents 3 and 5). Latent 2 is freely deletable.

3. Shuffle control — "does the order matter?" (random permutation of the 6 latents per example)

Shuffling drops accuracy to 48.6% (−3.4 pp) — as damaging as deleting the most important single latent. The trace is consumed as an ordered sequential computation, not a bag of hints.

4. CoT transplant — "can you swap in someone else's thoughts?" Replace the trace (fully, or one position) with the latents recorded from a different GSM8k problem (the example's batch neighbor), and measure both accuracy on the model's own question and how often it instead produces the donor problem's answer.

condition own-answer % donor-answer %
clean 52.0 1.0 (coincidence base rate)
swap all 6 44.2 (−7.8) 1.1 (≈ base rate)
swap latent 3 46.9 (−5.2) 0.9
swap latent 4 47.6 (−4.4) 1.3
swap latent 1 / 2 / 5 / 6 52.2 / 52.5 / 51.1 / 50.6 ~1.0

Swap own vs donor rates

Two findings. (a) A foreign CoT is worse than no CoT (−7.8 pp vs −5.0 pp for deleting the trace): inconsistent thoughts actively interfere rather than merely withholding help. (b) The thoughts do not transplant the answer — the donor-answer rate never rises above the ~1% coincidence base rate. The answer readout integrates the latents with the question (it attends to both), so mismatched thoughts corrupt the computation instead of redirecting it. And once more the damage concentrates in latents 3–4 — the third independent line of evidence (after dropout and the post's logit-lens) that the middle of the trace is where the computation lives.

Internal validity checks: drop latent 6 and keep first 5 are logically the same condition and scored identically (51.10%) via different code paths; the clean donor-answer rate (~1%) gives the coincidental-match base rate for the transplant experiment.

Caveats: effects are moderate (~5 pp) because Qwen3-8B solves much of GSM8k single-pass — the stronger the base, the less headroom any CoT (latent or verbal) has. Ablations replace the readout's inputs on a frozen trace; they don't regenerate downstream latents (a propagated-corruption variant would likely show larger effects).

Latent vs. verbalized on the same problem

The model is co-trained with a verbalized-CoT head, so you can ask it to show its work — and the latent head lands on the same answer without writing anything:

Q. Janet's ducks lay 16 eggs per day. She eats three for breakfast and bakes muffins with four. She sells the remainder at $2 per egg. How much does she make daily at the market?

Verbalized CoT → "The total number of eggs Janet uses for breakfast and muffins is 3 + 4 = 7. Janet sells 16 − 7 = 9 eggs every day at the farmers' market."$18

Latent CoT → (no text; 6 internal vectors) → $18

(Note: because the model was trained answer_only, the verbalized-CoT text is occasionally noisy — it can contain stray tokens — but the reasoning content and final answer are sound. This is why the interesting comparison is accuracy, not text quality: see the per-epoch table above.)


Usage

Load with the CODI code from github.com/cywinski/codi. Two fixes are required in test.py for non-Llama (e.g. Qwen) checkpoints — they are upstream bugs that happen to be masked when the base hidden size is 2048:

  1. CODI.from_pretrained(...) must forward prj_dim/prj_dropout/prj_no_ln (otherwise it builds a default-2048 projection and the state-dict mismatches a 4096 model).
  2. model.generate(...) does not accept ablate_latent= — drop that kwarg for non-ablation eval.
from src.model import CODI   # from github.com/cywinski/codi

model = CODI.from_pretrained(
    checkpoint_path="cds-jb/codi_qwen3-8b-answer_only",
    model_name_or_path="Qwen/Qwen3-8B",
    lora_r=128, lora_alpha=32,
    num_latent=6,
    use_prj=True, prj_dim=4096, prj_no_ln=False, prj_dropout=0.0,
    dtype="bfloat16",
)
# generate(...) with num_latent_iterations=6 runs latent CoT;
# --skip_thinking answers directly; --verbalize_cot emits worded CoT.

Evaluate exactly as in the repo's scripts/test_llama1b.sh, swapping the base model to Qwen/Qwen3-8B, prj_dim 4096, and data_names gsm8k.


Intended use & limitations

  • Intended for research on the interpretability of latent (continuous) reasoning — activation patching, logit-lens, and ablations on the 6 latent vectors, as in the post.
  • Trained on grade-school math (GSM8k-Aug) + CommonsenseQA in answer-only format; it is not a general-purpose chat/instruction model and should not be used as one.
  • The latent path is competitive with worded CoT on these tasks; this does not imply general reasoning ability.

License

Apache-2.0, inherited from the Qwen/Qwen3-8B base model. Trainable weights (LoRA + projection) are released under the same terms. Please also respect the licenses/terms of the training datasets (zen-E/GSM8k-Aug-NL, zen-E/CommonsenseQA-GPT4omini) and cite the CODI paper and the LessWrong post if you use this model.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for cds-jb/codi_qwen3-8b-answer_only

Finetuned
Qwen/Qwen3-8B
Adapter
(1465)
this model

Datasets used to train cds-jb/codi_qwen3-8b-answer_only

Collection including cds-jb/codi_qwen3-8b-answer_only

Paper for cds-jb/codi_qwen3-8b-answer_only