nexus-coder-alpha / train_grpo.py
olanigan's picture
Add GRPO RL training script with execution reward function
33bac25 verified
#!/usr/bin/env python3
"""
Reference GRPO training script for agentic coding RL.
Uses execution-verified pass_rate as the reward signal.
Usage:
python train_grpo.py \
--model ./nexus-coder-sft \
--output_dir ./nexus-coder-rl
"""
import argparse
import json
import subprocess
import tempfile
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
# ---------------------------------------------------------------------------
# Execution reward function (simplified — adapt to your sandbox)
# ---------------------------------------------------------------------------
def execution_reward_fn(completions: list, **kwargs) -> list:
"""
Reward function for GRPO.
Expects completions that contain bash commands or patches.
In a real setup, replay commands in a Docker sandbox and return pass_rate.
"""
rewards = []
for completion in completions:
try:
# Look for ```bash ... ``` blocks
if "```bash" in completion:
cmd = completion.split("```bash")[-1].split("```")[0].strip()
result = subprocess.run(cmd, shell=True, capture_output=True, timeout=30, cwd=tempfile.gettempdir())
reward = 1.0 if result.returncode == 0 else 0.0
else:
reward = 0.0
except Exception:
reward = 0.0
rewards.append(reward)
return rewards
# ---------------------------------------------------------------------------
# Dataset prep
# ---------------------------------------------------------------------------
def load_rl_dataset():
"""Load Nemotron RL SWE pivot dataset and normalize prompts."""
ds = load_dataset("nvidia/Nemotron-RL-Agentic-SWE-Pivot-v1", split="train")
def normalize(example):
params = example.get("responses_create_params", {})
inp = params.get("input", [])
if len(inp) > 0 and isinstance(inp[0], dict):
system = inp[0].get("content", "")
ref = example.get("ref_message", {})
reasoning = ref.get("reasoning_content", "") if isinstance(ref, dict) else ""
return {
"prompt": system,
"completion": reasoning,
}
return {"prompt": "", "completion": ""}
ds = ds.map(normalize, remove_columns=ds.column_names)
ds = ds.filter(lambda x: len(x["prompt"]) > 50)
return ds
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path to SFT checkpoint")
parser.add_argument("--output_dir", default="./nexus-coder-rl")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--grad_accum", type=int, default=16)
parser.add_argument("--lr", type=float, default=1e-6)
parser.add_argument("--max_prompt_length", type=int, default=4096)
parser.add_argument("--max_completion_length", type=int, default=12288)
parser.add_argument("--num_generations", type=int, default=8)
parser.add_argument("--hub_model_id", default=None)
args = parser.parse_args()
print("[1/4] Loading SFT model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype="bfloat16",
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("[2/4] Loading RL dataset...")
dataset = load_rl_dataset()
print(f" RL dataset size: {len(dataset)} examples")
print("[3/4] Configuring GRPO trainer...")
grpo_config = GRPOConfig(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
max_prompt_length=args.max_prompt_length,
max_completion_length=args.max_completion_length,
num_generations=args.num_generations,
temperature=0.7,
logging_strategy="steps",
logging_steps=5,
logging_first_step=True,
bf16=True,
gradient_checkpointing=True,
disable_tqdm=True,
push_to_hub=args.hub_model_id is not None,
hub_model_id=args.hub_model_id,
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[execution_reward_fn],
args=grpo_config,
train_dataset=dataset,
processing_class=tokenizer,
)
print("[4/4] Starting GRPO training...")
trainer.train()
trainer.save_model(args.output_dir)
print(f"Done. Model saved to {args.output_dir}")
if __name__ == "__main__":
main()