marinone94
commited on
Commit
•
23bb45c
1
Parent(s):
6d4cdd4
allowing multiple datasets
Browse files- run_speech_recognition_seq2seq_streaming.py +30 -10
- test_run_nordic.sh +39 -0
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -328,24 +328,28 @@ def notify_me(recipient, message=None):
|
|
328 |
smtp_obj.quit()
|
329 |
|
330 |
|
331 |
-
def load_maybe_streaming_dataset(
|
332 |
"""
|
333 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
334 |
each split is loaded individually and then splits combined by taking alternating examples from
|
335 |
each (interleaving).
|
336 |
"""
|
337 |
-
if "+" in split:
|
338 |
# load multiple splits separated by the `+` symbol with streaming mode
|
339 |
-
dataset_splits = [
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
343 |
# interleave multiple splits to form one dataset
|
344 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
345 |
return interleaved_dataset
|
346 |
else:
|
347 |
# load a single split *with* streaming mode
|
348 |
-
dataset = load_dataset(
|
349 |
return dataset
|
350 |
|
351 |
|
@@ -652,14 +656,22 @@ def main():
|
|
652 |
elif last_checkpoint is not None:
|
653 |
checkpoint = last_checkpoint
|
654 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
|
|
|
|
655 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
656 |
-
|
657 |
metrics = train_result.metrics
|
658 |
if data_args.max_train_samples:
|
659 |
metrics["train_samples"] = data_args.max_train_samples
|
|
|
660 |
trainer.log_metrics("train", metrics)
|
|
|
|
|
661 |
trainer.save_metrics("train", metrics)
|
|
|
|
|
662 |
trainer.save_state()
|
|
|
663 |
|
664 |
# 13. Evaluation
|
665 |
results = {}
|
@@ -670,13 +682,18 @@ def main():
|
|
670 |
max_length=training_args.generation_max_length,
|
671 |
num_beams=training_args.generation_num_beams,
|
672 |
)
|
|
|
673 |
if data_args.max_eval_samples:
|
674 |
metrics["eval_samples"] = data_args.max_eval_samples
|
675 |
-
|
676 |
trainer.log_metrics("eval", metrics)
|
|
|
|
|
677 |
trainer.save_metrics("eval", metrics)
|
|
|
678 |
|
679 |
# 14. Write Training Stats
|
|
|
680 |
kwargs = {
|
681 |
"finetuned_from": model_args.model_name_or_path,
|
682 |
"tasks": "automatic-speech-recognition",
|
@@ -693,11 +710,14 @@ def main():
|
|
693 |
if model_args.model_index_name is not None:
|
694 |
kwargs["model_name"] = model_args.model_index_name
|
695 |
|
696 |
-
logger.info("*** Pushing to hub ***")
|
697 |
if training_args.push_to_hub:
|
|
|
698 |
trainer.push_to_hub(**kwargs)
|
|
|
699 |
else:
|
|
|
700 |
trainer.create_model_card(**kwargs)
|
|
|
701 |
|
702 |
# Training complete notification
|
703 |
logger.info("*** Sending notification ***")
|
|
|
328 |
smtp_obj.quit()
|
329 |
|
330 |
|
331 |
+
def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="train", streaming=True, **kwargs):
|
332 |
"""
|
333 |
Utility function to load a dataset in streaming mode. For datasets with multiple splits,
|
334 |
each split is loaded individually and then splits combined by taking alternating examples from
|
335 |
each (interleaving).
|
336 |
"""
|
337 |
+
if "," in dataset_names or "+" in split:
|
338 |
# load multiple splits separated by the `+` symbol with streaming mode
|
339 |
+
dataset_splits = []
|
340 |
+
for dataset_name, dataset_config_name, split_names in zip(
|
341 |
+
dataset_names.split(","), dataset_config_names.split(","), split.split(",")
|
342 |
+
):
|
343 |
+
for split_name in split_names.split("+"):
|
344 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
|
345 |
+
dataset_splits.append(dataset)
|
346 |
+
|
347 |
# interleave multiple splits to form one dataset
|
348 |
interleaved_dataset = interleave_datasets(dataset_splits)
|
349 |
return interleaved_dataset
|
350 |
else:
|
351 |
# load a single split *with* streaming mode
|
352 |
+
dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
|
353 |
return dataset
|
354 |
|
355 |
|
|
|
656 |
elif last_checkpoint is not None:
|
657 |
checkpoint = last_checkpoint
|
658 |
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
659 |
+
logger.info("*** Training completed ***")
|
660 |
+
logger.info("*** Saving model ***")
|
661 |
trainer.save_model() # Saves the feature extractor too for easy upload
|
662 |
+
logger.info("*** Model saves ***")
|
663 |
metrics = train_result.metrics
|
664 |
if data_args.max_train_samples:
|
665 |
metrics["train_samples"] = data_args.max_train_samples
|
666 |
+
logger.info("*** Logging metrics ***")
|
667 |
trainer.log_metrics("train", metrics)
|
668 |
+
logger.info("*** Metrics logged ***")
|
669 |
+
logger.info("*** Saving metrics ***")
|
670 |
trainer.save_metrics("train", metrics)
|
671 |
+
logger.info("*** Metrics saved ***")
|
672 |
+
logger.info("*** Saving state ***")
|
673 |
trainer.save_state()
|
674 |
+
logger.info("*** State saved ***")
|
675 |
|
676 |
# 13. Evaluation
|
677 |
results = {}
|
|
|
682 |
max_length=training_args.generation_max_length,
|
683 |
num_beams=training_args.generation_num_beams,
|
684 |
)
|
685 |
+
logger.info("*** Evaluation done ***")
|
686 |
if data_args.max_eval_samples:
|
687 |
metrics["eval_samples"] = data_args.max_eval_samples
|
688 |
+
logger.info("*** Logging metrics ***")
|
689 |
trainer.log_metrics("eval", metrics)
|
690 |
+
logger.info("*** Metrics logged ***")
|
691 |
+
logger.info("*** Saving metrics ***")
|
692 |
trainer.save_metrics("eval", metrics)
|
693 |
+
logger.info("*** Metrics saved ***")
|
694 |
|
695 |
# 14. Write Training Stats
|
696 |
+
logger.info("*** Writing training stats ***")
|
697 |
kwargs = {
|
698 |
"finetuned_from": model_args.model_name_or_path,
|
699 |
"tasks": "automatic-speech-recognition",
|
|
|
710 |
if model_args.model_index_name is not None:
|
711 |
kwargs["model_name"] = model_args.model_index_name
|
712 |
|
|
|
713 |
if training_args.push_to_hub:
|
714 |
+
logger.info("*** Pushing to hub ***")
|
715 |
trainer.push_to_hub(**kwargs)
|
716 |
+
logger.info("*** Pushed to hub ***")
|
717 |
else:
|
718 |
+
logger.info("*** Creating model card ***")
|
719 |
trainer.create_model_card(**kwargs)
|
720 |
+
logger.info("*** Model card created ***")
|
721 |
|
722 |
# Training complete notification
|
723 |
logger.info("*** Sending notification ***")
|
test_run_nordic.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python $1run_speech_recognition_seq2seq_streaming.py \
|
2 |
+
--model_name_or_path="openai/whisper-tiny" \
|
3 |
+
--dataset_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC" \
|
4 |
+
--dataset_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk" \
|
5 |
+
--language="swedish" \
|
6 |
+
--train_split_name="train+validation,train+validation,train+validation,train,train+test, train,train+validation" \
|
7 |
+
--eval_split_name="test" \
|
8 |
+
--model_index_name="Whisper Tiny Swedish" \
|
9 |
+
--max_train_samples="64" \
|
10 |
+
--max_eval_samples="32" \
|
11 |
+
--max_steps="5000" \
|
12 |
+
--output_dir="./" \
|
13 |
+
--per_device_train_batch_size="8" \
|
14 |
+
--per_device_eval_batch_size="4" \
|
15 |
+
--logging_steps="25" \
|
16 |
+
--learning_rate="1e-5" \
|
17 |
+
--warmup_steps="500" \
|
18 |
+
--evaluation_strategy="steps" \
|
19 |
+
--eval_steps="1000" \
|
20 |
+
--save_strategy="steps" \
|
21 |
+
--save_steps="1000" \
|
22 |
+
--generation_max_length="225" \
|
23 |
+
--length_column_name="input_length" \
|
24 |
+
--max_duration_in_seconds="30" \
|
25 |
+
--text_column_name="sentence" \
|
26 |
+
--freeze_feature_encoder="False" \
|
27 |
+
--report_to="wandb" \
|
28 |
+
--metric_for_best_model="wer" \
|
29 |
+
--greater_is_better="False" \
|
30 |
+
--load_best_model_at_end \
|
31 |
+
--gradient_checkpointing \
|
32 |
+
--overwrite_output_dir \
|
33 |
+
--do_train \
|
34 |
+
--do_eval \
|
35 |
+
--predict_with_generate \
|
36 |
+
--do_normalize_eval \
|
37 |
+
--streaming \
|
38 |
+
--use_auth_token \
|
39 |
+
--push_to_hub
|