Spaces:
Running
Running
Merge branch 'main'
Browse files- seq2seq/run_seq2seq_flax.py +7 -1
- seq2seq/sweep.yaml +1 -0
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -161,6 +161,9 @@ class DataTrainingArguments:
|
|
161 |
"than this will be truncated, sequences shorter will be padded."
|
162 |
},
|
163 |
)
|
|
|
|
|
|
|
164 |
max_target_length: Optional[int] = field(
|
165 |
default=OUTPUT_LENGTH,
|
166 |
metadata={
|
@@ -324,12 +327,14 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
324 |
|
325 |
|
326 |
def create_learning_rate_fn(
|
327 |
-
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
328 |
) -> Callable[[int], jnp.array]:
|
329 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
330 |
steps_per_epoch = train_ds_size // train_batch_size
|
331 |
num_train_steps = steps_per_epoch * num_train_epochs
|
332 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
|
|
|
333 |
decay_fn = optax.linear_schedule(
|
334 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
335 |
)
|
@@ -595,6 +600,7 @@ def main():
|
|
595 |
training_args.num_train_epochs,
|
596 |
training_args.warmup_steps,
|
597 |
training_args.learning_rate,
|
|
|
598 |
)
|
599 |
|
600 |
# We use Optax's "masking" functionality to not apply weight decay
|
|
|
161 |
"than this will be truncated, sequences shorter will be padded."
|
162 |
},
|
163 |
)
|
164 |
+
no_decay: bool = field(
|
165 |
+
default=False, metadata={"help": "Whether to use decay in the learning rate scheduler."}
|
166 |
+
)
|
167 |
max_target_length: Optional[int] = field(
|
168 |
default=OUTPUT_LENGTH,
|
169 |
metadata={
|
|
|
327 |
|
328 |
|
329 |
def create_learning_rate_fn(
|
330 |
+
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
331 |
) -> Callable[[int], jnp.array]:
|
332 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
333 |
steps_per_epoch = train_ds_size // train_batch_size
|
334 |
num_train_steps = steps_per_epoch * num_train_epochs
|
335 |
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
336 |
+
if no_decay:
|
337 |
+
return warmup_fn
|
338 |
decay_fn = optax.linear_schedule(
|
339 |
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
340 |
)
|
|
|
600 |
training_args.num_train_epochs,
|
601 |
training_args.warmup_steps,
|
602 |
training_args.learning_rate,
|
603 |
+
data_args.no_decay
|
604 |
)
|
605 |
|
606 |
# We use Optax's "masking" functionality to not apply weight decay
|
seq2seq/sweep.yaml
CHANGED
@@ -37,6 +37,7 @@ command:
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
|
|
40 |
- "--do_train"
|
41 |
- "--do_eval"
|
42 |
- ${args}
|
|
|
37 |
- 56
|
38 |
- "--preprocessing_num_workers"
|
39 |
- 80
|
40 |
+
- "--no_decay"
|
41 |
- "--do_train"
|
42 |
- "--do_eval"
|
43 |
- ${args}
|