ydshieh
commited on
Commit
•
9ca46fa
1
Parent(s):
a231b72
use jax rng
Browse files
run_image_captioning_flax.py
CHANGED
@@ -297,7 +297,8 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
297 |
steps_per_epoch = len(dataset) // batch_size
|
298 |
|
299 |
if shuffle:
|
300 |
-
batch_idx =
|
|
|
301 |
else:
|
302 |
batch_idx = np.arange(len(dataset))
|
303 |
|
@@ -845,7 +846,8 @@ def main():
|
|
845 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
846 |
|
847 |
if shuffle:
|
848 |
-
indices =
|
|
|
849 |
else:
|
850 |
indices = np.arange(len(ds))
|
851 |
|
|
|
297 |
steps_per_epoch = len(dataset) // batch_size
|
298 |
|
299 |
if shuffle:
|
300 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
+
batch_idx = np.asarray(batch_idx)
|
302 |
else:
|
303 |
batch_idx = np.arange(len(dataset))
|
304 |
|
|
|
846 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
847 |
|
848 |
if shuffle:
|
849 |
+
indices = jax.random.permutation(rng, len(train_dataset))
|
850 |
+
indices = np.asarray(indices)
|
851 |
else:
|
852 |
indices = np.arange(len(ds))
|
853 |
|