ydshieh commited on
Commit
7245cb4
1 Parent(s): 9ca46fa

update debug.py

Browse files
Files changed (1) hide show
  1. debug.py +14 -43
debug.py CHANGED
@@ -298,46 +298,22 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
298
 
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
 
301
  else:
302
- s = time.time()
303
- # batch_idx = jnp.arange(len(dataset))
304
  batch_idx = np.arange(len(dataset))
305
- e = time.time()
306
- print(f'get permutation indices for the block with jax - time: {e-s}')
307
 
308
- s = time.time()
309
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
310
- e = time.time()
311
- print(f'skip incomplete batch with jax - time: {e-s}')
312
-
313
- s = time.time()
314
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
315
- e = time.time()
316
- print(f'reshape block indices with np - time: {e-s}')
317
 
318
  for idx in batch_idx:
319
-
320
- print(f'type idx: {type(idx)}')
321
-
322
- print(f'pixel values type: {type(dataset["pixel_values"])}')
323
- print(f'pixel values shape: {dataset["pixel_values"].shape}')
324
-
325
  s = time.time()
326
  batch = dataset[idx]
327
  e = time.time()
328
- print(f'get one batch with jax - time: {e-s}')
329
-
330
- exit(0)
331
-
332
- s = time.time()
333
  batch = {k: jnp.array(v) for k, v in batch.items()}
334
- e = time.time()
335
- print(f'convert one batch from np to jax - time: {e-s}')
336
 
337
- s = time.time()
338
  batch = shard(batch)
339
- e = time.time()
340
- print(f'shard one batch with jax - time: {e-s}')
341
  yield batch
342
 
343
 
@@ -781,9 +757,9 @@ def main():
781
  if "train" not in dataset:
782
  raise ValueError("--do_train requires a train dataset")
783
  train_dataset = dataset["train"]
 
784
  # remove problematic examples
785
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
786
- train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
787
  if data_args.max_train_samples is not None:
788
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
789
  train_dataset = train_dataset.map(
@@ -803,6 +779,7 @@ def main():
803
  eval_dataset = dataset["validation"]
804
  # remove problematic examples
805
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
806
  if data_args.max_eval_samples is not None:
807
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
808
  eval_dataset = eval_dataset.map(
@@ -820,6 +797,7 @@ def main():
820
  if "test" not in dataset:
821
  raise ValueError("--do_predict requires a test dataset")
822
  predict_dataset = dataset["test"]
 
823
  # remove problematic examples
824
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
825
  if data_args.max_predict_samples is not None:
@@ -840,7 +818,7 @@ def main():
840
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
841
  # data loader separately (in a sequential order).
842
  block_size = training_args.block_size
843
-
844
  # Store some constant
845
 
846
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
@@ -874,28 +852,22 @@ def main():
874
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
875
 
876
  if shuffle:
877
- s = time.time()
878
- #indices = jax.random.permutation(input_rng, len(ds))
879
- indices = np.random.permutation(len(train_dataset))
880
- e = time.time()
881
- print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
882
  else:
883
- indices = jnp.arange(len(ds))
884
 
885
  for idx in range(num_splits):
886
 
887
  start_idx = block_size * idx
888
  end_idx = block_size * (idx + 1)
889
 
890
- s = time.time()
891
  selected_indices = indices[start_idx:end_idx]
892
- e = time.time()
893
- print(f'get block indices with jax - time: {e-s}')
894
 
895
  s = time.time()
896
  _ds = ds.select(selected_indices)
897
  e = time.time()
898
- print(f'select block with jax - time: {e-s}')
899
 
900
  names = {
901
  "train": "train",
@@ -904,20 +876,19 @@ def main():
904
  }
905
 
906
  s = time.time()
907
- _ds =_ds.map(
908
  feature_extraction_fn,
909
  batched=True,
910
  num_proc=data_args.preprocessing_num_workers,
911
  remove_columns=[image_column],
912
  load_from_cache_file=not data_args.overwrite_cache,
913
  features=features,
914
- #keep_in_memory=keep_in_memory,
915
- keep_in_memory=False,
916
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
917
  )
918
  _ds = _ds.with_format("numpy")
919
  e = time.time()
920
- print(f'map feature extraction - time: {e-s}')
921
 
922
  # No need to shuffle here
923
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
298
 
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
301
+ batch_idx = np.asarray(batch_idx)
302
  else:
 
 
303
  batch_idx = np.arange(len(dataset))
 
 
304
 
 
305
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
 
 
 
 
306
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
 
 
307
 
308
  for idx in batch_idx:
 
 
 
 
 
 
309
  s = time.time()
310
  batch = dataset[idx]
311
  e = time.time()
312
+ print(f'fetch batch time: {e-s}')
 
 
 
 
313
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
314
 
 
315
  batch = shard(batch)
316
+
 
317
  yield batch
318
 
319
 
757
  if "train" not in dataset:
758
  raise ValueError("--do_train requires a train dataset")
759
  train_dataset = dataset["train"]
760
+ train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
761
  # remove problematic examples
762
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
 
763
  if data_args.max_train_samples is not None:
764
  train_dataset = train_dataset.select(range(data_args.max_train_samples))
765
  train_dataset = train_dataset.map(
779
  eval_dataset = dataset["validation"]
780
  # remove problematic examples
781
  eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
782
+ eval_dataset = datasets.concatenate_datasets([eval_dataset] * 205)
783
  if data_args.max_eval_samples is not None:
784
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
785
  eval_dataset = eval_dataset.map(
797
  if "test" not in dataset:
798
  raise ValueError("--do_predict requires a test dataset")
799
  predict_dataset = dataset["test"]
800
+ predict_dataset = datasets.concatenate_datasets([predict_dataset] * 1024)
801
  # remove problematic examples
802
  predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
803
  if data_args.max_predict_samples is not None:
818
  # Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
819
  # data loader separately (in a sequential order).
820
  block_size = training_args.block_size
821
+
822
  # Store some constant
823
 
824
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
852
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
853
 
854
  if shuffle:
855
+ indices = jax.random.permutation(rng, len(train_dataset))
856
+ indices = np.asarray(indices)
 
 
 
857
  else:
858
+ indices = np.arange(len(ds))
859
 
860
  for idx in range(num_splits):
861
 
862
  start_idx = block_size * idx
863
  end_idx = block_size * (idx + 1)
864
 
 
865
  selected_indices = indices[start_idx:end_idx]
 
 
866
 
867
  s = time.time()
868
  _ds = ds.select(selected_indices)
869
  e = time.time()
870
+ print(f'select block time: {e-s}')
871
 
872
  names = {
873
  "train": "train",
876
  }
877
 
878
  s = time.time()
879
+ _ds = _ds.map(
880
  feature_extraction_fn,
881
  batched=True,
882
  num_proc=data_args.preprocessing_num_workers,
883
  remove_columns=[image_column],
884
  load_from_cache_file=not data_args.overwrite_cache,
885
  features=features,
886
+ keep_in_memory=keep_in_memory,
 
887
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
888
  )
889
  _ds = _ds.with_format("numpy")
890
  e = time.time()
891
+ print(f'map time: {e-s}')
892
 
893
  # No need to shuffle here
894
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)