marinone94
commited on
Commit
·
833e02b
1
Parent(s):
96a2519
fix script
Browse files
run_speech_recognition_seq2seq_streaming.py
CHANGED
@@ -165,10 +165,16 @@ class DataTrainingArguments:
|
|
165 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
166 |
"""
|
167 |
|
168 |
-
|
169 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
170 |
)
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
173 |
)
|
174 |
text_column: Optional[str] = field(
|
@@ -529,17 +535,17 @@ def main():
|
|
529 |
)
|
530 |
|
531 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
532 |
-
|
533 |
if data_args.audio_column_name not in raw_datasets_features:
|
534 |
raise ValueError(
|
535 |
-
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.
|
536 |
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
537 |
f"{', '.join(raw_datasets_features)}."
|
538 |
)
|
539 |
|
540 |
-
if
|
541 |
raise ValueError(
|
542 |
-
f"--text_column_name {
|
543 |
"Make sure to set `--text_column_name` to the correct text column - one of "
|
544 |
f"{', '.join(raw_datasets_features)}."
|
545 |
)
|
@@ -600,7 +606,6 @@ def main():
|
|
600 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
601 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
602 |
audio_column_name = data_args.audio_column_name
|
603 |
-
text_column_name = data_args.text_column_name
|
604 |
model_input_name = feature_extractor.model_input_names[0]
|
605 |
do_lower_case = data_args.do_lower_case
|
606 |
do_remove_punctuation = data_args.do_remove_punctuation
|
@@ -761,13 +766,13 @@ def main():
|
|
761 |
"tasks": "automatic-speech-recognition",
|
762 |
"tags": "whisper-event",
|
763 |
}
|
764 |
-
if data_args.
|
765 |
-
kwargs["dataset_tags"] = data_args.
|
766 |
if data_args.dataset_config_name is not None:
|
767 |
-
kwargs["dataset"] = f"{data_args.
|
768 |
else:
|
769 |
-
kwargs["dataset"] = data_args.
|
770 |
-
if "common_voice" in data_args.
|
771 |
kwargs["language"] = data_args.dataset_config_name[:2]
|
772 |
if model_args.model_index_name is not None:
|
773 |
kwargs["model_name"] = model_args.model_index_name
|
|
|
165 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
166 |
"""
|
167 |
|
168 |
+
dataset_train_name: str = field(
|
169 |
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
170 |
)
|
171 |
+
dataset_train_config_name: Optional[str] = field(
|
172 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
173 |
+
)
|
174 |
+
dataset_eval_name: str = field(
|
175 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
176 |
+
)
|
177 |
+
dataset_eval_config_name: Optional[str] = field(
|
178 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
179 |
)
|
180 |
text_column: Optional[str] = field(
|
|
|
535 |
)
|
536 |
|
537 |
raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
|
538 |
+
text_column_name = data_args.text_column_name.split(",")[0]
|
539 |
if data_args.audio_column_name not in raw_datasets_features:
|
540 |
raise ValueError(
|
541 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_train_name}'. "
|
542 |
"Make sure to set `--audio_column_name` to the correct audio column - one of "
|
543 |
f"{', '.join(raw_datasets_features)}."
|
544 |
)
|
545 |
|
546 |
+
if text_column_name not in raw_datasets_features:
|
547 |
raise ValueError(
|
548 |
+
f"--text_column_name {text_column_name} not found in dataset '{data_args.dataset_train_name}'. "
|
549 |
"Make sure to set `--text_column_name` to the correct text column - one of "
|
550 |
f"{', '.join(raw_datasets_features)}."
|
551 |
)
|
|
|
606 |
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
607 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
608 |
audio_column_name = data_args.audio_column_name
|
|
|
609 |
model_input_name = feature_extractor.model_input_names[0]
|
610 |
do_lower_case = data_args.do_lower_case
|
611 |
do_remove_punctuation = data_args.do_remove_punctuation
|
|
|
766 |
"tasks": "automatic-speech-recognition",
|
767 |
"tags": "whisper-event",
|
768 |
}
|
769 |
+
if data_args.dataset_train_name is not None:
|
770 |
+
kwargs["dataset_tags"] = data_args.dataset_train_name
|
771 |
if data_args.dataset_config_name is not None:
|
772 |
+
kwargs["dataset"] = f"{data_args.dataset_train_name} {data_args.dataset_config_name}"
|
773 |
else:
|
774 |
+
kwargs["dataset"] = data_args.dataset_train_name
|
775 |
+
if "common_voice" in data_args.dataset_train_name:
|
776 |
kwargs["language"] = data_args.dataset_config_name[:2]
|
777 |
if model_args.model_index_name is not None:
|
778 |
kwargs["model_name"] = model_args.model_index_name
|