sql_env / docs /exploration /grpo-collapse-analysis.md
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
metadata
title: GRPO Training Collapse Analysis
description: >-
  Root-cause analysis of GRPO training collapse on Qwen3-1.7B caused by extra
  kwargs in tool calls and advantage collapse
doc_type: exploration

GRPO Training Collapse Analysis

What happened

After SFT warmup, GRPO training on Qwen3-1.7B collapsed within the first 30 steps. The model degenerated into passing extra null arguments to every tool call ("sql": null, "table_name": "...", "value": null), triggering unexpected keyword argument errors on every rollout. It never recovered across 351 steps (~8 hours on L4).

Timeline

Step Reward What the model does
10 -1.25 First call has extra args, gets error, loops with Episode is over
20 0.01 Occasionally correct describe, but passes wrong args to answer
30 0.00 Stuck: describe(sql=null, table_name="concert") infinite loop
40-351 0.00 Complete collapse: every rollout is identical error loops

Why it collapsed

1. SFT taught wrong argument patterns

The SFT examples show describe(table_name=...) correctly, but the base Qwen3-1.7B model has a strong prior from pretraining to include all available parameter names in every call. The 353-turn SFT warmup (2 epochs, batch=2) wasn't enough to override this for all 4 tools.

2. Extra kwargs cause hard failures, not soft degradation

When the model passes describe(sql=null, table_name="flights"), TRL dispatches SQLEnvTRL.describe(sql=None, table_name="flights") which raises TypeError: unexpected keyword argument 'sql'. This is a hard wall — the model gets zero useful information back, just an error string it can't learn from.

3. GRPO advantage collapse

With 6 generations per question:

  • All 6 rollouts pass the same extra args → all get reward 0.0
  • Advantage = 0.0 for every sample → zero gradient signal
  • The model has no way to discover that dropping the extra args would work
  • Loss oscillates near 0 throughout training

4. No recovery mechanism

Once the model enters the error loop:

  • Error messages say "unexpected keyword argument 'sql'" but don't say "try calling with only table_name"
  • The model retries the same call pattern endlessly
  • Post-episode penalty accumulates negative reward (-1.25 at step 10) but doesn't help because ALL rollouts are equally bad
  • No positive examples exist in any rollout group to provide advantage signal

The core problem: kwargs rejection vs. kwargs tolerance

The TRL adapter methods have strict signatures:

def describe(self, table_name: str) -> str:
def query(self, sql: str) -> str:
def answer(self, value: str) -> str:

When the model generates {"table_name": "flights", "sql": null}, Python raises TypeError before the method body executes. The model never gets a schema response, so it has no path to success.

Fix: Accept and ignore extra kwargs

The simplest fix is to make the tool methods tolerant of extra arguments:

def describe(self, table_name: str, **kwargs) -> str:
def query(self, sql: str, **kwargs) -> str:
def answer(self, value: str, **kwargs) -> str:
def sample(self, table_name: str, **kwargs) -> str:

This means describe(sql=null, table_name="flights") would work — it would ignore sql and return the schema. The model gets useful feedback, can write SQL, and has a path to positive reward. GRPO then has signal to learn that the extra args are unnecessary.

Why this is the right approach:

  • Small models (1.7B) lack the capacity to perfectly learn function signatures from tool definitions alone
  • The tool definitions in <tools> XML clearly state which params are required — the model will converge toward correct signatures over time via reward signal
  • Strict rejection creates an unrecoverable dead end; tolerance creates a learning gradient
  • This matches how real APIs work — most accept and ignore unexpected fields

Other contributing factors

SFT quality issues

  • SFT was only 100 questions x ~3.5 turns = 347 examples
  • Only 2 epochs at batch=2 (total 347 steps)
  • The model learned tool-call format but not strict argument isolation
  • Need: more SFT data or more epochs on existing data

Missing KL penalty

  • No KL divergence penalty against the SFT reference model
  • GRPO updated the policy freely, drifting away from the SFT distribution
  • A KL penalty (beta=0.01-0.05) would have anchored the model near the working SFT baseline

Learning rate may be too high

  • Default TRL learning rate (5e-7 or 1e-6) may be too aggressive for 1.7B
  • Lower LR (1e-7) would make smaller updates, reducing drift risk

Recommended fixes (priority order)

1. Add **kwargs to all tool methods (critical)

Prevents the hard wall. Model can still learn correct signatures from reward signal.

2. Increase SFT warmup

  • 4 epochs instead of 2
  • Or increase SFT data from 100 to 200 questions
  • Verify post-SFT that the model generates correct single-arg calls

3. Add KL penalty

GRPOConfig(
    ...,
    beta=0.04,  # KL penalty against SFT reference
)

Prevents policy from drifting too far from the working SFT baseline.

4. Lower GRPO learning rate

From default to 1e-7 or 5e-8.

Verification checklist

Before running GRPO again:

  • Post-SFT format check shows describe(table_name="X") with NO extra args
  • Tool methods accept **kwargs so extra args don't crash
  • First 10 GRPO steps show at least some reward > 0
  • Reward doesn't flatline at 0.0 by step 30