ydshieh
commited on
Commit
•
91d8939
1
Parent(s):
0b49c18
fix
Browse files
run_image_captioning_flax.py
CHANGED
@@ -929,6 +929,9 @@ def main():
|
|
929 |
|
930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
931 |
|
|
|
|
|
|
|
932 |
if training_args.do_train:
|
933 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
934 |
num_train_examples_per_epoch = steps_per_epoch * train_batch_size
|
929 |
|
930 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
931 |
|
932 |
+
if training_args.block_size % train_batch_size > 0:
|
933 |
+
raise ValueError(f"`training_args.block_size` needs to be a multiple of the global batch size. Got {training_args.block_size} and {train_batch_size} instead.")
|
934 |
+
|
935 |
if training_args.do_train:
|
936 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
937 |
num_train_examples_per_epoch = steps_per_epoch * train_batch_size
|