Add GRPO RL training script with execution reward function
Browse files- train_grpo.py +139 -0
train_grpo.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Reference GRPO training script for agentic coding RL.
|
| 4 |
+
Uses execution-verified pass_rate as the reward signal.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python train_grpo.py \
|
| 8 |
+
--model ./nexus-coder-sft \
|
| 9 |
+
--output_dir ./nexus-coder-rl
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import subprocess
|
| 15 |
+
import tempfile
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 18 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 19 |
+
|
| 20 |
+
# ---------------------------------------------------------------------------
|
| 21 |
+
# Execution reward function (simplified — adapt to your sandbox)
|
| 22 |
+
# ---------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
def execution_reward_fn(completions: list, **kwargs) -> list:
|
| 25 |
+
"""
|
| 26 |
+
Reward function for GRPO.
|
| 27 |
+
Expects completions that contain bash commands or patches.
|
| 28 |
+
In a real setup, replay commands in a Docker sandbox and return pass_rate.
|
| 29 |
+
"""
|
| 30 |
+
rewards = []
|
| 31 |
+
for completion in completions:
|
| 32 |
+
try:
|
| 33 |
+
# Look for ```bash ... ``` blocks
|
| 34 |
+
if "```bash" in completion:
|
| 35 |
+
cmd = completion.split("```bash")[-1].split("```")[0].strip()
|
| 36 |
+
result = subprocess.run(cmd, shell=True, capture_output=True, timeout=30, cwd=tempfile.gettempdir())
|
| 37 |
+
reward = 1.0 if result.returncode == 0 else 0.0
|
| 38 |
+
else:
|
| 39 |
+
reward = 0.0
|
| 40 |
+
except Exception:
|
| 41 |
+
reward = 0.0
|
| 42 |
+
rewards.append(reward)
|
| 43 |
+
return rewards
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Dataset prep
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
def load_rl_dataset():
|
| 51 |
+
"""Load Nemotron RL SWE pivot dataset and normalize prompts."""
|
| 52 |
+
ds = load_dataset("nvidia/Nemotron-RL-Agentic-SWE-Pivot-v1", split="train")
|
| 53 |
+
def normalize(example):
|
| 54 |
+
params = example.get("responses_create_params", {})
|
| 55 |
+
inp = params.get("input", [])
|
| 56 |
+
if len(inp) > 0 and isinstance(inp[0], dict):
|
| 57 |
+
system = inp[0].get("content", "")
|
| 58 |
+
ref = example.get("ref_message", {})
|
| 59 |
+
reasoning = ref.get("reasoning_content", "") if isinstance(ref, dict) else ""
|
| 60 |
+
return {
|
| 61 |
+
"prompt": system,
|
| 62 |
+
"completion": reasoning,
|
| 63 |
+
}
|
| 64 |
+
return {"prompt": "", "completion": ""}
|
| 65 |
+
ds = ds.map(normalize, remove_columns=ds.column_names)
|
| 66 |
+
ds = ds.filter(lambda x: len(x["prompt"]) > 50)
|
| 67 |
+
return ds
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# Main
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
parser = argparse.ArgumentParser()
|
| 76 |
+
parser.add_argument("--model", required=True, help="Path to SFT checkpoint")
|
| 77 |
+
parser.add_argument("--output_dir", default="./nexus-coder-rl")
|
| 78 |
+
parser.add_argument("--epochs", type=int, default=1)
|
| 79 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 80 |
+
parser.add_argument("--grad_accum", type=int, default=16)
|
| 81 |
+
parser.add_argument("--lr", type=float, default=1e-6)
|
| 82 |
+
parser.add_argument("--max_prompt_length", type=int, default=4096)
|
| 83 |
+
parser.add_argument("--max_completion_length", type=int, default=12288)
|
| 84 |
+
parser.add_argument("--num_generations", type=int, default=8)
|
| 85 |
+
parser.add_argument("--hub_model_id", default=None)
|
| 86 |
+
args = parser.parse_args()
|
| 87 |
+
|
| 88 |
+
print("[1/4] Loading SFT model and tokenizer...")
|
| 89 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 90 |
+
args.model,
|
| 91 |
+
torch_dtype="bfloat16",
|
| 92 |
+
device_map="auto",
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
)
|
| 95 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 96 |
+
if tokenizer.pad_token is None:
|
| 97 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 98 |
+
|
| 99 |
+
print("[2/4] Loading RL dataset...")
|
| 100 |
+
dataset = load_rl_dataset()
|
| 101 |
+
print(f" RL dataset size: {len(dataset)} examples")
|
| 102 |
+
|
| 103 |
+
print("[3/4] Configuring GRPO trainer...")
|
| 104 |
+
grpo_config = GRPOConfig(
|
| 105 |
+
output_dir=args.output_dir,
|
| 106 |
+
num_train_epochs=args.epochs,
|
| 107 |
+
per_device_train_batch_size=args.batch_size,
|
| 108 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 109 |
+
learning_rate=args.lr,
|
| 110 |
+
max_prompt_length=args.max_prompt_length,
|
| 111 |
+
max_completion_length=args.max_completion_length,
|
| 112 |
+
num_generations=args.num_generations,
|
| 113 |
+
temperature=0.7,
|
| 114 |
+
logging_strategy="steps",
|
| 115 |
+
logging_steps=5,
|
| 116 |
+
logging_first_step=True,
|
| 117 |
+
bf16=True,
|
| 118 |
+
gradient_checkpointing=True,
|
| 119 |
+
disable_tqdm=True,
|
| 120 |
+
push_to_hub=args.hub_model_id is not None,
|
| 121 |
+
hub_model_id=args.hub_model_id,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
trainer = GRPOTrainer(
|
| 125 |
+
model=model,
|
| 126 |
+
reward_funcs=[execution_reward_fn],
|
| 127 |
+
args=grpo_config,
|
| 128 |
+
train_dataset=dataset,
|
| 129 |
+
processing_class=tokenizer,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
print("[4/4] Starting GRPO training...")
|
| 133 |
+
trainer.train()
|
| 134 |
+
trainer.save_model(args.output_dir)
|
| 135 |
+
print(f"Done. Model saved to {args.output_dir}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|