Update README.md
Browse files
README.md
CHANGED
@@ -27,10 +27,10 @@ model-index:
|
|
27 |
metrics:
|
28 |
- name: Test WER
|
29 |
type: wer
|
30 |
-
value:
|
31 |
- name: Test CER
|
32 |
type: cer
|
33 |
-
value: 1.
|
34 |
- task:
|
35 |
name: Speech Recognition
|
36 |
type: automatic-speech-recognition
|
@@ -41,10 +41,10 @@ model-index:
|
|
41 |
metrics:
|
42 |
- name: Test WER
|
43 |
type: wer
|
44 |
-
value: 15.
|
45 |
- name: Test CER
|
46 |
type: cer
|
47 |
-
value: 4.
|
48 |
---
|
49 |
|
50 |
# Wav2Vec2-Large-Ru-Golos
|
@@ -66,7 +66,7 @@ import torch
|
|
66 |
processor = Wav2Vec2Processor.from_pretrained("bond005/wav2vec2-large-ru-golos")
|
67 |
model = Wav2Vec2ForCTC.from_pretrained("bond005/wav2vec2-large-ru-golos")
|
68 |
|
69 |
-
# load test part of Golos dataset and read first soundfile
|
70 |
ds = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
|
71 |
|
72 |
# tokenize
|
@@ -81,6 +81,79 @@ transcription = processor.batch_decode(predicted_ids)[0]
|
|
81 |
print(transcription)
|
82 |
```
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
## Citation
|
85 |
If you want to cite this model you can use this:
|
86 |
|
|
|
27 |
metrics:
|
28 |
- name: Test WER
|
29 |
type: wer
|
30 |
+
value: 5.860
|
31 |
- name: Test CER
|
32 |
type: cer
|
33 |
+
value: 1.228
|
34 |
- task:
|
35 |
name: Speech Recognition
|
36 |
type: automatic-speech-recognition
|
|
|
41 |
metrics:
|
42 |
- name: Test WER
|
43 |
type: wer
|
44 |
+
value: 15.330
|
45 |
- name: Test CER
|
46 |
type: cer
|
47 |
+
value: 4.299
|
48 |
---
|
49 |
|
50 |
# Wav2Vec2-Large-Ru-Golos
|
|
|
66 |
processor = Wav2Vec2Processor.from_pretrained("bond005/wav2vec2-large-ru-golos")
|
67 |
model = Wav2Vec2ForCTC.from_pretrained("bond005/wav2vec2-large-ru-golos")
|
68 |
|
69 |
+
# load the test part of Golos dataset and read first soundfile
|
70 |
ds = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
|
71 |
|
72 |
# tokenize
|
|
|
81 |
print(transcription)
|
82 |
```
|
83 |
|
84 |
+
## Evaluation
|
85 |
+
|
86 |
+
This code snippet shows how to evaluate **bond005/wav2vec2-large-ru-golos** on Golos dataset's "crowd" and "farfield" test data.
|
87 |
+
|
88 |
+
```python
|
89 |
+
from datasets import load_dataset
|
90 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
91 |
+
import torch
|
92 |
+
from jiwer import wer, cer # we need word error rate (WER) and character error rate (CER)
|
93 |
+
|
94 |
+
# load the test part of Golos Crowd and remove samples with empty "true" transcriptions
|
95 |
+
golos_crowd_test = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
|
96 |
+
golos_crowd_test = golos_crowd_test.filter(
|
97 |
+
lambda it1: (it1["transcription"] is not None) and (len(it1["transcription"].strip()) > 0)
|
98 |
+
)
|
99 |
+
|
100 |
+
# load the test part of Golos Farfield and remove sampels with empty "true" transcriptions
|
101 |
+
golos_farfield_test = load_dataset("bond005/sberdevices_golos_100h_farfield", split="test")
|
102 |
+
golos_farfield_test = golos_farfield_test.filter(
|
103 |
+
lambda it2: (it2["transcription"] is not None) and (len(it2["transcription"].strip()) > 0)
|
104 |
+
)
|
105 |
+
|
106 |
+
# load model and tokenizer
|
107 |
+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
|
108 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
109 |
+
|
110 |
+
# recognize one sound
|
111 |
+
def map_to_pred(batch):
|
112 |
+
# tokenize and vectorize
|
113 |
+
processed = processor(
|
114 |
+
batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"],
|
115 |
+
return_tensors="pt", padding="longest"
|
116 |
+
)
|
117 |
+
input_values = processed.input_values.to("cuda")
|
118 |
+
attention_mask = processed.attention_mask.to("cuda")
|
119 |
+
|
120 |
+
# recognize
|
121 |
+
with torch.no_grad():
|
122 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
123 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
124 |
+
|
125 |
+
# decode
|
126 |
+
transcription = processor.batch_decode(predicted_ids)
|
127 |
+
batch["text"] = transcription[0]
|
128 |
+
return batch
|
129 |
+
|
130 |
+
# calculate WER and CER on the crowd domain
|
131 |
+
crowd_result = golos_crowd_test.map(map_to_pred, remove_columns=["audio"])
|
132 |
+
crowd_wer = wer(crowd_result["transcription"], crowd_result["text"])
|
133 |
+
crowd_cer = cer(crowd_result["transcription"], crowd_result["text"])
|
134 |
+
print("Word error rate on the Crowd domain:", crowd_wer)
|
135 |
+
print("Character error rate on the Crowd domain:", crowd_cer)
|
136 |
+
|
137 |
+
# calculate WER and CER on the farfield domain
|
138 |
+
farfield_result = golos_farfield_test.map(map_to_pred, remove_columns=["audio"])
|
139 |
+
farfield_wer = wer(farfield_result["transcription"], farfield_result["text"])
|
140 |
+
farfield_cer = cer(farfield_result["transcription"], farfield_result["text"])
|
141 |
+
print("Word error rate on the Farfield domain:", farfield_wer)
|
142 |
+
print("Character error rate on the Farfield domain:", farfield_cer)
|
143 |
+
```
|
144 |
+
|
145 |
+
*Result (WER, %)*:
|
146 |
+
|
147 |
+
| "crowd" | "farfield" |
|
148 |
+
|---------|------------|
|
149 |
+
| 5.860 | 15.330 |
|
150 |
+
|
151 |
+
*Result (CER, %)*:
|
152 |
+
|
153 |
+
| "crowd" | "farfield" |
|
154 |
+
|---------|------------|
|
155 |
+
| 1.228 | 4.299 |
|
156 |
+
|
157 |
## Citation
|
158 |
If you want to cite this model you can use this:
|
159 |
|