pere commited on
Commit
d2beff6
1 Parent(s): cb63820
Files changed (2) hide show
  1. run.sh +2 -2
  2. run_mlm_flax_stream.py +6 -2
run.sh CHANGED
@@ -6,8 +6,8 @@ python run_mlm_flax_stream.py \
6
  --dataset_name="NbAiLab/scandinavian" \
7
  --max_seq_length="512" \
8
  --weight_decay="0.01" \
9
- --per_device_train_batch_size="62" \
10
- --per_device_eval_batch_size="62" \
11
  --learning_rate="1e-4" \
12
  --warmup_steps="10000" \
13
  --overwrite_output_dir \
 
6
  --dataset_name="NbAiLab/scandinavian" \
7
  --max_seq_length="512" \
8
  --weight_decay="0.01" \
9
+ --per_device_train_batch_size="12" \
10
+ --per_device_eval_batch_size="12" \
11
  --learning_rate="1e-4" \
12
  --warmup_steps="10000" \
13
  --overwrite_output_dir \
run_mlm_flax_stream.py CHANGED
@@ -395,11 +395,11 @@ if __name__ == "__main__":
395
 
396
  if model_args.tokenizer_name:
397
  tokenizer = AutoTokenizer.from_pretrained(
398
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
399
  )
400
  elif model_args.model_name_or_path:
401
  tokenizer = AutoTokenizer.from_pretrained(
402
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
403
  )
404
  else:
405
  raise ValueError(
@@ -451,6 +451,10 @@ if __name__ == "__main__":
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
 
 
 
 
454
 
455
  # define number steps per stream epoch
456
  num_train_steps = data_args.num_train_steps
 
395
 
396
  if model_args.tokenizer_name:
397
  tokenizer = AutoTokenizer.from_pretrained(
398
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
399
  )
400
  elif model_args.model_name_or_path:
401
  tokenizer = AutoTokenizer.from_pretrained(
402
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer,model_max_length=512
403
  )
404
  else:
405
  raise ValueError(
 
451
  num_epochs = int(training_args.num_train_epochs)
452
  train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
453
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
454
+
455
+ print("***************************")
456
+ print(f"Train Batch Size: {train_batch_size}")
457
+ print("***************************")
458
 
459
  # define number steps per stream epoch
460
  num_train_steps = data_args.num_train_steps