marinone94 commited on
Commit
fbda210
1 Parent(s): 8829a08

clean code. add logs. log audio correctly

Browse files
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +91 -46
run_speech_recognition_ctc.py CHANGED
@@ -22,7 +22,6 @@ TODO:
22
  """
23
 
24
  import datetime
25
- import functools
26
  import json
27
  import logging
28
  import os
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
34
 
35
  import datasets
36
  import numpy as np
37
- import pandas as pd
38
  import torch
39
  import wandb
40
  from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
@@ -382,9 +380,11 @@ def log_to_wandb(training_args):
382
  wandb.login()
383
  training_args.report_to = ["wandb"]
384
  training_args.run_name = run_name
 
385
  except Exception as e:
386
  logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
387
 
 
388
 
389
  def detect_last_checkpoint(training_args):
390
 
@@ -417,7 +417,7 @@ def log_small_sumary(training_args):
417
  logger.info("Training/evaluation parameters %s", training_args)
418
 
419
 
420
- def load_dataset(training_args, data_args):
421
 
422
  raw_datasets = DatasetDict()
423
 
@@ -470,7 +470,7 @@ def load_dataset(training_args, data_args):
470
  return raw_datasets
471
 
472
 
473
- def clean_dataset(raw_datasets, training_args, data_args):
474
 
475
  chars_to_ignore_regex = (
476
  f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
@@ -528,7 +528,7 @@ def clean_dataset(raw_datasets, training_args, data_args):
528
  return raw_datasets
529
 
530
 
531
- def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args, config):
532
 
533
  tokenizer_name_or_path = model_args.tokenizer_name_or_path
534
  tokenizer_kwargs = {}
@@ -546,7 +546,7 @@ def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args,
546
  if not os.path.isfile(vocab_file):
547
  os.makedirs(tokenizer_name_or_path, exist_ok=True)
548
  vocab_dict = create_vocabulary_from_data(
549
- raw_datasets,
550
  word_delimiter_token=data_args.word_delimiter_token,
551
  unk_token=data_args.unk_token,
552
  pad_token=data_args.pad_token,
@@ -566,17 +566,22 @@ def create_tokenizer_kwargs(raw_datasets, training_args, model_args, data_args,
566
  "word_delimiter_token": data_args.word_delimiter_token,
567
  }
568
 
569
- return tokenizer_kwargs
570
 
571
 
572
- def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args, data_args):
573
 
574
  # make sure that dataset decodes audio with correct sampling rate
575
- dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
576
  if dataset_sampling_rate != feature_extractor.sampling_rate:
577
- raw_datasets = raw_datasets.cast_column(
578
  data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
579
  )
 
 
 
 
 
580
 
581
  # derive max & min input length for sample rate & max duration
582
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
@@ -606,15 +611,15 @@ def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args,
606
 
607
  with training_args.main_process_first(desc="dataset map preprocessing"):
608
  vectorized_datasets = DatasetDict()
609
- vectorized_datasets["train"] = raw_datasets["train"].map(
610
  prepare_dataset,
611
- remove_columns=raw_datasets["train"].column_names,
612
  num_proc=data_args.preprocessing_num_workers,
613
  desc="preprocess datasets",
614
  )
615
- vectorized_datasets["eval"] = raw_datasets["eval"].map(
616
  prepare_dataset,
617
- remove_columns=raw_datasets["eval"].column_names,
618
  num_proc=data_args.preprocessing_num_workers,
619
  desc="preprocess datasets",
620
  )
@@ -628,30 +633,57 @@ def vectorize_dataset(raw_datasets, feature_extractor, tokenizer, training_args,
628
  num_proc=data_args.preprocessing_num_workers,
629
  input_columns=["input_length"],
630
  )
 
 
 
 
 
 
 
 
631
 
632
 
633
- def log_dataset_sample_on_wandb(vectorized_datasets, audio_column_name):
 
 
 
 
 
 
 
634
 
635
- pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
636
- pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
 
 
 
 
 
 
 
 
 
 
637
 
638
  dict_log = {}
639
- for i, audio in pd_train[audio_column_name]:
640
  dict_log[f"Training sample {i}"] = wandb.Audio(
641
- audio["array"],
642
- audio_rate=audio["sampling_rate"]
643
  )
644
- for i, audio in pd_eval[audio_column_name]:
 
 
645
  dict_log[f"Eval sample {i}"] = wandb.Audio(
646
- audio["array"],
647
- audio_rate=audio["sampling_rate"]
648
  )
 
 
649
 
650
- wandb.log({
651
- "Training samples": pd_train.drop(labels=audio_column_name, axis=1),
652
- "Eval samples": pd_eval.drop(labels=audio_column_name, axis=1),
653
- "Audio samples": dict_log
654
- })
655
 
656
 
657
  def prepare_training(
@@ -671,11 +703,6 @@ def prepare_training(
671
  if data_args.dataset_seed is not None:
672
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
673
 
674
- log_dataset_sample_on_wandb(
675
- vectorized_datasets=vectorized_datasets,
676
- audio_column_name=data_args.audio_column_name
677
- )
678
-
679
  # for large datasets it is advised to run the preprocessing on a
680
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
681
  # be a timeout when running the script in distributed mode.
@@ -722,7 +749,7 @@ def prepare_training(
722
  data_collator = DataCollatorCTCWithPadding(processor=processor)
723
 
724
  # Initialize Trainer
725
- return Trainer(
726
  model=model,
727
  data_collator=data_collator,
728
  args=training_args,
@@ -731,6 +758,7 @@ def prepare_training(
731
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
732
  tokenizer=feature_extractor,
733
  )
 
734
 
735
 
736
  def do_training(
@@ -786,7 +814,7 @@ def do_eval(
786
  return trainer
787
 
788
 
789
- def log_results(trainer, training_args, model_args, data_args):
790
 
791
  config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
792
  kwargs = {
@@ -806,6 +834,7 @@ def log_results(trainer, training_args, model_args, data_args):
806
 
807
 
808
  def inst_model_tokenizer_feature_extractor(
 
809
  tokenizer_kwargs,
810
  training_args,
811
  model_args,
@@ -815,7 +844,7 @@ def inst_model_tokenizer_feature_extractor(
815
 
816
  # load tokenizer
817
  tokenizer = AutoTokenizer.from_pretrained(
818
- model_args.tokenizer_name_or_path,
819
  use_auth_token=data_args.use_auth_token,
820
  **tokenizer_kwargs,
821
  )
@@ -874,67 +903,78 @@ def main():
874
  else:
875
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
876
 
877
- # 1. Set logging
878
  set_log_config_and_level(local_rank=training_args.local_rank)
879
  training_args = log_to_wandb(training_args=training_args)
880
  log_small_sumary(training_args=training_args)
 
881
 
882
  # 2. Set random seed
883
  set_seed(training_args.seed)
 
884
 
885
- # 3. First, let's load the dataset
886
- raw_datasets = load_dataset(training_args=training_args, data_args=data_args)
 
887
 
888
  # 4. We remove some special characters from the datasets
889
  # that make training complicated and do not help in transcribing the speech
890
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
891
  # that could be easily picked up by the model
892
- raw_datasets = clean_dataset(
893
  raw_datasets=raw_datasets,
894
  training_args=training_args,
895
  data_args=data_args
896
  )
 
897
 
898
  # 5. Next, let's load the config as we might need it to create the tokenizer
899
  config = AutoConfig.from_pretrained(
900
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
 
 
901
  )
 
902
 
903
  # 6. Next, if no tokenizer file is defined,
904
  # we create the vocabulary of the model by extracting all unique characters from
905
  # the training and evaluation datasets
906
  # We need to make sure that only first rank saves vocabulary
907
  # make sure all processes wait until vocab is created
908
- tokenizer_kwargs = create_tokenizer_kwargs(
909
- raw_datasets=raw_datasets,
910
  training_args=training_args,
911
  model_args=model_args,
912
  data_args=data_args,
913
  config=config
914
  )
 
915
 
916
  # 7. Now we can instantiate the feature extractor, tokenizer and model
917
  # Note for distributed training, the .from_pretrained methods guarantee that only
918
  # one local process can concurrently download model & vocab.
919
  model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
 
920
  tokenizer_kwargs=tokenizer_kwargs,
921
  training_args=training_args,
922
  model_args=model_args,
923
  data_args=data_args,
924
  config=config
925
  )
 
926
 
927
  # 8. Now we preprocess the datasets including loading the audio, resampling and normalization
928
  # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
929
  # so that we just need to set the correct target sampling rate and normalize the input
930
  # via the `feature_extractor`
931
  vectorized_datasets = vectorize_dataset(
932
- raw_datasets=raw_datasets,
933
  feature_extractor=feature_extractor,
934
  tokenizer=tokenizer,
935
  training_args=training_args,
936
  data_args=data_args
937
  )
 
938
 
939
  # 9. Next, we can prepare the training.
940
  # Let's use word error rate (WER) as our evaluation metric,
@@ -948,9 +988,11 @@ def main():
948
  data_args=data_args,
949
  config=config
950
  )
 
951
 
952
  # 10. Train model
953
  last_checkpoint = detect_last_checkpoint(training_args=training_args)
 
954
  if training_args.do_train:
955
  trainer = do_training(
956
  trainer=trainer,
@@ -959,6 +1001,7 @@ def main():
959
  model_args=model_args,
960
  data_args=data_args
961
  )
 
962
 
963
  # 11. Eval model
964
  if training_args.do_eval:
@@ -967,15 +1010,17 @@ def main():
967
  vectorized_datasets=vectorized_datasets,
968
  data_args=data_args
969
  )
 
970
 
971
  # 12. Push to hub and update model card
972
- log_results(
973
  trainer=trainer,
974
  training_args=training_args,
975
  model_args=model_args,
976
  data_args=data_args
977
  )
978
-
 
979
 
980
  if __name__ == "__main__":
981
  main()
22
  """
