Spaces:
Running
Running
feat: no decay option
Browse files- seq2seq/run_seq2seq_flax.py +7 -1
- seq2seq/sweep.yaml +1 -0
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -162,6 +162,9 @@ class DataTrainingArguments:
|
|
162 |
"than this will be truncated, sequences shorter will be padded."
|
163 |
},
|
164 |
)
|
|
|
|
|
|
|
165 |
max_target_length: Optional[int] = field(
|
166 |
default=OUTPUT_LENGTH,
|
167 |
metadata={
|
@@ -332,12 +335,14 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
332 |
|
333 |
|
334 |
def create_learning_rate_fn(
|
335 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
336 |
) -> Callable[[int], jnp.array]:
|
337 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
338 |
steps_per_epoch = train_ds_size // train_batch_size
|
339 |
num_train_steps = steps_per_epoch * num_train_epochs
|
340 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
|
|
|
341 |
decay_fn = optax.linear_schedule(
|
342 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
343 |
)
|
@@ -610,6 +615,7 @@ def main():
|
|
610 |
training_args.num_train_epochs,
|
611 |
training_args.warmup_steps,
|
612 |
training_args.learning_rate,
|
|
|
613 |
)
|
614 |
|
615 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
162 |
"than this will be truncated, sequences shorter will be padded."
|
163 |
},
|
164 |
)
|
165 |
+
no_decay: bool = field(
|
166 |
+
default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
|
167 |
+
)
|
168 |
max_target_length: Optional[int] = field(
|
169 |
default=OUTPUT_LENGTH,
|
170 |
metadata={
|
|
|
335 |
|
336 |
|
337 |
def create_learning_rate_fn(
|
338 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
339 |
) -> Callable[[int], jnp.array]:
|
340 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
341 |
steps_per_epoch = train_ds_size // train_batch_size
|
342 |
num_train_steps = steps_per_epoch * num_train_epochs
|
343 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
344 |
+
if no_decay:
|
345 |
+
return warmup_fn
|
346 |
decay_fn = optax.linear_schedule(
|
347 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
348 |
)
|
|
|
615 |
training_args.num_train_epochs,
|
616 |
training_args.warmup_steps,
|
617 |
training_args.learning_rate,
|
618 |
+
data_args.no_decay
|
619 |
)
|
620 |
|
621 |
# We use Optax's "masking" functionality to not apply weight decay
|
seq2seq/sweep.yaml
CHANGED
@@ -37,6 +37,7 @@ command:
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
|
|
40 |
- "--do_train"
|
41 |
- "--do_eval"
|
42 |
- ${args}
|
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
40 |
+
- "--no_decay"
|
41 |
- "--do_train"
|
42 |
- "--do_eval"
|
43 |
- ${args}
|