Rolv-Arild
commited on
Commit
•
8dfd77e
1
Parent(s):
7a4ad37
Change text column
Browse filesAdjust warmup steps
- run.sh +4 -5
- run_speech_recognition_seq2seq.py +6 -2
run.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
python run_speech_recognition_seq2seq.py \
|
2 |
--dataset_name="NbAiLab/NPSC" \
|
3 |
-
|
4 |
-
|
5 |
--model_name_or_path="./" \
|
6 |
--output_dir="./" \
|
7 |
--preprocessing_num_workers="16" \
|
@@ -12,7 +12,7 @@ python run_speech_recognition_seq2seq.py \
|
|
12 |
--per_device_eval_batch_size="8" \
|
13 |
--gradient_accumulation_steps="8" \
|
14 |
--learning_rate="3e-4" \
|
15 |
-
--warmup_steps="
|
16 |
--evaluation_strategy="steps" \
|
17 |
--text_column_name="text" \
|
18 |
--save_steps="400" \
|
@@ -28,5 +28,4 @@ python run_speech_recognition_seq2seq.py \
|
|
28 |
--generation_num_beams="1" \
|
29 |
--do_train --do_eval \
|
30 |
--do_lower_case \
|
31 |
-
|
32 |
-
--push_to_hub
|
|
|
1 |
python run_speech_recognition_seq2seq.py \
|
2 |
--dataset_name="NbAiLab/NPSC" \
|
3 |
+
--dataset_config_name="16K_mp3" \
|
4 |
+
--data_cache_dir="/mnt/lv_ai_1_ficino/rolvb/cache" \
|
5 |
--model_name_or_path="./" \
|
6 |
--output_dir="./" \
|
7 |
--preprocessing_num_workers="16" \
|
|
|
12 |
--per_device_eval_batch_size="8" \
|
13 |
--gradient_accumulation_steps="8" \
|
14 |
--learning_rate="3e-4" \
|
15 |
+
--warmup_steps="1000" \
|
16 |
--evaluation_strategy="steps" \
|
17 |
--text_column_name="text" \
|
18 |
--save_steps="400" \
|
|
|
28 |
--generation_num_beams="1" \
|
29 |
--do_train --do_eval \
|
30 |
--do_lower_case \
|
31 |
+
--push_to_hub
|
|
run_speech_recognition_seq2seq.py
CHANGED
@@ -21,6 +21,7 @@ Fine-tuning the library models for sequence to sequence speech recognition.
|
|
21 |
|
22 |
import logging
|
23 |
import os
|
|
|
24 |
import sys
|
25 |
from dataclasses import dataclass, field
|
26 |
from typing import Any, Dict, List, Optional, Union
|
@@ -355,7 +356,6 @@ def main():
|
|
355 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
356 |
audio_column_name = data_args.audio_column_name
|
357 |
num_workers = data_args.preprocessing_num_workers
|
358 |
-
text_column_name = data_args.text_column_name
|
359 |
model_input_name = feature_extractor.model_input_names[0]
|
360 |
do_lower_case = data_args.do_lower_case
|
361 |
|
@@ -373,8 +373,12 @@ def main():
|
|
373 |
batch[model_input_name] = inputs.input_values[0]
|
374 |
batch["input_length"] = len(batch["input_values"])
|
375 |
|
|
|
376 |
# process targets
|
377 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
|
|
|
|
|
|
378 |
batch["labels"] = tokenizer(input_str).input_ids
|
379 |
return batch
|
380 |
|
@@ -389,7 +393,7 @@ def main():
|
|
389 |
# filter data that is shorter than min_input_length or longer than
|
390 |
# max_input_length
|
391 |
def is_audio_in_length_range(length):
|
392 |
-
return
|
393 |
|
394 |
vectorized_datasets = vectorized_datasets.filter(
|
395 |
is_audio_in_length_range,
|
|
|
21 |
|
22 |
import logging
|
23 |
import os
|
24 |
+
import re
|
25 |
import sys
|
26 |
from dataclasses import dataclass, field
|
27 |
from typing import Any, Dict, List, Optional, Union
|
|
|
356 |
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
357 |
audio_column_name = data_args.audio_column_name
|
358 |
num_workers = data_args.preprocessing_num_workers
|
|
|
359 |
model_input_name = feature_extractor.model_input_names[0]
|
360 |
do_lower_case = data_args.do_lower_case
|
361 |
|
|
|
373 |
batch[model_input_name] = inputs.input_values[0]
|
374 |
batch["input_length"] = len(batch["input_values"])
|
375 |
|
376 |
+
text_column_name = "transsentence_text" if batch["sentence_language_code"] == "nn-NO" else "normsentence_text"
|
377 |
# process targets
|
378 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
379 |
+
|
380 |
+
input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
|
381 |
+
|
382 |
batch["labels"] = tokenizer(input_str).input_ids
|
383 |
return batch
|
384 |
|
|
|
393 |
# filter data that is shorter than min_input_length or longer than
|
394 |
# max_input_length
|
395 |
def is_audio_in_length_range(length):
|
396 |
+
return min_input_length < length < max_input_length
|
397 |
|
398 |
vectorized_datasets = vectorized_datasets.filter(
|
399 |
is_audio_in_length_range,
|