shank commited on
Commit ·
2b499e7
1
Parent(s): b8172c5
Fix: batch%num_generations math
Browse files- training/train_grpo.py +19 -12
training/train_grpo.py
CHANGED
|
@@ -303,23 +303,30 @@ if torch.cuda.is_available():
|
|
| 303 |
|
| 304 |
COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
|
| 305 |
|
| 306 |
-
# Scale batch/generation config to available VRAM
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
_max_comp = 256
|
| 312 |
_lora_r = 16
|
| 313 |
-
elif _gpu_vram_gb >= 20: # A10G 24GB / V100 32GB
|
| 314 |
-
_batch =
|
| 315 |
-
_grad_accum =
|
| 316 |
-
_num_gen =
|
| 317 |
_max_comp = 192
|
| 318 |
_lora_r = 8
|
| 319 |
else: # T4 15GB / anything smaller
|
| 320 |
-
_batch =
|
| 321 |
-
_grad_accum =
|
| 322 |
-
_num_gen =
|
| 323 |
_max_comp = 160
|
| 324 |
_lora_r = 8
|
| 325 |
|
|
|
|
| 303 |
|
| 304 |
COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
|
| 305 |
|
| 306 |
+
# Scale batch/generation config to available VRAM.
|
| 307 |
+
# GRPO constraint: per_device_train_batch_size % num_generations == 0
|
| 308 |
+
if _gpu_vram_gb >= 70: # A100 80GB
|
| 309 |
+
_batch = 8
|
| 310 |
+
_grad_accum = 1 # effective batch = 8
|
| 311 |
+
_num_gen = 8 # 8 % 8 == 0
|
| 312 |
+
_max_comp = 256
|
| 313 |
+
_lora_r = 16
|
| 314 |
+
elif _gpu_vram_gb >= 40: # A100 40GB
|
| 315 |
+
_batch = 4
|
| 316 |
+
_grad_accum = 2 # effective batch = 8
|
| 317 |
+
_num_gen = 4 # 4 % 4 == 0
|
| 318 |
_max_comp = 256
|
| 319 |
_lora_r = 16
|
| 320 |
+
elif _gpu_vram_gb >= 20: # A10G 24GB / V100 32GB
|
| 321 |
+
_batch = 2
|
| 322 |
+
_grad_accum = 4
|
| 323 |
+
_num_gen = 2 # 2 % 2 == 0
|
| 324 |
_max_comp = 192
|
| 325 |
_lora_r = 8
|
| 326 |
else: # T4 15GB / anything smaller
|
| 327 |
+
_batch = 2
|
| 328 |
+
_grad_accum = 4
|
| 329 |
+
_num_gen = 2 # 2 % 2 == 0
|
| 330 |
_max_comp = 160
|
| 331 |
_lora_r = 8
|
| 332 |
|