boris commited on
Commit
87fac28
1 Parent(s): b7d8724

feat: simplify parameters

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +20 -47
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -151,7 +151,7 @@ class DataTrainingArguments:
151
  "than this will be truncated, sequences shorter will be padded."
152
  },
153
  )
154
- no_decay: bool = field(
155
  default=False,
156
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
157
  )
@@ -170,18 +170,16 @@ class DataTrainingArguments:
170
  },
171
  )
172
  preprocessing_num_workers: Optional[int] = field(
173
- default=80, # ensure we have the same datasets cached data and avoid using too much space
174
- metadata={"help": "The number of processes to use for the preprocessing."},
175
- )
176
- source_prefix: Optional[str] = field(
177
  default=None,
178
  metadata={
179
- "help": "A prefix to add before every source text (useful for T5 models)."
180
  },
181
  )
182
  overwrite_cache: bool = field(
183
  default=False,
184
- metadata={"help": "Overwrite the cached training and evaluation sets"},
 
 
185
  )
186
  log_interval: Optional[int] = field(
187
  default=40,
@@ -189,41 +187,16 @@ class DataTrainingArguments:
189
  )
190
  log_model: bool = field(
191
  default=False,
192
- metadata={"help": "Overwrite the cached training and evaluation sets"},
193
  )
194
  save_model_steps: Optional[int] = field(
195
- default=5000, # about once every 1.5h in our experiments
196
- metadata={
197
- "help": "For logging the model more frequently. Used only when `log_model` is set."
198
- },
199
  )
200
 
201
  def __post_init__(self):
202
  if self.dataset_repo_or_path is None:
203
  raise ValueError("Need a dataset repository or path.")
204
- if self.train_file is None or self.validation_file is None:
205
- raise ValueError("Need training/validation file.")
206
- else:
207
- if self.train_file is not None:
208
- extension = self.train_file.split(".")[-1]
209
- assert extension in [
210
- "tsv",
211
- "csv",
212
- "json",
213
- "jsonl",
214
- ], "`train_file` should be a tsv, csv or json file."
215
- if self.validation_file is not None:
216
- extension = self.validation_file.split(".")[-1]
217
- assert extension in [
218
- "tsv",
219
- "csv",
220
- "json",
221
- "jsonl",
222
- ], "`validation_file` should be a tsv, csv or json file."
223
- if self.streaming and (self.len_train is None or self.len_eval is None):
224
- raise ValueError(
225
- "Streaming requires providing length of training and validation datasets"
226
- )
227
 
228
 
229
  class TrainState(train_state.TrainState):
@@ -291,7 +264,7 @@ def create_learning_rate_fn(
291
  num_train_epochs: int,
292
  num_warmup_steps: int,
293
  learning_rate: float,
294
- no_decay: bool,
295
  ) -> Callable[[int], jnp.array]:
296
  """Returns a linear warmup, linear_decay learning rate function."""
297
  steps_per_epoch = train_ds_size // train_batch_size
