Spaces:
Runtime error
Runtime error
add loose threshold/ remove speed limitation
Browse files- .gitignore +2 -1
- app.py +22 -6
- config/Arthur.yaml +2 -5
- local/check_data.py +15 -1
- local/indicator_plot.py +97 -0
.gitignore
CHANGED
@@ -540,4 +540,5 @@ user/
|
|
540 |
|
541 |
.vscode
|
542 |
|
543 |
-
!data/Patient_sil_trim_16k_normed_5_snr_40/*
|
|
|
|
540 |
|
541 |
.vscode
|
542 |
|
543 |
+
!data/Patient_sil_trim_16k_normed_5_snr_40/*
|
544 |
+
downloads
|
app.py
CHANGED
@@ -178,6 +178,13 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
178 |
truth_transform=transformation,
|
179 |
hypothesis_transform=transformation,
|
180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# MOS
|
182 |
batch = {
|
183 |
"wav": out_wavs,
|
@@ -187,7 +194,12 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
187 |
with torch.no_grad():
|
188 |
output = model(batch)
|
189 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
191 |
# Phonemes per minute (PPM)
|
192 |
with torch.no_grad():
|
193 |
logits = phoneme_model(out_wavs).logits
|
@@ -204,6 +216,10 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
204 |
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
205 |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
206 |
|
|
|
|
|
|
|
|
|
207 |
error_msg = "!!! ERROR MESSAGE !!!\n"
|
208 |
if audio_path == _ or audio_path == None:
|
209 |
error_msg += "ERROR: Fail recording, Please start from the beginning again."
|
@@ -216,11 +232,11 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
216 |
ppm,
|
217 |
error_msg,
|
218 |
)
|
219 |
-
if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
|
220 |
-
|
221 |
-
elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
|
222 |
-
|
223 |
-
|
224 |
error_msg += "ERROR: Naturalness is too low, Please try again.\n"
|
225 |
elif wer >= float(config["thre"]["WER"]):
|
226 |
error_msg += "ERROR: Intelligibility is too low, Please try again\n"
|
|
|
178 |
truth_transform=transformation,
|
179 |
hypothesis_transform=transformation,
|
180 |
)
|
181 |
+
|
182 |
+
# round to 1 decimal
|
183 |
+
wer = np.round(wer, 1)
|
184 |
+
|
185 |
+
# WER convert to Intellibility score
|
186 |
+
INTELI_score = WER2INTELI(wer*100)
|
187 |
+
|
188 |
# MOS
|
189 |
batch = {
|
190 |
"wav": out_wavs,
|
|
|
194 |
with torch.no_grad():
|
195 |
output = model(batch)
|
196 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
197 |
+
|
198 |
+
# round to 1 decimal
|
199 |
+
predic_mos = np.round(predic_mos, 1)
|
200 |
+
|
201 |
+
# MOS to AVA MOS
|
202 |
+
AVA_MOS = nat2avaMOS(predic_mos)
|
203 |
# Phonemes per minute (PPM)
|
204 |
with torch.no_grad():
|
205 |
logits = phoneme_model(out_wavs).logits
|
|
|
216 |
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
217 |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
218 |
|
219 |
+
|
220 |
+
ppm = np.round(ppm, 1)
|
221 |
+
|
222 |
+
|
223 |
error_msg = "!!! ERROR MESSAGE !!!\n"
|
224 |
if audio_path == _ or audio_path == None:
|
225 |
error_msg += "ERROR: Fail recording, Please start from the beginning again."
|
|
|
232 |
ppm,
|
233 |
error_msg,
|
234 |
)
|
235 |
+
# if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
|
236 |
+
# error_msg += "ERROR: Please speak slower.\n"
|
237 |
+
# elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
|
238 |
+
# error_msg += "ERROR: Please speak faster.\n"
|
239 |
+
if predic_mos <= float(config["thre"]["AUTOMOS"]):
|
240 |
error_msg += "ERROR: Naturalness is too low, Please try again.\n"
|
241 |
elif wer >= float(config["thre"]["WER"]):
|
242 |
error_msg += "ERROR: Intelligibility is too low, Please try again\n"
|
config/Arthur.yaml
CHANGED
@@ -3,10 +3,7 @@ ref_txt: data/Arthur_the_rat.txt
|
|
3 |
ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
|
4 |
ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
|
5 |
thre:
|
6 |
-
minppm:
|
7 |
-
maxppm:
|
8 |
WER: 0.5
|
9 |
AUTOMOS: 2.0
|
10 |
-
auth:
|
11 |
-
username: Kath
|
12 |
-
password: Kath
|
|
|
3 |
ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
|
4 |
ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
|
5 |
thre:
|
6 |
+
minppm: 0
|
7 |
+
maxppm: 1000
|
8 |
WER: 0.5
|
9 |
AUTOMOS: 2.0
|
|
|
|
|
|
local/check_data.py
CHANGED
@@ -27,9 +27,23 @@ import io
|
|
27 |
import sys
|
28 |
|
29 |
file_id = sys.argv[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
|
31 |
# Get the file's metadata
|
32 |
-
|
|
|
33 |
|
34 |
request = service.files().get_media(fileId=file_id)
|
35 |
with open(file['name'], 'wb') as file_obj:
|
|
|
27 |
import sys
|
28 |
|
29 |
file_id = sys.argv[1]
|
30 |
+
if file_id == "all":
|
31 |
+
results = service.files().list().execute()
|
32 |
+
files = results.get('files', [])
|
33 |
+
# download all files
|
34 |
+
for file in files:
|
35 |
+
request = service.files().get_media(fileId=file['id'])
|
36 |
+
with open("download/" + file['name'], 'wb') as file_obj:
|
37 |
+
downloader = MediaIoBaseDownload(file_obj, request)
|
38 |
+
done = False
|
39 |
+
while not done:
|
40 |
+
status, done = downloader.next_chunk()
|
41 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
42 |
+
|
43 |
# "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
|
44 |
# Get the file's metadata
|
45 |
+
else:
|
46 |
+
file = service.files().get(fileId=file_id).execute()
|
47 |
|
48 |
request = service.files().get_media(fileId=file_id)
|
49 |
with open(file['name'], 'wb') as file_obj:
|
local/indicator_plot.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import plotly.graph_objects as go
|
2 |
+
|
3 |
+
def Intelligibility_Plot(Int_Score, fair_thre=30, good_thre = 70, Upper=100, Lower=0):
|
4 |
+
'''
|
5 |
+
Int_Score: a float number between 0 and 100
|
6 |
+
Upper: the upper bound of the plot
|
7 |
+
Lower: the lower bound of the plot
|
8 |
+
'''
|
9 |
+
# Assert Nat_Score is a float number between 0 and 100
|
10 |
+
assert isinstance(Int_Score, float|int)
|
11 |
+
assert Int_Score >= Lower
|
12 |
+
assert Int_Score <= Upper
|
13 |
+
# Indicator plot with different colors, under fair_threshold the plot is red, then yellow, then green
|
14 |
+
# Design 1: Show bar in different colors refer to the threshold
|
15 |
+
|
16 |
+
color = "#75DA99"
|
17 |
+
if Int_Score <= fair_thre:
|
18 |
+
color = "#F2ADA0"
|
19 |
+
elif Int_Score <= good_thre:
|
20 |
+
color = "#e8ee89"
|
21 |
+
else:
|
22 |
+
color = "#75DA99"
|
23 |
+
|
24 |
+
fig = go.Figure(go.Indicator(
|
25 |
+
mode="number+gauge",
|
26 |
+
gauge={'shape': "bullet",
|
27 |
+
'axis':{'range': [Lower, Upper]},
|
28 |
+
'bgcolor': 'white',
|
29 |
+
'bar': {'color': color},
|
30 |
+
},
|
31 |
+
value=Int_Score,
|
32 |
+
domain = {'x': [0, 1], 'y': [0, 1]},
|
33 |
+
)
|
34 |
+
)
|
35 |
+
# # Design 2: Show all thresholds in the background
|
36 |
+
# fig = go.Figure(go.Indicator(
|
37 |
+
# mode = "number+gauge",
|
38 |
+
# gauge = {'shape': "bullet",
|
39 |
+
# 'axis': {'range': [Lower, Upper]},
|
40 |
+
# 'bgcolor': 'white',
|
41 |
+
# 'steps': [
|
42 |
+
# {'range': [Lower, fair_thre], 'color': "#F2ADA0"},
|
43 |
+
# {'range': [fair_thre, good_thre], 'color': "#e8ee89"},
|
44 |
+
# {'range': [good_thre, Upper], 'color': " #75DA99"}],
|
45 |
+
# 'bar': {'color': "grey"},
|
46 |
+
# },
|
47 |
+
# value = Int_Score,
|
48 |
+
# domain = {'x': [0, 1], 'y': [0, 1]},
|
49 |
+
# )
|
50 |
+
# )
|
51 |
+
fig.update_layout(height=300, width=1000)
|
52 |
+
return fig
|
53 |
+
|
54 |
+
|
55 |
+
def Naturalness_Plot(Nat_Score, fair_thre=2, good_thre = 4, Upper=5, Lower=1.0):
|
56 |
+
'''
|
57 |
+
Int_Score: a float number between 0 and 100
|
58 |
+
Upper: the upper bound of the plot
|
59 |
+
Lower: the lower bound of the plot
|
60 |
+
'''
|
61 |
+
# Assert Nat_Score is a float number between 0 and 100
|
62 |
+
assert isinstance(Nat_Score, float|int)
|
63 |
+
assert Nat_Score >= Lower
|
64 |
+
assert Nat_Score <= Upper
|
65 |
+
|
66 |
+
# Indicator plot with different colors, under fair_threshold the plot is red, then yellow, then green
|
67 |
+
|
68 |
+
color = "#75DA99"
|
69 |
+
if Nat_Score <= fair_thre:
|
70 |
+
color = "#F2ADA0"
|
71 |
+
elif Nat_Score <= good_thre:
|
72 |
+
color = "#e8ee89"
|
73 |
+
else:
|
74 |
+
color = "#75DA99"
|
75 |
+
|
76 |
+
fig = go.Figure(go.Indicator(
|
77 |
+
mode="number+gauge",
|
78 |
+
gauge={'shape': "bullet",
|
79 |
+
'axis':{'range': [Lower, Upper]},
|
80 |
+
'bgcolor': 'white',
|
81 |
+
'bar': {'color': color},
|
82 |
+
},
|
83 |
+
value=Nat_Score,
|
84 |
+
domain = {'x': [0, 1], 'y': [0, 1]},
|
85 |
+
)
|
86 |
+
)
|
87 |
+
|
88 |
+
fig.update_layout(height=300, width=1000)
|
89 |
+
return fig
|
90 |
+
|
91 |
+
# test case Intelligibility_Plot
|
92 |
+
x = Intelligibility_Plot(10)
|
93 |
+
x.show()
|
94 |
+
x = Intelligibility_Plot(50)
|
95 |
+
x.show()
|
96 |
+
x = Intelligibility_Plot(90)
|
97 |
+
x.show()
|