shank commited on
Commit
1128de1
Β·
1 Parent(s): 3152fa9

Restore full 1000-step training with original curriculum

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +4 -4
training/train_grpo.py CHANGED
@@ -33,7 +33,7 @@ parser.add_argument("--test", action="store_true", help="Run 10 steps for testin
33
  parser.add_argument("--test-local", action="store_true", dest="test_local",
34
  help="Sanity-check reward function locally without any model or GPU")
35
  parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
36
- parser.add_argument("--max_steps", type=int, default=500)
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
@@ -104,9 +104,9 @@ def load_bugs(tier: int) -> list[dict]:
104
 
105
  def get_bugs_for_step(step: int) -> list[dict]:
106
  tier1 = load_bugs(1)
107
- if step < 150:
108
  return tier1
109
- elif step < 350:
110
  return tier1 + load_bugs(2)
111
  return tier1 + load_bugs(2) + load_bugs(3)
112
 
@@ -359,7 +359,7 @@ trainer = GRPOTrainer(
359
  class CurriculumCallback(TrainerCallback):
360
  def on_step_end(self, args, state, control, **kwargs):
361
  step = state.global_step
362
- if step in [150, 350]:
363
  trainer.train_dataset = make_dataset(step)
364
  print(f"\nCurriculum advanced at step {step}!")
365
  if WANDB_API_KEY:
 
33
  parser.add_argument("--test-local", action="store_true", dest="test_local",
34
  help="Sanity-check reward function locally without any model or GPU")
35
  parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
36
+ parser.add_argument("--max_steps", type=int, default=1000)
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
 
104
 
105
  def get_bugs_for_step(step: int) -> list[dict]:
106
  tier1 = load_bugs(1)
107
+ if step < 300:
108
  return tier1
109
+ elif step < 600:
110
  return tier1 + load_bugs(2)
111
  return tier1 + load_bugs(2) + load_bugs(3)
112
 
 
359
  class CurriculumCallback(TrainerCallback):
360
  def on_step_end(self, args, state, control, **kwargs):
361
  step = state.global_step
362
+ if step in [300, 600]:
363
  trainer.train_dataset = make_dataset(step)
364
  print(f"\nCurriculum advanced at step {step}!")
365
  if WANDB_API_KEY: