| |
| """ |
| Reference GRPO training script for agentic coding RL. |
| Uses execution-verified pass_rate as the reward signal. |
| |
| Usage: |
| python train_grpo.py \ |
| --model ./nexus-coder-sft \ |
| --output_dir ./nexus-coder-rl |
| """ |
|
|
| import argparse |
| import json |
| import subprocess |
| import tempfile |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from trl import GRPOTrainer, GRPOConfig |
|
|
| |
| |
| |
|
|
| def execution_reward_fn(completions: list, **kwargs) -> list: |
| """ |
| Reward function for GRPO. |
| Expects completions that contain bash commands or patches. |
| In a real setup, replay commands in a Docker sandbox and return pass_rate. |
| """ |
| rewards = [] |
| for completion in completions: |
| try: |
| |
| if "```bash" in completion: |
| cmd = completion.split("```bash")[-1].split("```")[0].strip() |
| result = subprocess.run(cmd, shell=True, capture_output=True, timeout=30, cwd=tempfile.gettempdir()) |
| reward = 1.0 if result.returncode == 0 else 0.0 |
| else: |
| reward = 0.0 |
| except Exception: |
| reward = 0.0 |
| rewards.append(reward) |
| return rewards |
|
|
|
|
| |
| |
| |
|
|
| def load_rl_dataset(): |
| """Load Nemotron RL SWE pivot dataset and normalize prompts.""" |
| ds = load_dataset("nvidia/Nemotron-RL-Agentic-SWE-Pivot-v1", split="train") |
| def normalize(example): |
| params = example.get("responses_create_params", {}) |
| inp = params.get("input", []) |
| if len(inp) > 0 and isinstance(inp[0], dict): |
| system = inp[0].get("content", "") |
| ref = example.get("ref_message", {}) |
| reasoning = ref.get("reasoning_content", "") if isinstance(ref, dict) else "" |
| return { |
| "prompt": system, |
| "completion": reasoning, |
| } |
| return {"prompt": "", "completion": ""} |
| ds = ds.map(normalize, remove_columns=ds.column_names) |
| ds = ds.filter(lambda x: len(x["prompt"]) > 50) |
| return ds |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model", required=True, help="Path to SFT checkpoint") |
| parser.add_argument("--output_dir", default="./nexus-coder-rl") |
| parser.add_argument("--epochs", type=int, default=1) |
| parser.add_argument("--batch_size", type=int, default=1) |
| parser.add_argument("--grad_accum", type=int, default=16) |
| parser.add_argument("--lr", type=float, default=1e-6) |
| parser.add_argument("--max_prompt_length", type=int, default=4096) |
| parser.add_argument("--max_completion_length", type=int, default=12288) |
| parser.add_argument("--num_generations", type=int, default=8) |
| parser.add_argument("--hub_model_id", default=None) |
| args = parser.parse_args() |
|
|
| print("[1/4] Loading SFT model and tokenizer...") |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model, |
| torch_dtype="bfloat16", |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| print("[2/4] Loading RL dataset...") |
| dataset = load_rl_dataset() |
| print(f" RL dataset size: {len(dataset)} examples") |
|
|
| print("[3/4] Configuring GRPO trainer...") |
| grpo_config = GRPOConfig( |
| output_dir=args.output_dir, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| learning_rate=args.lr, |
| max_prompt_length=args.max_prompt_length, |
| max_completion_length=args.max_completion_length, |
| num_generations=args.num_generations, |
| temperature=0.7, |
| logging_strategy="steps", |
| logging_steps=5, |
| logging_first_step=True, |
| bf16=True, |
| gradient_checkpointing=True, |
| disable_tqdm=True, |
| push_to_hub=args.hub_model_id is not None, |
| hub_model_id=args.hub_model_id, |
| ) |
|
|
| trainer = GRPOTrainer( |
| model=model, |
| reward_funcs=[execution_reward_fn], |
| args=grpo_config, |
| train_dataset=dataset, |
| processing_class=tokenizer, |
| ) |
|
|
| print("[4/4] Starting GRPO training...") |
| trainer.train() |
| trainer.save_model(args.output_dir) |
| print(f"Done. Model saved to {args.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|