Merge branch 'main' of hf.co:alvanlii/wav2vec2-BERT-cantonese
Browse files
README.md
CHANGED
@@ -41,7 +41,28 @@ bert_asr = pipeline(
|
|
41 |
)
|
42 |
text = pipe(file)["text"]
|
43 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
## Training Hyperparameters
|
47 |
- learning_rate: 1e-4
|
|
|
41 |
)
|
42 |
text = pipe(file)["text"]
|
43 |
```
|
44 |
+
or
|
45 |
+
```
|
46 |
+
import torch
|
47 |
+
import soundfile as sf
|
48 |
+
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
|
49 |
+
|
50 |
+
model_name = "alvanlii/wav2vec2-BERT-cantonese"
|
51 |
+
|
52 |
+
asr_model = AutoModelForCTC.from_pretrained(model_name).to(device)
|
53 |
+
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
|
54 |
+
|
55 |
+
audio_input, _ = sf.read(file)
|
56 |
|
57 |
+
inputs = processor([audio_input], sampling_rate=16_000).input_features
|
58 |
+
features = torch.tensor(inputs)
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
logits = asr_model(features).logits
|
62 |
+
|
63 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
64 |
+
predictions = processor.batch_decode(predicted_ids)
|
65 |
+
```
|
66 |
|
67 |
## Training Hyperparameters
|
68 |
- learning_rate: 1e-4
|