patrickvonplaten
commited on
Commit
•
2502577
1
Parent(s):
43b9e58
Update README.md
Browse files
README.md
CHANGED
@@ -27,7 +27,8 @@ In a nutshell: This PR adds a new Wav2Vec2WithLMProcessor class as drop-in repla
|
|
27 |
The only change from the existing ASR pipeline will be:
|
28 |
|
29 |
```diff
|
30 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
|
31 |
from datasets import load_dataset
|
32 |
|
33 |
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
@@ -35,14 +36,16 @@ ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
|
35 |
sample = next(iter(ds))
|
36 |
|
37 |
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
38 |
-
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
|
|
39 |
|
40 |
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
41 |
|
42 |
logits = model(input_values).logits
|
43 |
-
prediction_ids = torch.argmax(logits, dim=-1)
|
44 |
|
45 |
-
|
|
|
|
|
46 |
|
47 |
print(transcription)
|
48 |
```
|
|
|
27 |
The only change from the existing ASR pipeline will be:
|
28 |
|
29 |
```diff
|
30 |
+
-from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
31 |
+
+from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
|
32 |
from datasets import load_dataset
|
33 |
|
34 |
ds = load_dataset("common_voice", "es", split="test", streaming=True)
|
|
|
36 |
sample = next(iter(ds))
|
37 |
|
38 |
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
39 |
+
-processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
40 |
+
+processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm")
|
41 |
|
42 |
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
|
43 |
|
44 |
logits = model(input_values).logits
|
|
|
45 |
|
46 |
+
-prediction_ids = torch.argmax(logits, dim=-1)
|
47 |
+
-transcription = processor.batch_decode(prediction_ids)
|
48 |
+
+transcription = processor.batch_decode(logits)
|
49 |
|
50 |
print(transcription)
|
51 |
```
|