@@ -299,7 +272,7 @@ def create_learning_rate_fn(
299
  warmup_fn = optax.linear_schedule(
300
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
301
  )
302
- if no_decay:
303
  return warmup_fn
304
  decay_fn = optax.linear_schedule(
305
  init_value=learning_rate,
@@ -372,10 +345,13 @@ def main():
372
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
373
  # (the dataset will be downloaded automatically from the datasets Hub).
374
  #
375
- data_files = {
376
- "train": data_args.train_file,
377
- "validation": data_args.validation_file,
378
- }
 
 
 
379
  dataset = load_dataset(
380
  data_args.dataset_repo_or_path,
381
  data_files=data_files,
@@ -449,8 +425,6 @@ def main():
449
  print(f"TPUs: {jax.device_count()}")
450
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
451
 
452
- prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
453
-
454
  # Preprocessing the datasets.
455
  # We need to tokenize inputs and targets.
456
 
@@ -475,7 +449,6 @@ def main():
475
 
476
  def preprocess_function(examples):
477
  inputs = examples[text_column]
478
- inputs = [prefix + inp for inp in inputs] if prefix else inputs
479
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
480
  model_inputs = tokenizer(
481
  inputs,
@@ -617,7 +590,7 @@ def main():
617
  training_args.num_train_epochs,
618
  training_args.warmup_steps,
619
  training_args.learning_rate,
620
- data_args.no_decay,
621
  )
622
 
623
  # We use Optax's "masking" functionality to not apply weight decay
@@ -625,8 +598,6 @@ def main():
625
  # mask boolean with the same structure as the parameters.
626
  # The mask is True for parameters that should be decayed.
627
  # Note that this mask is specifically adapted for FlaxBart.
628
- # For FlaxT5, one should correct the layer norm parameter naming
629
- # accordingly - see `run_t5_mlm_flax.py` e.g.
630
  def decay_mask_fn(params):
631
  flat_params = traverse_util.flatten_dict(params)
632
  layer_norm_params = [
@@ -649,6 +620,8 @@ def main():
649
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
650
  optimizer = optax.adafactor(
651
  learning_rate=learning_rate_fn,
 
 
652
  )
653
  else:
654
  optimizer = optax.adamw(
 
151
  "than this will be truncated, sequences shorter will be padded."
152
  },
153
  )
154
+ use_decay: bool = field(
155
  default=False,
156
  metadata={"help": "Whether to use decay in the learning rate scheduler."},
157
  )
 
170
  },
171
  )
172
  preprocessing_num_workers: Optional[int] = field(
 
 
 
 
173
  default=None,
174
  metadata={
175
+ "help": "The number of processes to use for the preprocessing. Not used in streaming mode."
176
  },
177
  )
178
  overwrite_cache: bool = field(
179
  default=False,
180
+ metadata={
181
+ "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
182
+ },
183
  )
184
  log_interval: Optional[int] = field(
185
  default=40,
 
187
  )
188
  log_model: bool = field(
189
  default=False,
190
+ metadata={"help": "Log frequency for model"},
191
  )
192
  save_model_steps: Optional[int] = field(
193
+ default=5000,
194
+ metadata={"help": "For saving/logging the model more frequently"},
 
 
195
  )
196
 
197
  def __post_init__(self):
198
  if self.dataset_repo_or_path is None:
199
  raise ValueError("Need a dataset repository or path.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
 
202
  class TrainState(train_state.TrainState):
 
264
  num_train_epochs: int,
265
  num_warmup_steps: int,
266
  learning_rate: float,
267
+ use_decay: bool,
268
  ) -> Callable[[int], jnp.array]:
269
  """Returns a linear warmup, linear_decay learning rate function."""
270
  steps_per_epoch = train_ds_size // train_batch_size
 
272
  warmup_fn = optax.linear_schedule(
273
  init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
274
  )
275
+ if not use_decay:
276
  return warmup_fn
277
  decay_fn = optax.linear_schedule(
278
  init_value=learning_rate,
 
345
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
346
  # (the dataset will be downloaded automatically from the datasets Hub).
347
  #
348
+ if data_args.train_file is not None or data_args.validation_file is not None:
349
+ data_files = {
350
+ "train": data_args.train_file,
351
+ "validation": data_args.validation_file,
352
+ }
353
+ else:
354
+ data_files = None
355
  dataset = load_dataset(
356
  data_args.dataset_repo_or_path,
357
  data_files=data_files,
 
425
  print(f"TPUs: {jax.device_count()}")
426
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
427
 
 
 
428
  # Preprocessing the datasets.
429
  # We need to tokenize inputs and targets.
430
 
 
449
 
450
  def preprocess_function(examples):
451
  inputs = examples[text_column]
 
452
  # Setting padding="max_length" as we need fixed length inputs for jitted functions
453
  model_inputs = tokenizer(
454
  inputs,
 
590
  training_args.num_train_epochs,
591
  training_args.warmup_steps,
592
  training_args.learning_rate,
593
+ data_args.use_decay,
594
  )
595
 
596
  # We use Optax's "masking" functionality to not apply weight decay
 
598
  # mask boolean with the same structure as the parameters.
599
  # The mask is True for parameters that should be decayed.
600
  # Note that this mask is specifically adapted for FlaxBart.
 
 
601
  def decay_mask_fn(params):
602
  flat_params = traverse_util.flatten_dict(params)
603
  layer_norm_params = [
 
620
  # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
621
  optimizer = optax.adafactor(
622
  learning_rate=learning_rate_fn,
623
+ weight_decay_rate=training_args.weight_decay,
624
+ weight_decay_mask=decay_mask_fn
625
  )
626
  else:
627
  optimizer = optax.adamw(