boris commited on
Commit
85748ef
1 Parent(s): bab75aa

feat: use custom TrainingArguments

Browse files
Files changed (1) hide show
  1. dev/seq2seq/run_seq2seq_flax.py +114 -29
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -44,7 +44,6 @@ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_ke
44
  from transformers import (
45
  AutoTokenizer,
46
  HfArgumentParser,
47
- TrainingArguments,
48
  )
49
  from transformers.models.bart.modeling_flax_bart import BartConfig
50
 
@@ -93,12 +92,6 @@ class ModelArguments:
93
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
94
  },
95
  )
96
- from_checkpoint: Optional[str] = field(
97
- default=None,
98
- metadata={
99
- "help": "Loads a pretrained wandb checkpoint. Use artifact reference."
100
- },
101
- )
102
 
103
 
104
  @dataclass
@@ -143,10 +136,6 @@ class DataTrainingArguments:
143
  "than this will be truncated, sequences shorter will be padded."
144
  },
145
  )
146
- use_decay: bool = field(
147
- default=False,
148
- metadata={"help": "Whether to use decay in the learning rate scheduler."},
149
- )
150
  max_train_samples: Optional[int] = field(
151
  default=None,
152
  metadata={
@@ -173,18 +162,116 @@ class DataTrainingArguments:
173
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
174
  },
175
  )
176
- log_interval: Optional[int] = field(
177
- default=40,
178
- metadata={"help": "Log frequency for metrics"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  )
180
  log_model: bool = field(
181
  default=False,
182
- metadata={"help": "Log frequency for model"},
183
  )
184
 
185
- def __post_init__(self):
186
- if self.dataset_repo_or_path is None:
187
- raise ValueError("Need a dataset repository or path.")
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  class TrainState(train_state.TrainState):
@@ -291,10 +378,7 @@ def wandb_log(metrics, step=None, prefix=None):
291
 
292
 
293
  def main():
294
- # See all possible arguments in src/transformers/training_args.py
295
- # or by passing the --help flag to this script.
296
- # We now keep distinct sets of args, for a cleaner separation of concerns.
297
-
298
  parser = HfArgumentParser(
299
  (ModelArguments, DataTrainingArguments, TrainingArguments)
300
  )
@@ -358,8 +442,8 @@ def main():
358
  config=parser.parse_args(),
359
  )
360
 
361
- if model_args.from_checkpoint is not None:
362
- artifact = wandb.run.use_artifact(model_args.from_checkpoint)
363
  artifact_dir = artifact.download()
364
 
365
  # load model
@@ -574,7 +658,7 @@ def main():
574
  learning_rate_fn = create_learning_rate_fn(
575
  training_args.warmup_steps,
576
  training_args.learning_rate,
577
- data_args.use_decay,
578
  num_train_steps,
579
  )
580
 
@@ -607,6 +691,7 @@ def main():
607
  learning_rate=learning_rate_fn,
608
  weight_decay_rate=training_args.weight_decay,
609
  weight_decay_mask=decay_mask_fn,
 
610
  )
611
  else:
612
  optimizer = optax.adamw(
@@ -631,7 +716,7 @@ def main():
631
  tx=optimizer,
632
  dropout_rng=dropout_rng,
633
  )
634
- if model_args.from_checkpoint is not None:
635
  # restore optimizer state and other parameters
636
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
637
  state = state.restore_state(artifact_dir)
@@ -771,7 +856,7 @@ def main():
771
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
772
  f.write(to_bytes(opt_state))
773
  state_dict = {
774
- k: unreplicate(getattr(state, k))
775
  for k in ["step", "epoch", "train_time", "train_samples"]
776
  }
777
  with (Path(training_args.output_dir) / "training_state.json").open(
@@ -783,7 +868,7 @@ def main():
783
  )
784
 
785
  # save to W&B
786
- if data_args.log_model:
787
  # save some space
788
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
789
  c.cleanup(wandb.util.from_human_size("10GB"))
@@ -866,7 +951,7 @@ def main():
866
  )
867
  step = unreplicate(state.step)
868
 
869
- if step % data_args.log_interval == 0 and jax.process_index() == 0:
870
  # log metrics
871
  wandb_log(unreplicate(train_metric), step=step, prefix="train")
872
  # log state parameters
 
44
  from transformers import (
45
  AutoTokenizer,
46
  HfArgumentParser,
 
47
  )
48
  from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
 
92
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
93
  },
94
  )
 
 
 
 
 
 
95
 
96
 
97
  @dataclass
 
136
  "than this will be truncated, sequences shorter will be padded."
137
  },
138
  )
 
 
 
 
139
  max_train_samples: Optional[int] = field(
140
  default=None,
141
  metadata={
 
162
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
163
  },
164
  )
