shank commited on
Commit
2b1fbf3
·
1 Parent(s): 1128de1

Optimize for A100 80GB: 8 generations, batch 4, lr 2e-5, dense logging

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +10 -10
training/train_grpo.py CHANGED
@@ -311,7 +311,7 @@ def run_baseline(n: int = 20) -> dict:
311
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
312
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
313
  rewards.append(r[0])
314
- if r[0] > 0.30:
315
  solved += 1
316
 
317
  result = {"solve_rate": solved / max(len(bugs), 1), "avg_reward": sum(rewards) / max(len(rewards), 1), "rewards": rewards}
@@ -334,15 +334,15 @@ def make_dataset(step: int) -> Dataset:
334
  config = GRPOConfig(
335
  output_dir=CHECKPOINT_DIR,
336
  max_steps=MAX_STEPS,
337
- per_device_train_batch_size=2,
338
- gradient_accumulation_steps=4,
339
- learning_rate=1e-5,
340
  lr_scheduler_type="cosine",
341
- warmup_steps=20 if args.test else 50,
342
- num_generations=4,
343
- max_new_tokens=400,
344
- temperature=0.8,
345
- logging_steps=5 if args.test else 10,
346
  save_steps=50 if args.test else 100,
347
  report_to="wandb" if WANDB_API_KEY else "none",
348
  )
@@ -385,7 +385,7 @@ for bug in bugs:
385
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
386
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
387
  post_rewards.append(r[0])
388
- if r[0] > 0.30:
389
  post_solved += 1
390
 
391
  post_solve_rate = post_solved / max(len(bugs), 1)
 
311
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
312
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
313
  rewards.append(r[0])
314
+ if r[0] > 0.20: # threshold: any positive structured response counts
315
  solved += 1
316
 
317
  result = {"solve_rate": solved / max(len(bugs), 1), "avg_reward": sum(rewards) / max(len(rewards), 1), "rewards": rewards}
 
334
  config = GRPOConfig(
335
  output_dir=CHECKPOINT_DIR,
336
  max_steps=MAX_STEPS,
337
+ per_device_train_batch_size=4, # A100 80GB handles 4 (was 2)
338
+ gradient_accumulation_steps=2, # effective batch = 8 (same total, less accumulation lag)
339
+ learning_rate=2e-5, # slightly higher lr for faster convergence
340
  lr_scheduler_type="cosine",
341
+ warmup_steps=20 if args.test else 40,
342
+ num_generations=8, # GRPO key: more rollouts = stronger learning signal (was 4)
343
+ max_new_tokens=512, # longer responses = more complete fixes (was 400)
344
+ temperature=0.9, # slightly higher temp = more diverse rollouts for GRPO
345
+ logging_steps=5 if args.test else 5, # log every 5 steps for dense W&B curve
346
  save_steps=50 if args.test else 100,
347
  report_to="wandb" if WANDB_API_KEY else "none",
348
  )
 
385
  completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
386
  r = reward_fn([completion], [prompt], bug_metadata=[bug])
387
  post_rewards.append(r[0])
388
+ if r[0] > 0.20:
389
  post_solved += 1
390
 
391
  post_solve_rate = post_solved / max(len(bugs), 1)