KevinGeng commited on
Commit
67beffe
·
1 Parent(s): 3725f04

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from random import sample
3
+ import gradio as gr
4
+ import torchaudio
5
+ import torch
6
+ import torch.nn as nn
7
+ import lightning_module
8
+ import pdb
9
+ import jiwer
10
+
11
+ # ASR part
12
+ from transformers import pipeline
13
+ p = pipeline("automatic-speech-recognition")
14
+
15
+ # WER part
16
+ transformation = jiwer.Compose([
17
+ jiwer.ToLowerCase(),
18
+ jiwer.RemoveWhiteSpace(replace_by_space=True),
19
+ jiwer.RemoveMultipleSpaces(),
20
+ jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
21
+ ])
22
+
23
+ # WPM part
24
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
25
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
26
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
27
+ # phoneme_model = pipeline(model="facebook/wav2vec2-xlsr-53-espeak-cv-ft")
28
+ class ChangeSampleRate(nn.Module):
29
+ def __init__(self, input_rate: int, output_rate: int):
30
+ super().__init__()
31
+ self.output_rate = output_rate
32
+ self.input_rate = input_rate
33
+
34
+ def forward(self, wav: torch.tensor) -> torch.tensor:
35
+ # Only accepts 1-channel waveform input
36
+ wav = wav.view(wav.size(0), -1)
37
+ new_length = wav.size(-1) * self.output_rate // self.input_rate
38
+ indices = (torch.arange(new_length) * (self.input_rate / self.output_rate))
39
+ round_down = wav[:, indices.long()]
40
+ round_up = wav[:, (indices.long() + 1).clamp(max=wav.size(-1) - 1)]
41
+ output = round_down * (1. - indices.fmod(1.)).unsqueeze(0) + round_up * indices.fmod(1.).unsqueeze(0)
42
+ return output
43
+
44
+ model = lightning_module.BaselineLightningModule.load_from_checkpoint("epoch=3-step=7459.ckpt").eval()
45
+
46
+ def calc_mos(audio_path, ref):
47
+ wav, sr = torchaudio.load(audio_path)
48
+ osr = 16_000
49
+ batch = wav.unsqueeze(0).repeat(10, 1, 1)
50
+ csr = ChangeSampleRate(sr, osr)
51
+ out_wavs = csr(wav)
52
+ # ASR
53
+ trans = p(audio_path)["text"]
54
+ # WER
55
+ wer = jiwer.wer(ref, trans, truth_transform=transformation, hypothesis_transform=transformation)
56
+ # MOS
57
+ batch = {
58
+ 'wav': out_wavs,
59
+ 'domains': torch.tensor([0]),
60
+ 'judge_id': torch.tensor([288])
61
+ }
62
+ with torch.no_grad():
63
+ output = model(batch)
64
+ predic_mos = output.mean(dim=1).squeeze().detach().numpy()*2 + 3
65
+ # Phonemes per minute (PPM)
66
+ with torch.no_grad():
67
+ logits = phoneme_model(out_wavs).logits
68
+ phone_predicted_ids = torch.argmax(logits, dim=-1)
69
+ phone_transcription = processor.batch_decode(phone_predicted_ids)
70
+ lst_phonemes = phone_transcription[0].split(" ")
71
+ wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
72
+ ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
73
+
74
+ return predic_mos, trans, wer, phone_transcription, ppm
75
+
76
+ description ="""
77
+ MOS prediction demo using UTMOS-strong w/o phoneme encoder model, which is trained on the main track dataset.
78
+ This demo only accepts .wav format. Best at 16 kHz sampling rate.
79
+
80
+ Paper is available [here](https://arxiv.org/abs/2204.02152)
81
+
82
+ Add ASR based on wav2vec-960, currently only English available.
83
+ Add WER interface.
84
+ """
85
+
86
+
87
+ iface = gr.Interface(
88
+ fn=calc_mos,
89
+ inputs=[gr.Audio(source="microphone", type='filepath', label="Audio to evaluate"),
90
+ gr.Textbox(value="Once upon a time there was a young rat named Author who couldn’t make up his mind.",
91
+ placeholder="Input referance here",
92
+ label="Referance")],
93
+ outputs=[gr.Textbox(placeholder="Predicted MOS", label="Predicted MOS"),
94
+ gr.Textbox(placeholder="Hypothesis", label="Hypothesis"),
95
+ gr.Textbox(placeholder="Word Error Rate", label = "WER"),
96
+ gr.Textbox(placeholder="Predicted Phonemes", label="Predicted Phonemes"),
97
+ gr.Textbox(placeholder="Phonemes per minutes", label="PPM")],
98
+ title="Laronix's Voice Quality Checking System Demo",
99
+ description=description,
100
+ allow_flagging="auto",
101
+ )
102
+ iface.launch()