sakares commited on
Commit
6aba5c0
1 Parent(s): b0729e9

update script regarding to https://github.com/huggingface/transformers/pull/12608

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +26 -12
run_mlm_flax.py CHANGED
@@ -431,7 +431,8 @@ if __name__ == "__main__":
431
  total_length = len(concatenated_examples[list(examples.keys())[0]])
432
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
433
  # customize this part to your needs.
434
- total_length = (total_length // max_seq_length) * max_seq_length
 
435
  # Split by chunks of max_len.
436
  result = {
437
  k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
@@ -478,7 +479,14 @@ if __name__ == "__main__":
478
  rng = jax.random.PRNGKey(training_args.seed)
479
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
480
 
481
- model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
482
 
483
  # Store some constant
484
  num_epochs = int(training_args.num_train_epochs)
@@ -513,17 +521,24 @@ if __name__ == "__main__":
513
  return traverse_util.unflatten_dict(flat_mask)
514
 
515
  # create adam optimizer
516
- adamw = optax.adamw(
517
- learning_rate=linear_decay_lr_schedule_fn,
518
- b1=training_args.adam_beta1,
519
- b2=training_args.adam_beta2,
520
- eps=1e-8,
521
- weight_decay=training_args.weight_decay,
522
- mask=decay_mask_fn,
523
- )
 
 
 
 
 
 
 
524
 
525
  # Setup train state
526
- state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
527
 
528
  # Define gradient update step fn
529
  def train_step(state, batch, dropout_rng):
@@ -648,7 +663,6 @@ if __name__ == "__main__":
648
 
649
  # Save metrics
650
  if has_tensorboard and jax.process_index() == 0:
651
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
652
  write_eval_metric(summary_writer, eval_metrics, cur_step)
653
 
654
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
 
431
  total_length = len(concatenated_examples[list(examples.keys())[0]])
432
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
433
  # customize this part to your needs.
434
+ if total_length >= max_seq_length:
435
+ total_length = (total_length // max_seq_length) * max_seq_length
436
  # Split by chunks of max_len.
437
  result = {
438
  k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
 
479
  rng = jax.random.PRNGKey(training_args.seed)
480
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
481
 
482
+ if model_args.model_name_or_path:
483
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
484
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
485
+ )
486
+ else:
487
+ model = FlaxAutoModelForMaskedLM.from_config(
488
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
489
+ )
490
 
491
  # Store some constant
492
  num_epochs = int(training_args.num_train_epochs)
 
521
  return traverse_util.unflatten_dict(flat_mask)
522
 
523
  # create adam optimizer
524
+ if training_args.adafactor:
525
+ # We use the default parameters here to initialize adafactor,
526
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
527
+ optimizer = optax.adafactor(
528
+ learning_rate=linear_decay_lr_schedule_fn,
529
+ )
530
+ else:
531
+ optimizer = optax.adamw(
532
+ learning_rate=linear_decay_lr_schedule_fn,
533
+ b1=training_args.adam_beta1,
534
+ b2=training_args.adam_beta2,
535
+ eps=training_args.adam_epsilon,
536
+ weight_decay=training_args.weight_decay,
537
+ mask=decay_mask_fn,
538
+ )
539
 
540
  # Setup train state
541
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
542
 
543
  # Define gradient update step fn
544
  def train_step(state, batch, dropout_rng):
 
663
 
664
  # Save metrics
665
  if has_tensorboard and jax.process_index() == 0:
 
666
  write_eval_metric(summary_writer, eval_metrics, cur_step)
667
 
668
  if cur_step % training_args.save_steps == 0 and cur_step > 0: