patrickvonplaten sanchit-gandhi HF staff commited on
Commit
8893a00
1 Parent(s): 782ffeb

Fix examples: input_ids -> input_features (#1)

Browse files

- Fix examples: input_ids -> input_features (ab7755c91c785dcd7ab71428a8f8a22c8a5cd3fa)


Co-authored-by: Sanchit Gandhi <sanchit-gandhi@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +9 -15
README.md CHANGED
@@ -72,7 +72,7 @@ input_features = processor(
72
  sampling_rate=16_000,
73
  return_tensors="pt"
74
  ).input_features # Batch size 1
75
- generated_ids = model.generate(input_ids=input_features)
76
 
77
  transcription = processor.batch_decode(generated_ids)
78
  ```
@@ -83,35 +83,29 @@ The following script shows how to evaluate this model on the [LibriSpeech](https
83
  *"clean"* and *"other"* test dataset.
84
 
85
  ```python
86
- from datasets import load_dataset, load_metric
 
87
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
88
- import soundfile as sf
89
 
90
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
91
- wer = load_metric("wer")
92
 
93
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr").to("cuda")
94
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
95
 
96
- def map_to_array(batch):
97
- speech, _ = sf.read(batch["file"])
98
- batch["speech"] = speech
99
- return batch
100
-
101
- librispeech_eval = librispeech_eval.map(map_to_array)
102
 
103
  def map_to_pred(batch):
104
- features = processor(batch["speech"], sampling_rate=16000, padding=True, return_tensors="pt")
105
  input_features = features.input_features.to("cuda")
106
  attention_mask = features.attention_mask.to("cuda")
107
 
108
- gen_tokens = model.generate(input_ids=input_features, attention_mask=attention_mask)
109
- batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)
110
  return batch
111
 
112
- result = librispeech_eval.map(map_to_pred, batched=True, batch_size=8, remove_columns=["speech"])
113
 
114
- print("WER:", wer(predictions=result["transcription"], references=result["text"]))
115
  ```
116
 
117
  *Result (WER)*:
 
72
  sampling_rate=16_000,
73
  return_tensors="pt"
74
  ).input_features # Batch size 1
75
+ generated_ids = model.generate(input_features=input_features)
76
 
77
  transcription = processor.batch_decode(generated_ids)
78
  ```
 
83
  *"clean"* and *"other"* test dataset.
84
 
85
  ```python
86
+ from datasets import load_dataset
87
+ from evaluate import load
88
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
 
89
 
90
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
91
+ wer = load("wer")
92
 
93
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr").to("cuda")
94
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr", do_upper_case=True)
95
 
 
 
 
 
 
 
96
 
97
  def map_to_pred(batch):
98
+ features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
99
  input_features = features.input_features.to("cuda")
100
  attention_mask = features.attention_mask.to("cuda")
101
 
102
+ gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
103
+ batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
104
  return batch
105
 
106
+ result = librispeech_eval.map(map_to_pred, remove_columns=["audio"])
107
 
108
+ print("WER:", wer.compute(predictions=result["transcription"], references=result["text"]))
109
  ```
110
 
111
  *Result (WER)*: