ydshieh commited on
Commit
a231b72
1 Parent(s): efe62b9

fix speed 2

Browse files
Files changed (1) hide show
  1. run_image_captioning_flax.py +2 -2
run_image_captioning_flax.py CHANGED
@@ -297,7 +297,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
297
  steps_per_epoch = len(dataset) // batch_size
298
 
299
  if shuffle:
300
- batch_idx = jax.random.permutation(rng, len(dataset))
301
  else:
302
  batch_idx = np.arange(len(dataset))
303
 
@@ -847,7 +847,7 @@ def main():
847
  if shuffle:
848
  indices = np.random.permutation(len(train_dataset))
849
  else:
850
- indices = jnp.arange(len(ds))
851
 
852
  for idx in range(num_splits):
853
 
 
297
  steps_per_epoch = len(dataset) // batch_size
298
 
299
  if shuffle:
300
+ batch_idx = np.random.permutation(len(dataset))
301
  else:
302
  batch_idx = np.arange(len(dataset))
303
 
 
847
  if shuffle:
848
  indices = np.random.permutation(len(train_dataset))
849
  else:
850
+ indices = np.arange(len(ds))
851
 
852
  for idx in range(num_splits):
853