Rolv-Arild commited on
Commit
8dfd77e
1 Parent(s): 7a4ad37

Change text column

Browse files

Adjust warmup steps

Files changed (2) hide show
  1. run.sh +4 -5
  2. run_speech_recognition_seq2seq.py +6 -2
run.sh CHANGED
@@ -1,7 +1,7 @@
1
  python run_speech_recognition_seq2seq.py \
2
  --dataset_name="NbAiLab/NPSC" \
3
- --dataset_config_name="16K_mp3" \
4
- --data_cache_dir="/mnt/lv_ai_1_ficino/rolvb/cache" \
5
  --model_name_or_path="./" \
6
  --output_dir="./" \
7
  --preprocessing_num_workers="16" \
@@ -12,7 +12,7 @@ python run_speech_recognition_seq2seq.py \
12
  --per_device_eval_batch_size="8" \
13
  --gradient_accumulation_steps="8" \
14
  --learning_rate="3e-4" \
15
- --warmup_steps="400" \
16
  --evaluation_strategy="steps" \
17
  --text_column_name="text" \
18
  --save_steps="400" \
@@ -28,5 +28,4 @@ python run_speech_recognition_seq2seq.py \
28
  --generation_num_beams="1" \
29
  --do_train --do_eval \
30
  --do_lower_case \
31
- --preprocessing_num_workers="8" \
32
- --push_to_hub
 
1
  python run_speech_recognition_seq2seq.py \
2
  --dataset_name="NbAiLab/NPSC" \
3
+ --dataset_config_name="16K_mp3" \
4
+ --data_cache_dir="/mnt/lv_ai_1_ficino/rolvb/cache" \
5
  --model_name_or_path="./" \
6
  --output_dir="./" \
7
  --preprocessing_num_workers="16" \
 
12
  --per_device_eval_batch_size="8" \
13
  --gradient_accumulation_steps="8" \
14
  --learning_rate="3e-4" \
15
+ --warmup_steps="1000" \
16
  --evaluation_strategy="steps" \
17
  --text_column_name="text" \
18
  --save_steps="400" \
 
28
  --generation_num_beams="1" \
29
  --do_train --do_eval \
30
  --do_lower_case \
31
+ --push_to_hub
 
run_speech_recognition_seq2seq.py CHANGED
@@ -21,6 +21,7 @@ Fine-tuning the library models for sequence to sequence speech recognition.
21
 
22
  import logging
23
  import os
 
24
  import sys
25
  from dataclasses import dataclass, field
26
  from typing import Any, Dict, List, Optional, Union
@@ -355,7 +356,6 @@ def main():
355
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
356
  audio_column_name = data_args.audio_column_name
357
  num_workers = data_args.preprocessing_num_workers
358
- text_column_name = data_args.text_column_name
359
  model_input_name = feature_extractor.model_input_names[0]
360
  do_lower_case = data_args.do_lower_case
361
 
@@ -373,8 +373,12 @@ def main():
373
  batch[model_input_name] = inputs.input_values[0]
374
  batch["input_length"] = len(batch["input_values"])
375
 
 
376
  # process targets
377
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
 
 
 
378
  batch["labels"] = tokenizer(input_str).input_ids
379
  return batch
380
 
@@ -389,7 +393,7 @@ def main():
389
  # filter data that is shorter than min_input_length or longer than
390
  # max_input_length
391
  def is_audio_in_length_range(length):
392
- return length > min_input_length and length < max_input_length
393
 
394
  vectorized_datasets = vectorized_datasets.filter(
395
  is_audio_in_length_range,
 
21
 
22
  import logging
23
  import os
24
+ import re
25
  import sys
26
  from dataclasses import dataclass, field
27
  from typing import Any, Dict, List, Optional, Union
 
356
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
357
  audio_column_name = data_args.audio_column_name
358
  num_workers = data_args.preprocessing_num_workers
 
359
  model_input_name = feature_extractor.model_input_names[0]
360
  do_lower_case = data_args.do_lower_case
361
 
 
373
  batch[model_input_name] = inputs.input_values[0]
374
  batch["input_length"] = len(batch["input_values"])
375
 
376
+ text_column_name = "transsentence_text" if batch["sentence_language_code"] == "nn-NO" else "normsentence_text"
377
  # process targets
378
  input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
379
+
380
+ input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
381
+
382
  batch["labels"] = tokenizer(input_str).input_ids
383
  return batch
384
 
 
393
  # filter data that is shorter than min_input_length or longer than
394
  # max_input_length
395
  def is_audio_in_length_range(length):
396
+ return min_input_length < length < max_input_length
397
 
398
  vectorized_datasets = vectorized_datasets.filter(
399
  is_audio_in_length_range,