ydshieh commited on
Commit
9ca46fa
1 Parent(s): a231b72

use jax rng

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +4 -2
run_image_captioning_flax.py CHANGED
@@ -297,7 +297,8 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
297
  steps_per_epoch = len(dataset) // batch_size
298
 
299
  if shuffle:
300
- batch_idx = np.random.permutation(len(dataset))
 
301
  else:
302
  batch_idx = np.arange(len(dataset))
303
 
@@ -845,7 +846,8 @@ def main():
845
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
846
 
847
  if shuffle:
848
- indices = np.random.permutation(len(train_dataset))
 
849
  else:
850
  indices = np.arange(len(ds))
851
 
 
297
  steps_per_epoch = len(dataset) // batch_size
298
 
299
  if shuffle:
300
+ batch_idx = jax.random.permutation(rng, len(dataset))
301
+ batch_idx = np.asarray(batch_idx)
302
  else:
303
  batch_idx = np.arange(len(dataset))
304
 
 
846
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
847
 
848
  if shuffle:
849
+ indices = jax.random.permutation(rng, len(train_dataset))
850
+ indices = np.asarray(indices)
851
  else:
852
  indices = np.arange(len(ds))
853