Spaces:
Running
on
Zero
Running
on
Zero
kAIto47802
commited on
Commit
·
8537948
1
Parent(s):
a18d920
Fix and add quick option
Browse files
app.py
CHANGED
@@ -24,24 +24,25 @@ cfg.config = "fusion_stage3"
|
|
24 |
cfg.print_config = False
|
25 |
cfg.data_config = None
|
26 |
cfg.phase = "inference"
|
27 |
-
cfg.weight = None
|
28 |
cfg.num_workers = 1
|
29 |
|
30 |
@spaces.GPU
|
31 |
@torch.inference_mode()
|
32 |
-
def predict_mos(audio_path: str, domain: str) -> float:
|
33 |
data = pd.DataFrame({"file_path": [audio_path]})
|
34 |
data["dataset"] = domain
|
35 |
-
data[
|
36 |
-
|
37 |
preds = 0.0
|
38 |
for fold in range(5):
|
39 |
cfg.now_fold = fold
|
|
|
40 |
model = get_model(cfg, device).eval()
|
41 |
for _ in range(5):
|
42 |
test_dataset = get_dataset(cfg, data, "test")
|
43 |
p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
|
44 |
-
preds += p.cpu().numpy()[0]
|
|
|
|
|
45 |
preds /= 25.0
|
46 |
return preds
|
47 |
|
@@ -65,12 +66,22 @@ with gr.Blocks() as demo:
|
|
65 |
"blizzard2011",
|
66 |
],
|
67 |
label="Data-domain ID for the MOS prediction",
|
68 |
-
value="sarulab"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
)
|
70 |
submit = gr.Button(value="Submit")
|
71 |
|
72 |
with gr.Column():
|
73 |
output = gr.Textbox(label="Predicted MOS", type="text")
|
74 |
-
submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
|
75 |
|
76 |
demo.queue().launch()
|
|
|
24 |
cfg.print_config = False
|
25 |
cfg.data_config = None
|
26 |
cfg.phase = "inference"
|
|
|
27 |
cfg.num_workers = 1
|
28 |
|
29 |
@spaces.GPU
|
30 |
@torch.inference_mode()
|
31 |
+
def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
|
32 |
data = pd.DataFrame({"file_path": [audio_path]})
|
33 |
data["dataset"] = domain
|
34 |
+
data["mos"] = 0
|
|
|
35 |
preds = 0.0
|
36 |
for fold in range(5):
|
37 |
cfg.now_fold = fold
|
38 |
+
cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
|
39 |
model = get_model(cfg, device).eval()
|
40 |
for _ in range(5):
|
41 |
test_dataset = get_dataset(cfg, data, "test")
|
42 |
p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
|
43 |
+
preds += p.cpu().numpy()[0][0]
|
44 |
+
if quick:
|
45 |
+
return preds
|
46 |
preds /= 25.0
|
47 |
return preds
|
48 |
|
|
|
66 |
"blizzard2011",
|
67 |
],
|
68 |
label="Data-domain ID for the MOS prediction",
|
69 |
+
value="sarulab",
|
70 |
+
)
|
71 |
+
quick = gr.Checkbox(
|
72 |
+
label="Quick prediction",
|
73 |
+
value=True,
|
74 |
+
info=(
|
75 |
+
"UTMOSv2 makes predictions repeatedly for five randomly selected frames "
|
76 |
+
"of the input speech waveform for all five folds. "
|
77 |
+
"To make quick predictions by reducing this to a single repetition, "
|
78 |
+
"check this checkbox:",
|
79 |
+
),
|
80 |
)
|
81 |
submit = gr.Button(value="Submit")
|
82 |
|
83 |
with gr.Column():
|
84 |
output = gr.Textbox(label="Predicted MOS", type="text")
|
85 |
+
submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])
|
86 |
|
87 |
demo.queue().launch()
|