boris commited on
Commit
bab75aa
1 Parent(s): 0df810d

fix: comments

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +2 -2
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -605,8 +605,8 @@ def main():
605
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
606
  optimizer = optax.adafactor(
607
  learning_rate=learning_rate_fn,
608
- # weight_decay_rate=training_args.weight_decay,
609
- # weight_decay_mask=decay_mask_fn,
610
  )
611
  else:
612
  optimizer = optax.adamw(
 
605
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
606
  optimizer = optax.adafactor(
607
  learning_rate=learning_rate_fn,
608
+ weight_decay_rate=training_args.weight_decay,
609
+ weight_decay_mask=decay_mask_fn,
610
  )
611
  else:
612
  optimizer = optax.adamw(