ydshieh
commited on
Commit
•
0c9b4f3
1
Parent(s):
5306066
update debug.py
Browse files
debug.py
CHANGED
@@ -299,17 +299,45 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
else:
|
302 |
-
|
|
|
|
|
|
|
|
|
303 |
|
|
|
304 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
|
|
|
|
|
|
|
305 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
|
|
|
306 |
|
307 |
for idx in batch_idx:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
batch = dataset[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
310 |
|
|
|
311 |
batch = shard(batch)
|
312 |
-
|
|
|
313 |
yield batch
|
314 |
|
315 |
|
@@ -750,126 +778,43 @@ def main():
|
|
750 |
)
|
751 |
|
752 |
if training_args.do_train:
|
753 |
-
|
754 |
if "train" not in dataset:
|
755 |
-
|
756 |
train_dataset = dataset["train"]
|
757 |
-
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
758 |
-
|
759 |
# remove problematic examples
|
760 |
-
s = time.time()
|
761 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
indices_np = np.random.permutation(len(train_dataset))
|
776 |
-
e = time.time()
|
777 |
-
print(f'get permutation indices for the whole dataset with np - time: {e-s}')
|
778 |
-
|
779 |
-
# indices = jnp.arange(len(ds))
|
780 |
-
|
781 |
-
block_size = 4096
|
782 |
-
for idx in range(4):
|
783 |
-
|
784 |
-
start_idx = block_size * idx
|
785 |
-
end_idx = block_size * (idx + 1)
|
786 |
-
|
787 |
-
s = time.time()
|
788 |
-
selected_indices_jax = indices_jax[start_idx:end_idx]
|
789 |
-
e = time.time()
|
790 |
-
print(f'get block indices with jax - time: {e-s}')
|
791 |
-
print(type(selected_indices_jax))
|
792 |
-
|
793 |
-
s = time.time()
|
794 |
-
selected_indices_np = indices_np[start_idx:end_idx]
|
795 |
-
e = time.time()
|
796 |
-
print(f'get block indices with np - time: {e-s}')
|
797 |
-
print(type(selected_indices_np))
|
798 |
-
|
799 |
-
|
800 |
-
s = time.time()
|
801 |
-
_ds = train_dataset.select(selected_indices_jax)
|
802 |
-
e = time.time()
|
803 |
-
print(f'select block with jax - time: {e-s}')
|
804 |
-
|
805 |
-
s = time.time()
|
806 |
-
_ds = train_dataset.select(selected_indices_np)
|
807 |
-
e = time.time()
|
808 |
-
print(f'select block with np - time: {e-s}')
|
809 |
-
|
810 |
-
s = time.time()
|
811 |
-
_selected_indices_np = np.array(selected_indices_jax)
|
812 |
-
e = time.time()
|
813 |
-
print(f'convert jax to np - time: {e-s}')
|
814 |
-
|
815 |
-
|
816 |
-
batch_size = 256
|
817 |
-
|
818 |
-
steps_per_epoch = len(_ds) // batch_size
|
819 |
-
|
820 |
-
s = time.time()
|
821 |
-
batch_idx_jax = jax.random.permutation(rng, len(_ds))
|
822 |
-
e = time.time()
|
823 |
-
print(f'get permutation indices for the block with jax - time: {e-s}')
|
824 |
-
# batch_idx = jnp.arange(len(dataset))
|
825 |
-
|
826 |
-
s = time.time()
|
827 |
-
batch_idx_np = np.random.permutation(len(_ds))
|
828 |
-
e = time.time()
|
829 |
-
print(f'get permutation indices for the block with np - time: {e-s}')
|
830 |
-
|
831 |
-
s = time.time()
|
832 |
-
batch_idx_jax = batch_idx_jax[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
833 |
-
e = time.time()
|
834 |
-
print(f'skip incomplete batch with jax - time: {e-s}')
|
835 |
-
|
836 |
-
s = time.time()
|
837 |
-
batch_idx_np = batch_idx_np[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
838 |
-
e = time.time()
|
839 |
-
print(f'skip incomplete batch with np - time: {e-s}')
|
840 |
-
|
841 |
-
s = time.time()
|
842 |
-
batch_idx_jax = batch_idx_jax.reshape((steps_per_epoch, batch_size))
|
843 |
-
e = time.time()
|
844 |
-
print(f'reshape block indices with jax - time: {e-s}')
|
845 |
-
|
846 |
-
s = time.time()
|
847 |
-
batch_idx_np = batch_idx_np.reshape((steps_per_epoch, batch_size))
|
848 |
-
e = time.time()
|
849 |
-
print(f'reshape block indices with np - time: {e-s}')
|
850 |
-
|
851 |
-
for idx in batch_idx_jax:
|
852 |
-
|
853 |
-
s = time.time()
|
854 |
-
batch = _ds[idx]
|
855 |
-
e = time.time()
|
856 |
-
print(f'get one batch with jax - time: {e-s}')
|
857 |
-
|
858 |
-
#s = time.time()
|
859 |
-
#batch = {k: jnp.array(v) for k, v in batch.items()}
|
860 |
-
#e = time.time()
|
861 |
-
#print(f'convert one batch to jnp time: {e-s}')
|
862 |
-
|
863 |
-
for idx in batch_idx_np:
|
864 |
-
|
865 |
-
s = time.time()
|
866 |
-
batch = _ds[idx]
|
867 |
-
e = time.time()
|
868 |
-
print(f'get one batch with np - time: {e-s}')
|
869 |
-
|
870 |
-
|
871 |
-
exit(0)
|
872 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
873 |
|
874 |
if training_args.do_predict:
|
875 |
if "test" not in dataset:
|
@@ -929,7 +874,11 @@ def main():
|
|
929 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
930 |
|
931 |
if shuffle:
|
932 |
-
|
|
|
|
|
|
|
|
|
933 |
else:
|
934 |
indices = jnp.arange(len(ds))
|
935 |
|
@@ -938,9 +887,15 @@ def main():
|
|
938 |
start_idx = block_size * idx
|
939 |
end_idx = block_size * (idx + 1)
|
940 |
|
|
|
941 |
selected_indices = indices[start_idx:end_idx]
|
|
|
|
|
942 |
|
|
|
943 |
_ds = ds.select(selected_indices)
|
|
|
|
|
944 |
|
945 |
names = {
|
946 |
"train": "train",
|
@@ -948,6 +903,7 @@ def main():
|
|
948 |
"test": "prediction",
|
949 |
}
|
950 |
|
|
|
951 |
_ds =_ds.map(
|
952 |
feature_extraction_fn,
|
953 |
batched=True,
|
@@ -955,10 +911,13 @@ def main():
|
|
955 |
remove_columns=[image_column],
|
956 |
load_from_cache_file=not data_args.overwrite_cache,
|
957 |
features=features,
|
958 |
-
keep_in_memory=keep_in_memory,
|
|
|
959 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
960 |
)
|
961 |
_ds = _ds.with_format("numpy")
|
|
|
|
|
962 |
|
963 |
# No need to shuffle here
|
964 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
|
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
else:
|
302 |
+
s = time.time()
|
303 |
+
# batch_idx = jnp.arange(len(dataset))
|
304 |
+
batch_idx = np.arange(len(dataset))
|
305 |
+
e = time.time()
|
306 |
+
print(f'get permutation indices for the block with jax - time: {e-s}')
|
307 |
|
308 |
+
s = time.time()
|
309 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
310 |
+
e = time.time()
|
311 |
+
print(f'skip incomplete batch with jax - time: {e-s}')
|
312 |
+
|
313 |
+
s = time.time()
|
314 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
315 |
+
e = time.time()
|
316 |
+
print(f'reshape block indices with np - time: {e-s}')
|
317 |
|
318 |
for idx in batch_idx:
|
319 |
+
|
320 |
+
print(f'type idx: {type(idx)}')
|
321 |
+
|
322 |
+
print(f'pixel values type: {type(dataset["pixel_values"])}')
|
323 |
+
print(f'pixel values shape: {dataset["pixel_values"].shape}')
|
324 |
+
|
325 |
+
s = time.time()
|
326 |
batch = dataset[idx]
|
327 |
+
e = time.time()
|
328 |
+
print(f'get one batch with jax - time: {e-s}')
|
329 |
+
|
330 |
+
exit(0)
|
331 |
+
|
332 |
+
s = time.time()
|
333 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
334 |
+
e = time.time()
|
335 |
+
print(f'convert one batch from np to jax - time: {e-s}')
|
336 |
|
337 |
+
s = time.time()
|
338 |
batch = shard(batch)
|
339 |
+
e = time.time()
|
340 |
+
print(f'shard one batch with jax - time: {e-s}')
|
341 |
yield batch
|
342 |
|
343 |
|
|
|
778 |
)
|
779 |
|
780 |
if training_args.do_train:
|
|
|
781 |
if "train" not in dataset:
|
782 |
+
raise ValueError("--do_train requires a train dataset")
|
783 |
train_dataset = dataset["train"]
|
|
|
|
|
784 |
# remove problematic examples
|
|
|
785 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
786 |
+
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
787 |
+
if data_args.max_train_samples is not None:
|
788 |
+
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
789 |
+
train_dataset = train_dataset.map(
|
790 |
+
tokenization_fn,
|
791 |
+
batched=True,
|
792 |
+
num_proc=data_args.preprocessing_num_workers,
|
793 |
+
# kept image paths
|
794 |
+
remove_columns=[x for x in column_names if x != image_column],
|
795 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
796 |
+
desc=f"Running tokenizer on train dataset",
|
797 |
+
fn_kwargs={"max_target_length": data_args.max_target_length},
|
798 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|
800 |
+
if training_args.do_eval:
|
801 |
+
if "validation" not in dataset:
|
802 |
+
raise ValueError("--do_eval requires a validation dataset")
|
803 |
+
eval_dataset = dataset["validation"]
|
804 |
+
# remove problematic examples
|
805 |
+
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
806 |
+
if data_args.max_eval_samples is not None:
|
807 |
+
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
808 |
+
eval_dataset = eval_dataset.map(
|
809 |
+
tokenization_fn,
|
810 |
+
batched=True,
|
811 |
+
num_proc=data_args.preprocessing_num_workers,
|
812 |
+
# kept image paths
|
813 |
+
remove_columns=[x for x in column_names if x != image_column],
|
814 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
815 |
+
desc=f"Running tokenizer on validation dataset",
|
816 |
+
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
817 |
+
)
|
818 |
|
819 |
if training_args.do_predict:
|
820 |
if "test" not in dataset:
|
|
|
874 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
875 |
|
876 |
if shuffle:
|
877 |
+
s = time.time()
|
878 |
+
#indices = jax.random.permutation(input_rng, len(ds))
|
879 |
+
indices = np.random.permutation(len(train_dataset))
|
880 |
+
e = time.time()
|
881 |
+
print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
|
882 |
else:
|
883 |
indices = jnp.arange(len(ds))
|
884 |
|
|
|
887 |
start_idx = block_size * idx
|
888 |
end_idx = block_size * (idx + 1)
|
889 |
|
890 |
+
s = time.time()
|
891 |
selected_indices = indices[start_idx:end_idx]
|
892 |
+
e = time.time()
|
893 |
+
print(f'get block indices with jax - time: {e-s}')
|
894 |
|
895 |
+
s = time.time()
|
896 |
_ds = ds.select(selected_indices)
|
897 |
+
e = time.time()
|
898 |
+
print(f'select block with jax - time: {e-s}')
|
899 |
|
900 |
names = {
|
901 |
"train": "train",
|
|
|
903 |
"test": "prediction",
|
904 |
}
|
905 |
|
906 |
+
s = time.time()
|
907 |
_ds =_ds.map(
|
908 |
feature_extraction_fn,
|
909 |
batched=True,
|
|
|
911 |
remove_columns=[image_column],
|
912 |
load_from_cache_file=not data_args.overwrite_cache,
|
913 |
features=features,
|
914 |
+
#keep_in_memory=keep_in_memory,
|
915 |
+
keep_in_memory=False,
|
916 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
917 |
)
|
918 |
_ds = _ds.with_format("numpy")
|
919 |
+
e = time.time()
|
920 |
+
print(f'map feature extraction - time: {e-s}')
|
921 |
|
922 |
# No need to shuffle here
|
923 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|