Spaces:
Running
Running
Fix GRPO batch/generation mismatch: auto-adjust num_generations; set launcher default to 2.
Browse files- launch_job.py +2 -2
- ultimate_sota_training.py +17 -3
launch_job.py
CHANGED
|
@@ -18,7 +18,7 @@ Environment (optional):
|
|
| 18 |
TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
|
| 19 |
TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
|
| 20 |
ROWS_PER_TASK default: 32
|
| 21 |
-
GRPO_NUM_GENERATIONS default:
|
| 22 |
SKIP_HUB_PUSH default: 0
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
|
@@ -33,7 +33,7 @@ _REPO_URL = os.environ.get("TRAIN_REPO_GIT_URL", _DEFAULT_REPO)
|
|
| 33 |
_OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
|
| 34 |
_MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
|
| 35 |
_ROWS = os.environ.get("ROWS_PER_TASK", "32")
|
| 36 |
-
_NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "
|
| 37 |
_SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
|
| 38 |
_TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
|
| 39 |
# l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
|
|
|
|
| 18 |
TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
|
| 19 |
TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
|
| 20 |
ROWS_PER_TASK default: 32
|
| 21 |
+
GRPO_NUM_GENERATIONS default: 2
|
| 22 |
SKIP_HUB_PUSH default: 0
|
| 23 |
"""
|
| 24 |
from __future__ import annotations
|
|
|
|
| 33 |
_OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
|
| 34 |
_MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
|
| 35 |
_ROWS = os.environ.get("ROWS_PER_TASK", "32")
|
| 36 |
+
_NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "2")
|
| 37 |
_SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
|
| 38 |
_TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
|
| 39 |
# l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
|
ultimate_sota_training.py
CHANGED
|
@@ -394,12 +394,26 @@ def run_sota_train():
|
|
| 394 |
if report_to == "tensorboard":
|
| 395 |
_ensure_dir(tb_dir)
|
| 396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
_cfg: Dict[str, Any] = dict(
|
| 398 |
output_dir=out_dir,
|
| 399 |
learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
|
| 400 |
-
per_device_train_batch_size=
|
| 401 |
-
gradient_accumulation_steps=
|
| 402 |
-
num_generations=
|
| 403 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 404 |
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
|
| 405 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|
|
|
|
| 394 |
if report_to == "tensorboard":
|
| 395 |
_ensure_dir(tb_dir)
|
| 396 |
|
| 397 |
+
per_device_bs = int(os.environ.get("PER_DEVICE_TRAIN_BS", "1"))
|
| 398 |
+
grad_accum = int(os.environ.get("GRAD_ACCUM", "2"))
|
| 399 |
+
requested_num_gen = int(os.environ.get("GRPO_NUM_GENERATIONS", "8"))
|
| 400 |
+
effective_bs = max(1, per_device_bs * grad_accum)
|
| 401 |
+
if effective_bs % requested_num_gen != 0:
|
| 402 |
+
valid = [d for d in range(2, effective_bs + 1) if effective_bs % d == 0]
|
| 403 |
+
num_gen = valid[-1] if valid else 2
|
| 404 |
+
print(
|
| 405 |
+
f"Adjusting GRPO_NUM_GENERATIONS from {requested_num_gen} to {num_gen} "
|
| 406 |
+
f"for effective batch size {effective_bs}."
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
num_gen = requested_num_gen
|
| 410 |
+
|
| 411 |
_cfg: Dict[str, Any] = dict(
|
| 412 |
output_dir=out_dir,
|
| 413 |
learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
|
| 414 |
+
per_device_train_batch_size=per_device_bs,
|
| 415 |
+
gradient_accumulation_steps=grad_accum,
|
| 416 |
+
num_generations=num_gen,
|
| 417 |
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "256")),
|
| 418 |
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
|
| 419 |
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|