KevinGeng's picture
default concurrency limit
c17f86f
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import pipeline
import pandas as pd
from random import sample
import gradio as gr
import torchaudio
import torch
import torch.nn as nn
import lightning_module
import pdb
import jiwer
from local.convert_metrics import nat2avaMOS, WER2INTELI
from local.indicator_plot import Intelligibility_Plot, Naturalness_Plot
from local.pitch_contour import draw_spec_db_pitch, draw_intensity, draw_pitch
import parselmouth
# Report part
from local.pdf_generator import MedicalReport
# from local.pdf_generator import generate_report
from datetime import datetime
from time import time
import io
# ASR part
import csv
csv.field_size_limit(100000000)
p = pipeline("automatic-speech-recognition")
# WER part
transformation = jiwer.Compose([
jiwer.ToLowerCase(),
jiwer.RemoveWhiteSpace(replace_by_space=True),
jiwer.RemoveMultipleSpaces(),
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
])
# WPM part
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
# phoneme_model = pipeline(model="vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
class ChangeSampleRate(nn.Module):
def __init__(self, input_rate: int, output_rate: int):
super().__init__()
self.output_rate = output_rate
self.input_rate = input_rate
def forward(self, wav: torch.tensor) -> torch.tensor:
# Only accepts 1-channel waveform input
wav = wav.view(wav.size(0), -1)
new_length = wav.size(-1) * self.output_rate // self.input_rate
indices = (torch.arange(new_length) *
(self.input_rate / self.output_rate))
round_down = wav[:, indices.long()]
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + \
round_up * indices.fmod(1.).unsqueeze(0)
return output
model = lightning_module.BaselineLightningModule.load_from_checkpoint(
"epoch=3-step=7459.ckpt").eval()
# add java script code for front right logo
# Inputs
ref = gr.Textbox(placeholder="Input reference here (Don't keep this empty)",
label="Reference", container=True, show_label=True, lines=3, min_width=50, type="text")
exp_id = gr.Textbox(placeholder="ID", label="ID", visible=False)
input_audio = gr.Audio(label="Audio to evaluate", sources=None,
interactive=True, type="filepath", container=True)
# Sub input
pitch_uplimit = gr.Slider(minimum=0, maximum=1000, step=50,
value=1000, label="Pitch Upper Limit", interactive=True)
pitch_downlimit = gr.Slider(value=0, minimum=0, maximum=1000,
step=50, label="Pitch Lower Limit", interactive=True)
pitch_update_button = gr.Button(
value="Generate Pitch Plot", elem_classes="primary", variant="primary")
# Outputs
nat_score = gr.Textbox(placeholder="Naturalness Score, ranged from 1 to 5, the higher the better.",
label="Naturalness Score, ranged from 1 to 5, the higher the better.", visible=False)
nat_plot = gr.Plot(
label="Naturalness Score, ranged from 1 to 5, the higher the better.", show_label=False)
int_score = gr.Textbox(placeholder="Intelligibility Score",
label="Intelligibility Score, range from 0 to 100, the higher the better", visible=False)
int_plot = gr.Plot(
label="Intelligibility Score, range from 0 to 100, the higher the better", show_label=False)
Mean_DB = gr.Label(label="Mean Decibel", show_label=False)
Mean_Pitch = gr.Label(label="Mean Decibel", show_label=False)
hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis",
show_label=False, min_width=50, lines=3)
phonemes = gr.Textbox(placeholder="Predicted Phonemes",
label="Predicted Phonemes", visible=False)
spk_rate = gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes",
label="Speaking Rate, Phonemes per minutes", visible=False)
pitch_plot = gr.Plot(label="Pitch Contour and dB Analysis",
show_label=False, container=True)
patient_name = gr.Textbox(placeholder="Patient Name",
label="Patient Name", type="text")
slp_name = gr.Textbox(placeholder="SLP Name", label="SLP Name", type="text")
# Report
output_report = gr.File(interactive=False, type="filepath", visible=True)
report_trigger = gr.DownloadButton(
label="Download Report", value="report.pdf", elem_classes="primary", variant="primary", visible=True)
# If you adjust pitch_plot, reflect modification in report
pdf_update_button = gr.Button(value="Update Pitch in Report")
# Analysis
run_all = gr.Button(value="Submit", elem_classes="primary", variant="primary")
def report_change():
return (gr.File(value="report.pdf", visible=False))
def test_audio(audio_path):
try:
audio_path is not None
except:
raise ValueError("Audio file is required.")
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
# get decibel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
return out_wavs
def pitch_update(audio_path, pitch_downlimit, pitch_uplimit):
f0_db_fig = draw_spec_db_pitch(audio_path, low_limit=pitch_downlimit, high_limit=pitch_uplimit, save_fig_path=None)
return f0_db_fig
def calc_mos(_, audio_path, ref, Patient_Name, SLP_Name, pitch_downlimit, pitch_uplimit):
try:
audio_path is not None
except:
raise ValueError("Audio file is required.")
wav, sr = torchaudio.load(audio_path, channels_first=True)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Mono channel
# get decibel
osr = 16_000
batch = wav.unsqueeze(0).repeat(10, 1, 1)
csr = ChangeSampleRate(sr, osr)
out_wavs = csr(wav)
# out_wavs = test_audio(audio_path)
db = torchaudio.transforms.AmplitudeToDB(stype="amplitude", top_db=80)(wav)
# ASR
trans = p(audio_path)["text"]
# WER
wer = jiwer.wer(ref, trans, truth_transform=transformation,
hypothesis_transform=transformation)
# WER convert to Intellibility score
INTELI_score = WER2INTELI(wer*100)
INT_fig = Intelligibility_Plot(INTELI_score)
# MOS
batch = {
'wav': out_wavs,
'domains': torch.tensor([0]),
'judge_id': torch.tensor([288])
}
with torch.no_grad():
output = model(batch)
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
# MOS to AVA MOS
AVA_MOS = nat2avaMOS(predic_mos)
# round to 1 decimal
AVA_MOS = round(AVA_MOS, 1)
MOS_fig = Naturalness_Plot(AVA_MOS)
# Phonemes per minute (PPM)
# with torch.no_grad():
# logits = phoneme_model(out_wavs).logits
# phone_predicted_ids = torch.argmax(logits, dim=-1)
# phone_transcription = processor.batch_decode(phone_predicted_ids)
# lst_phonemes = phone_transcription[0].split(" ")
# wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
# draw f0 and db analysis plot
f0_db_fig = draw_spec_db_pitch(audio_path, low_limit=pitch_downlimit, high_limit=pitch_uplimit, save_fig_path=None,)
# buffer to show fig
f0_db_fig_buffer = io.BytesIO()
f0_db_fig_bytes = f0_db_fig.savefig(f0_db_fig_buffer, format="jpg")
f0_db_fig_buffer.seek(0)
pitch_plot_for_pdf = f0_db_fig_buffer.read()
f0_db_fig_buffer.close()
# ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
# mean pitch
snd = parselmouth.Sound(audio_path)
pitch = snd.to_pitch().selected_array['frequency']
# remove outliers base on pitch_downlimit and pitch_uplimit
pitch = pitch[(pitch > pitch_downlimit) & (pitch < pitch_uplimit)]
pitch_mean = round(pitch[pitch.nonzero()].mean(), 2)
intensity = snd.to_intensity().values
# remove silence < 20
# mean intensity
intensity_mean = round(intensity[intensity > 20].mean(), 2)
# Error fix for phone_transcription ppm
phone_transcription = "None"
ppm = 100
# today's date
from datetime import datetime
current_time_with_UTC = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
# generate PDF report
if Patient_Name is None:
Patient_Name = "Undefined"
if SLP_Name is None:
SLP_Name = "Undefined"
report = MedicalReport(str(Patient_Name), current_time_with_UTC, SLP_Name, MOS_fig.to_image("jpg"), INT_fig.to_image("jpg"),
ref=ref, hyp=trans, pitch_plot=pitch_plot_for_pdf,
mean_pitch=pitch_mean, mean_decibel=intensity_mean
)
# output report
# rename pdf with current time
pdf_name = str(Patient_Name) + "_" + SLP_Name + \
"_""report_" + current_time_with_UTC + ".pdf"
report.generate_report(pdf_name)
return round(AVA_MOS, 1), MOS_fig, round(INTELI_score, 1), INT_fig, trans, phone_transcription, ppm, f0_db_fig, pdf_name, str(pitch_mean), str(intensity_mean)
with open("local/description.html") as f:
description = f.read()
examples = [
[None, None, "Once upon a time, there was a young rat named Arthur who couldn't make up his mind."],
[None, None, "Whenever the other rats asked him if he would like to go hunting with them, he would answer in a soft voice, 'I don't know.'"],
[None, None, "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow."],
[None, None, "The rainbow is a division of white light into many beautiful colors."],
[None, None, "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon."],
[None, None, "There is, according to legend, a boiling pot of gold at one end."],
[None, None, "People look, but no one ever finds it."],
[None, None, "When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow."],
[None, None, "Throughout the centuries people have explained the rainbow in various ways."],
[None, None, "Some have accepted it as a miracle without physical explanation."]
]
def change_button_visibility(list_of_conponents):
for comp in list_of_conponents:
comp.value is not None
with gr.Blocks(theme="KevinGeng/Laronix_ver2") as demo:
# with gr.Row():
# gr.HTML(description, elem_classes="description")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('''
# 1. Enter session details.
''')
with gr.Row():
patient_name.render()
slp_name.render()
gr.Markdown('''
# 2. Choose a sentence to read.
Copy & paste or type your sentence in the "Reference" field. \n
Alternatively, you can also choose one from the "Examples" list below.
''')
ref.render()
examples = gr.Examples(examples=[
[None, None, "Once upon a time, there was a young rat named Arthur who couldn't make up his mind."],
[None, None, "Whenever the other rats asked him if he would like to go hunting with them, he would answer in a soft voice, 'I don't know.'"],
[None, None, "When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow."],
[None, None, "The rainbow is a division of white light into many beautiful colors."],
[None, None, "These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon."],
[None, None, "There is, according to legend, a boiling pot of gold at one end."],
[None, None, "People look, but no one ever finds it."],
[None, None, "When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of the rainbow."],
[None, None, "Throughout the centuries people have explained the rainbow in various ways."],
[None, None, "Some have accepted it as a miracle without physical explanation."]
],
inputs=[exp_id, input_audio, ref], examples_per_page=5, label="Examples"),
exp_id.render()
with gr.Column(scale=1):
gr.Markdown('''
# 3. Record or Upload audio.
Either click on the 🎤 microphone icon, then hit “Record” and ask your patient to read the sentence in the “Reference” field.\n
Alternatively, you can also upload an audio file of the reference sentence that you recorded before.
''')
input_audio.render()
with gr.Row():
clear = gr.ClearButton(
components=[ref, input_audio, exp_id], value="Clear All")
run_all.render()
run_all.click(fn=calc_mos, inputs=[exp_id, input_audio, ref, patient_name, slp_name, pitch_downlimit, pitch_uplimit], outputs=[
nat_score, nat_plot, int_score, int_plot, hyp, phonemes, spk_rate, pitch_plot, output_report, Mean_Pitch, Mean_DB], scroll_to_output=True)
# once run button is clicked, show nat_score, nat_plot, int_score, int_plot, these are invisible by default
with gr.Row():
gr.HTML("<hr>")
with gr.Row():
gr.Markdown('''
# 4. Results.
''')
with gr.Row():
with gr.Column(scale=1):
gr.Markdown('''
## Naturalness:
Naturalness score is ranged from 1 to 5, the higher the better.
''')
nat_score.render()
nat_plot.render()
with gr.Column(scale=1):
gr.Markdown('''
## Intelligibility:
Intelligibility score is ranged from 0 to 100, the higher the better.
''')
int_score.render()
int_plot.render()
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Average fundamental frequency")
Mean_Pitch.render()
with gr.Column(scale=1):
gr.Markdown("## Average intensity (dB)")
Mean_DB.render()
gr.Markdown("## Pitch Contour and dB Analysis\n" +
"If you want to change the pitch range to remove the outliers:\n"+
"1. Select the proper pitch range.\n"+
"2. Click 'Generate Pitch Plot'.\n"+
"3. Click 'Update Plot in Report'.\n")
with gr.Row():
pitch_downlimit.render()
pitch_uplimit.render()
pitch_update_button.render() # Todo: reflect result in final report
pitch_update_button.click(fn=pitch_update, inputs=[
input_audio, pitch_downlimit, pitch_uplimit], outputs=[pitch_plot])
pdf_update_button.render()
pdf_update_button.click(fn=calc_mos, inputs=[exp_id, input_audio, ref, patient_name, slp_name, pitch_downlimit, pitch_uplimit], outputs=[
nat_score, nat_plot, int_score, int_plot, hyp, phonemes, spk_rate, pitch_plot, output_report, Mean_Pitch, Mean_DB], scroll_to_output=True)
pitch_plot.render()
with gr.Column(scale=1):
gr.Markdown("## Hypothesis")
hyp.render()
phonemes.render()
spk_rate.render()
# button for PDF download
# Once the Submit button is clicked, nat_plot.value , int_plot.value are not None, the Download Report button will be visible
# generate report
# patient_name.render()
# slp_name.render()
gr.Markdown("## PDF Report")
output_report.render()
# report_trigger.render()
# y = calc_mos(None, "test.wav", ref="xxx", Patient_Name=None, SLP_Name=None)
# x = update_pdf("xxx", "xxx", nat_plot, int_plot, pitch_plot, "1", "1", "xxx", "xxx")
demo.queue(default_concurrency_limit=10,
max_size=50)
demo.launch(share=True, allowed_paths=["local"])
# iface = gr.Interface(
# fn=calc_mos,
# inputs=[gr.Textbox(placeholder="ID", label="ID", visible=False),
# gr.Audio(type='filepath', label="Audio to evaluate"),
# gr.Textbox(placeholder="Input reference here (Don't keep this empty)", label="Reference")],
# outputs=[gr.Textbox(placeholder="Naturalness Score, ranged from 1 to 5, the higher the better.", label="Naturalness Score, ranged from 1 to 5, the higher the better.", visible=False),
# gr.Plot(label="Naturalness Score, ranged from 1 to 5, the higher the better.", show_label=True, container=True),
# gr.Textbox(placeholder="Intelligibility Score", label = "Intelligibility Score, range from 0 to 100, the higher the better", visible=False),
# gr.Plot(label="Intelligibility Score, range from 0 to 100, the higher the better", show_label=True, container=True),
# gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
# gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes", visible=False),
# gr.Textbox(placeholder="Speaking Rate, Phonemes per minutes", label="Speaking Rate, Phonemes per minutes", visible=False),
# gr.Plot(label="Pitch Contour and dB Analysis", show_label=True, container=True)],
# title="Speech Analysis by Laronix AI",
# description=description,
# allow_flagging="auto",
# examples=examples,
# cache_examples=False,
# )
# # Currently remove PPM and Phonemes
# # add password to protect the interface
# # read the account and password
df = pd.read_csv('./local/auth.info',
names=['username', 'password'], header=None)
auth_info_dict = df.set_index('username').T.to_dict('list')
# # authenticator for username and password
def auth_func(username, password):
if username in auth_info_dict:
if password == auth_info_dict[username][0]:
return True
return False
# iface.queue(default_concurrency_limit=5, max_size=10)