shank commited on
Commit ·
2b1fbf3
1
Parent(s): 1128de1
Optimize for A100 80GB: 8 generations, batch 4, lr 2e-5, dense logging
Browse files- training/train_grpo.py +10 -10
training/train_grpo.py
CHANGED
|
@@ -311,7 +311,7 @@ def run_baseline(n: int = 20) -> dict:
|
|
| 311 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 312 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 313 |
rewards.append(r[0])
|
| 314 |
-
if r[0] > 0.
|
| 315 |
solved += 1
|
| 316 |
|
| 317 |
result = {"solve_rate": solved / max(len(bugs), 1), "avg_reward": sum(rewards) / max(len(rewards), 1), "rewards": rewards}
|
|
@@ -334,15 +334,15 @@ def make_dataset(step: int) -> Dataset:
|
|
| 334 |
config = GRPOConfig(
|
| 335 |
output_dir=CHECKPOINT_DIR,
|
| 336 |
max_steps=MAX_STEPS,
|
| 337 |
-
per_device_train_batch_size=
|
| 338 |
-
gradient_accumulation_steps=
|
| 339 |
-
learning_rate=
|
| 340 |
lr_scheduler_type="cosine",
|
| 341 |
-
warmup_steps=20 if args.test else
|
| 342 |
-
num_generations=
|
| 343 |
-
max_new_tokens=
|
| 344 |
-
temperature=0.
|
| 345 |
-
logging_steps=5 if args.test else
|
| 346 |
save_steps=50 if args.test else 100,
|
| 347 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 348 |
)
|
|
@@ -385,7 +385,7 @@ for bug in bugs:
|
|
| 385 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 386 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 387 |
post_rewards.append(r[0])
|
| 388 |
-
if r[0] > 0.
|
| 389 |
post_solved += 1
|
| 390 |
|
| 391 |
post_solve_rate = post_solved / max(len(bugs), 1)
|
|
|
|
| 311 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 312 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 313 |
rewards.append(r[0])
|
| 314 |
+
if r[0] > 0.20: # threshold: any positive structured response counts
|
| 315 |
solved += 1
|
| 316 |
|
| 317 |
result = {"solve_rate": solved / max(len(bugs), 1), "avg_reward": sum(rewards) / max(len(rewards), 1), "rewards": rewards}
|
|
|
|
| 334 |
config = GRPOConfig(
|
| 335 |
output_dir=CHECKPOINT_DIR,
|
| 336 |
max_steps=MAX_STEPS,
|
| 337 |
+
per_device_train_batch_size=4, # A100 80GB handles 4 (was 2)
|
| 338 |
+
gradient_accumulation_steps=2, # effective batch = 8 (same total, less accumulation lag)
|
| 339 |
+
learning_rate=2e-5, # slightly higher lr for faster convergence
|
| 340 |
lr_scheduler_type="cosine",
|
| 341 |
+
warmup_steps=20 if args.test else 40,
|
| 342 |
+
num_generations=8, # GRPO key: more rollouts = stronger learning signal (was 4)
|
| 343 |
+
max_new_tokens=512, # longer responses = more complete fixes (was 400)
|
| 344 |
+
temperature=0.9, # slightly higher temp = more diverse rollouts for GRPO
|
| 345 |
+
logging_steps=5 if args.test else 5, # log every 5 steps for dense W&B curve
|
| 346 |
save_steps=50 if args.test else 100,
|
| 347 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 348 |
)
|
|
|
|
| 385 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
| 386 |
r = reward_fn([completion], [prompt], bug_metadata=[bug])
|
| 387 |
post_rewards.append(r[0])
|
| 388 |
+
if r[0] > 0.20:
|
| 389 |
post_solved += 1
|
| 390 |
|
| 391 |
post_solve_rate = post_solved / max(len(bugs), 1)
|