Commit
•
bb4003f
1
Parent(s):
89d22e9
Fix examples: input_ids -> input_features
Browse filesModel expects args of `input_features` not `input_ids`: https://github.com/huggingface/transformers/blob/fc95386ea12fc11942cc7f2a4f99ef9602d774ef/src/transformers/models/speech_to_text/modeling_speech_to_text.py#L1298
README.md
CHANGED
@@ -101,7 +101,7 @@ input_features = processor(
|
|
101 |
sampling_rate=16_000,
|
102 |
return_tensors="pt"
|
103 |
).input_features # Batch size 1
|
104 |
-
generated_ids = model.generate(
|
105 |
|
106 |
transcription = processor.batch_decode(generated_ids)
|
107 |
```
|
@@ -112,27 +112,26 @@ The following script shows how to evaluate this model on the [LibriSpeech](https
|
|
112 |
*"clean"* and *"other"* test dataset.
|
113 |
|
114 |
```python
|
115 |
-
from datasets import load_dataset
|
|
|
116 |
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
117 |
|
118 |
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
|
119 |
-
wer =
|
120 |
|
121 |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
|
122 |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
|
123 |
|
124 |
-
librispeech_eval = librispeech_eval.map(map_to_array)
|
125 |
-
|
126 |
def map_to_pred(batch):
|
127 |
features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
128 |
input_features = features.input_features.to("cuda")
|
129 |
attention_mask = features.attention_mask.to("cuda")
|
130 |
|
131 |
-
gen_tokens = model.generate(
|
132 |
batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)
|
133 |
return batch
|
134 |
|
135 |
-
result = librispeech_eval.map(map_to_pred,
|
136 |
|
137 |
print("WER:", wer(predictions=result["transcription"], references=result["text"]))
|
138 |
```
|
|
|
101 |
sampling_rate=16_000,
|
102 |
return_tensors="pt"
|
103 |
).input_features # Batch size 1
|
104 |
+
generated_ids = model.generate(input_features=input_features)
|
105 |
|
106 |
transcription = processor.batch_decode(generated_ids)
|
107 |
```
|
|
|
112 |
*"clean"* and *"other"* test dataset.
|
113 |
|
114 |
```python
|
115 |
+
from datasets import load_dataset
|
116 |
+
from evaluate import load
|
117 |
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
|
118 |
|
119 |
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
|
120 |
+
wer = load("wer")
|
121 |
|
122 |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
|
123 |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
|
124 |
|
|
|
|
|
125 |
def map_to_pred(batch):
|
126 |
features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
|
127 |
input_features = features.input_features.to("cuda")
|
128 |
attention_mask = features.attention_mask.to("cuda")
|
129 |
|
130 |
+
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
|
131 |
batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)
|
132 |
return batch
|
133 |
|
134 |
+
result = librispeech_eval.map(map_to_pred, remove_columns=["audio"])
|
135 |
|
136 |
print("WER:", wer(predictions=result["transcription"], references=result["text"]))
|
137 |
```
|