patrickvonplaten commited on
Commit
75ecb48
1 Parent(s): 71ed25f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -4
README.md CHANGED
@@ -76,7 +76,7 @@ from jiwer import wer
76
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
77
 
78
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to("cuda")
79
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h-lv60-self")
80
 
81
  def map_to_array(batch):
82
  speech, _ = sf.read(batch["file"])
@@ -86,9 +86,12 @@ def map_to_array(batch):
86
  librispeech_eval = librispeech_eval.map(map_to_array)
87
 
88
  def map_to_pred(batch):
89
- input_values = tokenizer(batch["speech"], return_tensors="pt", padding="longest").input_values
 
 
 
90
  with torch.no_grad():
91
- logits = model(input_values.to("cuda")).logits
92
 
93
  predicted_ids = torch.argmax(logits, dim=-1)
94
  transcription = tokenizer.batch_decode(predicted_ids)
@@ -104,4 +107,4 @@ print("WER:", wer(result["text"], result["transcription"]))
104
 
105
  | "clean" | "other" |
106
  |---|---|
107
- | 2.2 | 5.2 |
 
76
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
77
 
78
  model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to("cuda")
79
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
80
 
81
  def map_to_array(batch):
82
  speech, _ = sf.read(batch["file"])
 
86
  librispeech_eval = librispeech_eval.map(map_to_array)
87
 
88
  def map_to_pred(batch):
89
+ inputs = tokenizer(batch["speech"], return_tensors="pt", padding="longest")
90
+ input_values = inputs.input_values.to("cuda")
91
+ attention_mask = inputs.attention_mask.to("cuda")
92
+
93
  with torch.no_grad():
94
+ logits = model(input_values, attention_mask=attention_mask).logits
95
 
96
  predicted_ids = torch.argmax(logits, dim=-1)
97
  transcription = tokenizer.batch_decode(predicted_ids)
 
107
 
108
  | "clean" | "other" |
109
  |---|---|
110
+ | 1.9 | 3.9 |