ydshieh commited on
Commit
0c9b4f3
1 Parent(s): 5306066

update debug.py

Browse files
Files changed (1) hide show
  1. debug.py +78 -119
debug.py CHANGED
@@ -299,17 +299,45 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
301
  else:
302
- batch_idx = jnp.arange(len(dataset))
 
 
 
 
303
 
 
304
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
 
 
 
 
305
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
 
 
306
 
307
  for idx in batch_idx:
 
 
 
 
 
 
 
308
  batch = dataset[idx]
 
 
 
 
 
 
309
  batch = {k: jnp.array(v) for k, v in batch.items()}
 
 
310
 
 
311
  batch = shard(batch)
312
-
 
313
  yield batch
314
 
315
 
@@ -750,126 +778,43 @@ def main():
750
  )
751
 
752
  if training_args.do_train:
753
-
754
  if "train" not in dataset:
755
- raise ValueError("--do_train requires a train dataset")
756
  train_dataset = dataset["train"]
757
- train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
758
-
759
  # remove problematic examples
760
- s = time.time()
761
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
762
- e = time.time()
763
- print(f'filter time: {e-s}')
764
- print(len(train_dataset))
765
-
766
- rng = jax.random.PRNGKey(training_args.seed)
767
- rng, input_rng = jax.random.split(rng)
768
-
769
- s = time.time()
770
- indices_jax = jax.random.permutation(input_rng, len(train_dataset))
771
- e = time.time()
772
- print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
773
-
774
- s = time.time()
775
- indices_np = np.random.permutation(len(train_dataset))
776
- e = time.time()
777
- print(f'get permutation indices for the whole dataset with np - time: {e-s}')
778
-
779
- # indices = jnp.arange(len(ds))
780
-
781
- block_size = 4096
782
- for idx in range(4):
783
-
784
- start_idx = block_size * idx
785
- end_idx = block_size * (idx + 1)
786
-
787
- s = time.time()
788
- selected_indices_jax = indices_jax[start_idx:end_idx]
789
- e = time.time()
790
- print(f'get block indices with jax - time: {e-s}')
791
- print(type(selected_indices_jax))
792
-
793
- s = time.time()
794
- selected_indices_np = indices_np[start_idx:end_idx]
795
- e = time.time()
796
- print(f'get block indices with np - time: {e-s}')
797
- print(type(selected_indices_np))
798
-
799
-
800
- s = time.time()
801
- _ds = train_dataset.select(selected_indices_jax)
802
- e = time.time()
803
- print(f'select block with jax - time: {e-s}')
804
-
805
- s = time.time()
806
- _ds = train_dataset.select(selected_indices_np)
807
- e = time.time()
808
- print(f'select block with np - time: {e-s}')
809
-
810
- s = time.time()
811
- _selected_indices_np = np.array(selected_indices_jax)
812
- e = time.time()
813
- print(f'convert jax to np - time: {e-s}')
814
-
815
-
816
- batch_size = 256
817
-
818
- steps_per_epoch = len(_ds) // batch_size
819
-
820
- s = time.time()
821
- batch_idx_jax = jax.random.permutation(rng, len(_ds))
822
- e = time.time()
823
- print(f'get permutation indices for the block with jax - time: {e-s}')
824
- # batch_idx = jnp.arange(len(dataset))
825
-
826
- s = time.time()
827
- batch_idx_np = np.random.permutation(len(_ds))
828
- e = time.time()
829
- print(f'get permutation indices for the block with np - time: {e-s}')
830
-
831
- s = time.time()
832
- batch_idx_jax = batch_idx_jax[: steps_per_epoch * batch_size] # Skip incomplete batch.
833
- e = time.time()
834
- print(f'skip incomplete batch with jax - time: {e-s}')
835
-
836
- s = time.time()
837
- batch_idx_np = batch_idx_np[: steps_per_epoch * batch_size] # Skip incomplete batch.
838
- e = time.time()
839
- print(f'skip incomplete batch with np - time: {e-s}')
840
-
841
- s = time.time()
842
- batch_idx_jax = batch_idx_jax.reshape((steps_per_epoch, batch_size))
843
- e = time.time()
844
- print(f'reshape block indices with jax - time: {e-s}')
845
-
846
- s = time.time()
847
- batch_idx_np = batch_idx_np.reshape((steps_per_epoch, batch_size))
848
- e = time.time()
849
- print(f'reshape block indices with np - time: {e-s}')
850
-
851
- for idx in batch_idx_jax:
852
-
853
- s = time.time()
854
- batch = _ds[idx]
855
- e = time.time()
856
- print(f'get one batch with jax - time: {e-s}')
857
-
858
- #s = time.time()
859
- #batch = {k: jnp.array(v) for k, v in batch.items()}
860
- #e = time.time()
861
- #print(f'convert one batch to jnp time: {e-s}')
862
-
863
- for idx in batch_idx_np:
864
-
865
- s = time.time()
866
- batch = _ds[idx]
867
- e = time.time()
868
- print(f'get one batch with np - time: {e-s}')
869
-
870
-
871
- exit(0)
872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873
 
874
  if training_args.do_predict:
875
  if "test" not in dataset:
@@ -929,7 +874,11 @@ def main():
929
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
930
 
931
  if shuffle:
932
- indices = jax.random.permutation(input_rng, len(ds))
 
 
 
 
933
  else:
934
  indices = jnp.arange(len(ds))
935
 
@@ -938,9 +887,15 @@ def main():
938
  start_idx = block_size * idx
