Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from random import sample
|
3 |
+
import gradio as gr
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import lightning_module
|
8 |
+
import pdb
|
9 |
+
import jiwer
|
10 |
+
|
11 |
+
# ASR part
|
12 |
+
from transformers import pipeline
|
13 |
+
p = pipeline("automatic-speech-recognition")
|
14 |
+
|
15 |
+
# WER part
|
16 |
+
transformation = jiwer.Compose([
|
17 |
+
jiwer.ToLowerCase(),
|
18 |
+
jiwer.RemoveWhiteSpace(replace_by_space=True),
|
19 |
+
jiwer.RemoveMultipleSpaces(),
|
20 |
+
jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
|
21 |
+
])
|
22 |
+
|
23 |
+
# WPM part
|
24 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
25 |
+
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
26 |
+
phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
27 |
+
# phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
|
28 |
+
class ChangeSampleRate(nn.Module):
|
29 |
+
def __init__(self, input_rate: int, output_rate: int):
|
30 |
+
super().__init__()
|
31 |
+
self.output_rate = output_rate
|
32 |
+
self.input_rate = input_rate
|
33 |
+
|
34 |
+
def forward(self, wav: torch.tensor) -> torch.tensor:
|
35 |
+
# Only accepts 1-channel waveform input
|
36 |
+
wav = wav.view(wav.size(0), -1)
|
37 |
+
new_length = wav.size(-1) * self.output_rate // self.input_rate
|
38 |
+
indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
|
39 |
+
round_down = wav[:, indices.long()]
|
40 |
+
round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
|
41 |
+
output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
|
42 |
+
return output
|
43 |
+
|
44 |
+
model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
|
45 |
+
|
46 |
+
def calc_mos(audio_path, ref):
|
47 |
+
wav, sr = torchaudio.load(audio_path)
|
48 |
+
osr = 16_000
|
49 |
+
batch = wav.unsqueeze(0).repeat(10, 1, 1)
|
50 |
+
csr = ChangeSampleRate(sr, osr)
|
51 |
+
out_wavs = csr(wav)
|
52 |
+
# ASR
|
53 |
+
trans = p(audio_path)["text"]
|
54 |
+
# WER
|
55 |
+
wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
|
56 |
+
# MOS
|
57 |
+
batch = {
|
58 |
+
'wav': out_wavs,
|
59 |
+
'domains': torch.tensor([0]),
|
60 |
+
'judge_id': torch.tensor([288])
|
61 |
+
}
|
62 |
+
with torch.no_grad():
|
63 |
+
output = model(batch)
|
64 |
+
predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
|
65 |
+
# Phonemes per minute (PPM)
|
66 |
+
with torch.no_grad():
|
67 |
+
logits = phoneme_model(out_wavs).logits
|
68 |
+
phone_predicted_ids = torch.argmax(logits, dim=-1)
|
69 |
+
phone_transcription = processor.batch_decode(phone_predicted_ids)
|
70 |
+
lst_phonemes = phone_transcription[0].split(" ")
|
71 |
+
wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
|
72 |
+
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
73 |
+
|
74 |
+
return predic_mos, trans, wer, phone_transcription, ppm
|
75 |
+
|
76 |
+
description ="""
|
77 |
+
MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
|
78 |
+
This demo only accepts .wav format. Best at 16 kHz sampling rate.
|
79 |
+
|
80 |
+
Paper is available [here](https://arxiv.org/abs/2204.02152)
|
81 |
+
|
82 |
+
Add ASR based on wav2vec-960, currently only English available.
|
83 |
+
Add WER interface.
|
84 |
+
"""
|
85 |
+
|
86 |
+
|
87 |
+
iface = gr.Interface(
|
88 |
+
fn=calc_mos,
|
89 |
+
inputs=[gr.Audio(source="microphone", type='filepath', label="Audio to evaluate"),
|
90 |
+
gr.Textbox(value="Once upon a time there was a young rat named Author who couldn’t make up his mind.",
|
91 |
+
placeholder="Input referance here",
|
92 |
+
label="Referance")],
|
93 |
+
outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
|
94 |
+
gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
|
95 |
+
gr.Textbox(placeholder="Word Error Rate", label = "WER"),
|
96 |
+
gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
|
97 |
+
gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
|
98 |
+
title="Laronix's Voice Quality Checking System Demo",
|
99 |
+
description=description,
|
100 |
+
allow_flagging="auto",
|
101 |
+
)
|
102 |
+
iface.launch()
|