Sudoku GPT-2 Curriculum β€” Multi-Output Cell Policy

GPT-2 models trained to solve 9Γ—9 Sudoku as a per-cell set-prediction policy: given a grid and a target empty cell, the model emits the JSON set of values consistent with the puzzle constraints, e.g. {"values":[3,7,9]}. Solving a puzzle = repeatedly querying the policy cell-by-cell.

This repository holds the code, training logs, and model checkpoints for the GPT-2 line of experiments (companion to the Qwen latent/backtracking work in Avra98/sudoku-latent-backtracking).

Task & curriculum

The supervision target for a cell is its stage-i consistent value set:

Stage Target = values consistent under …
1 direct row/column/box constraints (depth-1 legal values)
2 + 2-step lookahead consistency
3 + 3-step lookahead consistency

Targets are generated on the fly from each puzzle by --stage_i, so each stage trains on its own (and only its own) consistency set; a cell legitimately has multiple valid values.

Method

  • Full fine-tuning (--full_finetune): all backbone parameters are trained (small models lack the capacity to learn this through low-rank adapters), and checkpoints store the full model.
  • Local-constraint prompting (--local_constraint_prompt): the prompt appends the explicit per-cell row/column/box constraint block. This is essential for GPT-2 β€” without it the models collapse to constant predictions; with it, GPT-2-large reaches perfect Stage-1 SFT.
  • Latent thought token (--num_cot_tokens k, recurrent_hidden): an optional recurrent latent "thinking" step, introduced after a warm-up.
  • SFT β†’ GRPO per stage. GRPO uses a set-prediction reward (super-linear reward for correct values, penalties for wrong/missing/malformed, exact-set bonus) and the same local-constraint prompt as SFT.

Current experiments (Stage 1, gpt2-large 774M)

Two variants, each SFT β†’ GRPO, full-FT + local constraint:

  • baseline (k=0): no latent token. SFT (lc_large) reaches perfect Stage-1 solve; GRPO refines from there.
  • latent (k=1): recurrent latent token switched on after warm-up.

Earlier size/format sweep (runs/local_constraint/): plain_* = no local constraint (collapses), lc_* = local constraint, lat_lc_* = local constraint + latent. Finding: local constraint is the decisive ingredient; the latent token adds optimization difficulty for the smallest models.

Repository layout

cc/                       # training/eval code (this experiment's code copy)
  latent_multi_output_cell_policy/   # SFT + GRPO trainers
  multi_output_cell_policy/          # prompt builder, rewards, shared policy
runs/
  local_constraint/<variant>/checkpoint-step-04000/   # final SFT models (full model)
  stage1_large/<baseline|latent>/{sft,grpo}/          # current Stage-1 runs
logs/                     # training logs
code/                     # launch + Hugging Face push scripts

Checkpoints are full models (model.safetensors) β€” load directly with AutoModelForCausalLM.from_pretrained(<checkpoint dir>).

Reproduce

# Stage-1 gpt2-large, baseline (k=0) or latent (k=1): SFT -> GRPO
bash code/stage1_large.sh baseline 0
bash code/stage1_large.sh latent   1

Checkpoints, logs, and code are pushed here automatically every ~15 min during training.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Avra98/sudoku-gpt2-curriculum

Finetuned
(132)
this model