updates
Browse files- run.sh +2 -2
- run_mlm_flax_stream.py +6 -2
run.sh
CHANGED
@@ -6,8 +6,8 @@ python run_mlm_flax_stream.py \
|
|
6 |
--dataset_name="NbAiLab/scandinavian" \
|
7 |
--max_seq_length="512" \
|
8 |
--weight_decay="0.01" \
|
9 |
-
--per_device_train_batch_size="
|
10 |
-
--per_device_eval_batch_size="
|
11 |
--learning_rate="1e-4" \
|
12 |
--warmup_steps="10000" \
|
13 |
--overwrite_output_dir \
|
|
|
6 |
--dataset_name="NbAiLab/scandinavian" \
|
7 |
--max_seq_length="512" \
|
8 |
--weight_decay="0.01" \
|
9 |
+
--per_device_train_batch_size="12" \
|
10 |
+
--per_device_eval_batch_size="12" \
|
11 |
--learning_rate="1e-4" \
|
12 |
--warmup_steps="10000" \
|
13 |
--overwrite_output_dir \
|
run_mlm_flax_stream.py
CHANGED
@@ -395,11 +395,11 @@ if __name__ == "__main__":
|
|
395 |
|
396 |
if model_args.tokenizer_name:
|
397 |
tokenizer = AutoTokenizer.from_pretrained(
|
398 |
-
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
399 |
)
|
400 |
elif model_args.model_name_or_path:
|
401 |
tokenizer = AutoTokenizer.from_pretrained(
|
402 |
-
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
403 |
)
|
404 |
else:
|
405 |
raise ValueError(
|
@@ -451,6 +451,10 @@ if __name__ == "__main__":
|
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
|
|
|
|
|
|
|
|
454 |
|
455 |
# define number steps per stream epoch
|
456 |
num_train_steps = data_args.num_train_steps
|
|
|
395 |
|
396 |
if model_args.tokenizer_name:
|
397 |
tokenizer = AutoTokenizer.from_pretrained(
|
398 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
|
399 |
)
|
400 |
elif model_args.model_name_or_path:
|
401 |
tokenizer = AutoTokenizer.from_pretrained(
|
402 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
|
403 |
)
|
404 |
else:
|
405 |
raise ValueError(
|
|
|
451 |
num_epochs = int(training_args.num_train_epochs)
|
452 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
453 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
454 |
+
|
455 |
+
print("***************************")
|
456 |
+
print(f"Train Batch Size: {train_batch_size}")
|
457 |
+
print("***************************")
|
458 |
|
459 |
# define number steps per stream epoch
|
460 |
num_train_steps = data_args.num_train_steps
|