bond005 commited on
Commit
2eed5b6
1 Parent(s): c8ed624

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +78 -5
README.md CHANGED
@@ -27,10 +27,10 @@ model-index:
27
  metrics:
28
  - name: Test WER
29
  type: wer
30
- value: 6.358
31
  - name: Test CER
32
  type: cer
33
- value: 1.711
34
  - task:
35
  name: Speech Recognition
36
  type: automatic-speech-recognition
@@ -41,10 +41,10 @@ model-index:
41
  metrics:
42
  - name: Test WER
43
  type: wer
44
- value: 15.402
45
  - name: Test CER
46
  type: cer
47
- value: 4.315
48
  ---
49
 
50
  # Wav2Vec2-Large-Ru-Golos
@@ -66,7 +66,7 @@ import torch
66
  processor = Wav2Vec2Processor.from_pretrained("bond005/wav2vec2-large-ru-golos")
67
  model = Wav2Vec2ForCTC.from_pretrained("bond005/wav2vec2-large-ru-golos")
68
 
69
- # load test part of Golos dataset and read first soundfile
70
  ds = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
71
 
72
  # tokenize
@@ -81,6 +81,79 @@ transcription = processor.batch_decode(predicted_ids)[0]
81
  print(transcription)
82
  ```
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ## Citation
85
  If you want to cite this model you can use this:
86
 
 
27
  metrics:
28
  - name: Test WER
29
  type: wer
30
+ value: 5.860
31
  - name: Test CER
32
  type: cer
33
+ value: 1.228
34
  - task:
35
  name: Speech Recognition
36
  type: automatic-speech-recognition
 
41
  metrics:
42
  - name: Test WER
43
  type: wer
44
+ value: 15.330
45
  - name: Test CER
46
  type: cer
47
+ value: 4.299
48
  ---
49
 
50
  # Wav2Vec2-Large-Ru-Golos
 
66
  processor = Wav2Vec2Processor.from_pretrained("bond005/wav2vec2-large-ru-golos")
67
  model = Wav2Vec2ForCTC.from_pretrained("bond005/wav2vec2-large-ru-golos")
68
 
69
+ # load the test part of Golos dataset and read first soundfile
70
  ds = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
71
 
72
  # tokenize
 
81
  print(transcription)
82
  ```
83
 
84
+ ## Evaluation
85
+
86
+ This code snippet shows how to evaluate **bond005/wav2vec2-large-ru-golos** on Golos dataset's "crowd" and "farfield" test data.
87
+
88
+ ```python
89
+ from datasets import load_dataset
90
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
91
+ import torch
92
+ from jiwer import wer, cer # we need word error rate (WER) and character error rate (CER)
93
+
94
+ # load the test part of Golos Crowd and remove samples with empty "true" transcriptions
95
+ golos_crowd_test = load_dataset("bond005/sberdevices_golos_10h_crowd", split="test")
96
+ golos_crowd_test = golos_crowd_test.filter(
97
+ lambda it1: (it1["transcription"] is not None) and (len(it1["transcription"].strip()) > 0)
98
+ )
99
+
100
+ # load the test part of Golos Farfield and remove sampels with empty "true" transcriptions
101
+ golos_farfield_test = load_dataset("bond005/sberdevices_golos_100h_farfield", split="test")
102
+ golos_farfield_test = golos_farfield_test.filter(
103
+ lambda it2: (it2["transcription"] is not None) and (len(it2["transcription"].strip()) > 0)
104
+ )
105
+
106
+ # load model and tokenizer
107
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
108
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
109
+
110
+ # recognize one sound
111
+ def map_to_pred(batch):
112
+ # tokenize and vectorize
113
+ processed = processor(
114
+ batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"],
115
+ return_tensors="pt", padding="longest"
116
+ )
117
+ input_values = processed.input_values.to("cuda")
118
+ attention_mask = processed.attention_mask.to("cuda")
119
+
120
+ # recognize
121
+ with torch.no_grad():
122
+ logits = model(input_values, attention_mask=attention_mask).logits
123
+ predicted_ids = torch.argmax(logits, dim=-1)
124
+
125
+ # decode
126
+ transcription = processor.batch_decode(predicted_ids)
127
+ batch["text"] = transcription[0]
128
+ return batch
129
+
130
+ # calculate WER and CER on the crowd domain
131
+ crowd_result = golos_crowd_test.map(map_to_pred, remove_columns=["audio"])
132
+ crowd_wer = wer(crowd_result["transcription"], crowd_result["text"])
133
+ crowd_cer = cer(crowd_result["transcription"], crowd_result["text"])
134
+ print("Word error rate on the Crowd domain:", crowd_wer)
135
+ print("Character error rate on the Crowd domain:", crowd_cer)
136
+
137
+ # calculate WER and CER on the farfield domain
138
+ farfield_result = golos_farfield_test.map(map_to_pred, remove_columns=["audio"])
139
+ farfield_wer = wer(farfield_result["transcription"], farfield_result["text"])
140
+ farfield_cer = cer(farfield_result["transcription"], farfield_result["text"])
141
+ print("Word error rate on the Farfield domain:", farfield_wer)
142
+ print("Character error rate on the Farfield domain:", farfield_cer)
143
+ ```
144
+
145
+ *Result (WER, %)*:
146
+
147
+ | "crowd" | "farfield" |
148
+ |---------|------------|
149
+ | 5.860 | 15.330 |
150
+
151
+ *Result (CER, %)*:
152
+
153
+ | "crowd" | "farfield" |
154
+ |---------|------------|
155
+ | 1.228 | 4.299 |
156
+
157
  ## Citation
158
  If you want to cite this model you can use this:
159