Automatic Speech Recognition
Transformers
4 languages
whisper
whisper-event
Generated from Trainer
Inference Endpoints
marinone94 commited on
Commit
98dfb11
1 Parent(s): a9f9b4a

avoid pushing checkpoints

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -797,6 +797,9 @@ def main():
797
  )
798
  logger.info("*** Trainer initialized ***")
799
 
 
 
 
800
  # 12. Training
801
  if training_args.do_train:
802
  logger.info("*** Train ***")
@@ -812,10 +815,7 @@ def main():
812
  # We don't want to push the model to the hub now
813
  # so we temporarily set to false the push_to_hub attribute
814
  # and then reset it to the original value
815
- orig_push_to_hub = trainer.args.push_to_hub
816
- trainer.args.push_to_hub = False
817
  trainer.save_model() # Saves the feature extractor too for easy upload
818
- trainer.args.push_to_hub = orig_push_to_hub
819
  logger.info("*** Model saved ***")
820
  metrics = train_result.metrics
821
  if data_args.max_train_samples:
@@ -909,7 +909,7 @@ def main():
909
  notify_me(recipient=RECIPIENT_ADDRESS,
910
  message=f"Training complete! {train_results = } {eval_results = }")
911
 
912
-
913
  if training_args.push_to_hub:
914
  logger.info("*** Pushing to hub ***")
915
  trainer.push_to_hub(**kwargs)
 
797
  )
798
  logger.info("*** Trainer initialized ***")
799
 
800
+ orig_push_to_hub = trainer.args.push_to_hub
801
+ trainer.args.push_to_hub = False
802
+
803
  # 12. Training
804
  if training_args.do_train:
805
  logger.info("*** Train ***")
 
815
  # We don't want to push the model to the hub now
816
  # so we temporarily set to false the push_to_hub attribute
817
  # and then reset it to the original value
 
 
818
  trainer.save_model() # Saves the feature extractor too for easy upload
 
819
  logger.info("*** Model saved ***")
820
  metrics = train_result.metrics
821
  if data_args.max_train_samples:
 
909
  notify_me(recipient=RECIPIENT_ADDRESS,
910
  message=f"Training complete! {train_results = } {eval_results = }")
911
 
912
+ trainer.args.push_to_hub = orig_push_to_hub
913
  if training_args.push_to_hub:
914
  logger.info("*** Pushing to hub ***")
915
  trainer.push_to_hub(**kwargs)