ydshieh commited on
Commit
91d8939
1 Parent(s): 0b49c18
Files changed (1) hide show
  1. run_image_captioning_flax.py +3 -0
run_image_captioning_flax.py CHANGED
@@ -929,6 +929,9 @@ def main():
929
 
930
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
931
 
 
 
 
932
  if training_args.do_train:
933
  steps_per_epoch = len(train_dataset) // train_batch_size
934
  num_train_examples_per_epoch = steps_per_epoch * train_batch_size
929
 
930
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
931
 
932
+ if training_args.block_size % train_batch_size > 0:
933
+ raise ValueError(f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead.")
934
+
935
  if training_args.do_train:
936
  steps_per_epoch = len(train_dataset) // train_batch_size
937
  num_train_examples_per_epoch = steps_per_epoch * train_batch_size