marinone94 commited on
Commit
23bb45c
1 Parent(s): 6d4cdd4

allowing multiple datasets

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -328,24 +328,28 @@ def notify_me(recipient, message=None):
328
  smtp_obj.quit()
329
 
330
 
331
- def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
332
  """
333
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
334
  each split is loaded individually and then splits combined by taking alternating examples from
335
  each (interleaving).
336
  """
337
- if "+" in split:
338
  # load multiple splits separated by the `+` symbol with streaming mode
339
- dataset_splits = [
340
- load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
341
- for split_name in split.split("+")
342
- ]
 
 
 
 
343
  # interleave multiple splits to form one dataset
344
  interleaved_dataset = interleave_datasets(dataset_splits)
345
  return interleaved_dataset
346
  else:
347
  # load a single split *with* streaming mode
348
- dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
349
  return dataset
350
 
351
 
@@ -652,14 +656,22 @@ def main():
652
  elif last_checkpoint is not None:
653
  checkpoint = last_checkpoint
654
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
 
 
655
  trainer.save_model() # Saves the feature extractor too for easy upload
656
-
657
  metrics = train_result.metrics
658
  if data_args.max_train_samples:
659
  metrics["train_samples"] = data_args.max_train_samples
 
660
  trainer.log_metrics("train", metrics)
 
 
661
  trainer.save_metrics("train", metrics)
 
 
662
  trainer.save_state()
 
663
 
664
  # 13. Evaluation
665
  results = {}
@@ -670,13 +682,18 @@ def main():
670
  max_length=training_args.generation_max_length,
671
  num_beams=training_args.generation_num_beams,
672
  )
 
673
  if data_args.max_eval_samples:
674
  metrics["eval_samples"] = data_args.max_eval_samples
675
-
676
  trainer.log_metrics("eval", metrics)
 
 
677
  trainer.save_metrics("eval", metrics)
 
678
 
679
  # 14. Write Training Stats
 
680
  kwargs = {
681
  "finetuned_from": model_args.model_name_or_path,
682
  "tasks": "automatic-speech-recognition",
@@ -693,11 +710,14 @@ def main():
693
  if model_args.model_index_name is not None:
694
  kwargs["model_name"] = model_args.model_index_name
695
 
696
- logger.info("*** Pushing to hub ***")
697
  if training_args.push_to_hub:
 
698
  trainer.push_to_hub(**kwargs)
 
699
  else:
 
700
  trainer.create_model_card(**kwargs)
 
701
 
702
  # Training complete notification
703
  logger.info("*** Sending notification ***")
 
328
  smtp_obj.quit()
329
 
330
 
331
+ def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="train", streaming=True, **kwargs):
332
  """
333
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
334
  each split is loaded individually and then splits combined by taking alternating examples from
335
  each (interleaving).
336
  """
337
+ if "," in dataset_names or "+" in split:
338
  # load multiple splits separated by the `+` symbol with streaming mode
339
+ dataset_splits = []
340
+ for dataset_name, dataset_config_name, split_names in zip(
341
+ dataset_names.split(","), dataset_config_names.split(","), split.split(",")
342
+ ):
343
+ for split_name in split_names.split("+"):
344
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
345
+ dataset_splits.append(dataset)
346
+
347
  # interleave multiple splits to form one dataset
348
  interleaved_dataset = interleave_datasets(dataset_splits)
349
  return interleaved_dataset
350
  else:
351
  # load a single split *with* streaming mode
352
+ dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
353
  return dataset
354
 
355
 
 
656
  elif last_checkpoint is not None:
657
  checkpoint = last_checkpoint
658
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
659
+ logger.info("*** Training completed ***")
660
+ logger.info("*** Saving model ***")
661
  trainer.save_model() # Saves the feature extractor too for easy upload
662
+ logger.info("*** Model saves ***")
663
  metrics = train_result.metrics
664
  if data_args.max_train_samples:
665
  metrics["train_samples"] = data_args.max_train_samples
666
+ logger.info("*** Logging metrics ***")
667
  trainer.log_metrics("train", metrics)
668
+ logger.info("*** Metrics logged ***")
669
+ logger.info("*** Saving metrics ***")
670
  trainer.save_metrics("train", metrics)
671
+ logger.info("*** Metrics saved ***")
672
+ logger.info("*** Saving state ***")
673
  trainer.save_state()
674
+ logger.info("*** State saved ***")
675
 
676
  # 13. Evaluation
677
  results = {}
 
682
  max_length=training_args.generation_max_length,
683
  num_beams=training_args.generation_num_beams,
684
  )
685
+ logger.info("*** Evaluation done ***")
686
  if data_args.max_eval_samples:
687
  metrics["eval_samples"] = data_args.max_eval_samples
688
+ logger.info("*** Logging metrics ***")
689
  trainer.log_metrics("eval", metrics)
690
+ logger.info("*** Metrics logged ***")
691
+ logger.info("*** Saving metrics ***")
692
  trainer.save_metrics("eval", metrics)
693
+ logger.info("*** Metrics saved ***")
694
 
695
  # 14. Write Training Stats
696
+ logger.info("*** Writing training stats ***")
697
  kwargs = {
698
  "finetuned_from": model_args.model_name_or_path,
699
  "tasks": "automatic-speech-recognition",
 
710
  if model_args.model_index_name is not None:
711
  kwargs["model_name"] = model_args.model_index_name
712
 
 
713
  if training_args.push_to_hub:
714
+ logger.info("*** Pushing to hub ***")
715
  trainer.push_to_hub(**kwargs)
716
+ logger.info("*** Pushed to hub ***")
717
  else:
718
+ logger.info("*** Creating model card ***")
719
  trainer.create_model_card(**kwargs)
720
+ logger.info("*** Model card created ***")
721
 
722
  # Training complete notification
723
  logger.info("*** Sending notification ***")
test_run_nordic.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python $1run_speech_recognition_seq2seq_streaming.py \
2
+ --model_name_or_path="openai/whisper-tiny" \
3
+ --dataset_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC" \
4
+ --dataset_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk" \
5
+ --language="swedish" \
6
+ --train_split_name="train+validation,train+validation,train+validation,train,train+test, train,train+validation" \
7
+ --eval_split_name="test" \
8
+ --model_index_name="Whisper Tiny Swedish" \
9
+ --max_train_samples="64" \
10
+ --max_eval_samples="32" \
11
+ --max_steps="5000" \
12
+ --output_dir="./" \
13
+ --per_device_train_batch_size="8" \
14
+ --per_device_eval_batch_size="4" \
15
+ --logging_steps="25" \
16
+ --learning_rate="1e-5" \
17
+ --warmup_steps="500" \
18
+ --evaluation_strategy="steps" \
19
+ --eval_steps="1000" \
20
+ --save_strategy="steps" \
21
+ --save_steps="1000" \
22
+ --generation_max_length="225" \
23
+ --length_column_name="input_length" \
24
+ --max_duration_in_seconds="30" \
25
+ --text_column_name="sentence" \
26
+ --freeze_feature_encoder="False" \
27
+ --report_to="wandb" \
28
+ --metric_for_best_model="wer" \
29
+ --greater_is_better="False" \
30
+ --load_best_model_at_end \
31
+ --gradient_checkpointing \
32
+ --overwrite_output_dir \
33
+ --do_train \
34
+ --do_eval \
35
+ --predict_with_generate \
36
+ --do_normalize_eval \
37
+ --streaming \
38
+ --use_auth_token \
39
+ --push_to_hub