patrickvonplaten commited on
Commit
8558f54
1 Parent(s): 7ef592a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -2
README.md CHANGED
@@ -27,6 +27,7 @@ In a nutshell: This PR adds a new Wav2Vec2WithLMProcessor class as drop-in repla
27
  The only change from the existing ASR pipeline will be:
28
 
29
  ```diff
 
30
  -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
31
  +from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
32
  from datasets import load_dataset
@@ -41,11 +42,12 @@ model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-
41
 
42
  input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
43
 
44
- logits = model(input_values).logits
 
45
 
46
  -prediction_ids = torch.argmax(logits, dim=-1)
47
  -transcription = processor.batch_decode(prediction_ids)
48
- +transcription = processor.batch_decode(logits).text
49
 
50
  print(transcription)
51
  ```
 
27
  The only change from the existing ASR pipeline will be:
28
 
29
  ```diff
30
+ import torch
31
  -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
32
  +from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
33
  from datasets import load_dataset
 
42
 
43
  input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
44
 
45
+ with torch.no_grad():
46
+ logits = model(input_values).logits
47
 
48
  -prediction_ids = torch.argmax(logits, dim=-1)
49
  -transcription = processor.batch_decode(prediction_ids)
50
+ +transcription = processor.batch_decode(logits.cpu().numpy()).text
51
 
52
  print(transcription)
53
  ```