Spaces:
Running
Running
Merge pull request #21 from borisdayma/feat-no_decay
Browse files- seq2seq/run_seq2seq_flax.py +7 -1
- seq2seq/sweep.yaml +1 -0
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -162,6 +162,9 @@ class DataTrainingArguments:
|
|
162 |
"than this will be truncated, sequences shorter will be padded."
|
163 |
},
|
164 |
)
|
|
|
|
|
|
|
165 |
max_target_length: Optional[int] = field(
|
166 |
default=OUTPUT_LENGTH,
|
167 |
metadata={
|
@@ -338,12 +341,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
338 |
|
339 |
|
340 |
def create_learning_rate_fn(
|
341 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
342 |
) -> Callable[[int], jnp.array]:
|
343 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
344 |
steps_per_epoch = train_ds_size // train_batch_size
|
345 |
num_train_steps = steps_per_epoch * num_train_epochs
|
346 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
|
|
|
347 |
decay_fn = optax.linear_schedule(
|
348 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
349 |
)
|
@@ -616,6 +621,7 @@ def main():
|
|
616 |
training_args.num_train_epochs,
|
617 |
training_args.warmup_steps,
|
618 |
training_args.learning_rate,
|
|
|
619 |
)
|
620 |
|
621 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
162 |
"than this will be truncated, sequences shorter will be padded."
|
163 |
},
|
164 |
)
|
165 |
+
no_decay: bool = field(
|
166 |
+
default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
|
167 |
+
)
|
168 |
max_target_length: Optional[int] = field(
|
169 |
default=OUTPUT_LENGTH,
|
170 |
metadata={
|
|
|
341 |
|
342 |
|
343 |
def create_learning_rate_fn(
|
344 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
345 |
) -> Callable[[int], jnp.array]:
|
346 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
347 |
steps_per_epoch = train_ds_size // train_batch_size
|
348 |
num_train_steps = steps_per_epoch * num_train_epochs
|
349 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
350 |
+
if no_decay:
|
351 |
+
return warmup_fn
|
352 |
decay_fn = optax.linear_schedule(
|
353 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
354 |
)
|
|
|
621 |
training_args.num_train_epochs,
|
622 |
training_args.warmup_steps,
|
623 |
training_args.learning_rate,
|
624 |
+
data_args.no_decay
|
625 |
)
|
626 |
|
627 |
# We use Optax's "masking" functionality to not apply weight decay
|
seq2seq/sweep.yaml
CHANGED
@@ -37,6 +37,7 @@ command:
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
|
|
40 |
- "--do_train"
|
41 |
- "--do_eval"
|
42 |
- ${args}
|
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
40 |
+
- "--no_decay"
|
41 |
- "--do_train"
|
42 |
- "--do_eval"
|
43 |
- ${args}
|