pere commited on
Commit
bcb4f0f
·
1 Parent(s): b6b7395
events.out.tfevents.1672834909.t1v-n-0853dee6-w-3.740708.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa375387021b8ed7a3b3fd7ca6197ff9cba91079608f8279052917b0fca28325
3
+ size 7637
run_mlm_flax_stream.py CHANGED
@@ -42,6 +42,7 @@ import optax
42
  from flax import jax_utils, traverse_util
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard
 
45
  from transformers import (
46
  CONFIG_MAPPING,
47
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@@ -660,10 +661,17 @@ if __name__ == "__main__":
660
  model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
661
  model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
662
 
663
- print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}.")
664
- except:
 
 
 
 
 
 
665
  breakpoint()
666
- model.save_pretrained(
 
667
  training_args.output_dir,
668
  params=params
669
  )
 
42
  from flax import jax_utils, traverse_util
43
  from flax.training import train_state
44
  from flax.training.common_utils import get_metrics, onehot, shard
45
+ from huggingface_hub import Repository
46
  from transformers import (
47
  CONFIG_MAPPING,
48
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
 
661
  model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
662
  model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
663
 
664
+ print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_folder} by worker {jax.process_index()}."
665
+
666
+ #Also push the relevant events files
667
+ repo = Repository(local_dir=training_args.output_dir)
668
+ except: repo.git_pull()
669
+ repo.git_add("*.*")
670
+ repo.git_commit(commit_message="Pushing some additional files")
671
+ repo.git_push()
672
  breakpoint()
673
+ model.save_prdired(i
674
+
675
  training_args.output_dir,
676
  params=params
677
  )