23
 
24
  import datetime
 
25
  import json
26
  import logging
27
  import os
33
 
34
  import datasets
35
  import numpy as np
 
36
  import torch
37
  import wandb
38
  from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
380
  wandb.login()
381
  training_args.report_to = ["wandb"]
382
  training_args.run_name = run_name
383
+ wandb.init()
384
  except Exception as e:
385
  logger.warning(f"\nFailed logging in to wandb: {e}\nThis experiment will not be logged.\n")
386
 
387
+ return training_args
388
 
389
  def detect_last_checkpoint(training_args):
390
 
417
  logger.info("Training/evaluation parameters %s", training_args)
418
 
419
 
420
+ def load_datasets(training_args, data_args):
421
 
422
  raw_datasets = DatasetDict()
423
 
470
  return raw_datasets
471
 
472
 
473
+ def clean_datasets(raw_datasets, training_args, data_args):
474
 
475
  chars_to_ignore_regex = (
476
  f'[{"".join(data_args.chars_to_ignore)}]' if data_args.chars_to_ignore is not None else None
528
  return raw_datasets
529
 
530
 
531
+ def create_tokenizer_args(cleaned_datasets, training_args, model_args, data_args, config):
532
 
533
  tokenizer_name_or_path = model_args.tokenizer_name_or_path
534
  tokenizer_kwargs = {}
546
  if not os.path.isfile(vocab_file):
547
  os.makedirs(tokenizer_name_or_path, exist_ok=True)
548
  vocab_dict = create_vocabulary_from_data(
549
+ cleaned_datasets,
550
  word_delimiter_token=data_args.word_delimiter_token,
551
  unk_token=data_args.unk_token,
552
  pad_token=data_args.pad_token,
566
  "word_delimiter_token": data_args.word_delimiter_token,
567
  }
568
 
569
+ return tokenizer_name_or_path, tokenizer_kwargs
570
 
571
 
572
+ def vectorize_dataset(cleaned_datasets, feature_extractor, tokenizer, training_args, data_args):
573
 
574
  # make sure that dataset decodes audio with correct sampling rate
575
+ dataset_sampling_rate = next(iter(cleaned_datasets.values())).features[data_args.audio_column_name].sampling_rate
576
  if dataset_sampling_rate != feature_extractor.sampling_rate:
577
+ cleaned_datasets = cleaned_datasets.cast_column(
578
  data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
579
  )
580
+
581
+ log_metadata_on_wandb(
582
+ cleaned_datasets=cleaned_datasets,
583
+ audio_column_name=data_args.audio_column_name
584
+ )
585
 
586
  # derive max & min input length for sample rate & max duration
587
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
611
 
612
  with training_args.main_process_first(desc="dataset map preprocessing"):
613
  vectorized_datasets = DatasetDict()
614
+ vectorized_datasets["train"] = cleaned_datasets["train"].map(
615
  prepare_dataset,
616
+ remove_columns=cleaned_datasets["train"].column_names,
617
  num_proc=data_args.preprocessing_num_workers,
618
  desc="preprocess datasets",
619
  )
620
+ vectorized_datasets["eval"] = cleaned_datasets["eval"].map(
621
  prepare_dataset,
622
+ remove_columns=cleaned_datasets["eval"].column_names,
623
  num_proc=data_args.preprocessing_num_workers,
624
  desc="preprocess datasets",
625
  )
633
  num_proc=data_args.preprocessing_num_workers,
634
  input_columns=["input_length"],
635
  )
