KevinGeng's picture
Shar true
051f75a
raw
history blame
5.58 kB
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer
from local.convert_metrics import nat2avaMOS, WER2INTELI
from local.indicator_plot import Intelligibility_Plot, Naturalness_Plot
from local.pitch_contour import draw_spec_db_pitch
# ASR part
import csv
csv.field_size_limit(100000000)
from transformers import pipeline
p = pipeline("automatic-speech-recognition")
# WER part
transformation = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveWhiteSpace(replace_by_space=True),
jiwer.RemoveMultipleSpaces(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])
# WPM part
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = pipeline(model="vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
class ChangeSampleRate(nn.Module):
def __init__(self, input_rate: int, output_rate: int):
super().__init__()
self.output_rate = output_rate
self.input_rate = input_rate
def forward(self, wav: torch.tensor) -> torch.tensor:
# Only accepts 1-channel waveform input
wav = wav.view(wav.size(0), -1)
new_length = wav.size(-1) * self.output_rate // self.input_rate
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
round_down = wav[:, indices.long()]
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
return output
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
def calc_mos(audio_path, ref):
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
# get decibel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
db = torchaudio.transforms.AmplitudeToDB(stype="amplitude", top_db=80)(wav)
# ASR
trans = p(audio_path)["text"]
# WER
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
# WER convert to Intellibility score
INTELI_score = WER2INTELI(wer*100)
INT_fig = Intelligibility_Plot(INTELI_score)
# MOS
batch = {
'wav': out_wavs,
'domains': torch.tensor([0]),
'judge_id': torch.tensor([288])
}
with torch.no_grad():
output = model(batch)
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
# MOS to AVA MOS
AVA_MOS = nat2avaMOS(predic_mos)
MOS_fig = Naturalness_Plot(AVA_MOS)
# Phonemes per minute (PPM)
with torch.no_grad():
logits = phoneme_model(out_wavs).logits
phone_predicted_ids = torch.argmax(logits, dim=-1)
phone_transcription = processor.batch_decode(phone_predicted_ids)
lst_phonemes = phone_transcription[0].split(" ")
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
# draw f0 and db analysis plot
f0_db_fig = draw_spec_db_pitch(audio_path, save_fig_path=None)
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
return AVA_MOS, MOS_fig, INTELI_score, INT_fig, trans, phone_transcription, ppm, f0_db_fig
with open("local/description.md") as f:
description = f.read()
# x = calc_mos("JOHN1.wav", "he would answer in a soft voice, 'I don't know.'")
# pdb.set_trace()
examples = [
["local/Julianna_Set1_Author_01.wav", "Once upon a time, there was a young rat named Arthur who couldn't make up his mind."],
["local/Patient_Arthur_set1_002_noisy.wav", "Whenever the other rats asked him if he would like to go hunting with them, he would answer in a soft voice, 'I don't know.'"],
]
iface = gr.Interface(
fn=calc_mos,
inputs=[gr.Audio(type='filepath', label="Audio to evaluate"),
gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
outputs=[gr.Textbox(placeholder="Naturalness Score, ranged from 1 to 5, the higher the better.", label="Naturalness Score, ranged from 0 to 5, the higher the better.", visible=False),
gr.Plot(label="Naturalness Score, ranged from 1 to 5, the higher the better.", show_label=True, container=True),
gr.Textbox(placeholder="Intelligibility Score", label = "Intelligibility Score, range from 0 to 100, the higher the better", visible=False),
gr.Plot(label="Intelligibility Score, range from 0 to 100, the higher the better", show_label=True, container=True),
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes", visible=False),
gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="Speaking Rate, Phonemes per minutes", visible=False),
gr.Plot(label="Pitch Contour and dB Analysis", show_label=True, container=True)],
title="Speech Analysis by Laronix AI",
description=description,
allow_flagging="auto",
examples=examples,
)
# Currently remove PPM and Phonemes
# add password to protect the interface
iface.launch(share=True, auth=['Laronix', 'LaronixSLP'], auth_message="Authentication Required, ask kevin@laronix.com for password.\n Thanks for your cooperation!")