boris commited on
Commit
6e89e9e
1 Parent(s): 63249ac

fix: output directory must exist

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +6 -6
seq2seq/run_seq2seq_flax.py CHANGED
@@ -783,6 +783,12 @@ def main():
783
  if jax.process_index() == 0:
784
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
 
 
 
 
 
 
 
786
  # save state
787
  state = unreplicate(state)
788
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
@@ -790,12 +796,6 @@ def main():
790
  with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
791
  json.dump({'step': state.step.item()}, f)
792
 
793
- # save model locally
794
- model.save_pretrained(
795
- training_args.output_dir,
796
- params=params,
797
- )
798
-
799
  # save to W&B
800
  if data_args.log_model:
801
  metadata = {'step': step, 'epoch': epoch}
 
783
  if jax.process_index() == 0:
784
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
 
786
+ # save model locally
787
+ model.save_pretrained(
788
+ training_args.output_dir,
789
+ params=params,
790
+ )
791
+
792
  # save state
793
  state = unreplicate(state)
794
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
 
796
  with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
797
  json.dump({'step': state.step.item()}, f)
798
 
 
 
 
 
 
 
799
  # save to W&B
800
  if data_args.log_model:
801
  metadata = {'step': step, 'epoch': epoch}