marinone94 commited on
Commit
a057c82
1 Parent(s): f850b55

debug decoding

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -432,36 +432,6 @@ def load_maybe_streaming_dataset(
432
  return dataset
433
 
434
 
435
- def load_common_voice_like_dataset(
436
- dataset_name,
437
- config,
438
- split,
439
- audio_column_name=None,
440
- sampling_rate=None,
441
- streaming=True,
442
- use_auth_token=False
443
- ):
444
-
445
- """
446
- Utility function to load the Common Voice dataset.
447
- """
448
- dataset = load_dataset(
449
- dataset_name,
450
- config,
451
- split=split,
452
- streaming=streaming,
453
- use_auth_token=use_auth_token,
454
- )
455
- if audio_column_name is not None and sampling_rate is not None:
456
- dataset = dataset.cast_column(
457
- audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
458
- )
459
- return dataset
460
-
461
-
462
- # def load_nst_nbailab(config, split, )
463
-
464
-
465
  def main():
466
  # 1. Parse input arguments
467
  # See all possible arguments in src/transformers/training_args.py
@@ -476,8 +446,6 @@ def main():
476
  model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
477
  else:
478
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
479
- training_args.do_train = True
480
- training_args.do_eval = True
481
 
482
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
483
  # information sent is the one passed as arguments along with your Python/PyTorch versions.
@@ -541,6 +509,9 @@ def main():
541
  logger.info("*** Load dataset ***")
542
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
543
 
 
 
 
544
  if training_args.do_train:
545
  raw_datasets["train"] = load_maybe_streaming_dataset(
546
  data_args.dataset_train_name,
@@ -807,10 +778,31 @@ def main():
807
  trainer.save_state()
808
  logger.info("*** State saved ***")
809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
810
  # 13. Evaluation
811
  results = {}
812
  if training_args.do_eval:
813
  logger.info("*** Evaluate ***")
 
814
  metrics = trainer.evaluate(
815
  metric_key_prefix="eval",
816
  max_length=training_args.generation_max_length,
 
432
  return dataset
433
 
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  def main():
436
  # 1. Parse input arguments
437
  # See all possible arguments in src/transformers/training_args.py
 
446
  model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
447
  else:
448
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
 
449
 
450
  # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
451
  # information sent is the one passed as arguments along with your Python/PyTorch versions.
 
509
  logger.info("*** Load dataset ***")
510
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
511
 
512
+ if len(data_args.language_eval.split(",")) > 1:
513
+ raise ValueError("Implementation does not support multiple language evaluation.")
514
+
515
  if training_args.do_train:
516
  raw_datasets["train"] = load_maybe_streaming_dataset(
517
  data_args.dataset_train_name,
 
778
  trainer.save_state()
779
  logger.info("*** State saved ***")
780
 
781
+ # Run a test prediction to check outputs
782
+ predictions = trainer.predict(
783
+ test_dataset=vectorized_datasets["test"].shuffle(seed=training_args.seed).select(range(5)),
784
+ metric_key_prefix="test",
785
+ max_length=training_args.generation_max_length,
786
+ num_beams=training_args.generation_num_beams,
787
+ )
788
+ logger.info("*** Test prediction done ***")
789
+ predictions = processor.batch_decode(predictions.predictions)
790
+ labels = processor.batch_decode(predictions.label_ids)
791
+ pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(predictions, labels)]
792
+ logger.info("Before setting language and task")
793
+ logger.info(f"{pred_labels}")
794
+ trainer.data_collator.processor.tokenizer.set_prefix_tokens(language=data_args.language_eval, task=data_args.task)
795
+ predictions = processor.batch_decode(predictions.predictions)
796
+ labels = processor.batch_decode(predictions.label_ids)
797
+ pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(predictions, labels)]
798
+ logger.info("After setting language and task")
799
+ logger.info(f"{pred_labels}")
800
+
801
  # 13. Evaluation
802
  results = {}
803
  if training_args.do_eval:
804
  logger.info("*** Evaluate ***")
805
+
806
  metrics = trainer.evaluate(
807
  metric_key_prefix="eval",
808
  max_length=training_args.generation_max_length,
test_run_nordic.sh CHANGED
@@ -4,9 +4,9 @@ python $1run_speech_recognition_seq2seq_streaming.py \
4
  --dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
  --language_train="sv,da,no,sv,no,no,sv,da,no" \
6
  --train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
7
- --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
8
- --dataset_eval_config_name="sv-SE,da,nn-NO" \
9
- --language_eval="sv,da,no" \
10
  --eval_split_name="test" \
11
  --model_index_name="Whisper Tiny Nordic" \
12
  --max_train_samples="64" \
 
4
  --dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
  --language_train="sv,da,no,sv,no,no,sv,da,no" \
6
  --train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
7
+ --dataset_eval_name="mozilla-foundation/common_voice_11_0" \
8
+ --dataset_eval_config_name="sv-SE" \
9
+ --language_eval="sv" \
10
  --eval_split_name="test" \
11
  --model_index_name="Whisper Tiny Nordic" \
12
  --max_train_samples="64" \