boris commited on
Commit
b29bab7
2 Parent(s): f0a53ac 5a3211f

Merge pull request #21 from borisdayma/feat-no_decay

Browse files
seq2seq/run_seq2seq_flax.py CHANGED
@@ -162,6 +162,9 @@ class DataTrainingArguments:
162
  "than this will be truncated, sequences shorter will be padded."
163
  },
164
  )
 
 
 
165
  max_target_length: Optional[int] = field(
166
  default=OUTPUT_LENGTH,
167
  metadata={
@@ -338,12 +341,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
338
 
339
 
340
  def create_learning_rate_fn(
341
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
342
  ) -> Callable[[int], jnp.array]:
343
  """Returns a linear warmup, linear_decay learning rate function."""
344
  steps_per_epoch = train_ds_size // train_batch_size
345
  num_train_steps = steps_per_epoch * num_train_epochs
346
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
347
  decay_fn = optax.linear_schedule(
348
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
349
  )
@@ -616,6 +621,7 @@ def main():
616
  training_args.num_train_epochs,
617
  training_args.warmup_steps,
618
  training_args.learning_rate,
 
619
  )
620
 
621
  # We use Optax's "masking" functionality to not apply weight decay
 
162
  "than this will be truncated, sequences shorter will be padded."
163
  },
164
  )
165
+ no_decay: bool = field(
166
+ default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
167
+ )
168
  max_target_length: Optional[int] = field(
169
  default=OUTPUT_LENGTH,
170
  metadata={
 
341
 
342
 
343
  def create_learning_rate_fn(
344
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
345
  ) -> Callable[[int], jnp.array]:
346
  """Returns a linear warmup, linear_decay learning rate function."""
347
  steps_per_epoch = train_ds_size // train_batch_size
348
  num_train_steps = steps_per_epoch * num_train_epochs
349
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
350
+ if no_decay:
351
+ return warmup_fn
352
  decay_fn = optax.linear_schedule(
353
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
354
  )
 
621
  training_args.num_train_epochs,
622
  training_args.warmup_steps,
623
  training_args.learning_rate,
624
+ data_args.no_decay
625
  )
626
 
627
  # We use Optax's "masking" functionality to not apply weight decay
seq2seq/sweep.yaml CHANGED
@@ -37,6 +37,7 @@ command:
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
 
40
  - "--do_train"
41
  - "--do_eval"
42
  - ${args}
 
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
40
+ - "--no_decay"
41
  - "--do_train"
42
  - "--do_eval"
43
  - ${args}