from random import sample import gradio as gr import torchaudio import torch import torch.nn as nn import lightning_module import pdb import jiwer # ASR part from transformers import pipeline # p = pipeline("automatic-speech-recognition") p = pipeline( "automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25", device=0, ) # WER part transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.RemoveMultipleSpaces(), jiwer.ReduceToListOfListOfWords(word_delimiter=" ") ]) # WPM part from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft") # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft") class ChangeSampleRate(nn.Module): def __init__(self, input_rate: int, output_rate: int): super().__init__() self.output_rate = output_rate self.input_rate = input_rate def forward(self, wav: torch.tensor) -> torch.tensor: # Only accepts 1-channel waveform input wav = wav.view(wav.size(0), -1) new_length = wav.size(-1) * self.output_rate // self.input_rate indices = (torch.arange(new_length) * (self.input_rate / self.output_rate)) round_down = wav[:, indices.long()] round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)] output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0) return output model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval() def calc_mos(audio_path, ref): wav, sr = torchaudio.load(audio_path, channels_first=True) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) # Mono channel osr = 16_000 batch = wav.unsqueeze(0).repeat(10, 1, 1) csr = ChangeSampleRate(sr, osr) out_wavs = csr(wav) # ASR trans = p(audio_path)["text"] # WER wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation) # MOS batch = { 'wav': out_wavs, 'domains': torch.tensor([0]), 'judge_id': torch.tensor([288]) } with torch.no_grad(): output = model(batch) predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3 # Phonemes per minute (PPM) with torch.no_grad(): logits = phoneme_model(out_wavs).logits phone_predicted_ids = torch.argmax(logits, dim=-1) phone_transcription = processor.batch_decode(phone_predicted_ids) lst_phonemes = phone_transcription[0].split(" ") wav_vad = torchaudio.functional.vad(wav, sample_rate=sr) ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60 return predic_mos, trans, wer, phone_transcription, ppm description =""" MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset. This demo only accepts .wav format. Best at 16 kHz sampling rate. Paper is available [here](https://arxiv.org/abs/2204.02152) Add ASR based on wav2vec-960, currently only English available. Add WER interface. """ iface = gr.Interface( fn=calc_mos, inputs=[gr.Audio(type='filepath', label="Audio to evaluate"), gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")], outputs=[gr.Textbox(placeholder="Naturalness evaluation, ranged 1 to 5, the higher the better.", label="Predicted MOS"), gr.Textbox(placeholder="Hypothesis", label="Hypothesis"), gr.Textbox(placeholder="Word Error Rate: Only valid when Reference is given", label = "WER"), gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"), gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="PPM")], title="Laronix's Voice Quality Checking System Demo", description=description, allow_flagging="auto", ) iface.launch()