marinone94 commited on
Commit
ed53c37
1 Parent(s): 2786d7b
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -337,6 +337,7 @@ def main():
337
  # See all possible arguments in src/transformers/training_args.py
338
  # or by passing the --help flag to this script.
339
  # We now keep distinct sets of args, for a cleaner separation of concerns.
 
340
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
341
 
342
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
@@ -351,6 +352,7 @@ def main():
351
  send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
352
 
353
  # 2. Setup logging
 
354
  logging.basicConfig(
355
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
356
  datefmt="%m/%d/%Y %H:%M:%S",
@@ -396,6 +398,7 @@ def main():
396
  set_seed(training_args.seed)
397
 
398
  # 4. Load dataset
 
399
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
400
 
401
  if training_args.do_train:
@@ -433,7 +436,7 @@ def main():
433
  )
434
 
435
  # 5. Load pretrained model, tokenizer, and feature extractor
436
- #
437
  # Distributed training:
438
  # The .from_pretrained methods guarantee that only one local process can concurrently
439
  config = AutoConfig.from_pretrained(
@@ -483,6 +486,7 @@ def main():
483
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
484
 
485
  # 6. Resample speech dataset if necessary
 
486
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
487
  if dataset_sampling_rate != feature_extractor.sampling_rate:
488
  raw_datasets = raw_datasets.cast_column(
@@ -491,6 +495,7 @@ def main():
491
 
492
  # 7. Preprocessing the datasets.
493
  # We need to read the audio files as arrays and tokenize the targets.
 
494
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
495
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
496
  audio_column_name = data_args.audio_column_name
@@ -554,6 +559,7 @@ def main():
554
  )
555
 
556
  # 8. Load Metric
 
557
  metric = evaluate.load("wer")
558
  do_normalize_eval = data_args.do_normalize_eval
559
 
@@ -578,6 +584,7 @@ def main():
578
  return {"wer": wer}
579
 
580
  # 9. Create a single speech processor
 
581
  if is_main_process(training_args.local_rank):
582
  # save feature extractor, tokenizer and config
583
  feature_extractor.save_pretrained(training_args.output_dir)
@@ -595,6 +602,7 @@ def main():
595
  # 11. Configure Trainer
596
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
597
  # Only required for streaming: Trainer automatically shuffles non-streaming datasets
 
598
  class ShuffleCallback(TrainerCallback):
599
  def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
600
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
@@ -603,6 +611,7 @@ def main():
603
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
604
 
605
  # Initialize Trainer
 
606
  trainer = Seq2SeqTrainer(
607
  model=model,
608
  args=training_args,
@@ -616,6 +625,7 @@ def main():
616
 
617
  # 12. Training
618
  if training_args.do_train:
 
619
  checkpoint = None
620
  if training_args.resume_from_checkpoint is not None:
621
  checkpoint = training_args.resume_from_checkpoint
@@ -663,14 +673,18 @@ def main():
663
  if model_args.model_index_name is not None:
664
  kwargs["model_name"] = model_args.model_index_name
665
 
 
666
  if training_args.push_to_hub:
667
  trainer.push_to_hub(**kwargs)
668
  else:
669
  trainer.create_model_card(**kwargs)
670
 
671
  # Training complete notification
 
672
  notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
673
 
 
 
674
  return results
675
 
676
 
 
337
  # See all possible arguments in src/transformers/training_args.py
338
  # or by passing the --help flag to this script.
339
  # We now keep distinct sets of args, for a cleaner separation of concerns.
340
+ logger.info("*** Parse args ***")
341
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
342
 
343
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
 
352
  send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
353
 
354
  # 2. Setup logging
355
+ logger.info("*** Setup logging ***")
356
  logging.basicConfig(
357
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
358
  datefmt="%m/%d/%Y %H:%M:%S",
 
398
  set_seed(training_args.seed)
399
 
400
  # 4. Load dataset
401
+ logger.info("*** Load dataset ***")
402
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
403
 
404
  if training_args.do_train:
 
436
  )
437
 
438
  # 5. Load pretrained model, tokenizer, and feature extractor
439
+ logger.info("*** Load pretrained model, tokenizer, and feature extractor ***")
440
  # Distributed training:
441
  # The .from_pretrained methods guarantee that only one local process can concurrently
442
  config = AutoConfig.from_pretrained(
 
486
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
487
 
488
  # 6. Resample speech dataset if necessary
489
+ logger.info("*** Resample dataset ***")
490
  dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
491
  if dataset_sampling_rate != feature_extractor.sampling_rate:
492
  raw_datasets = raw_datasets.cast_column(
 
495
 
496
  # 7. Preprocessing the datasets.
497
  # We need to read the audio files as arrays and tokenize the targets.
498
+ logger.info("*** Preprocess dataset ***")
499
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
500
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
501
  audio_column_name = data_args.audio_column_name
 
559
  )
560
 
561
  # 8. Load Metric
562
+ logger.info("*** Load metric ***")
563
  metric = evaluate.load("wer")
564
  do_normalize_eval = data_args.do_normalize_eval
565
 
 
584
  return {"wer": wer}
585
 
586
  # 9. Create a single speech processor
587
+ logger.info("*** Init processor ***")
588
  if is_main_process(training_args.local_rank):
589
  # save feature extractor, tokenizer and config
590
  feature_extractor.save_pretrained(training_args.output_dir)
 
602
  # 11. Configure Trainer
603
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
604
  # Only required for streaming: Trainer automatically shuffles non-streaming datasets
605
+ logger.info("*** Set shuffle callback ***")
606
  class ShuffleCallback(TrainerCallback):
607
  def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
608
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
 
611
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
612
 
613
  # Initialize Trainer
614
+ logger.info("*** Init trainer ***")
615
  trainer = Seq2SeqTrainer(
616
  model=model,
617
  args=training_args,
 
625
 
626
  # 12. Training
627
  if training_args.do_train:
628
+ logger.info("*** Train ***")
629
  checkpoint = None
630
  if training_args.resume_from_checkpoint is not None:
631
  checkpoint = training_args.resume_from_checkpoint
 
673
  if model_args.model_index_name is not None:
674
  kwargs["model_name"] = model_args.model_index_name
675
 
676
+ logger.info("*** Pushing to hub ***")
677
  if training_args.push_to_hub:
678
  trainer.push_to_hub(**kwargs)
679
  else:
680
  trainer.create_model_card(**kwargs)
681
 
682
  # Training complete notification
683
+ logger.info("*** Sending notification ***")
684
  notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
685
 
686
+ logger.info("*** Training complete!!! ***")
687
+
688
  return results
689
 
690