Changed print to logger
Browse files- run_mlm_flax_stream.py +1 -1
run_mlm_flax_stream.py
CHANGED
@@ -700,7 +700,7 @@ if __name__ == "__main__":
|
|
700 |
|
701 |
# save checkpoint after eval_steps
|
702 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
703 |
-
|
704 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
705 |
model.save_pretrained(
|
706 |
training_args.output_dir,
|
|
|
700 |
|
701 |
# save checkpoint after eval_steps
|
702 |
if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
|
703 |
+
logger.info(f"Saving checkpoint at {step + 1} steps")
|
704 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
705 |
model.save_pretrained(
|
706 |
training_args.output_dir,
|