boris commited on
Commit
3ddf1c5
2 Parent(s): 97a008e b29bab7

Merge branch 'main'

Browse files
seq2seq/run_seq2seq_flax.py CHANGED
@@ -161,6 +161,9 @@ class DataTrainingArguments:
161
  "than this will be truncated, sequences shorter will be padded."
162
  },
163
  )
 
 
 
164
  max_target_length: Optional[int] = field(
165
  default=OUTPUT_LENGTH,
166
  metadata={
@@ -324,12 +327,14 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
324
 
325
 
326
  def create_learning_rate_fn(
327
- train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
328
  ) -> Callable[[int], jnp.array]:
329
  """Returns a linear warmup, linear_decay learning rate function."""
330
  steps_per_epoch = train_ds_size // train_batch_size
331
  num_train_steps = steps_per_epoch * num_train_epochs
332
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
333
  decay_fn = optax.linear_schedule(
334
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
335
  )
@@ -595,6 +600,7 @@ def main():
595
  training_args.num_train_epochs,
596
  training_args.warmup_steps,
597
  training_args.learning_rate,
 
598
  )
599
 
600
  # We use Optax's "masking" functionality to not apply weight decay
 
161
  "than this will be truncated, sequences shorter will be padded."
162
  },
163
  )
164
+ no_decay: bool = field(
165
+ default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
166
+ )
167
  max_target_length: Optional[int] = field(
168
  default=OUTPUT_LENGTH,
169
  metadata={
 
327
 
328
 
329
  def create_learning_rate_fn(
330
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
331
  ) -> Callable[[int], jnp.array]:
332
  """Returns a linear warmup, linear_decay learning rate function."""
333
  steps_per_epoch = train_ds_size // train_batch_size
334
  num_train_steps = steps_per_epoch * num_train_epochs
335
  warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
336
+ if no_decay:
337
+ return warmup_fn
338
  decay_fn = optax.linear_schedule(
339
  init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
340
  )
 
600
  training_args.num_train_epochs,
601
  training_args.warmup_steps,
602
  training_args.learning_rate,
603
+ data_args.no_decay
604
  )
605
 
606
  # We use Optax's "masking" functionality to not apply weight decay
seq2seq/sweep.yaml CHANGED
@@ -37,6 +37,7 @@ command:
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
 
40
  - "--do_train"
41
  - "--do_eval"
42
  - ${args}
 
37
  - 56
38
  - "--preprocessing_num_workers"
39
  - 80
40
+ - "--no_decay"
41
  - "--do_train"
42
  - "--do_eval"
43
  - ${args}