|
import sys |
|
|
|
sys.path.append("..") |
|
|
|
import gradio |
|
import torch, torchaudio |
|
import numpy as np |
|
from transformers import ( |
|
Wav2Vec2ForPreTraining, |
|
Wav2Vec2CTCTokenizer, |
|
Wav2Vec2FeatureExtractor, |
|
) |
|
from finetuning.wav2vec2 import SpeechRecognizer |
|
|
|
|
|
def load_model(ckpt_path: str): |
|
model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h" |
|
|
|
wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name) |
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name) |
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name) |
|
|
|
model = SpeechRecognizer.load_from_checkpoint( |
|
ckpt_path, |
|
wav2vec2=wav2vec2, |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
map_location='cpu' |
|
) |
|
|
|
return model |
|
|
|
model = load_model("checkpoints/last.ckpt") |
|
model.eval() |
|
|
|
def transcribe(audio): |
|
sample_rate, waveform = audio |
|
if len(waveform.shape) == 2: |
|
waveform = waveform[:, 0] |
|
waveform = torch.from_numpy(waveform).float().unsqueeze_(0) |
|
waveform = torchaudio.functional.resample(waveform, sample_rate, 16_000) |
|
|
|
transcript = model.predict(waveform)[0] |
|
|
|
return transcript |
|
|
|
gradio.Interface(fn=transcribe, inputs=gradio.Audio(source="microphone", type="numpy"), outputs="textbox").launch() |