boris commited on
Commit
600ad79
1 Parent(s): 498559f

feat: add adafactor

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +16 -9
seq2seq/run_seq2seq_flax.py CHANGED
@@ -623,17 +623,24 @@ def main():
623
  return traverse_util.unflatten_dict(flat_mask)
624
 
625
  # create adam optimizer
626
- adamw = optax.adamw(
627
- learning_rate=linear_decay_lr_schedule_fn,
628
- b1=training_args.adam_beta1,
629
- b2=training_args.adam_beta2,
630
- eps=training_args.adam_epsilon,
631
- weight_decay=training_args.weight_decay,
632
- mask=decay_mask_fn,
633
- )
 
 
 
 
 
 
 
634
 
635
  # Setup train state
636
- state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
637
 
638
  # label smoothed cross entropy
639
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
 
623
  return traverse_util.unflatten_dict(flat_mask)
624
 
625
  # create adam optimizer
626
+ if training_args.adafactor:
627
+ # We use the default parameters here to initialize adafactor,
628
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
629
+ optimizer = optax.adafactor(
630
+ learning_rate=linear_decay_lr_schedule_fn,
631
+ )
632
+ else:
633
+ optimizer = optax.adamw(
634
+ learning_rate=linear_decay_lr_schedule_fn,
635
+ b1=training_args.adam_beta1,
636
+ b2=training_args.adam_beta2,
637
+ eps=training_args.adam_epsilon,
638
+ weight_decay=training_args.weight_decay,
639
+ mask=decay_mask_fn,
640
+ )
641
 
642
  # Setup train state
643
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
644
 
645
  # label smoothed cross entropy
646
  def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):