boris commited on
Commit
28f08be
2 Parent(s): ad6ad64 aecf3a7

Merge branch 'add-tokenizer-save' into feat-model

Browse files

Former-commit-id: 2cfaef4a020f43332a8f33b6a9bd8221ec9fae34

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +9 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -818,13 +818,16 @@ def main():
818
  params=params,
819
  )
820
 
 
 
 
821
  # save state
822
  state = unreplicate(state)
823
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
824
  f.write(to_bytes(state.opt_state))
825
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
826
  json.dump({'step': state.step.item()}, f)
827
-
828
  # save to W&B
829
  if data_args.log_model:
830
  metadata = {'step': step, 'epoch': epoch}
@@ -834,6 +837,11 @@ def main():
834
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
835
  )
836
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
 
 
 
 
 
837
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
838
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
839
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
 
818
  params=params,
819
  )
820
 
821
+ # save tokenizer
822
+ tokenizer.save_pretrained(training_args.output_dir)
823
+
824
  # save state
825
  state = unreplicate(state)
826
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
827
  f.write(to_bytes(state.opt_state))
828
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
829
  json.dump({'step': state.step.item()}, f)
830
+
831
  # save to W&B
832
  if data_args.log_model:
833
  metadata = {'step': step, 'epoch': epoch}
 
837
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
838
  )
839
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
840
+ artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
841
+ artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
842
+ artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
843
+ artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
844
+ artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
845
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
846
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
847
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))