olanigan commited on
Commit
33bac25
·
verified ·
1 Parent(s): 90e71a6

Add GRPO RL training script with execution reward function

Browse files
Files changed (1) hide show
  1. train_grpo.py +139 -0
train_grpo.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Reference GRPO training script for agentic coding RL.
4
+ Uses execution-verified pass_rate as the reward signal.
5
+
6
+ Usage:
7
+ python train_grpo.py \
8
+ --model ./nexus-coder-sft \
9
+ --output_dir ./nexus-coder-rl
10
+ """
11
+
12
+ import argparse
13
+ import json
14
+ import subprocess
15
+ import tempfile
16
+ from datasets import load_dataset
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from trl import GRPOTrainer, GRPOConfig
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Execution reward function (simplified — adapt to your sandbox)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ def execution_reward_fn(completions: list, **kwargs) -> list:
25
+ """
26
+ Reward function for GRPO.
27
+ Expects completions that contain bash commands or patches.
28
+ In a real setup, replay commands in a Docker sandbox and return pass_rate.
29
+ """
30
+ rewards = []
31
+ for completion in completions:
32
+ try:
33
+ # Look for ```bash ... ``` blocks
34
+ if "```bash" in completion:
35
+ cmd = completion.split("```bash")[-1].split("```")[0].strip()
36
+ result = subprocess.run(cmd, shell=True, capture_output=True, timeout=30, cwd=tempfile.gettempdir())
37
+ reward = 1.0 if result.returncode == 0 else 0.0
38
+ else:
39
+ reward = 0.0
40
+ except Exception:
41
+ reward = 0.0
42
+ rewards.append(reward)
43
+ return rewards
44
+
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Dataset prep
48
+ # ---------------------------------------------------------------------------
49
+
50
+ def load_rl_dataset():
51
+ """Load Nemotron RL SWE pivot dataset and normalize prompts."""
52
+ ds = load_dataset("nvidia/Nemotron-RL-Agentic-SWE-Pivot-v1", split="train")
53
+ def normalize(example):
54
+ params = example.get("responses_create_params", {})
55
+ inp = params.get("input", [])
56
+ if len(inp) > 0 and isinstance(inp[0], dict):
57
+ system = inp[0].get("content", "")
58
+ ref = example.get("ref_message", {})
59
+ reasoning = ref.get("reasoning_content", "") if isinstance(ref, dict) else ""
60
+ return {
61
+ "prompt": system,
62
+ "completion": reasoning,
63
+ }
64
+ return {"prompt": "", "completion": ""}
65
+ ds = ds.map(normalize, remove_columns=ds.column_names)
66
+ ds = ds.filter(lambda x: len(x["prompt"]) > 50)
67
+ return ds
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Main
72
+ # ---------------------------------------------------------------------------
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser()
76
+ parser.add_argument("--model", required=True, help="Path to SFT checkpoint")
77
+ parser.add_argument("--output_dir", default="./nexus-coder-rl")
78
+ parser.add_argument("--epochs", type=int, default=1)
79
+ parser.add_argument("--batch_size", type=int, default=1)
80
+ parser.add_argument("--grad_accum", type=int, default=16)
81
+ parser.add_argument("--lr", type=float, default=1e-6)
82
+ parser.add_argument("--max_prompt_length", type=int, default=4096)
83
+ parser.add_argument("--max_completion_length", type=int, default=12288)
84
+ parser.add_argument("--num_generations", type=int, default=8)
85
+ parser.add_argument("--hub_model_id", default=None)
86
+ args = parser.parse_args()
87
+
88
+ print("[1/4] Loading SFT model and tokenizer...")
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ args.model,
91
+ torch_dtype="bfloat16",
92
+ device_map="auto",
93
+ trust_remote_code=True,
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
96
+ if tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ print("[2/4] Loading RL dataset...")
100
+ dataset = load_rl_dataset()
101
+ print(f" RL dataset size: {len(dataset)} examples")
102
+
103
+ print("[3/4] Configuring GRPO trainer...")
104
+ grpo_config = GRPOConfig(
105
+ output_dir=args.output_dir,
106
+ num_train_epochs=args.epochs,
107
+ per_device_train_batch_size=args.batch_size,
108
+ gradient_accumulation_steps=args.grad_accum,
109
+ learning_rate=args.lr,
110
+ max_prompt_length=args.max_prompt_length,
111
+ max_completion_length=args.max_completion_length,
112
+ num_generations=args.num_generations,
113
+ temperature=0.7,
114
+ logging_strategy="steps",
115
+ logging_steps=5,
116
+ logging_first_step=True,
117
+ bf16=True,
118
+ gradient_checkpointing=True,
119
+ disable_tqdm=True,
120
+ push_to_hub=args.hub_model_id is not None,
121
+ hub_model_id=args.hub_model_id,
122
+ )
123
+
124
+ trainer = GRPOTrainer(
125
+ model=model,
126
+ reward_funcs=[execution_reward_fn],
127
+ args=grpo_config,
128
+ train_dataset=dataset,
129
+ processing_class=tokenizer,
130
+ )
131
+
132
+ print("[4/4] Starting GRPO training...")
133
+ trainer.train()
134
+ trainer.save_model(args.output_dir)
135
+ print(f"Done. Model saved to {args.output_dir}")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()