boris commited on
Commit
5a3211f
1 Parent(s): 7aa2f4b

feat: no decay option

Browse files
seq2seq/run_seq2seq_flax.py CHANGED
@@ -162,6 +162,9 @@ class DataTrainingArguments:
162
  "than this will be truncated, sequences shorter will be padded."
163
  },
164
  )
 
 
 
165
  max_target_length: Optional[int] = field(
166
  default=OUTPUT_LENGTH,
167
  metadata={
@@ -332,12 +335,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
332
 
333
 
334
  def create_learning_rate_fn(
335
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
336
  ) -> Callable[[int], jnp.array]:
337
  """Returns a linear warmup, linear_decay learning rate function."""
338
  steps_per_epoch = train_ds_size // train_batch_size
339
  num_train_steps = steps_per_epoch * num_train_epochs
340
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
341
  decay_fn = optax.linear_schedule(
342
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
343
  )
@@ -610,6 +615,7 @@ def main():
610
  training_args.num_train_epochs,
611
  training_args.warmup_steps,
612
  training_args.learning_rate,
 
613
  )
614
 
615
  # We use Optax's "masking" functionality to not apply weight decay
 
162
  "than this will be truncated, sequences shorter will be padded."
163
  },
164
  )
165
+ no_decay: bool = field(
166
+ default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
167
+ )
168
  max_target_length: Optional[int] = field(
169
  default=OUTPUT_LENGTH,
170
  metadata={
 
335
 
336
 
337
  def create_learning_rate_fn(
338
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
339
  ) -> Callable[[int], jnp.array]:
340
  """Returns a linear warmup, linear_decay learning rate function."""
341
  steps_per_epoch = train_ds_size // train_batch_size
342
  num_train_steps = steps_per_epoch * num_train_epochs
343
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
344
+ if no_decay:
345
+ return warmup_fn
346
  decay_fn = optax.linear_schedule(
347
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
348
  )
 
615
  training_args.num_train_epochs,
616
  training_args.warmup_steps,
617
  training_args.learning_rate,
618
+ data_args.no_decay
619
  )
620
 
621
  # We use Optax's "masking" functionality to not apply weight decay
seq2seq/sweep.yaml CHANGED
@@ -37,6 +37,7 @@ command:
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
 
40
  - "--do_train"
41
  - "--do_eval"
42
  - ${args}
 
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
40
+ - "--no_decay"
41
  - "--do_train"
42
  - "--do_eval"
43
  - ${args}