patrickvonplaten commited on
Commit
2502577
1 Parent(s): 43b9e58

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -4
README.md CHANGED
@@ -27,7 +27,8 @@ In a nutshell: This PR adds a new Wav2Vec2WithLMProcessor class as drop-in repla
27
  The only change from the existing ASR pipeline will be:
28
 
29
  ```diff
30
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
31
  from datasets import load_dataset
32
 
33
  ds = load_dataset("common_voice", "es", split="test", streaming=True)
@@ -35,14 +36,16 @@ ds = load_dataset("common_voice", "es", split="test", streaming=True)
35
  sample = next(iter(ds))
36
 
37
  model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
38
- processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
 
39
 
40
  input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
41
 
42
  logits = model(input_values).logits
43
- prediction_ids = torch.argmax(logits, dim=-1)
44
 
45
- transcription = processor.batch_decode(prediction_ids)
 
 
46
 
47
  print(transcription)
48
  ```
 
27
  The only change from the existing ASR pipeline will be:
28
 
29
  ```diff
30
+ -from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
31
+ +from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
32
  from datasets import load_dataset
33
 
34
  ds = load_dataset("common_voice", "es", split="test", streaming=True)
 
36
  sample = next(iter(ds))
37
 
38
  model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
39
+ -processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
40
+ +processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
41
 
42
  input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
43
 
44
  logits = model(input_values).logits
 
45
 
46
+ -prediction_ids = torch.argmax(logits, dim=-1)
47
+ -transcription = processor.batch_decode(prediction_ids)
48
+ +transcription = processor.batch_decode(logits)
49
 
50
  print(transcription)
51
  ```