title: GRPO Training Collapse Analysis
description: >-
Root-cause analysis of GRPO training collapse on Qwen3-1.7B caused by extra
kwargs in tool calls and advantage collapse
doc_type: exploration
GRPO Training Collapse Analysis
What happened
After SFT warmup, GRPO training on Qwen3-1.7B collapsed within the first 30 steps. The model degenerated into passing extra null arguments to every tool call ("sql": null, "table_name": "...", "value": null), triggering unexpected keyword argument errors on every rollout. It never recovered across 351 steps (~8 hours on L4).
Timeline
| Step | Reward | What the model does |
|---|---|---|
| 10 | -1.25 | First call has extra args, gets error, loops with Episode is over |
| 20 | 0.01 | Occasionally correct describe, but passes wrong args to answer |
| 30 | 0.00 | Stuck: describe(sql=null, table_name="concert") infinite loop |
| 40-351 | 0.00 | Complete collapse: every rollout is identical error loops |
Why it collapsed
1. SFT taught wrong argument patterns
The SFT examples show describe(table_name=...) correctly, but the base Qwen3-1.7B model has a strong prior from pretraining to include all available parameter names in every call. The 353-turn SFT warmup (2 epochs, batch=2) wasn't enough to override this for all 4 tools.
2. Extra kwargs cause hard failures, not soft degradation
When the model passes describe(sql=null, table_name="flights"), TRL dispatches SQLEnvTRL.describe(sql=None, table_name="flights") which raises TypeError: unexpected keyword argument 'sql'. This is a hard wall — the model gets zero useful information back, just an error string it can't learn from.
3. GRPO advantage collapse
With 6 generations per question:
- All 6 rollouts pass the same extra args → all get reward 0.0
- Advantage = 0.0 for every sample → zero gradient signal
- The model has no way to discover that dropping the extra args would work
- Loss oscillates near 0 throughout training
4. No recovery mechanism
Once the model enters the error loop:
- Error messages say "unexpected keyword argument 'sql'" but don't say "try calling with only table_name"
- The model retries the same call pattern endlessly
- Post-episode penalty accumulates negative reward (-1.25 at step 10) but doesn't help because ALL rollouts are equally bad
- No positive examples exist in any rollout group to provide advantage signal
The core problem: kwargs rejection vs. kwargs tolerance
The TRL adapter methods have strict signatures:
def describe(self, table_name: str) -> str:
def query(self, sql: str) -> str:
def answer(self, value: str) -> str:
When the model generates {"table_name": "flights", "sql": null}, Python raises TypeError before the method body executes. The model never gets a schema response, so it has no path to success.
Fix: Accept and ignore extra kwargs
The simplest fix is to make the tool methods tolerant of extra arguments:
def describe(self, table_name: str, **kwargs) -> str:
def query(self, sql: str, **kwargs) -> str:
def answer(self, value: str, **kwargs) -> str:
def sample(self, table_name: str, **kwargs) -> str:
This means describe(sql=null, table_name="flights") would work — it would ignore sql and return the schema. The model gets useful feedback, can write SQL, and has a path to positive reward. GRPO then has signal to learn that the extra args are unnecessary.
Why this is the right approach:
- Small models (1.7B) lack the capacity to perfectly learn function signatures from tool definitions alone
- The tool definitions in
<tools>XML clearly state which params are required — the model will converge toward correct signatures over time via reward signal - Strict rejection creates an unrecoverable dead end; tolerance creates a learning gradient
- This matches how real APIs work — most accept and ignore unexpected fields
Other contributing factors
SFT quality issues
- SFT was only 100 questions x ~3.5 turns = 347 examples
- Only 2 epochs at batch=2 (total 347 steps)
- The model learned tool-call format but not strict argument isolation
- Need: more SFT data or more epochs on existing data
Missing KL penalty
- No KL divergence penalty against the SFT reference model
- GRPO updated the policy freely, drifting away from the SFT distribution
- A KL penalty (beta=0.01-0.05) would have anchored the model near the working SFT baseline
Learning rate may be too high
- Default TRL learning rate (5e-7 or 1e-6) may be too aggressive for 1.7B
- Lower LR (1e-7) would make smaller updates, reducing drift risk
Recommended fixes (priority order)
1. Add **kwargs to all tool methods (critical)
Prevents the hard wall. Model can still learn correct signatures from reward signal.
2. Increase SFT warmup
- 4 epochs instead of 2
- Or increase SFT data from 100 to 200 questions
- Verify post-SFT that the model generates correct single-arg calls
3. Add KL penalty
GRPOConfig(
...,
beta=0.04, # KL penalty against SFT reference
)
Prevents policy from drifting too far from the working SFT baseline.
4. Lower GRPO learning rate
From default to 1e-7 or 5e-8.
Verification checklist
Before running GRPO again:
- Post-SFT format check shows
describe(table_name="X")with NO extra args - Tool methods accept
**kwargsso extra args don't crash - First 10 GRPO steps show at least some reward > 0
- Reward doesn't flatline at 0.0 by step 30