boris commited on
Commit
648e404
2 Parent(s): 5e244d0 19d68bb

Merge pull request #24 from borisdayma/feat--log-model-frequently

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +44 -29
seq2seq/run_seq2seq_flax.py CHANGED
@@ -84,7 +84,7 @@ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
84
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
85
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
86
  BOS_TOKEN_ID = 16384
87
- BASE_MODEL = 'facebook/bart-large'
88
 
89
 
90
  @dataclass
@@ -231,6 +231,12 @@ class DataTrainingArguments:
231
  log_model: bool = field(
232
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
  )
 
 
 
 
 
 
234
 
235
  def __post_init__(self):
236
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -340,7 +346,7 @@ def wandb_log(metrics, step=None, prefix=None):
340
  if jax.process_index() == 0:
341
  log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
- log_metrics = {**log_metrics, 'train/step': step}
344
  wandb.log(log_metrics)
345
 
346
 
@@ -773,6 +779,38 @@ def main():
773
 
774
  return eval_metrics
775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  for epoch in epochs:
777
  # ======================== Training ================================
778
  train_start = time.time()
@@ -795,6 +833,9 @@ def main():
795
 
796
  if global_step % training_args.eval_steps == 0:
797
  run_evaluation()
 
 
 
798
 
799
  # log final train metrics
800
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
@@ -809,34 +850,8 @@ def main():
809
  eval_metrics = run_evaluation()
810
 
811
  # save checkpoint after each epoch and push checkpoint to the hub
812
- if jax.process_index() == 0:
813
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
814
-
815
- # save model locally
816
- model.save_pretrained(
817
- training_args.output_dir,
818
- params=params,
819
- )
820
-
821
- # save to W&B
822
- if data_args.log_model:
823
- metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
824
- artifact = wandb.Artifact(
825
- name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
826
- )
827
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
828
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
829
- wandb.run.log_artifact(artifact)
830
 
831
- # save to the hub
832
- if training_args.push_to_hub:
833
- model.save_pretrained(
834
- training_args.output_dir,
835
- params=params,
836
- push_to_hub=training_args.push_to_hub,
837
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
838
- temp_dir=True # avoid issues with being in a repository
839
- )
840
 
841
  # ======================== Prediction loop ==============================
842
  if training_args.do_predict:
 
84
  OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
85
  OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
86
  BOS_TOKEN_ID = 16384
87
+ BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
88
 
89
 
90
  @dataclass
 
231
  log_model: bool = field(
232
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
  )
234
+ save_model_steps: Optional[int] = field(
235
+ default=3000, # about once every hour in our experiments
236
+ metadata={
237
+ "help": "For logging the model more frequently. Used only when `log_model` is set."
238
+ },
239
+ )
240
 
241
  def __post_init__(self):
242
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
 
346
  if jax.process_index() == 0:
347
  log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
348
  if step is not None:
349
+ log_metrics['train/step'] = step
350
  wandb.log(log_metrics)
351
 
352
 
 
779
 
780
  return eval_metrics
781
 
782
+ def run_save_model(step, epoch, eval_metrics=None):
783
+ if jax.process_index() == 0:
784
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
+
786
+ # save model locally
787
+ model.save_pretrained(
788
+ training_args.output_dir,
789
+ params=params,
790
+ )
791
+
792
+ # save to W&B
793
+ if data_args.log_model:
794
+ metadata = {'step': step, 'epoch': epoch}
795
+ if eval_metrics is not None:
796
+ metadata['eval/loss'] = eval_metrics['loss']
797
+ artifact = wandb.Artifact(
798
+ name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
799
+ )
800
+ artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
801
+ artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
802
+ wandb.run.log_artifact(artifact)
803
+
804
+ # save to the hub
805
+ if training_args.push_to_hub:
806
+ model.save_pretrained(
807
+ training_args.output_dir,
808
+ params=params,
809
+ push_to_hub=training_args.push_to_hub,
810
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
811
+ temp_dir=True # avoid issues with being in a repository
812
+ )
813
+
814
  for epoch in epochs:
815
  # ======================== Training ================================
816
  train_start = time.time()
 
833
 
834
  if global_step % training_args.eval_steps == 0:
835
  run_evaluation()
836
+
837
+ if global_step % data_args.save_model_steps == 0:
838
+ run_save_model(global_step, epoch)
839
 
840
  # log final train metrics
841
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
 
850
  eval_metrics = run_evaluation()
851
 
852
  # save checkpoint after each epoch and push checkpoint to the hub
853
+ run_save_model(global_step, epoch, eval_metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
 
 
 
 
 
 
 
 
 
 
855
 
856
  # ======================== Prediction loop ==============================
857
  if training_args.do_predict: