jonatasgrosman commited on
Commit
7852deb
1 Parent(s): de62d73

update model

Browse files
Files changed (5) hide show
  1. README.md +54 -41
  2. config.json +9 -9
  3. pytorch_model.bin +2 -2
  4. special_tokens_map.json +1 -1
  5. vocab.json +1 -1
README.md CHANGED
@@ -4,6 +4,7 @@ datasets:
4
  - common_voice
5
  metrics:
6
  - wer
 
7
  tags:
8
  - audio
9
  - automatic-speech-recognition
@@ -23,53 +24,68 @@ model-index:
23
  metrics:
24
  - name: Test WER
25
  type: wer
26
- value: 62.39
 
 
 
27
  ---
28
 
29
  # Wav2Vec2-Large-XLSR-53-Finnish
30
 
31
- Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Finnish using the [Common Voice](https://huggingface.co/datasets/common_voice).
32
  When using this model, make sure that your speech input is sampled at 16kHz.
33
 
 
 
34
  ## Usage
35
 
36
  The model can be used directly (without a language model) as follows:
37
 
38
  ```python
39
  import torch
40
- import torchaudio
41
  from datasets import load_dataset
42
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
43
 
44
  LANG_ID = "fi"
45
  MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-finnish"
 
46
 
47
- test_dataset = load_dataset("common_voice", LANG_ID, split="test[:2%]")
48
 
49
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
50
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
51
 
52
- resampler = torchaudio.transforms.Resample(48_000, 16_000)
53
-
54
  # Preprocessing the datasets.
55
  # We need to read the audio files as arrays
56
  def speech_file_to_array_fn(batch):
57
- \tspeech_array, sampling_rate = torchaudio.load(batch["path"])
58
- \tbatch["speech"] = resampler(speech_array).squeeze().numpy()
59
- \treturn batch
 
60
 
61
  test_dataset = test_dataset.map(speech_file_to_array_fn)
62
- inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
63
 
64
  with torch.no_grad():
65
- \tlogits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
66
 
67
  predicted_ids = torch.argmax(logits, dim=-1)
 
68
 
69
- print("Prediction:", processor.batch_decode(predicted_ids))
70
- print("Reference:", test_dataset["sentence"][:2])
 
 
71
  ```
72
 
 
 
 
 
 
 
 
73
 
74
  ## Evaluation
75
 
@@ -77,45 +93,38 @@ The model can be evaluated as follows on the Finnish test data of Common Voice.
77
 
78
  ```python
79
  import torch
80
- import torchaudio
 
81
  from datasets import load_dataset, load_metric
82
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
83
- import re
84
- import homoglyphs as hg
85
 
86
  LANG_ID = "fi"
87
  MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-finnish"
88
  DEVICE = "cuda"
89
 
90
- CHARS_TO_IGNORE = [",", "?", ".", "!", "-", ";", ":", '""', "%", "'", '"', "�", "·", "", "¿", "¡", "~", "՞", "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》"]
91
- CURRENCY_SYMBOLS = ["{{%htmlContent%}}quot;, "£", "", "¥", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "฿", "¢"]
 
92
 
93
  test_dataset = load_dataset("common_voice", LANG_ID, split="test")
94
- wer = load_metric("wer")
95
 
96
- unk_regex = None
97
- if LANG_ID in hg.Languages.get_all():
98
- # creating regex to match language specific non valid characters
99
- alphabet = list(hg.Languages.get_alphabet([LANG_ID]))
100
- valid_chars = alphabet + CURRENCY_SYMBOLS
101
- unk_regex = "[^"+re.escape("".join(valid_chars))+"\\s\\d]"
102
 
103
- chars_to_ignore_regex = f'[{re.escape("".join(CHARS_TO_IGNORE))}]'
104
 
105
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
106
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
107
  model.to(DEVICE)
108
 
109
- resampler = torchaudio.transforms.Resample(48_000, 16_000)
110
-
111
  # Preprocessing the datasets.
112
  # We need to read the audio files as arrays
113
  def speech_file_to_array_fn(batch):
114
- batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
115
- if unk_regex is not None:
116
- batch["sentence"] = re.sub(unk_regex, "[UNK]", batch["sentence"])
117
- speech_array, sampling_rate = torchaudio.load(batch["path"])
118
- batch["speech"] = resampler(speech_array).squeeze().numpy()
119
  return batch
120
 
121
  test_dataset = test_dataset.map(speech_file_to_array_fn)
@@ -123,18 +132,22 @@ test_dataset = test_dataset.map(speech_file_to_array_fn)
123
  # Preprocessing the datasets.
124
  # We need to read the audio files as arrays
125
  def evaluate(batch):
126
- \tinputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
127
 
128
- \twith torch.no_grad():
129
- \t\tlogits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
130
 
131
- \tpred_ids = torch.argmax(logits, dim=-1)
132
- \tbatch["pred_strings"] = processor.batch_decode(pred_ids)
133
- \treturn batch
134
 
135
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
136
 
137
- print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
 
138
  ```
139
 
140
- **Test Result**: 62.39%
 
 
 
4
  - common_voice
5
  metrics:
6
  - wer
7
+ - cer
8
  tags:
9
  - audio
10
  - automatic-speech-recognition
24
  metrics:
25
  - name: Test WER
26
  type: wer
27
+ value: 41.60
28
+ - name: Test CER
29
+ type: cer
30
+ value: 8.23
31
  ---
32
 
33
  # Wav2Vec2-Large-XLSR-53-Finnish
34
 
35
+ Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Finnish using the [Common Voice](https://huggingface.co/datasets/common_voice) and [CSS10](https://github.com/Kyubyong/css10).
36
  When using this model, make sure that your speech input is sampled at 16kHz.
37
 
38
+ The script used for training can be found here: https://github.com/jonatasgrosman/wav2vec2-sprint
39
+
40
  ## Usage
41
 
42
  The model can be used directly (without a language model) as follows:
43
 
44
  ```python
45
  import torch
46
+ import librosa
47
  from datasets import load_dataset
48
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
49
 
50
  LANG_ID = "fi"
51
  MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-finnish"
52
+ SAMPLES = 5
53
 
54
+ test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
55
 
56
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
57
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
58
 
 
 
59
  # Preprocessing the datasets.
60
  # We need to read the audio files as arrays
61
  def speech_file_to_array_fn(batch):
62
+ speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
63
+ batch["speech"] = speech_array
64
+ batch["sentence"] = batch["sentence"].upper()
65
+ return batch
66
 
67
  test_dataset = test_dataset.map(speech_file_to_array_fn)
68
+ inputs = processor(test_dataset["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
69
 
70
  with torch.no_grad():
71
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
72
 
73
  predicted_ids = torch.argmax(logits, dim=-1)
74
+ predicted_sentences = processor.batch_decode(predicted_ids)
75
 
76
+ for i, predicted_sentence in enumerate(predicted_sentences):
77
+ print("-" * 100)
78
+ print("Reference:", test_dataset[i]["sentence"])
79
+ print("Prediction:", predicted_sentence)
80
  ```
81
 
82
+ | Reference | Prediction |
83
+ | ------------- | ------------- |
84
+ | MYSTEERIMIES OLI OPPINUT MORAALINSA TARUISTA, ELOKUVISTA JA PELEISTÄ. | MYSTEERIMIES OLI OPPINUT MORALINSA TARUISTA ELOKUVISTA JA PELEISTÄ |
85
+ | ÄÄNESTIN MIETINNÖN PUOLESTA! | ÄÄNESTIN MIETINNÖN PUOLESTA |
86
+ | VAIN TUNTIA AIKAISEMMIN OLIMME MIEHENI KANSSA TUNTENEET SUURINTA ILOA. | PAIN TUNTIA AIKAISEMMIN OLIN MIEHENI KANSSA TUNTENEET SUURINTA ILAA |
87
+ | ENSIMMÄISELLE MIEHELLE SAI KOLME LASTA. | ENSIMMÄISELLE MIEHELLE SAI KOLME LASTA |
88
+ | ÄÄNESTIN MIETINNÖN PUOLESTA, SILLÄ POHJIMMILTAAN SIINÄ VASTUSTETAAN TÄTÄ SUUNTAUSTA. | ÄÄNESTIN MIETINNÖN PUOLESTA SILLÄ POHJIMMILTAAN SIINÄ VASTOTTETAAN TÄTÄ SUUNTAUSTA |
89
 
90
  ## Evaluation
91
 
93
 
94
  ```python
95
  import torch
96
+ import re
97
+ import librosa
98
  from datasets import load_dataset, load_metric
99
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
 
100
 
101
  LANG_ID = "fi"
102
  MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-finnish"
103
  DEVICE = "cuda"
104
 
105
+ CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ":", '""', "%", '"', "�", "ʿ", "·", "", "~", "՞",
106
+ "؟", "،", "", "", "«", "»", "", "", "", "", "", "", "", "", "", "(", ")", "[", "]",
107
+ "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。"]
108
 
109
  test_dataset = load_dataset("common_voice", LANG_ID, split="test")
 
110
 
111
+ wer = load_metric("wer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/wer.py
112
+ cer = load_metric("cer.py") # https://github.com/jonatasgrosman/wav2vec2-sprint/blob/main/cer.py
 
 
 
 
113
 
114
+ chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]"
115
 
116
  processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
117
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
118
  model.to(DEVICE)
119
 
 
 
120
  # Preprocessing the datasets.
121
  # We need to read the audio files as arrays
122
  def speech_file_to_array_fn(batch):
123
+ with warnings.catch_warnings():
124
+ warnings.simplefilter("ignore")
125
+ speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
126
+ batch["speech"] = speech_array
127
+ batch["sentence"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).upper()
128
  return batch
129
 
130
  test_dataset = test_dataset.map(speech_file_to_array_fn)
132
  # Preprocessing the datasets.
133
  # We need to read the audio files as arrays
134
  def evaluate(batch):
135
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
136
 
137
+ with torch.no_grad():
138
+ logits = model(inputs.input_values.to(DEVICE), attention_mask=inputs.attention_mask.to(DEVICE)).logits
139
 
140
+ pred_ids = torch.argmax(logits, dim=-1)
141
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
142
+ return batch
143
 
144
  result = test_dataset.map(evaluate, batched=True, batch_size=8)
145
 
146
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"], chunk_size=1000)))
147
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["pred_strings"], references=result["sentence"], chunk_size=1000)))
148
  ```
149
 
150
+ **Test Result**:
151
+
152
+ - WER: 41.60%
153
+ - CER: 8.23%
config.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
- "_name_or_path": "../models-fi/wav2vec2-large-xlsr-fi-sweep/checkpoint-500",
3
- "activation_dropout": 0.0,
4
  "apply_spec_augment": true,
5
  "architectures": [
6
  "Wav2Vec2ForCTC"
7
  ],
8
- "attention_dropout": 0.2,
9
  "bos_token_id": 1,
10
  "conv_bias": true,
11
  "conv_dim": [
@@ -42,16 +42,16 @@
42
  "feat_extract_activation": "gelu",
43
  "feat_extract_dropout": 0.0,
44
  "feat_extract_norm": "layer",
45
- "feat_proj_dropout": 0.2,
46
  "final_dropout": 0.0,
47
  "gradient_checkpointing": true,
48
  "hidden_act": "gelu",
49
- "hidden_dropout": 0.0,
50
  "hidden_size": 1024,
51
  "initializer_range": 0.02,
52
  "intermediate_size": 4096,
53
  "layer_norm_eps": 1e-05,
54
- "layerdrop": 0.0,
55
  "mask_channel_length": 10,
56
  "mask_channel_min_space": 1,
57
  "mask_channel_other": 0.0,
@@ -62,7 +62,7 @@
62
  "mask_time_length": 10,
63
  "mask_time_min_space": 1,
64
  "mask_time_other": 0.0,
65
- "mask_time_prob": 0.2,
66
  "mask_time_selection": "static",
67
  "model_type": "wav2vec2",
68
  "num_attention_heads": 16,
@@ -70,7 +70,7 @@
70
  "num_conv_pos_embeddings": 128,
71
  "num_feat_extract_layers": 7,
72
  "num_hidden_layers": 24,
73
- "pad_token_id": 29,
74
  "transformers_version": "4.5.0.dev0",
75
- "vocab_size": 30
76
  }
1
  {
2
+ "_name_or_path": "facebook/wav2vec2-large-xlsr-53",
3
+ "activation_dropout": 0.05,
4
  "apply_spec_augment": true,
5
  "architectures": [
6
  "Wav2Vec2ForCTC"
7
  ],
8
+ "attention_dropout": 0.1,
9
  "bos_token_id": 1,
10
  "conv_bias": true,
11
  "conv_dim": [
42
  "feat_extract_activation": "gelu",
43
  "feat_extract_dropout": 0.0,
44
  "feat_extract_norm": "layer",
45
+ "feat_proj_dropout": 0.05,
46
  "final_dropout": 0.0,
47
  "gradient_checkpointing": true,
48
  "hidden_act": "gelu",
49
+ "hidden_dropout": 0.05,
50
  "hidden_size": 1024,
51
  "initializer_range": 0.02,
52
  "intermediate_size": 4096,
53
  "layer_norm_eps": 1e-05,
54
+ "layerdrop": 0.05,
55
  "mask_channel_length": 10,
56
  "mask_channel_min_space": 1,
57
  "mask_channel_other": 0.0,
62
  "mask_time_length": 10,
63
  "mask_time_min_space": 1,
64
  "mask_time_other": 0.0,
65
+ "mask_time_prob": 0.05,
66
  "mask_time_selection": "static",
67
  "model_type": "wav2vec2",
68
  "num_attention_heads": 16,
70
  "num_conv_pos_embeddings": 128,
71
  "num_feat_extract_layers": 7,
72
  "num_hidden_layers": 24,
73
+ "pad_token_id": 0,
74
  "transformers_version": "4.5.0.dev0",
75
+ "vocab_size": 34
76
  }
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9067ba6beee4c6ad2692651ff18d9e95c522143d8fa38e5f155af833118e156d
3
- size 1262056855
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3293144121790976a21ddd565d25aef7024c94309d9638e12c4e77106eb5ac2
3
+ size 1262073239
special_tokens_map.json CHANGED
@@ -1 +1 @@
1
- {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]"}
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
vocab.json CHANGED
@@ -1 +1 @@
1
- {"r": 0, "h": 1, "k": 2, "g": 3, "u": 4, "m": 5, "t": 6, "z": 7, "s": 8, "i": 9, "ö": 10, "v": 11, "l": 12, "q": 13, "b": 14, "e": 15, "p": 16, "y": 17, "f": 18, "d": 19, "ä": 21, "j": 22, "x": 23, "a": 24, "c": 25, "n": 26, "o": 27, "|": 20, "[UNK]": 28, "[PAD]": 29}
1
+ {"<pad>": 0, "<s>": 1, "</s>": 2, "<unk>": 3, "|": 4, "J": 5, "Q": 6, "B": 7, "X": 8, "I": 9, "D": 10, "R": 11, "U": 12, "-": 13, "K": 14, "T": 15, "L": 17, "V": 18, "Ä": 19, "A": 20, "F": 21, "S": 22, "'": 23, "G": 24, "N": 25, "Y": 26, "M": 27, "C": 28, "E": 29, "Ö": 30, "O": 31, "H": 32, "P": 33, "Z": 34}