Fix OOM: reduce batch/gen/tokens, add grad checkpointing + adafactor
Browse files- run_training.py +8 -5
run_training.py
CHANGED
|
@@ -4,6 +4,7 @@ Runs env-grounded GRPO training, saves model + plots,
|
|
| 4 |
then starts a FastAPI server to serve/download results.
|
| 5 |
"""
|
| 6 |
import os
|
|
|
|
| 7 |
import sys
|
| 8 |
import json
|
| 9 |
import copy
|
|
@@ -125,7 +126,7 @@ def run_grpo_training():
|
|
| 125 |
obs_contexts = []
|
| 126 |
rng = np.random.RandomState(base_seed)
|
| 127 |
|
| 128 |
-
for episode in range(
|
| 129 |
ep_config = copy.deepcopy(task_config)
|
| 130 |
ep_config['seed'] = base_seed + episode
|
| 131 |
env = OpenGridEnv(ep_config)
|
|
@@ -208,17 +209,19 @@ def run_grpo_training():
|
|
| 208 |
grpo_config = GRPOConfig(
|
| 209 |
output_dir="training/outputs/grpo_checkpoints",
|
| 210 |
num_train_epochs=3,
|
| 211 |
-
per_device_train_batch_size=
|
| 212 |
-
gradient_accumulation_steps=
|
| 213 |
learning_rate=1e-5,
|
| 214 |
logging_steps=5,
|
| 215 |
save_steps=50,
|
| 216 |
-
max_completion_length=
|
| 217 |
-
num_generations=
|
| 218 |
report_to="none",
|
| 219 |
remove_unused_columns=False,
|
| 220 |
bf16=_bf16,
|
| 221 |
fp16=_fp16,
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
|
| 224 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|
|
|
|
| 4 |
then starts a FastAPI server to serve/download results.
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 8 |
import sys
|
| 9 |
import json
|
| 10 |
import copy
|
|
|
|
| 126 |
obs_contexts = []
|
| 127 |
rng = np.random.RandomState(base_seed)
|
| 128 |
|
| 129 |
+
for episode in range(10): # 10 episodes → ~600 prompts, fits training time
|
| 130 |
ep_config = copy.deepcopy(task_config)
|
| 131 |
ep_config['seed'] = base_seed + episode
|
| 132 |
env = OpenGridEnv(ep_config)
|
|
|
|
| 209 |
grpo_config = GRPOConfig(
|
| 210 |
output_dir="training/outputs/grpo_checkpoints",
|
| 211 |
num_train_epochs=3,
|
| 212 |
+
per_device_train_batch_size=2,
|
| 213 |
+
gradient_accumulation_steps=8,
|
| 214 |
learning_rate=1e-5,
|
| 215 |
logging_steps=5,
|
| 216 |
save_steps=50,
|
| 217 |
+
max_completion_length=128,
|
| 218 |
+
num_generations=2,
|
| 219 |
report_to="none",
|
| 220 |
remove_unused_columns=False,
|
| 221 |
bf16=_bf16,
|
| 222 |
fp16=_fp16,
|
| 223 |
+
gradient_checkpointing=True,
|
| 224 |
+
optim="adafactor",
|
| 225 |
)
|
| 226 |
|
| 227 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|