ydshieh
commited on
Commit
•
a231b72
1
Parent(s):
efe62b9
fix speed 2
Browse files
run_image_captioning_flax.py
CHANGED
@@ -297,7 +297,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
297 |
steps_per_epoch = len(dataset) // batch_size
|
298 |
|
299 |
if shuffle:
|
300 |
-
batch_idx =
|
301 |
else:
|
302 |
batch_idx = np.arange(len(dataset))
|
303 |
|
@@ -847,7 +847,7 @@ def main():
|
|
847 |
if shuffle:
|
848 |
indices = np.random.permutation(len(train_dataset))
|
849 |
else:
|
850 |
-
indices =
|
851 |
|
852 |
for idx in range(num_splits):
|
853 |
|
|
|
297 |
steps_per_epoch = len(dataset) // batch_size
|
298 |
|
299 |
if shuffle:
|
300 |
+
batch_idx = np.random.permutation(len(dataset))
|
301 |
else:
|
302 |
batch_idx = np.arange(len(dataset))
|
303 |
|
|
|
847 |
if shuffle:
|
848 |
indices = np.random.permutation(len(train_dataset))
|
849 |
else:
|
850 |
+
indices = np.arange(len(ds))
|
851 |
|
852 |
for idx in range(num_splits):
|
853 |
|