tmabraham commited on
Commit
aecf3a7
1 Parent(s): 6567fd7

add tokenizer save to wandb:

Browse files

Former-commit-id: 36b4af0d456410a4c2996d1476525e91205d3d1c

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +9 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -811,13 +811,16 @@ def main():
811
  params=params,
812
  )
813
 
 
 
 
814
  # save state
815
  state = unreplicate(state)
816
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
817
  f.write(to_bytes(state.opt_state))
818
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
819
  json.dump({'step': state.step.item()}, f)
820
-
821
  # save to W&B
822
  if data_args.log_model:
823
  metadata = {'step': step, 'epoch': epoch}
@@ -827,6 +830,11 @@ def main():
827
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
828
  )
829
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
 
 
 
 
 
830
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
831
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
832
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
 
811
  params=params,
812
  )
813
 
814
+ # save tokenizer
815
+ tokenizer.save_pretrained(training_args.output_dir)
816
+
817
  # save state
818
  state = unreplicate(state)
819
  with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
820
  f.write(to_bytes(state.opt_state))
821
  with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
822
  json.dump({'step': state.step.item()}, f)
823
+
824
  # save to W&B
825
  if data_args.log_model:
826
  metadata = {'step': step, 'epoch': epoch}
 
830
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
831
  )
832
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
833
+ artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
834
+ artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
835
+ artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
836
+ artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
837
+ artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
838
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
839
  artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
840
  artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))