ydshieh
commited on
Commit
•
7245cb4
1
Parent(s):
9ca46fa
update debug.py
Browse files
debug.py
CHANGED
@@ -298,46 +298,22 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
298 |
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
|
|
301 |
else:
|
302 |
-
s = time.time()
|
303 |
-
# batch_idx = jnp.arange(len(dataset))
|
304 |
batch_idx = np.arange(len(dataset))
|
305 |
-
e = time.time()
|
306 |
-
print(f'get permutation indices for the block with jax - time: {e-s}')
|
307 |
|
308 |
-
s = time.time()
|
309 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
310 |
-
e = time.time()
|
311 |
-
print(f'skip incomplete batch with jax - time: {e-s}')
|
312 |
-
|
313 |
-
s = time.time()
|
314 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
315 |
-
e = time.time()
|
316 |
-
print(f'reshape block indices with np - time: {e-s}')
|
317 |
|
318 |
for idx in batch_idx:
|
319 |
-
|
320 |
-
print(f'type idx: {type(idx)}')
|
321 |
-
|
322 |
-
print(f'pixel values type: {type(dataset["pixel_values"])}')
|
323 |
-
print(f'pixel values shape: {dataset["pixel_values"].shape}')
|
324 |
-
|
325 |
s = time.time()
|
326 |
batch = dataset[idx]
|
327 |
e = time.time()
|
328 |
-
print(f'
|
329 |
-
|
330 |
-
exit(0)
|
331 |
-
|
332 |
-
s = time.time()
|
333 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
334 |
-
e = time.time()
|
335 |
-
print(f'convert one batch from np to jax - time: {e-s}')
|
336 |
|
337 |
-
s = time.time()
|
338 |
batch = shard(batch)
|
339 |
-
|
340 |
-
print(f'shard one batch with jax - time: {e-s}')
|
341 |
yield batch
|
342 |
|
343 |
|
@@ -781,9 +757,9 @@ def main():
|
|
781 |
if "train" not in dataset:
|
782 |
raise ValueError("--do_train requires a train dataset")
|
783 |
train_dataset = dataset["train"]
|
|
|
784 |
# remove problematic examples
|
785 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
786 |
-
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
787 |
if data_args.max_train_samples is not None:
|
788 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
789 |
train_dataset = train_dataset.map(
|
@@ -803,6 +779,7 @@ def main():
|
|
803 |
eval_dataset = dataset["validation"]
|
804 |
# remove problematic examples
|
805 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
806 |
if data_args.max_eval_samples is not None:
|
807 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
808 |
eval_dataset = eval_dataset.map(
|
@@ -820,6 +797,7 @@ def main():
|
|
820 |
if "test" not in dataset:
|
821 |
raise ValueError("--do_predict requires a test dataset")
|
822 |
predict_dataset = dataset["test"]
|
|
|
823 |
# remove problematic examples
|
824 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
825 |
if data_args.max_predict_samples is not None:
|
@@ -840,7 +818,7 @@ def main():
|
|
840 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
841 |
# data loader separately (in a sequential order).
|
842 |
block_size = training_args.block_size
|
843 |
-
|
844 |
# Store some constant
|
845 |
|
846 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
@@ -874,28 +852,22 @@ def main():
|
|
874 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
875 |
|
876 |
if shuffle:
|
877 |
-
|
878 |
-
|
879 |
-
indices = np.random.permutation(len(train_dataset))
|
880 |
-
e = time.time()
|
881 |
-
print(f'get permutation indices for the whole dataset with jax - time: {e-s}')
|
882 |
else:
|
883 |
-
indices =
|
884 |
|
885 |
for idx in range(num_splits):
|
886 |
|
887 |
start_idx = block_size * idx
|
888 |
end_idx = block_size * (idx + 1)
|
889 |
|
890 |
-
s = time.time()
|
891 |
selected_indices = indices[start_idx:end_idx]
|
892 |
-
e = time.time()
|
893 |
-
print(f'get block indices with jax - time: {e-s}')
|
894 |
|
895 |
s = time.time()
|
896 |
_ds = ds.select(selected_indices)
|
897 |
e = time.time()
|
898 |
-
print(f'select block
|
899 |
|
900 |
names = {
|
901 |
"train": "train",
|
@@ -904,20 +876,19 @@ def main():
|
|
904 |
}
|
905 |
|
906 |
s = time.time()
|
907 |
-
_ds =_ds.map(
|
908 |
feature_extraction_fn,
|
909 |
batched=True,
|
910 |
num_proc=data_args.preprocessing_num_workers,
|
911 |
remove_columns=[image_column],
|
912 |
load_from_cache_file=not data_args.overwrite_cache,
|
913 |
features=features,
|
914 |
-
|
915 |
-
keep_in_memory=False,
|
916 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
917 |
)
|
918 |
_ds = _ds.with_format("numpy")
|
919 |
e = time.time()
|
920 |
-
print(f'map
|
921 |
|
922 |
# No need to shuffle here
|
923 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|
298 |
|
299 |
if shuffle:
|
300 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
301 |
+
batch_idx = np.asarray(batch_idx)
|
302 |
else:
|
|
|
|
|
303 |
batch_idx = np.arange(len(dataset))
|
|
|
|
|
304 |
|
|
|
305 |
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
|
|
|
|
|
|
|
|
306 |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
|
|
|
|
307 |
|
308 |
for idx in batch_idx:
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
s = time.time()
|
310 |
batch = dataset[idx]
|
311 |
e = time.time()
|
312 |
+
print(f'fetch batch time: {e-s}')
|
|
|
|
|
|
|
|
|
313 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
314 |
|
|
|
315 |
batch = shard(batch)
|
316 |
+
|
|
|
317 |
yield batch
|
318 |
|
319 |
|
757 |
if "train" not in dataset:
|
758 |
raise ValueError("--do_train requires a train dataset")
|
759 |
train_dataset = dataset["train"]
|
760 |
+
train_dataset = datasets.concatenate_datasets([train_dataset] * 205)
|
761 |
# remove problematic examples
|
762 |
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
|
|
763 |
if data_args.max_train_samples is not None:
|
764 |
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
765 |
train_dataset = train_dataset.map(
|
779 |
eval_dataset = dataset["validation"]
|
780 |
# remove problematic examples
|
781 |
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
782 |
+
eval_dataset = datasets.concatenate_datasets([eval_dataset] * 205)
|
783 |
if data_args.max_eval_samples is not None:
|
784 |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
785 |
eval_dataset = eval_dataset.map(
|
797 |
if "test" not in dataset:
|
798 |
raise ValueError("--do_predict requires a test dataset")
|
799 |
predict_dataset = dataset["test"]
|
800 |
+
predict_dataset = datasets.concatenate_datasets([predict_dataset] * 1024)
|
801 |
# remove problematic examples
|
802 |
predict_dataset = predict_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
803 |
if data_args.max_predict_samples is not None:
|
818 |
# Split the dataset into several chunks - each chunk is processed (.map) without cache to create a
|
819 |
# data loader separately (in a sequential order).
|
820 |
block_size = training_args.block_size
|
821 |
+
|
822 |
# Store some constant
|
823 |
|
824 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
852 |
num_splits = steps // steps_per_split + int(steps % steps_per_split > 0)
|
853 |
|
854 |
if shuffle:
|
855 |
+
indices = jax.random.permutation(rng, len(train_dataset))
|
856 |
+
indices = np.asarray(indices)
|
|
|
|
|
|
|
857 |
else:
|
858 |
+
indices = np.arange(len(ds))
|
859 |
|
860 |
for idx in range(num_splits):
|
861 |
|
862 |
start_idx = block_size * idx
|
863 |
end_idx = block_size * (idx + 1)
|
864 |
|
|
|
865 |
selected_indices = indices[start_idx:end_idx]
|
|
|
|
|
866 |
|
867 |
s = time.time()
|
868 |
_ds = ds.select(selected_indices)
|
869 |
e = time.time()
|
870 |
+
print(f'select block time: {e-s}')
|
871 |
|
872 |
names = {
|
873 |
"train": "train",
|
876 |
}
|
877 |
|
878 |
s = time.time()
|
879 |
+
_ds = _ds.map(
|
880 |
feature_extraction_fn,
|
881 |
batched=True,
|
882 |
num_proc=data_args.preprocessing_num_workers,
|
883 |
remove_columns=[image_column],
|
884 |
load_from_cache_file=not data_args.overwrite_cache,
|
885 |
features=features,
|
886 |
+
keep_in_memory=keep_in_memory,
|
|
|
887 |
desc=f"Running feature extraction on {names[split]} dataset".replace(" ", " "),
|
888 |
)
|
889 |
_ds = _ds.with_format("numpy")
|
890 |
e = time.time()
|
891 |
+
print(f'map time: {e-s}')
|
892 |
|
893 |
# No need to shuffle here
|
894 |
loader = data_loader(rng, _ds, batch_size=batch_size, shuffle=False)
|