versae commited on
Commit
1dc4fb8
1 Parent(s): 685d91d
Files changed (2) hide show
  1. images/bertin.png +0 -0
  2. run_mlm_flax_stream.py +55 -3
images/bertin.png CHANGED
run_mlm_flax_stream.py CHANGED
@@ -25,6 +25,7 @@ import json
25
  import os
26
  import shutil
27
  import sys
 
28
  import time
29
  from collections import defaultdict
30
  from dataclasses import dataclass, field
@@ -60,6 +61,8 @@ from transformers import (
60
  TrainingArguments,
61
  is_tensorboard_available,
62
  set_seed,
 
 
63
  )
64
 
65
 
@@ -376,6 +379,27 @@ def rotate_checkpoints(path, max_checkpoints=5):
376
  os.remove(path_to_delete)
377
 
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  if __name__ == "__main__":
380
  # See all possible arguments in src/transformers/training_args.py
381
  # or by passing the --help flag to this script.
@@ -749,7 +773,8 @@ if __name__ == "__main__":
749
  eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
750
 
751
  # Update progress bar
752
- steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
 
753
 
754
  if has_tensorboard and jax.process_index() == 0:
755
  write_eval_metric(summary_writer, eval_metrics, step)
@@ -762,8 +787,7 @@ if __name__ == "__main__":
762
  model.save_pretrained(
763
  training_args.output_dir,
764
  params=params,
765
- push_to_hub=training_args.push_to_hub,
766
- commit_message=f"Saving weights and logs of step {step + 1}",
767
  )
768
  save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
769
  checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
@@ -774,6 +798,34 @@ if __name__ == "__main__":
774
  Path(training_args.output_dir) / "checkpoints",
775
  max_checkpoints=training_args.save_total_limit
776
  )
 
 
 
 
 
 
 
777
 
778
  # update tqdm bar
779
  steps.update(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  import os
26
  import shutil
27
  import sys
28
+ import tempfile
29
  import time
30
  from collections import defaultdict
31
  from dataclasses import dataclass, field
61
  TrainingArguments,
62
  is_tensorboard_available,
63
  set_seed,
64
+ FlaxRobertaForMaskedLM,
65
+ RobertaForMaskedLM,
66
  )
67
 
68
 
379
  os.remove(path_to_delete)
380
 
381
 
382
+ def to_f32(t):
383
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
384
+
385
+
386
+ def convert(output_dir, destination_dir="./"):
387
+ shutil.copyfile(Path(output_dir) / "flax_model.msgpack", destination_dir)
388
+ shutil.copyfile(Path(output_dir) / "config.json", destination_dir)
389
+ # Saving extra files from config.json and tokenizer.json files
390
+ tokenizer = AutoTokenizer.from_pretrained(destination_dir)
391
+ tokenizer.save_pretrained(destination_dir)
392
+
393
+ # Temporary saving bfloat16 Flax model into float32
394
+ tmp = tempfile.mkdtemp()
395
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir)
396
+ flax_model.params = to_f32(flax_model.params)
397
+ flax_model.save_pretrained(tmp)
398
+ # Converting float32 Flax to PyTorch
399
+ model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
400
+ model.save_pretrained(destination_dir, save_config=False)
401
+
402
+
403
  if __name__ == "__main__":
404
  # See all possible arguments in src/transformers/training_args.py
405
  # or by passing the --help flag to this script.
773
  eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
774
 
775
  # Update progress bar
776
+ steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
777
+ last_desc = steps.desc
778
 
779
  if has_tensorboard and jax.process_index() == 0:
780
  write_eval_metric(summary_writer, eval_metrics, step)
787
  model.save_pretrained(
788
  training_args.output_dir,
789
  params=params,
790
+ push_to_hub=False,
 
791
  )
792
  save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
793
  checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
798
  Path(training_args.output_dir) / "checkpoints",
799
  max_checkpoints=training_args.save_total_limit
800
  )
801
+ convert(training_args.output_dir, "./")
802
+ model.save_pretrained(
803
+ training_args.output_dir,
804
+ params=params,
805
+ push_to_hub=training_args.push_to_hub,
806
+ commit_message=last_desc,
807
+ )
808
 
809
  # update tqdm bar
810
  steps.update(1)
811
+
812
+ if jax.process_index() == 0:
813
+ logger.info(f"Saving checkpoint at {step} steps")
814
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
815
+ model.save_pretrained(
816
+ training_args.output_dir,
817
+ params=params,
818
+ push_to_hub=False,
819
+ )
820
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
821
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
822
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
823
+ model.save_pretrained(checkpoints_dir, params=params)
824
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
825
+ convert(training_args.output_dir, "./")
826
+ model.save_pretrained(
827
+ training_args.output_dir,
828
+ params=params,
829
+ push_to_hub=training_args.push_to_hub,
830
+ commit_message=last_desc,
831
+ )