marinone94
commited on
Commit
•
a057c82
1
Parent(s):
f850b55
debug decoding
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -432,36 +432,6 @@ def load_maybe_streaming_dataset(
|
|
432 |
return dataset
|
433 |
|
434 |
|
435 |
-
def load_common_voice_like_dataset(
|
436 |
-
dataset_name,
|
437 |
-
config,
|
438 |
-
split,
|
439 |
-
audio_column_name=None,
|
440 |
-
sampling_rate=None,
|
441 |
-
streaming=True,
|
442 |
-
use_auth_token=False
|
443 |
-
):
|
444 |
-
|
445 |
-
"""
|
446 |
-
Utility function to load the Common Voice dataset.
|
447 |
-
"""
|
448 |
-
dataset = load_dataset(
|
449 |
-
dataset_name,
|
450 |
-
config,
|
451 |
-
split=split,
|
452 |
-
streaming=streaming,
|
453 |
-
use_auth_token=use_auth_token,
|
454 |
-
)
|
455 |
-
if audio_column_name is not None and sampling_rate is not None:
|
456 |
-
dataset = dataset.cast_column(
|
457 |
-
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
|
458 |
-
)
|
459 |
-
return dataset
|
460 |
-
|
461 |
-
|
462 |
-
# def load_nst_nbailab(config, split, )
|
463 |
-
|
464 |
-
|
465 |
def main():
|
466 |
# 1. Parse input arguments
|
467 |
# See all possible arguments in src/transformers/training_args.py
|
@@ -476,8 +446,6 @@ def main():
|
|
476 |
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
477 |
else:
|
478 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
479 |
-
training_args.do_train = True
|
480 |
-
training_args.do_eval = True
|
481 |
|
482 |
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
483 |
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
@@ -541,6 +509,9 @@ def main():
|
|
541 |
logger.info("*** Load dataset ***")
|
542 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
543 |
|
|
|
|
|
|
|
544 |
if training_args.do_train:
|
545 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
546 |
data_args.dataset_train_name,
|
@@ -807,10 +778,31 @@ def main():
|
|
807 |
trainer.save_state()
|
808 |
logger.info("*** State saved ***")
|
809 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
# 13. Evaluation
|
811 |
results = {}
|
812 |
if training_args.do_eval:
|
813 |
logger.info("*** Evaluate ***")
|
|
|
814 |
metrics = trainer.evaluate(
|
815 |
metric_key_prefix="eval",
|
816 |
max_length=training_args.generation_max_length,
|
|
|
432 |
return dataset
|
433 |
|
434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
def main():
|
436 |
# 1. Parse input arguments
|
437 |
# See all possible arguments in src/transformers/training_args.py
|
|
|
446 |
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
447 |
else:
|
448 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
449 |
|
450 |
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
451 |
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
|
|
509 |
logger.info("*** Load dataset ***")
|
510 |
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
|
511 |
|
512 |
+
if len(data_args.language_eval.split(",")) > 1:
|
513 |
+
raise ValueError("Implementation does not support multiple language evaluation.")
|
514 |
+
|
515 |
if training_args.do_train:
|
516 |
raw_datasets["train"] = load_maybe_streaming_dataset(
|
517 |
data_args.dataset_train_name,
|
|
|
778 |
trainer.save_state()
|
779 |
logger.info("*** State saved ***")
|
780 |
|
781 |
+
# Run a test prediction to check outputs
|
782 |
+
predictions = trainer.predict(
|
783 |
+
test_dataset=vectorized_datasets["test"].shuffle(seed=training_args.seed).select(range(5)),
|
784 |
+
metric_key_prefix="test",
|
785 |
+
max_length=training_args.generation_max_length,
|
786 |
+
num_beams=training_args.generation_num_beams,
|
787 |
+
)
|
788 |
+
logger.info("*** Test prediction done ***")
|
789 |
+
predictions = processor.batch_decode(predictions.predictions)
|
790 |
+
labels = processor.batch_decode(predictions.label_ids)
|
791 |
+
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(predictions, labels)]
|
792 |
+
logger.info("Before setting language and task")
|
793 |
+
logger.info(f"{pred_labels}")
|
794 |
+
trainer.data_collator.processor.tokenizer.set_prefix_tokens(language=data_args.language_eval, task=data_args.task)
|
795 |
+
predictions = processor.batch_decode(predictions.predictions)
|
796 |
+
labels = processor.batch_decode(predictions.label_ids)
|
797 |
+
pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(predictions, labels)]
|
798 |
+
logger.info("After setting language and task")
|
799 |
+
logger.info(f"{pred_labels}")
|
800 |
+
|
801 |
# 13. Evaluation
|
802 |
results = {}
|
803 |
if training_args.do_eval:
|
804 |
logger.info("*** Evaluate ***")
|
805 |
+
|
806 |
metrics = trainer.evaluate(
|
807 |
metric_key_prefix="eval",
|
808 |
max_length=training_args.generation_max_length,
|
test_run_nordic.sh
CHANGED
@@ -4,9 +4,9 @@ python $1run_speech_recognition_seq2seq_streaming.py \
|
|
4 |
--dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
5 |
--language_train="sv,da,no,sv,no,no,sv,da,no" \
|
6 |
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
|
7 |
-
--dataset_eval_name="mozilla-foundation/common_voice_11_0
|
8 |
-
--dataset_eval_config_name="sv-SE
|
9 |
-
--language_eval="sv
|
10 |
--eval_split_name="test" \
|
11 |
--model_index_name="Whisper Tiny Nordic" \
|
12 |
--max_train_samples="64" \
|
|
|
4 |
--dataset_train_config_name="sv-SE,da,nn-NO,nst,no-distant,16K_mp3_nynorsk,sv_se,da_dk,nb_no" \
|
5 |
--language_train="sv,da,no,sv,no,no,sv,da,no" \
|
6 |
--train_split_name="train+validation,train+validation,train+validation,train,train+test,train+validation,train+validation,train+validation,train+validation" \
|
7 |
+
--dataset_eval_name="mozilla-foundation/common_voice_11_0" \
|
8 |
+
--dataset_eval_config_name="sv-SE" \
|
9 |
+
--language_eval="sv" \
|
10 |
--eval_split_name="test" \
|
11 |
--model_index_name="Whisper Tiny Nordic" \
|
12 |
--max_train_samples="64" \
|