636
+
637
+ log_audio_on_wandb(
638
+ vectorized_datasets=vectorized_datasets,
639
+ audio_column_name="input_values",
640
+ sampling_rate=feature_extractor.sampling_rate
641
+ )
642
+
643
+ return vectorized_datasets
644
 
645
 
646
+ def log_metadata_on_wandb(
647
+ cleaned_datasets,
648
+ audio_column_name,
649
+ max_samples=10
650
+ ):
651
+
652
+ pd_train = cleaned_datasets["train"].select(range(max_samples)).to_pandas()
653
+ pd_eval = cleaned_datasets["eval"].select(range(max_samples)).to_pandas()
654
 
655
+ wandb.log({
656
+ "Training samples": pd_train.drop(labels=audio_column_name, axis=1),
657
+ "Eval samples": pd_eval.drop(labels=audio_column_name, axis=1),
658
+ })
659
+
660
+
661
+ def log_audio_on_wandb(
662
+ vectorized_datasets,
663
+ audio_column_name,
664
+ sampling_rate,
665
+ max_samples=10
666
+ ):
667
 
668
  dict_log = {}
669
+ for i, array in enumerate(vectorized_datasets["train"][audio_column_name]):
670
  dict_log[f"Training sample {i}"] = wandb.Audio(
671
+ array,
672
+ sample_rate=sampling_rate
673
  )
