patrickvonplaten commited on
Commit
f84b430
1 Parent(s): db24822

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +54 -0
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - librispeech_asr
5
+ tags:
6
+ - audio
7
+ - automatic-speech-recognition
8
+ license: apache-2.0
9
+ ---
10
+
11
+ TODO: [To be filled]
12
+
13
+
14
+ ## Evaluation on LibriSpeech Test
15
+
16
+ The following script shows how to evaluate this model on the [LibriSpeech](https://huggingface.co/datasets/librispeech_asr) *"clean"* and *"other"* dataset.
17
+
18
+ ```python
19
+ from datasets import load_dataset
20
+ from transformers import Speech2TextTransformerForConditionalGeneration, Speech2TextTransformerTokenizer
21
+ import soundfile as sf
22
+ from jiwer import wer
23
+
24
+ librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
25
+
26
+ model = Speech2TextTransformerForConditionalGeneration.from_pretrained("valhalla/s2t_librispeech_medium").to("cuda")
27
+ tokenizer = Speech2TextTransformerTokenizer.from_pretrained("valhalla/s2t_librispeech_medium", do_upper_case=True)
28
+
29
+ def map_to_array(batch):
30
+ speech, _ = sf.read(batch["file"])
31
+ batch["speech"] = speech
32
+ return batch
33
+
34
+ librispeech_eval = librispeech_eval.map(map_to_array)
35
+
36
+ def map_to_pred(batch):
37
+ features = tokenizer(batch["speech"], sample_rate=16000, padding=True, return_tensors="pt")
38
+ input_features = features.input_features.to("cuda")
39
+ attention_mask = features.attention_mask.to("cuda")
40
+
41
+ gen_tokens = model.generate(input_ids=input_features, attention_mask=attention_mask)
42
+ batch["transcription"] = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
43
+ return batch
44
+
45
+ result = librispeech_eval.map(map_to_pred, batched=True, batch_size=8, remove_columns=["speech"])
46
+
47
+ print("WER:", wer(result["text"], result["transcription"]))
48
+ ```
49
+
50
+ *Result (WER)*:
51
+
52
+ | "clean" | "other" |
53
+ |---|---|
54
+ | 3.5 | 7.8 |