marinone94 commited on
Commit
4e5c598
1 Parent(s): fb5ea5a

fix training script

Browse files
Files changed (4) hide show
  1. added_tokens.json +0 -1
  2. run.sh +2 -2
  3. run_speech_recognition_ctc.py +13 -5
  4. vocab.json +1 -1
added_tokens.json DELETED
@@ -1 +0,0 @@
1
- {"<s>": 33, "</s>": 34}
 
 
run.sh CHANGED
@@ -1,8 +1,8 @@
1
  python run_speech_recognition_ctc.py \
2
- --dataset_name="mozilla-foundation/common_voice_7_0,marinone94/nst_sv" \
3
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
4
  --dataset_config_name="sv-SE,distant_channel" \
5
- --train_split_name="None,train" \
6
  --eval_split_name="test,None" \
7
  --output_dir="./" \
8
  --overwrite_output_dir \
 
1
  python run_speech_recognition_ctc.py \
2
+ --dataset_name="mozilla-foundation/common_voice_8_0,marinone94/nst_sv" \
3
  --model_name_or_path="KBLab/wav2vec2-large-voxrex" \
4
  --dataset_config_name="sv-SE,distant_channel" \
5
+ --train_split_name="train+validation,train" \
6
  --eval_split_name="test,None" \
7
  --output_dir="./" \
8
  --overwrite_output_dir \
run_speech_recognition_ctc.py CHANGED
@@ -28,6 +28,7 @@ from typing import Dict, List, Optional, Union
28
 
29
  import datasets
30
  import numpy as np
 
31
  import torch
32
  import wandb
33
  from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
@@ -376,7 +377,7 @@ def main():
376
  wandb.login()
377
  training_args.report_to = ["wandb"]
378
  training_args.run_name = run_name
379
- wandb.init()
380
  except:
381
  pass
382
 
@@ -480,6 +481,11 @@ def main():
480
  other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
481
  raw_datasets["train"].remove_columns(other_columns_train)
482
 
 
 
 
 
 
483
  if training_args.do_eval:
484
  # Multiple datasets might need to be loaded from HF
485
  # It assumes they all follow the common voice format
@@ -520,6 +526,11 @@ def main():
520
  other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
521
  raw_datasets["eval"].remove_columns(other_columns_eval)
522
 
 
 
 
 
 
523
  # 2. We remove some special characters from the datasets
524
  # that make training complicated and do not help in transcribing the speech
525
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
@@ -755,15 +766,12 @@ def main():
755
  if data_args.dataset_seed is not None:
756
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
757
 
758
- # Log sample of datasets
759
  pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
760
  pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
761
  # wandb.log({"train_sample": pd_train})
762
  # wandb.log({"eval_sample": pd_eval})
763
 
764
- print(pd_train)
765
- print(pd_eval)
766
-
767
  # for large datasets it is advised to run the preprocessing on a
768
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
769
  # be a timeout when running the script in distributed mode.
 
28
 
29
  import datasets
30
  import numpy as np
31
+ import pandas as pd
32
  import torch
33
  import wandb
34
  from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
 
377
  wandb.login()
378
  training_args.report_to = ["wandb"]
379
  training_args.run_name = run_name
380
+ # wandb.init()
381
  except:
382
  pass
383
 
 
481
  other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
482
  raw_datasets["train"].remove_columns(other_columns_train)
483
 
484
+ # pd_train_head = raw_datasets["train"].select(range(10)).to_pandas()
485
+ # pd_train_tail = raw_datasets["train"].select(range(raw_datasets["train"].num_rows-10, raw_datasets["train"].num_rows)).to_pandas()
486
+ # pd_train = pd.concat([pd_train_head, pd_train_tail])
487
+ # print(pd_train["audio"])
488
+
489
  if training_args.do_eval:
490
  # Multiple datasets might need to be loaded from HF
491
  # It assumes they all follow the common voice format
 
526
  other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
527
  raw_datasets["eval"].remove_columns(other_columns_eval)
528
 
529
+ # pd_eval_head = raw_datasets["eval"].select(range(10)).to_pandas()
530
+ # pd_eval_tail = raw_datasets["eval"].select(range(raw_datasets["eval"].num_rows-10, raw_datasets["eval"].num_rows)).to_pandas()
531
+ # pd_eval = pd.concat([pd_eval_head, pd_eval_tail])
532
+ # print(pd_eval["audio"])
533
+
534
  # 2. We remove some special characters from the datasets
535
  # that make training complicated and do not help in transcribing the speech
536
  # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
 
766
  if data_args.dataset_seed is not None:
767
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(seed=data_args.dataset_seed)
768
 
769
+ # TODO: Log sample of datasets in the right way (see wandb docs)
770
  pd_train = vectorized_datasets["train"].select(range(10)).to_pandas()
771
  pd_eval = vectorized_datasets["eval"].select(range(10)).to_pandas()
772
  # wandb.log({"train_sample": pd_train})
773
  # wandb.log({"eval_sample": pd_eval})
774
 
 
 
 
775
  # for large datasets it is advised to run the preprocessing on a
776
  # single machine first with ``args.preprocessing_only`` since there will mostly likely
777
  # be a timeout when running the script in distributed mode.
vocab.json CHANGED
@@ -1 +1 @@
1
- {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "ä": 27, "å": 28, "ô": 29, "ö": 30, "|": 0, "[UNK]": 31, "[PAD]": 32}
 
1
+ {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e4": 27, "\u00e5": 28, "\u00f6": 29, "|": 0, "[UNK]": 30, "[PAD]": 31}