674
+ if i+1 == max_samples:
675
+ break
676
+ for i, array in enumerate(vectorized_datasets["eval"][audio_column_name]):
677
  dict_log[f"Eval sample {i}"] = wandb.Audio(
678
+ array,
679
+ sample_rate=sampling_rate
680
  )
681
+ if i+1 == max_samples:
682
+ break
683
 
684
+ print("\nLogging audio to wandb...\n")
685
+ wandb.log({"Audio samples": dict_log})
686
+ print("\nLogged audio to wandb...\n")
 
 
687
 
688
 
689
  def prepare_training(
703
  if data_args.dataset_seed is not None:
704
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
705
 
 
 
 
 
 
706
  # for large datasets it is advised to run the preprocessing on a
707
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
708
  # be a timeout when running the script in distributed mode.
749
  data_collator = DataCollatorCTCWithPadding(processor=processor)
750
 
751
  # Initialize Trainer
752
+ trainer = Trainer(
753
  model=model,
754
  data_collator=data_collator,
755
  args=training_args,
758
  eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
759
  tokenizer=feature_extractor,
760
  )
761
+ return trainer
762
 
763
 
764
  def do_training(
814
  return trainer
815
 
816
 
817
+ def log_and_push_results(trainer, training_args, model_args, data_args):
818
 
819
  config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
820
  kwargs = {
834
 
835
 
836
  def inst_model_tokenizer_feature_extractor(
837
+ tokenizer_name_or_path,
838
  tokenizer_kwargs,
839
  training_args,
840
  model_args,
844
 
845
  # load tokenizer
846
  tokenizer = AutoTokenizer.from_pretrained(
847
+ tokenizer_name_or_path,
848
  use_auth_token=data_args.use_auth_token,
849
  **tokenizer_kwargs,
850
  )
903
  else:
904
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
905
 
906
+ # 1. Set logs
907
  set_log_config_and_level(local_rank=training_args.local_rank)
908
  training_args = log_to_wandb(training_args=training_args)
909
  log_small_sumary(training_args=training_args)
910
+ logger.info("Logs set\n")
911
 
912
  # 2. Set random seed
913
  set_seed(training_args.seed)
914
+ logger.info("Seed set\n")
915
 
916
+ # 3. First, let's load the datasets
917
+ raw_datasets = load_datasets(training_args=training_args, data_args=data_args)
918
+ logger.info("Dataset loaded\n")
919
 
920
  # 4. We remove some special characters from the datasets
921
  # that make training complicated and do not help in transcribing the speech
922
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
923
  # that could be easily picked up by the model
924
+ cleaned_datasets = clean_datasets(
925
  raw_datasets=raw_datasets,
926
  training_args=training_args,
927
  data_args=data_args
928
  )
929
+ logger.info("Dataset cleaned\n")
930
 
931
  # 5. Next, let's load the config as we might need it to create the tokenizer
932
  config = AutoConfig.from_pretrained(
933
+ model_args.model_name_or_path,
934
+ cache_dir=model_args.cache_dir,
935
+ use_auth_token=data_args.use_auth_token
936
  )
937
+ logger.info("Config loaded\n")
938
 
939
  # 6. Next, if no tokenizer file is defined,
940
  # we create the vocabulary of the model by extracting all unique characters from
941
  # the training and evaluation datasets
942
  # We need to make sure that only first rank saves vocabulary
943
  # make sure all processes wait until vocab is created
944
+ tokenizer_name_or_path, tokenizer_kwargs = create_tokenizer_args(
945
+ cleaned_datasets=cleaned_datasets,
946
  training_args=training_args,
947
  model_args=model_args,
948
  data_args=data_args,
949
  config=config
950
  )
951
+ logger.info("Tokenizer args loaded\n")
952
 
953
  # 7. Now we can instantiate the feature extractor, tokenizer and model
954
  # Note for distributed training, the .from_pretrained methods guarantee that only
955
  # one local process can concurrently download model & vocab.
956
  model, tokenizer, feature_extractor, config = inst_model_tokenizer_feature_extractor(
957
+ tokenizer_name_or_path=tokenizer_name_or_path,
958
  tokenizer_kwargs=tokenizer_kwargs,
959
  training_args=training_args,
960
  model_args=model_args,
961
  data_args=data_args,
962
  config=config
963
  )
964
+ logger.info("Model, tokenizer, feature_extractor and config loaded\n")
965
 
966
  # 8. Now we preprocess the datasets including loading the audio, resampling and normalization
967
  # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
968
  # so that we just need to set the correct target sampling rate and normalize the input
969
  # via the `feature_extractor`
970
  vectorized_datasets = vectorize_dataset(
971
+ cleaned_datasets=cleaned_datasets,
972
  feature_extractor=feature_extractor,
973
  tokenizer=tokenizer,
974
  training_args=training_args,
975
  data_args=data_args
976
  )
977
+ logger.info("Dataset vectorized\n")
978
 
979
  # 9. Next, we can prepare the training.
980
  # Let's use word error rate (WER) as our evaluation metric,
988
  data_args=data_args,
989
  config=config
990
  )
991
+ logger.info("Trainer instantiated\n")
992
 
993
  # 10. Train model
994
  last_checkpoint = detect_last_checkpoint(training_args=training_args)
995
+ logger.info("Last checkpoint detected\n")
996
  if training_args.do_train:
997
  trainer = do_training(
998
  trainer=trainer,
1001
  model_args=model_args,
1002
  data_args=data_args
1003
  )
1004
+ logger.info("Training completed\n")
1005
 
1006
  # 11. Eval model
1007
  if training_args.do_eval:
1010
  vectorized_datasets=vectorized_datasets,
1011
  data_args=data_args
1012
  )
1013
+ logger.info("Eval completed\n")
1014
 
1015
  # 12. Push to hub and update model card
1016
+ log_and_push_results(
1017
  trainer=trainer,
1018
  training_args=training_args,
1019
  model_args=model_args,
1020
  data_args=data_args
1021
  )
1022
+ logger.info("Results logged\n")
1023
+
1024
 
1025
  if __name__ == "__main__":
1026
  main()