Codeseys's picture
Wave 15: 4-angle multi-model self-critique caught 2 math BLOCKERs in primary loss kernels; fixed against upstream byte-for-byte + GSM8K example + ergonomics
e5add15
# Composer Replication Framework — User Guide
A zero-to-training walkthrough for the open replication of Cursor Composer 2.5.
Pace: an ML engineer who knows GRPO/DPO at a textbook level but has never
opened this repo. Every step references real code, and every kwarg name
listed below has been imported and verified against
`composer_replication/` source.
---
## 1. What is this framework?
A pure-PyTorch replication of the **3-channel composer loss** that powers
agentic-coding model training. One model, one optimizer, three additive
loss terms — composed every step:
```
┌────────────────────────────────────────────┐
│ compose_loss(model, batch) │
└────────────────────────────────────────────┘
┌─────────────────────────────────┼─────────────────────────────────┐
▼ ▼ ▼
┌───────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐
│ Channel 1 (RL) │ │ Channel 2 (SDPO) │ │ Channel 3 (replay) │
│ GRPO │ │ hint-distillation │ │ multi-teacher DPO │
│ → lm_ce stub in │ │ generalized JSD │ │ on (chosen, │
│ verification │ │ student vs teacher │ │ rejected) pairs │
│ harness │ │ (hint-conditioned) │ │ from N teachers │
└─────────┬─────────┘ └──────────┬───────────┘ └──────────┬───────────┘
│ weight = 1 (always on) │ alpha_sdpo beta_replay │
└────────────────┬──────────────┴─────────────────┬──────────────────┘
▼ ▼
total = lm_ce + α·sdpo_jsd + β·trace_replay_dpo
(channel auto-disables if its weight=0 OR its inputs are missing)
```
Two API surfaces, on purpose:
- **Verification harness**`compose_loss(model, batch, ...)` is a free
function (channel 1 = LM cross-entropy, the GRPO limit under deterministic
rewards). Use it for CPU smokes, unit tests, and gradient-flow debugging.
- **Production trainer**`ComposerReplicationTrainer` is a `trl.GRPOTrainer`
subclass that overrides `_compute_loss` with the same 3 channels on top of
TRL's real reward + advantage machinery.
The verification harness is what you'll use for sections 2–6; the production
trainer (and its alternates VeRL/PRIME-RL/Monarch) is section 8.
Source of truth: `composer_replication/loss.py` for `compose_loss`,
`composer_replication/trainer/composer_trainer.py` for the trainer subclass.
---
## 2. Install — which extras to pick
Always start with the core install:
```bash
git clone https://huggingface.co/Codeseys/composer-replication-framework
cd composer-replication-framework
pip install -e .
```
That gets you `torch>=2.0` + `transformers>=4.46` and is enough for the
verification harness on CPU (sections 3, 5, 6).
The seven optional extras are declared in `pyproject.toml` `[project.optional-dependencies]`:
```
Do you need …
┌──────────────────────────┼──────────────────────────┐
▼ ▼ ▼
real teacher calls DiLoCo on production
over OpenRouter? >1 GPU? GRPO training?
│ │ │
│ yes │ yes │ yes
▼ ▼ ▼
pip install -e ".[replay]" pip install -e ".[diloco]" pip install -e ".[train]"
(httpx) (torchft-nightly) (trl, peft, accelerate, datasets)
│ │ │
│ + want CPU-side │ + scaling beyond a │ + want PRIME-RL
│ DPO normalization? │ single host? │ (Recipe C)?
▼ ▼ ▼
pip install -e \".[replaysim]\" pip install -e \".[serverless]\" pip install -e \".[prime-rl]\"
(data-juicer; depends (fsspec, huggingface_hub) (prime-rl>=0.5)
on [replay])
│ + Monarch actor mesh?
pip install -e \".[monarch]\"
(monarch>=0.4.1)
```
Quick decision table:
| Goal | Install |
|-------------------------------------------------------|------------------------------------------|
| CPU smoke / verification (sections 3, 5, 6) | `pip install -e .` |
| Section 4 (replaysim DJNormalizer) | `pip install -e ".[replaysim]"` |
| Section 7 dev loop (LocalProcessExecutor + file://) | `pip install -e ".[serverless]"` |
| Real DiLoCo outer-loop | `pip install -e ".[diloco,serverless]"` |
| Section 8 Recipe A (TRL GRPO) | `pip install -e ".[train]"` |
| Section 8 Recipe C (PRIME-RL) | `pip install -e ".[prime-rl]"` |
| Section 8 Recipe C+D (PRIME-RL + Monarch) | `pip install -e ".[prime-rl,monarch]"` |
| Everything for development | `pip install -e ".[dev]"` |
---
## 3. Quickstart: `examples/qwen_05b_quickstart` end-to-end on CPU
The fastest way to convince yourself the framework works on a real HF model.
~3–5 min wall-clock on CPU, ~1 GB disk for Qwen2.5-0.5B weights.
```bash
pip install -e .
python examples/qwen_05b_quickstart/run.py
```
What the script does (read the source at
`examples/qwen_05b_quickstart/run.py`):
1. Pin RNG (`random.seed(42)`, `torch.manual_seed(42)`) so the per-step
numbers below are reproducible.
2. Load `Qwen/Qwen2.5-0.5B-Instruct` on CPU in fp32, set `model.train()`.
3. `batch = build_batch(tokenizer, device="cpu")` — a real chat-template-formatted
batch with all keys the 3-channel composer might consume.
4. Five backward steps with `compose_loss(model, batch, alpha_sdpo=0.1,
beta_replay=0.05)`; an `AdamW(lr=1e-5)` optimizer; finite-grad check
after each step.
Expected output (transcribed from `examples/qwen_05b_quickstart/run.log`):
```
[quickstart] loading Qwen/Qwen2.5-0.5B-Instruct (CPU, fp32) ...
[quickstart] loaded — 0.494B params
[quickstart] building real chat-template batch ...
[quickstart] running 5 backward steps ...
step 0: total=0.7390 lm_ce=0.7358 sdpo=0.0000 dpo=0.0639 finite=True
step 1: total=0.0379 lm_ce=0.0351 sdpo=0.0000 dpo=0.0563 finite=True
step 2: total=0.0122 lm_ce=0.0110 sdpo=0.0000 dpo=0.0240 finite=True
step 3: total=0.0060 lm_ce=0.0055 sdpo=0.0000 dpo=0.0098 finite=True
step 4: total=0.0031 lm_ce=0.0029 sdpo=0.0000 dpo=0.0044 finite=True
========================================================
Initial loss: 0.7390 → Final loss: 0.0031 → Reduction: 99.6%
Verdict: PASS
========================================================
```
How to read this:
- **`total` collapses by ~99%.** The model successfully memorizes the
single batch — exactly what you expect from an SGD pass on a 0.5B model
with one fixed input. This is a wiring check, not a generalization claim.
- **`lm_ce` carries almost all the magnitude.** Channel 1 (the GRPO stub)
is doing the work — the response tokens are short and have low entropy
under the trained model.
- **`sdpo=0.0000` on every step.** Channel 2 has auto-disabled because the
default `build_batch` does not include `ctx_teacher_input_ids`. Compare
the conditional in `compose_loss`:
```python
if (alpha_sdpo > 0.0
and "ctx_teacher_input_ids" in inputs
and inputs["ctx_teacher_input_ids"].numel() > 0):
```
— channel auto-off if either the weight or the inputs are missing.
- **`dpo > 0` and trending down.** The batch *does* include
`dpo_chosen_input_ids`, `dpo_chosen_response_mask`,
`dpo_chosen_ref_logprobs` (and the rejected counterparts), so channel 3
is live.
- **`finite=True`** — every step's `p.grad` was finite for every parameter.
This is the wiring contract; if it ever flips to `False` the smoke fails.
If you see `Verdict: PASS`, the framework is correctly installed and
gradients flow through all live channels. You are ready for section 4.
---
## 4. Adding the trace-replay channel
The quickstart batch *had* DPO inputs, but they were synthetic — the
`build_batch` helper bakes them in. To get **real** DPO pairs from
multi-teacher disagreement, use the replaysim package.
### 4a. Spin up `replay_trace`
```python
import asyncio
from composer_replication import (
DEFAULT_TEACHERS, replay_trace, extract_dpo_pairs,
)
# Trace must be a list[TraceState]; see composer_replication/teacher_replay.py
# for the exact TypedDict shape. Each state holds a chat-messages prefix +
# the student's actual action at that step.
states = [...] # your frozen agentic trace; see spike 001 for a 50-step example
teacher_actions = asyncio.run(
replay_trace(
states=states,
teachers=DEFAULT_TEACHERS, # claude-opus-4.7 + gpt-5 + deepseek-v4-pro
max_total_usd=10.0, # hard ceiling (spike 001 measured $0.98/trace mean)
)
)
```
The 3 teachers are queried in parallel via OpenRouter
(`OPENROUTER_API_KEY` in env or `~/.hermes/.env`), latencies recorded,
costs tracked.
### 4b. Get `DPOPair`s from disagreement
```python
pairs = extract_dpo_pairs(
states=states,
teacher_actions=teacher_actions,
agreement_threshold=2, # at least 2/3 teachers must agree on the chosen action
)
```
Each pair is a `DPOPair` TypedDict with the exact shape the
`DJNormalizer` and downstream training expects:
```python
class DPOPair(TypedDict):
state_id: str
state_messages: list[dict] # conversation context
chosen: str # teacher-consensus action
rejected: str # student action
n_teachers_agreeing: int
```
(verified in `composer_replication/teacher_replay.py:99–105`).
### 4c. Run `DJNormalizer` with `default.yaml`
```python
from composer_replication.replaysim import DJNormalizer
normalizer = DJNormalizer() # uses recipes/replaysim/default.yaml
normalized = normalizer.normalize(pairs)
# → list[NormalizedDPOPair]
```
`DJNormalizer` shells out to data-juicer's `DefaultExecutor` under the hood
(file-in / file-out contract). The default recipe at
`composer_replication/recipes/replaysim/default.yaml` runs four CPU-only ops
in order:
1. `text_length_filter` (8 ≤ chars ≤ 32000) on `chosen` and `rejected`
2. `words_num_filter` (2 ≤ words ≤ 4096) on both
3. `special_characters_filter` (≤50% non-alpha) on both
4. `document_deduplicator` (per-batch hashing, lowercase, ignore non-character) on `chosen`
Records carry **two parallel shapes** for `chosen`/`rejected`:
- flat strings (`chosen`, `rejected`) → consumed by data-juicer's text_key-based filters
- chat-messages lists (`chosen_messages`, `rejected_messages`) → preserved for chat-aware ops + round-trip
This dual-shape design (verified in the test
`test_dpo_pair_to_dj_record_shape`,
`composer_replication/replaysim/tests/test_replaysim.py:44`) is what
unblocked the data-juicer integration in Wave 14.
### 4d. The 3-record fixture from spike 001
The fixture lives at
`spikes/001-teacher-replay-cost/states.jsonl` (50 states) and
`spikes/001-teacher-replay-cost/results.jsonl` (the teacher responses, all
priced and timed). The first 3 states are:
```jsonl
{"id": "state-000", "task": "Fix the failing test in tests/test_auth.py::test_login_with_email", ...}
{"id": "state-001", "task": "Add rate-limiting middleware to the Flask app", ...}
{"id": "state-002", "task": "Refactor the parse_config function — it's 200 lines and has 3 responsibilities", ...}
```
For each, all 3 teachers answered (claude-opus-4.7, gpt-5, deepseek-v4-pro);
agreement on the `(c)` choice for state-000 and state-001 (read more
files / check schema first) drives a clean DPO pair where the student's
action becomes the rejected. For state-002, all 3 agreed on `(c)` (write
characterization tests first) → another clean pair. These three records
pass through the `DJNormalizer` default recipe unchanged (length, words,
special-char ratios all in bounds; no duplicates).
The full 50-state trace cost **$0.98 mean** end-to-end across all three
teachers (spike 001 verdict). The framework's cost ceiling
(`max_total_usd`) and VOI gating drop this to ~$0.30/trace projected.
### 4e. End-to-end one-liner
```python
from composer_replication.replaysim import replay_and_normalize_trace
teacher_actions, normalized_pairs = await replay_and_normalize_trace(
states=states,
teachers=DEFAULT_TEACHERS,
agreement_threshold=2,
max_total_usd=10.0,
)
```
(`async def`; for sync callers use the sibling `replay_and_normalize_trace_sync`
in `composer_replication.replaysim.normalize`.)
---
## 5. Switching DPO → SimPO: one kwarg
```python
components = compose_loss(
model, batch,
alpha_sdpo=0.1,
beta_replay=0.05,
dpo_variant="simpo", # ← the only line that changes
simpo_beta=2.0, # paper default
simpo_gamma=1.0, # paper default
)
```
The kwarg is verified in `compose_loss`'s signature
(`composer_replication/loss.py:81`):
```python
dpo_variant: Literal["dpo", "simpo"] = "dpo",
```
### What changes in the loss curve
- **Channel 3 input requirements drop.** `compose_loss` no longer reads
`dpo_chosen_ref_logprobs` / `dpo_rejected_ref_logprobs`. Reference-model
VRAM cost goes to zero. (Source: `composer_replication/loss.py:111–113`
and `composer_replication/distillation/simpo.py:23–27`.)
- **Loss scale shifts.** Standard DPO is
`-logsigmoid(β·[(logπ(c) - logπ_ref(c)) - (logπ(r) - logπ_ref(r))])`.
SimPO is `-logsigmoid(β·[avg_logπ(c) - avg_logπ(r)] - γ)` — average
per-token log-prob (length-normalized) and a constant target margin γ.
- **Loss is ≤ DPO loss when chosen/rejected separation is large.** The
unit test `test_simpo_loss_lower_for_better_separation`
(`composer_replication/distillation/tests/test_distillation_losses.py:35`)
verifies that a wider chosen-vs-rejected gap drives lower SimPO loss —
meaning, in practice, SimPO curves are *steeper* than DPO when the
preference signal is strong, and *flatter* when it's weak.
- **No KL-against-reference regularization.** This is both the upside (no
ref-model serving) and the risk (more tendency to drift). Watch for
reward-hacking-style degeneracies if your preference data has noise.
### When to use SimPO
- **GPU-poor.** You can't afford to keep a frozen reference policy resident
alongside the trainer.
- **Cold-start preference data.** Length-normalization (avg_logπ vs sum)
helps when chosen/rejected lengths are wildly imbalanced — common in
agentic traces where the student's failed attempt is short and the
teacher's correction is long.
- **You don't have ref logprobs precomputed.** SimPO needs nothing from
the reference policy.
When to **stay on DPO**: when you need the explicit KL anchor against
a known-good reference (e.g., when training over a long horizon and you
want to bound the drift), or when your preference data is very noisy and
the reference acts as a regularizer.
---
## 6. Adding TAID / Entropy-Aware OPD wrappers
Channel 2 (SDPO/OPSD) can be replaced by **TAID** (Sakana AI,
arXiv:2501.16937) for capacity-gap distillation, or by
**Entropy-Aware OPD** (ICLR 2026 Spotlight) for per-token forward/reverse-KL
gating. Both are wired through `compose_loss`:
```python
sdpo_wrapper: Literal["none", "taid", "entropy_opd"] = "none",
taid_t: float | None = None, # current TAID interpolation coeff
entropy_opd_h_max: float | None = None,
```
(verified at `composer_replication/loss.py:82–93`.)
### TAID (upstream-faithful port)
> **Wave 15 rewrite, breaking change.** The previous in-tree TAID was
> algorithmically different from the paper (it mixed in probability space
> against a frozen step-0 student snapshot and wrapped a symmetric JSD
> criterion). It has been replaced with an upstream-faithful port:
> logit-space mix, current-student-detached anchor, forward-KL criterion.
> Old kwargs `taid_schedule_step`, `taid_total_steps`, `taid_schedule`,
> `taid_alpha_min`, `taid_alpha_max`, plus `inputs["student_init_logits"]` /
> `inputs["student_init_input_ids"]` are **gone**. They have no upstream
> analogue. Use `taid_t` (and optionally `TAIDScheduler`) instead.
The TAID criterion is forward-KL against a logit-space-interpolated target:
```
p_t = softmax( (1 - t) · stop_grad(student_logits) + t · teacher_logits )
L = - mean_token Σ_v p_t(v) · log_softmax(student_logits)(v)
```
where `t ∈ [0, 1]` is the interpolation coefficient. At `t=0` the target
is the (detached) student itself — the loss is the entropy of that
distribution and contributes no gradient to the student. At `t=1` it
reduces to standard forward-KL distillation against the teacher.
The schedule that produces `t` is the **trainer's** responsibility. The
package ships an optional `TAIDScheduler` that mirrors the paper's
adaptive momentum scheme:
```python
from composer_replication.distillation import TAIDScheduler
sched = TAIDScheduler(num_train_steps=10_000) # paper defaults
for step in range(num_train_steps):
components = compose_loss(
model, batch,
sdpo_wrapper="taid",
taid_t=sched.t,
)
components.total.backward(); optimizer.step()
sched.update_t(components.sdpo_jsd.detach(), global_step=step)
```
`TAIDScheduler` defaults match upstream: `t_start=0.4`, `t_end=1.0`,
`alpha=5e-4`, `beta=0.99`. Pass `disable_adaptive=True` to fall back to
the deterministic linear schedule
`t = t_start + progress · (t_end - t_start)`.
If you want a simple fixed schedule (no scheduler), just compute `t`
yourself and pass it in — `compose_loss` validates `taid_t ∈ [0, 1]`.
### Upstream-parity test
`composer_replication/distillation/tests/test_taid_parity.py` skip-imports
the upstream reference at `/tmp/taid-clone` (clone with
`git clone --depth 1 https://github.com/SakanaAI/TAID /tmp/taid-clone`)
and asserts our `taid_loss(student, teacher, mask, t)` matches upstream
`TAID.compute_loss(...)` within `atol=rtol=1e-5` across `t ∈ {0.0, 0.1, 0.4,
0.5, 0.9, 1.0}`. This is the load-bearing parity guarantee.
### Entropy-Aware OPD
Drop-in for channel 2 — gates between forward KL (mode-covering) and
reverse KL (mode-seeking) per token, weighted by the teacher's entropy:
```
L = Σ_t w(t) · KL_fwd_t + (1 - w(t)) · KL_rev_t
w(t) = clamp(H_teacher(t) / h_max, 0, 1)
```
`entropy_opd_h_max=None` (the default) auto-sets `h_max = log(V)` (the
maximum-entropy bound for a vocab-V softmax).
### Boundary-condition unit test (proof of correctness)
The test `test_taid_loss_t_zero_target_matches_detached_student`
(`composer_replication/distillation/tests/test_distillation_losses.py`)
pins TAID's `t=0` invariant — the teacher is *completely* hidden from the
gradient because the target collapses to `softmax(student.detach())`:
```python
def test_taid_loss_t_zero_target_matches_detached_student():
s1 = torch.randn(1, 2, 4, requires_grad=True)
teacher_a = torch.zeros(1, 2, 4); teacher_a[..., 0] = 10.0
teacher_b = torch.zeros(1, 2, 4); teacher_b[..., 3] = 10.0
mask = torch.ones(1, 2)
loss_a = taid_loss(s1, teacher_a, mask, t=0.0)
loss_b = taid_loss(s1, teacher_b, mask, t=0.0)
# Two completely different teachers must give the same loss at t=0.
assert abs(float(loss_a) - float(loss_b)) < 1e-6
```
This is the load-bearing test for TAID: if the `t=0` endpoint ever leaks
teacher signal into the gradient, this test fires and the contract is
broken. The companion test `test_taid_loss_t_one_is_pure_forward_kl`
pins the `t=1` endpoint by hand-computing `-Σ p_teacher · log_q` and
asserting equality.
---
## 7. Going multi-replica with serverless DiLoCo
DiLoCo is the outer-loop optimizer that lets you run N replicas in
parallel, sync them every H inner steps, and tolerate slow links — see
`docs/adrs/ADR-005-serverless-diloco.md` for the design. The framework
gives you three increasingly-distant deployments:
### Step 1 — `LocalProcessExecutor` for development
```python
from composer_replication.diloco.serverless import (
LocalProcessExecutor, ObjectStoreAllReduce,
)
import tempfile
with tempfile.TemporaryDirectory() as td:
rendezvous = ObjectStoreAllReduce(td, rank=0, world_size=4)
executor = LocalProcessExecutor()
handles = executor.launch_replicas(
n_replicas=4,
entrypoint="composer_replication.diloco.serverless.replica_entrypoint",
entrypoint_args={"rendezvous_uri": td, "rank_env": "REPLICA_RANK"},
)
results = executor.collect(handles, timeout=600)
```
`LocalProcessExecutor` (`composer_replication/diloco/serverless/executor.py:160`)
spawns N child processes via `multiprocessing.get_context("spawn")` and
sets `REPLICA_RANK={0..N-1}` in each child's env. It satisfies the
`ServerlessExecutor` Protocol (line 35) — the same Protocol the cloud
adapters implement. So the dev-loop code is byte-identical to the cloud
deploy: only the executor instance changes.
### Step 2 — `ObjectStoreAllReduce` as the rendezvous
```python
# Local file:// for tests
rendezvous = ObjectStoreAllReduce("/tmp/diloco-runs/run42/", rank=0, world_size=4)
# Real S3 (after `pip install -e .[serverless]`)
rendezvous = ObjectStoreAllReduce(
"s3://my-bucket/diloco-runs/run42/",
rank=0, world_size=4,
timeout_s=1800.0,
)
```
The communication pattern is `S3 PutObject + N GetObjects` once per
inner H steps (matches DiLoCo's actual sync cadence,
arXiv:2311.08105 §3.2). For 1B-param bf16, that's ~2 GB / 30 minutes
per replica — well within S3 free-tier. On the inner side the framework
exposes a `MockManager` that drops into the `torchft.Manager` slot, so
you can validate the rendezvous logic before plugging in real torchft
(verified by `test_serverless_diloco_integration.py`).
### Step 3 — point at `ModalExecutor` / `HFJobsExecutor`
```python
# Modal (skeleton at composer_replication/diloco/serverless/modal.py)
from composer_replication.diloco.serverless.modal import ModalExecutor
executor = ModalExecutor(image="modal:python3.11", gpu="A100")
# HuggingFace Jobs (skeleton at composer_replication/diloco/serverless/hf_jobs.py)
from composer_replication.diloco.serverless.hf_jobs import HFJobsExecutor
executor = HFJobsExecutor(hardware="a10g-large")
# Same Protocol — same launch_replicas / poll / collect calls as Local
handles = executor.launch_replicas(n_replicas=4, ...)
```
Both adapters check their cloud SDK at `__init__` time (not at module
import) so they don't break the package if you don't have `modal` or
`huggingface_hub` installed. Production maturity: dev-ready for cloud
trial; per ADR-005, full HA-cluster fan-out lives in v0.2+.
---
## 8. Picking an RL backend
Four canonical recipes, each tied to an upstream framework. Source:
`docs/INTEGRATION_ARCHITECTURE.md` Recipes A–D plus
`docs/adrs/ADR-006-rl-frameworks.md`.
### Recipe A — TRL `GRPOTrainer` subclass
`ComposerReplicationTrainer` is a `trl.GRPOTrainer` subclass that
overrides `_compute_loss(model, inputs)` to compose the same 3 channels
on top of TRL's real reward + advantage machinery. Install:
`pip install -e ".[train]"`. Then:
```python
from composer_replication import ComposerReplicationTrainer
trainer = ComposerReplicationTrainer(model=..., reward_funcs=[...], ...)
trainer.train()
```
**When to use it:** This is the v0.0/v0.1 recommended path. You want
real GRPO with rewards, you have HF model + dataset + (mostly) standard
GRPO infrastructure, and you don't need >100B-param scale. TRL is
mature, the trainer is a small subclass, and the same `compose_loss`
math runs in both the verification harness and in production with no
re-coding.
→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe A: TRL `GRPOTrainer`
subclass" (line 205).
### Recipe B — VeRL custom `adv_estimator` + DataProto extension
VeRL replaces TRL's reward+advantage machinery with a Ray-driven actor
graph that's specifically optimized for distributed RL training of
large LMs. Composition with the framework: extend `DataProto` with the
hint-conditioned columns + DPO pair fields, register a custom
`adv_estimator` that calls the same `compose_loss` body.
**When to use it:** You're past 7B-param, you have multi-host setup
(Ray cluster), and TRL's single-process trainer is the bottleneck. VeRL
is the recommended scale path for v0.2+. Trade-off: the extension surface
is larger and the docs are sparser than TRL's.
→ See `docs/INTEGRATION_ARCHITECTURE.md` § "Recipe B: VeRL custom
`adv_estimator`" (line 289).
### Recipe C — PRIME-RL with DPPO-clip details
`composer_replication/recipes/prime_rl/composer_loss.py` ships a
`loss_fn(inputs, *, alpha_sdpo=0.0, beta_dpo=0.0, dppo_mask_high=0.2,
dppo_mask_low=0.2, adv_tau=1.0, kl_tau=1e-3)` adapter that maps
PRIME-RL's `LossInputs` struct (1-D per-sample tensors:
`trainer_logprobs`, `inference_logprobs`, `teacher_logprobs`,
`advantages`, `loss_mask`) into our 3-channel composition.
The DPPO+KL bit is what makes PRIME-RL distinctive — and we mirror
PRIME-RL's upstream `default_loss_fn` exactly (verified against
`prime_rl/trainer/rl/loss.py` lines 116-165):
```python
log_ir = trainer_logprobs - inference_logprobs
ir = exp(log_ir) # importance ratio
probs_diff = exp(trainer_logprobs) - exp(inference_logprobs)
invalid_high = probs_diff > dppo_mask_high # for positive-advantage tokens
invalid_low = probs_diff < -dppo_mask_low # for negative-advantage tokens
invalid = where(advantages > 0, invalid_high, invalid_low)
keep = loss_mask & ~invalid
pg_loss = keep * (adv_tau * advantages) * ir
kl_loss = loss_mask * log_ir**2
loss = (-pg_loss + kl_tau * kl_loss).sum()
```
Three things to remember: (1) the mask gate is on **probability-space**
`exp(trainer_lp) - exp(inference_lp)`, not on the log-ratio; (2) the
policy-gradient term is multiplied by the importance ratio
`exp(trainer_lp - inference_lp)`, not by `trainer_lp` directly (proper
IS-corrected gradient, not REINFORCE); (3) the mask is **conditioned on
the sign of the advantage** — positive-advantage tokens are dropped on
the upper bound, negative-advantage tokens on the lower. Defaults
`dppo_mask_high=dppo_mask_low=0.2` and `adv_tau=1.0, kl_tau=1e-3`
match PRIME-RL's `DefaultLossConfig` (all fields `Field(..., ge=0)`).
SDPO (channel 2) is gated `NotImplementedError` in v0 because PRIME-RL
exposes log-probs, not full vocab logits; channel 3 (trace-replay DPO)
emits a warning if `beta_dpo != 0`.
**When to use it:** You're already in the PRIME-Intellect / decentralized
training universe, you want INTELLECT-style scaling on a long-horizon
training run, and DPPO masking is part of your existing reward+advantage
recipe. Install: `pip install -e ".[prime-rl]"`.
→ See `composer_replication/recipes/prime_rl/prime_rl_recipe.md` and
`docs/INTEGRATION_ARCHITECTURE.md` § "Recipe C: TorchForge + Monarch"
(line 356).
### Recipe C+D — Monarch as actor mesh
Monarch (the actor framework underpinning TorchForge) hosts the
trainer/generator/manager actors in a topology-aware mesh. The framework
ships *skeleton* actor definitions at
`composer_replication/recipes/monarch/actors.py` (TrainerActor,
GeneratorActor) and a layout doc at `monarch_actor_layout.md`. v0
intentionally *fails fast* if you try to instantiate them
(`raise NotImplementedError("v0 skeleton; deferred to v0.2 per ADR-006")`)
because the upstream Monarch API is still moving.
**When to use it:** Reference-pattern reading only in v0. Decision point
is v0.2 once the upstream actor API stabilizes. Treat the skeleton as
shape-of-the-final-answer documentation, not as a production target.
Install: `pip install -e ".[prime-rl,monarch]"` for the full surface.
→ See `composer_replication/recipes/monarch/monarch_actor_layout.md`
and `docs/adrs/ADR-006-rl-frameworks.md`.
---
## Common pitfalls + what tests catch them
The framework's 115-test suite (post-Wave-15) is structured so each pitfall has a
specific test-file home. If you hit one of these in production, the
corresponding test is your fastest reproducer.
| Pitfall | Symptom | Test file (catches it) |
|-----------------------------------------------------------------------------------------------|------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------|
| Forgetting `taid_schedule_step` when `sdpo_wrapper="taid"` | `ValueError` at first step | `composer_replication/tests/test_compose_loss_integration.py` (kwarg validation) |
| TAID α=0 endpoint leaks teacher signal | Teacher swap changes the loss when α should be 0 | `test_taid_loss_alpha_zero_ignores_teacher` in `composer_replication/distillation/tests/test_distillation_losses.py:153` |
| TAID α=1 endpoint differs from plain SDPO | Bit-difference vs reference SDPO at the schedule end | `test_taid_blended_logits_endpoints` in `composer_replication/distillation/tests/test_distillation_losses.py:115` |
| SimPO loss not differentiable through the loss-of-sigmoid path | `chosen.grad is None` after backward | `test_simpo_loss_differentiable` in `composer_replication/distillation/tests/test_distillation_losses.py:50` |
| SimPO shape-mismatch slips through silently | Broadcasting bug, NaN downstream | `test_simpo_loss_shape_mismatch_raises` in `composer_replication/distillation/tests/test_distillation_losses.py:61` |
| Entropy-OPD failing to zero out when distributions match | Loss > 0 when student≡teacher | `test_entropy_aware_opd_zero_when_distributions_match` in `composer_replication/distillation/tests/test_distillation_losses.py:217` |
| Entropy of one-hot ≠ 0 / entropy of uniform ≠ log(V) | Wrong gating weights w(t) | `test_teacher_entropy_one_hot_is_zero` and `test_teacher_entropy_uniform_is_log_v` in `composer_replication/distillation/tests/test_distillation_losses.py:175,183` |
| `DJNormalizer` records missing the chat-messages shape | Filters silently no-op or crash | `test_dpo_pair_to_dj_record_shape` in `composer_replication/replaysim/tests/test_replaysim.py:44` |
| `DJNormalizer` round-trip drops `state_messages` / metadata | Lost provenance | `test_dj_record_to_normalized_roundtrip` and `test_dj_record_to_normalized_preserves_state_messages` in `composer_replication/replaysim/tests/test_replaysim.py` |
| `ObjectStoreAllReduce` accepts an out-of-bounds rank | Silent corruption of the all-reduce average | `test_object_store_allreduce_init_validates_rank` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:31` |
| `ObjectStoreAllReduce(world_size=1)` doesn't passthrough cleanly | False all-reduce on single replica | `test_object_store_allreduce_world_size_1_passthrough` in `composer_replication/diloco/serverless/tests/test_serverless_local.py:46` |
| `LocalProcessExecutor` doesn't propagate child failures to `collect()` | Silent test pass when a replica crashed | `test_serverless_diloco_integration.py` in `composer_replication/diloco/serverless/tests/` |
| PRIME-RL adapter accidentally uses `(B, T)` shape instead of per-sample `(seq,)` | Shape mismatch / wrong reduction | `composer_replication/recipes/prime_rl/tests/test_composer_loss.py` (10 tests covering shape and DPPO mask edges) |
| Channel 2/3 fails to auto-disable when its inputs are absent | Crash on missing key, not graceful skip | `composer_replication/tests/test_compose_loss_integration.py` (`(a) defaults reproduce existing compose_loss output bit-exact`) |
Run the full suite with `pytest` from the repo root.
---
**File path:** `/mnt/e/CS/HF/composer-replication-framework/docs/USER_GUIDE.md`