test
Browse files- run_mlm_flax_stream.py +4 -0
run_mlm_flax_stream.py
CHANGED
|
@@ -660,6 +660,10 @@ if __name__ == "__main__":
|
|
| 660 |
|
| 661 |
model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
|
| 662 |
model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_dir} by worker {jax.process_index()}.")
|
| 665 |
|
|
|
|
| 660 |
|
| 661 |
model_pt.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for torch of step {step+1}")
|
| 662 |
model_flax.push_to_hub(training_args.hub_model_id,commit_message=f"Weights for flax of step {step+1}")
|
| 663 |
+
|
| 664 |
+
#Delete the models to free memory
|
| 665 |
+
del(model_pt)
|
| 666 |
+
del(model_flax)
|
| 667 |
|
| 668 |
print(f"Saving weights and logs of step {step+1}. \nThe result is saved to {training_args.output_dir} by worker {jax.process_index()}.")
|
| 669 |
|