939
  end_idx = block_size * (idx + 1)
940
 
 
941
  selected_indices = indices[start_idx:end_idx]
 
 
942
 
 
943
  _ds = ds.select(selected_indices)
 
 
944
 
945
  names = {
946
  "train": "train",
@@ -948,6 +903,7 @@ def main():
948
  "test": "prediction",
949
  }
950
 
 
951
  _ds =_ds.map(
952
  feature_extraction_fn,
953
  batched=True,
@@ -955,10 +911,13 @@ def main():
955
  remove_columns=[image_column],
956
  load_from_cache_file=not data_args.overwrite_cache,
957
  features=features,
958
- keep_in_memory=keep_in_memory,
 
959
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
960
  )
961
  _ds = _ds.with_format("numpy")
 
 
962
 
963
  # No need to shuffle here
964
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
 
299
  if shuffle:
300
  batch_idx = jax.random.permutation(rng, len(dataset))
301
  else:
302
+ s = time.time()
303
+ # batch_idx = jnp.arange(len(dataset))
304
+ batch_idx = np.arange(len(dataset))
305
+ e = time.time()
306
+ print(f'get permutation indices for the block with jax - time: {e-s}')
307
 
308
+ s = time.time()
309
  batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
310
+ e = time.time()
311
+ print(f'skip incomplete batch with jax - time: {e-s}')
312
+
313
+ s = time.time()
314
  batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
315
+ e = time.time()
316
+ print(f'reshape block indices with np - time: {e-s}')
317
 
318
  for idx in batch_idx:
319
+
320
+ print(f'type idx: {type(idx)}')
321
+
322
+ print(f'pixel values type: {type(dataset["pixel_values"])}')
323
+ print(f'pixel values shape: {dataset["pixel_values"].shape}')
324
+
325
+ s = time.time()
326
  batch = dataset[idx]
327
+ e = time.time()
328
+ print(f'get one batch with jax - time: {e-s}')
329
+
330
+ exit(0)
331
+
332
+ s = time.time()
333
  batch = {k: jnp.array(v) for k, v in batch.items()}
334
+ e = time.time()
335
+ print(f'convert one batch from np to jax - time: {e-s}')
336
 
337
+ s = time.time()
338
  batch = shard(batch)
339
+ e = time.time()
340
+ print(f'shard one batch with jax - time: {e-s}')
341
  yield batch
342
 
343
 
 
778
  )
779
 
780
  if training_args.do_train:
 
781
  if "train" not in dataset:
782
+ raise ValueError("--do_train requires a train dataset")
783
  train_dataset = dataset["train"]
 
 
784
  # remove problematic examples
 
785
  train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
786
+ train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
787
+ if data_args.max_train_samples is not None:
788
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
789
+ train_dataset = train_dataset.map(
790
+ tokenization_fn,
791
+ batched=True,
792
+ num_proc=data_args.preprocessing_num_workers,
793
+ # kept image paths
794
+ remove_columns=[x for x in column_names if x != image_column],
795
+ load_from_cache_file=not data_args.overwrite_cache,
796
+ desc=f"Running tokenizer on train dataset",
797
+ fn_kwargs={"max_target_length": data_args.max_target_length},
798
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
+ if training_args.do_eval:
801
+ if "validation" not in dataset:
802
+ raise ValueError("--do_eval requires a validation dataset")
803
+ eval_dataset = dataset["validation"]
804
+ # remove problematic examples
805
+ eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
806
+ if data_args.max_eval_samples is not None:
807
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
808
+ eval_dataset = eval_dataset.map(
809
+ tokenization_fn,
810
+ batched=True,
811
+ num_proc=data_args.preprocessing_num_workers,
812
+ # kept image paths
813
+ remove_columns=[x for x in column_names if x != image_column],
814
+ load_from_cache_file=not data_args.overwrite_cache,
815
+ desc=f"Running tokenizer on validation dataset",
816
+ fn_kwargs={"max_target_length": data_args.val_max_target_length},
817
+ )
818
 
819
  if training_args.do_predict:
820
  if "test" not in dataset:
 
874
  num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
875
 
876
  if shuffle:
877
+ s = time.time()
878
+ #indices = jax.random.permutation(input_rng, len(ds))
879
+ indices = np.random.permutation(len(train_dataset))
880
+ e = time.time()
881
+ print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
882
  else:
883
  indices = jnp.arange(len(ds))
884
 
 
887
  start_idx = block_size * idx
888
  end_idx = block_size * (idx + 1)
889
 
890
+ s = time.time()
891
  selected_indices = indices[start_idx:end_idx]
892
+ e = time.time()
893
+ print(f'get block indices with jax - time: {e-s}')
894
 
895
+ s = time.time()
896
  _ds = ds.select(selected_indices)
897
+ e = time.time()
898
+ print(f'select block with jax - time: {e-s}')
899
 
900
  names = {
901
  "train": "train",
 
903
  "test": "prediction",
904
  }
905
 
906
+ s = time.time()
907
  _ds =_ds.map(
908
  feature_extraction_fn,
909
  batched=True,
 
911
  remove_columns=[image_column],
912
  load_from_cache_file=not data_args.overwrite_cache,
913
  features=features,
914
+ #keep_in_memory=keep_in_memory,
915
+ keep_in_memory=False,
916
  desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
917
  )
918
  _ds = _ds.with_format("numpy")
919
+ e = time.time()
920
+ print(f'map feature extraction - time: {e-s}')
921
 
922
  # No need to shuffle here
923
  loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)