165
+
166
+ def __post_init__(self):
167
+ if self.dataset_repo_or_path is None:
168
+ raise ValueError("Need a dataset repository or path.")
169
+
170
+
171
+ @dataclass
172
+ class TrainingArguments:
173
+ """
174
+ Arguments pertaining to training parameters.
175
+ """
176
+
177
+ output_dir: str = field(
178
+ metadata={
179
+ "help": "The output directory where the model predictions and checkpoints will be written."
180
+ },
181
+ )
182
+ overwrite_output_dir: bool = field(
183
+ default=False,
184
+ metadata={
185
+ "help": (
186
+ "Overwrite the content of the output directory. "
187
+ "Use this to continue training if output_dir points to a checkpoint directory."
188
+ )
189
+ },
190
+ )
191
+
192
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
193
+ do_eval: bool = field(
194
+ default=False, metadata={"help": "Whether to run eval on the dev set."}
195
+ )
196
+
197
+ per_device_train_batch_size: int = field(
198
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
199
+ )
200
+ per_device_eval_batch_size: int = field(
201
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
202
+ )
203
+
204
+ gradient_accumulation_steps: int = field(
205
+ default=1,
206
+ metadata={
207
+ "help": "Number of updates steps to accumulate before performing a backward/update pass."
208
+ },
209
+ )
210
+
211
+ learning_rate: float = field(
212
+ default=5e-5, metadata={"help": "The initial learning rate."}
213
+ )
214
+ adafactor: bool = field(
215
+ default=False,
216
+ metadata={"help": "Whether or not to replace AdamW by Adafactor."},
217
+ )
218
+ weight_decay: float = field(
219
+ default=None, metadata={"help": "Weight decay if we apply some."}
220
+ )
221
+ adam_beta1: float = field(
222
+ default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}
223
+ )
224
+ adam_beta2: float = field(
225
+ default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}
226
+ )
227
+ adam_epsilon: float = field(
228
+ default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
229
+ )
230
+ max_grad_norm: float = field(
231
+ default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
232
+ )
233
+ use_decay: bool = field(
234
+ default=False,
235
+ metadata={"help": "Whether to use decay in the learning rate scheduler."},
236
+ )
237
+
238
+ num_train_epochs: float = field(
239
+ default=3.0, metadata={"help": "Total number of training epochs to perform."}
240
+ )
241
+ warmup_steps: int = field(
242
+ default=0, metadata={"help": "Linear warmup over warmup_steps."}
243
+ )
244
+
245
+ logging_steps: int = field(
246
+ default=40, metadata={"help": "Log every X updates steps."}
247
+ )
248
+ eval_steps: int = field(
249
+ default=400, metadata={"help": "Run an evaluation every X steps."}
250
+ )
251
+ save_steps: int = field(
252
+ default=4000, metadata={"help": "Save checkpoint every X updates steps."}
253
  )
254
  log_model: bool = field(
255
  default=False,
256
+ metadata={"help": "Log model to wandb at `save_steps` frequency."},
257
  )
258
 
259
+ seed: int = field(
260
+ default=42,
261
+ metadata={"help": "Random seed that will be set at the beginning of training."},
262
+ )
263
+
264
+ push_to_hub: bool = field(
265
+ default=False,
266
+ metadata={
267
+ "help": "Whether or not to upload the trained model to the model hub after training."
268
+ },
269
+ )
270
+
271
+ resume_from_wandb_checkpoint: Optional[str] = field(
272
+ default=None,
273
+ metadata={"help": "The reference to a wandb artifact for resuming training."},
274
+ )
275
 
276
 
277
  class TrainState(train_state.TrainState):
 
378
 
379
 
380
  def main():
381
+ # See all possible arguments by passing the --help flag to this script.
 
 
 
382
  parser = HfArgumentParser(
383
  (ModelArguments, DataTrainingArguments, TrainingArguments)
384
  )
 
442
  config=parser.parse_args(),
443
  )
444
 
445
+ if training_args.resume_from_wandb_checkpoint is not None:
446
+ artifact = wandb.run.use_artifact(training_args.resume_from_wandb_checkpoint)
447
  artifact_dir = artifact.download()
448
 
449
  # load model
 
658
  learning_rate_fn = create_learning_rate_fn(
659
  training_args.warmup_steps,
660
  training_args.learning_rate,
661
+ training_args.use_decay,
662
  num_train_steps,
663
  )
664
 
 
691
  learning_rate=learning_rate_fn,
692
  weight_decay_rate=training_args.weight_decay,
693
  weight_decay_mask=decay_mask_fn,
694
+ clipping_threshold=training_args.max_grad_norm,
695
  )
696
  else:
697
  optimizer = optax.adamw(
 
716
  tx=optimizer,
717
  dropout_rng=dropout_rng,
718
  )
719
+ if training_args.resume_from_wandb_checkpoint is not None:
720
  # restore optimizer state and other parameters
721
  # we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
722
  state = state.restore_state(artifact_dir)
 
856
  with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
857
  f.write(to_bytes(opt_state))
858
  state_dict = {
859
+ k: jax.device_get(unreplicate(getattr(state, k))).item()
860
  for k in ["step", "epoch", "train_time", "train_samples"]
861
  }
862
  with (Path(training_args.output_dir) / "training_state.json").open(
 
868
  )
869
 
870
  # save to W&B
871
+ if training_args.log_model:
872
  # save some space
873
  c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
874
  c.cleanup(wandb.util.from_human_size("10GB"))
 
951
  )
952
  step = unreplicate(state.step)
953
 
954
+ if step % training_args.logging_steps == 0 and jax.process_index() == 0:
955
  # log metrics
956
  wandb_log(unreplicate(train_metric), step=step, prefix="train")
957
  # log state parameters