bel32123's picture
Update demo to use MultitaskASRModel
b615647
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from speechbrain.pretrained import GraphemeToPhoneme
import datasets
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel, MultitaskPhonemeASRModel
import jiwer
import re
# Load sample data
audio_path, transcript_path = os.path.join(os.getcwd(), "data", "arctic_a0003.wav"), os.path.join(os.getcwd(),"data", "arctic_a0003.txt")
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
audio = audio.to("cpu")
with open(transcript_path) as f:
text = f.read()
f.close()
print("Done loading sample data")
# Load processors and models
device = "cpu"
path = os.path.join(os.getcwd(), "model", "multitask_best_ctc.pt")
vocab_path = os.path.join(os.getcwd(), "model", "vocab")
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, "cpu")
print("Done loading models and processors")
# Predict
raw_info = mispronounciation_detector.detect(audio, text)
print(raw_info['ref'])
print(raw_info['hyp'])
print(raw_info['phoneme_errors'])
print(f"PER: {raw_info['per']}\n")