New logo
Browse files- images/bertin.png +0 -0
- run_mlm_flax_stream.py +55 -3
images/bertin.png
CHANGED
run_mlm_flax_stream.py
CHANGED
@@ -25,6 +25,7 @@ import json
|
|
25 |
import os
|
26 |
import shutil
|
27 |
import sys
|
|
|
28 |
import time
|
29 |
from collections import defaultdict
|
30 |
from dataclasses import dataclass, field
|
@@ -60,6 +61,8 @@ from transformers import (
|
|
60 |
TrainingArguments,
|
61 |
is_tensorboard_available,
|
62 |
set_seed,
|
|
|
|
|
63 |
)
|
64 |
|
65 |
|
@@ -376,6 +379,27 @@ def rotate_checkpoints(path, max_checkpoints=5):
|
|
376 |
os.remove(path_to_delete)
|
377 |
|
378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
if __name__ == "__main__":
|
380 |
# See all possible arguments in src/transformers/training_args.py
|
381 |
# or by passing the --help flag to this script.
|
@@ -749,7 +773,8 @@ if __name__ == "__main__":
|
|
749 |
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
750 |
|
751 |
# Update progress bar
|
752 |
-
steps.desc = f"Step... ({step
|
|
|
753 |
|
754 |
if has_tensorboard and jax.process_index() == 0:
|
755 |
write_eval_metric(summary_writer, eval_metrics, step)
|
@@ -762,8 +787,7 @@ if __name__ == "__main__":
|
|
762 |
model.save_pretrained(
|
763 |
training_args.output_dir,
|
764 |
params=params,
|
765 |
-
push_to_hub=
|
766 |
-
commit_message=f"Saving weights and logs of step {step + 1}",
|
767 |
)
|
768 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
769 |
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
@@ -774,6 +798,34 @@ if __name__ == "__main__":
|
|
774 |
Path(training_args.output_dir) / "checkpoints",
|
775 |
max_checkpoints=training_args.save_total_limit
|
776 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
777 |
|
778 |
# update tqdm bar
|
779 |
steps.update(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
import os
|
26 |
import shutil
|
27 |
import sys
|
28 |
+
import tempfile
|
29 |
import time
|
30 |
from collections import defaultdict
|
31 |
from dataclasses import dataclass, field
|
|
|
61 |
TrainingArguments,
|
62 |
is_tensorboard_available,
|
63 |
set_seed,
|
64 |
+
FlaxRobertaForMaskedLM,
|
65 |
+
RobertaForMaskedLM,
|
66 |
)
|
67 |
|
68 |
|
|
|
379 |
os.remove(path_to_delete)
|
380 |
|
381 |
|
382 |
+
def to_f32(t):
|
383 |
+
return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
384 |
+
|
385 |
+
|
386 |
+
def convert(output_dir, destination_dir="./"):
|
387 |
+
shutil.copyfile(Path(output_dir) / "flax_model.msgpack", destination_dir)
|
388 |
+
shutil.copyfile(Path(output_dir) / "config.json", destination_dir)
|
389 |
+
# Saving extra files from config.json and tokenizer.json files
|
390 |
+
tokenizer = AutoTokenizer.from_pretrained(destination_dir)
|
391 |
+
tokenizer.save_pretrained(destination_dir)
|
392 |
+
|
393 |
+
# Temporary saving bfloat16 Flax model into float32
|
394 |
+
tmp = tempfile.mkdtemp()
|
395 |
+
flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir)
|
396 |
+
flax_model.params = to_f32(flax_model.params)
|
397 |
+
flax_model.save_pretrained(tmp)
|
398 |
+
# Converting float32 Flax to PyTorch
|
399 |
+
model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
|
400 |
+
model.save_pretrained(destination_dir, save_config=False)
|
401 |
+
|
402 |
+
|
403 |
if __name__ == "__main__":
|
404 |
# See all possible arguments in src/transformers/training_args.py
|
405 |
# or by passing the --help flag to this script.
|
|
|
773 |
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
774 |
|
775 |
# Update progress bar
|
776 |
+
steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
777 |
+
last_desc = steps.desc
|
778 |
|
779 |
if has_tensorboard and jax.process_index() == 0:
|
780 |
write_eval_metric(summary_writer, eval_metrics, step)
|
|
|
787 |
model.save_pretrained(
|
788 |
training_args.output_dir,
|
789 |
params=params,
|
790 |
+
push_to_hub=False,
|
|
|
791 |
)
|
792 |
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
793 |
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
|
|
798 |
Path(training_args.output_dir) / "checkpoints",
|
799 |
max_checkpoints=training_args.save_total_limit
|
800 |
)
|
801 |
+
convert(training_args.output_dir, "./")
|
802 |
+
model.save_pretrained(
|
803 |
+
training_args.output_dir,
|
804 |
+
params=params,
|
805 |
+
push_to_hub=training_args.push_to_hub,
|
806 |
+
commit_message=last_desc,
|
807 |
+
)
|
808 |
|
809 |
# update tqdm bar
|
810 |
steps.update(1)
|
811 |
+
|
812 |
+
if jax.process_index() == 0:
|
813 |
+
logger.info(f"Saving checkpoint at {step} steps")
|
814 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
815 |
+
model.save_pretrained(
|
816 |
+
training_args.output_dir,
|
817 |
+
params=params,
|
818 |
+
push_to_hub=False,
|
819 |
+
)
|
820 |
+
save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
|
821 |
+
checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
|
822 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
823 |
+
model.save_pretrained(checkpoints_dir, params=params)
|
824 |
+
save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
|
825 |
+
convert(training_args.output_dir, "./")
|
826 |
+
model.save_pretrained(
|
827 |
+
training_args.output_dir,
|
828 |
+
params=params,
|
829 |
+
push_to_hub=training_args.push_to_hub,
|
830 |
+
commit_message=last_desc,
|
831 |
+
)
|