boris commited on
Commit
d449092
1 Parent(s): 283adc6

fix: define function before it is used

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +32 -31
seq2seq/run_seq2seq_flax.py CHANGED
@@ -779,6 +779,38 @@ def main():
779
 
780
  return eval_metrics
781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782
  for epoch in epochs:
783
  # ======================== Training ================================
784
  train_start = time.time()
@@ -820,37 +852,6 @@ def main():
820
  # save checkpoint after each epoch and push checkpoint to the hub
821
  run_save_model(global_step, epoch, eval_metrics)
822
 
823
- def run_save_model(step, epoch, eval_metrics=None):
824
- if jax.process_index() == 0:
825
- params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
826
-
827
- # save model locally
828
- model.save_pretrained(
829
- training_args.output_dir,
830
- params=params,
831
- )
832
-
833
- # save to W&B
834
- if data_args.log_model:
835
- metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
836
- if eval_metrics is not None:
837
- metadata['eval/loss'] = eval_metrics['loss']
838
- artifact = wandb.Artifact(
839
- name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
840
- )
841
- artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
842
- artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
843
- wandb.run.log_artifact(artifact)
844
-
845
- # save to the hub
846
- if training_args.push_to_hub:
847
- model.save_pretrained(
848
- training_args.output_dir,
849
- params=params,
850
- push_to_hub=training_args.push_to_hub,
851
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
852
- temp_dir=True # avoid issues with being in a repository
853
- )
854
 
855
  # ======================== Prediction loop ==============================
856
  if training_args.do_predict:
 
779
 
780
  return eval_metrics
781
 
782
+ def run_save_model(step, epoch, eval_metrics=None):
783
+ if jax.process_index() == 0:
784
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
+
786
+ # save model locally
787
+ model.save_pretrained(
788
+ training_args.output_dir,
789
+ params=params,
790
+ )
791
+
792
+ # save to W&B
793
+ if data_args.log_model:
794
+ metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
795
+ if eval_metrics is not None:
796
+ metadata['eval/loss'] = eval_metrics['loss']
797
+ artifact = wandb.Artifact(
798
+ name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
799
+ )
800
+ artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
801
+ artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
802
+ wandb.run.log_artifact(artifact)
803
+
804
+ # save to the hub
805
+ if training_args.push_to_hub:
806
+ model.save_pretrained(
807
+ training_args.output_dir,
808
+ params=params,
809
+ push_to_hub=training_args.push_to_hub,
810
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
811
+ temp_dir=True # avoid issues with being in a repository
812
+ )
813
+
814
  for epoch in epochs:
815
  # ======================== Training ================================
816
  train_start = time.time()
 
852
  # save checkpoint after each epoch and push checkpoint to the hub
853
  run_save_model(global_step, epoch, eval_metrics)
854
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855
 
856
  # ======================== Prediction loop ==============================
857
  if training_args.do_predict: