shank commited on
Commit Β·
1128de1
1
Parent(s): 3152fa9
Restore full 1000-step training with original curriculum
Browse files- training/train_grpo.py +4 -4
training/train_grpo.py
CHANGED
|
@@ -33,7 +33,7 @@ parser.add_argument("--test", action="store_true", help="Run 10 steps for testin
|
|
| 33 |
parser.add_argument("--test-local", action="store_true", dest="test_local",
|
| 34 |
help="Sanity-check reward function locally without any model or GPU")
|
| 35 |
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
|
| 36 |
-
parser.add_argument("--max_steps", type=int, default=
|
| 37 |
args = parser.parse_args()
|
| 38 |
|
| 39 |
# ββ Install dependencies (for Colab/HF Spaces) βββββββββββββββββββββββββββββββ
|
|
@@ -104,9 +104,9 @@ def load_bugs(tier: int) -> list[dict]:
|
|
| 104 |
|
| 105 |
def get_bugs_for_step(step: int) -> list[dict]:
|
| 106 |
tier1 = load_bugs(1)
|
| 107 |
-
if step <
|
| 108 |
return tier1
|
| 109 |
-
elif step <
|
| 110 |
return tier1 + load_bugs(2)
|
| 111 |
return tier1 + load_bugs(2) + load_bugs(3)
|
| 112 |
|
|
@@ -359,7 +359,7 @@ trainer = GRPOTrainer(
|
|
| 359 |
class CurriculumCallback(TrainerCallback):
|
| 360 |
def on_step_end(self, args, state, control, **kwargs):
|
| 361 |
step = state.global_step
|
| 362 |
-
if step in [
|
| 363 |
trainer.train_dataset = make_dataset(step)
|
| 364 |
print(f"\nCurriculum advanced at step {step}!")
|
| 365 |
if WANDB_API_KEY:
|
|
|
|
| 33 |
parser.add_argument("--test-local", action="store_true", dest="test_local",
|
| 34 |
help="Sanity-check reward function locally without any model or GPU")
|
| 35 |
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
|
| 36 |
+
parser.add_argument("--max_steps", type=int, default=1000)
|
| 37 |
args = parser.parse_args()
|
| 38 |
|
| 39 |
# ββ Install dependencies (for Colab/HF Spaces) βββββββββββββββββββββββββββββββ
|
|
|
|
| 104 |
|
| 105 |
def get_bugs_for_step(step: int) -> list[dict]:
|
| 106 |
tier1 = load_bugs(1)
|
| 107 |
+
if step < 300:
|
| 108 |
return tier1
|
| 109 |
+
elif step < 600:
|
| 110 |
return tier1 + load_bugs(2)
|
| 111 |
return tier1 + load_bugs(2) + load_bugs(3)
|
| 112 |
|
|
|
|
| 359 |
class CurriculumCallback(TrainerCallback):
|
| 360 |
def on_step_end(self, args, state, control, **kwargs):
|
| 361 |
step = state.global_step
|
| 362 |
+
if step in [300, 600]:
|
| 363 |
trainer.train_dataset = make_dataset(step)
|
| 364 |
print(f"\nCurriculum advanced at step {step}!")
|
| 365 |
if WANDB_API_KEY:
|