test
Browse files
events.out.tfevents.1672834909.t1v-n-0853dee6-w-3.740708.0.v2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa375387021b8ed7a3b3fd7ca6197ff9cba91079608f8279052917b0fca28325
|
| 3 |
+
size 7637
|
run_mlm_flax_stream.py
CHANGED
|
@@ -42,6 +42,7 @@ import optax
|
|
| 42 |
from flax import jax_utils, traverse_util
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import get_metrics, onehot, shard
|
|
|
|
| 45 |
from transformers import (
|
| 46 |
CONFIG_MAPPING,
|
| 47 |
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
|
@@ -660,10 +661,17 @@ if __name__ == "__main__":
|
|
| 660 |
model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
|
| 661 |
model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
|
| 662 |
|
| 663 |
-
print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}."
|
| 664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
breakpoint()
|
| 666 |
-
model.
|
|
|
|
| 667 |
training_args.output_dir,
|
| 668 |
params=params
|
| 669 |
)
|
|
|
|
| 42 |
from flax import jax_utils, traverse_util
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import get_metrics, onehot, shard
|
| 45 |
+
from huggingface_hub import Repository
|
| 46 |
from transformers import (
|
| 47 |
CONFIG_MAPPING,
|
| 48 |
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
|
|
|
| 661 |
model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
|
| 662 |
model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
|
| 663 |
|
| 664 |
+
print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}."
|
| 665 |
+
|
| 666 |
+
#Also push the relevant events files
|
| 667 |
+
repo = Repository(local_dir=training_args.output_dir)
|
| 668 |
+
except: repo.git_pull()
|
| 669 |
+
repo.git_add("*.*")
|
| 670 |
+
repo.git_commit(commit_message="Pushing some additional files")
|
| 671 |
+
repo.git_push()
|
| 672 |
breakpoint()
|
| 673 |
+
model.save_prdired(i
|
| 674 |
+
|
| 675 |
training_args.output_dir,
|
| 676 |
params=params
|
| 677 |
)
|