shank commited on
Commit
2b499e7
·
1 Parent(s): b8172c5

Fix: batch%num_generations math

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +19 -12
training/train_grpo.py CHANGED
@@ -303,23 +303,30 @@ if torch.cuda.is_available():
303
 
304
  COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
305
 
306
- # Scale batch/generation config to available VRAM
307
- if _gpu_vram_gb >= 40: # A100 40GB / A100 80GB
308
- _batch = 2
309
- _grad_accum = 4 # effective batch = 8
310
- _num_gen = 8
 
 
 
 
 
 
 
311
  _max_comp = 256
312
  _lora_r = 16
313
- elif _gpu_vram_gb >= 20: # A10G 24GB / V100 32GB — float16 model ~14GB
314
- _batch = 1
315
- _grad_accum = 8
316
- _num_gen = 4
317
  _max_comp = 192
318
  _lora_r = 8
319
  else: # T4 15GB / anything smaller
320
- _batch = 1
321
- _grad_accum = 8
322
- _num_gen = 4
323
  _max_comp = 160
324
  _lora_r = 8
325
 
 
303
 
304
  COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
305
 
306
+ # Scale batch/generation config to available VRAM.
307
+ # GRPO constraint: per_device_train_batch_size % num_generations == 0
308
+ if _gpu_vram_gb >= 70: # A100 80GB
309
+ _batch = 8
310
+ _grad_accum = 1 # effective batch = 8
311
+ _num_gen = 8 # 8 % 8 == 0
312
+ _max_comp = 256
313
+ _lora_r = 16
314
+ elif _gpu_vram_gb >= 40: # A100 40GB
315
+ _batch = 4
316
+ _grad_accum = 2 # effective batch = 8
317
+ _num_gen = 4 # 4 % 4 == 0
318
  _max_comp = 256
319
  _lora_r = 16
320
+ elif _gpu_vram_gb >= 20: # A10G 24GB / V100 32GB
321
+ _batch = 2
322
+ _grad_accum = 4
323
+ _num_gen = 2 # 2 % 2 == 0
324
  _max_comp = 192
325
  _lora_r = 8
326
  else: # T4 15GB / anything smaller
327
+ _batch = 2
328
+ _grad_accum = 4
329
+ _num_gen = 2 # 2 % 2 == 0
330
  _max_comp = 160
331
  _lora_r = 8
332