kAIto47802 commited on
Commit
8537948
·
1 Parent(s): a18d920

Fix and add quick option

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -24,24 +24,25 @@ cfg.config = "fusion_stage3"
24
  cfg.print_config = False
25
  cfg.data_config = None
26
  cfg.phase = "inference"
27
- cfg.weight = None
28
  cfg.num_workers = 1
29
 
30
  @spaces.GPU
31
  @torch.inference_mode()
32
- def predict_mos(audio_path: str, domain: str) -> float:
33
  data = pd.DataFrame({"file_path": [audio_path]})
34
  data["dataset"] = domain
35
- data['mos'] = 0
36
-
37
  preds = 0.0
38
  for fold in range(5):
39
  cfg.now_fold = fold
 
40
  model = get_model(cfg, device).eval()
41
  for _ in range(5):
42
  test_dataset = get_dataset(cfg, data, "test")
43
  p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
44
- preds += p.cpu().numpy()[0]
 
 
45
  preds /= 25.0
46
  return preds
47
 
@@ -65,12 +66,22 @@ with gr.Blocks() as demo:
65
  "blizzard2011",
66
  ],
67
  label="Data-domain ID for the MOS prediction",
68
- value="sarulab"
 
 
 
 
 
 
 
 
 
 
69
  )
70
  submit = gr.Button(value="Submit")
71
 
72
  with gr.Column():
73
  output = gr.Textbox(label="Predicted MOS", type="text")
74
- submit.click(fn=predict_mos, inputs=[audio, domain], outputs=[output])
75
 
76
  demo.queue().launch()
 
24
  cfg.print_config = False
25
  cfg.data_config = None
26
  cfg.phase = "inference"
 
27
  cfg.num_workers = 1
28
 
29
  @spaces.GPU
30
  @torch.inference_mode()
31
+ def predict_mos(audio_path: str, domain: str, quick: bool) -> float:
32
  data = pd.DataFrame({"file_path": [audio_path]})
33
  data["dataset"] = domain
34
+ data["mos"] = 0
 
35
  preds = 0.0
36
  for fold in range(5):
37
  cfg.now_fold = fold
38
+ cfg.weight = f"models/fusion_stage3/fold{fold}_s42_best_model.pth"
39
  model = get_model(cfg, device).eval()
40
  for _ in range(5):
41
  test_dataset = get_dataset(cfg, data, "test")
42
  p = model(*[torch.tensor(t).unsqueeze(0).to(device) for t in test_dataset[0][:-1]])
43
+ preds += p.cpu().numpy()[0][0]
44
+ if quick:
45
+ return preds
46
  preds /= 25.0
47
  return preds
48
 
 
66
  "blizzard2011",
67
  ],
68
  label="Data-domain ID for the MOS prediction",
69
+ value="sarulab",
70
+ )
71
+ quick = gr.Checkbox(
72
+ label="Quick prediction",
73
+ value=True,
74
+ info=(
75
+ "UTMOSv2 makes predictions repeatedly for five randomly selected frames "
76
+ "of the input speech waveform for all five folds. "
77
+ "To make quick predictions by reducing this to a single repetition, "
78
+ "check this checkbox:",
79
+ ),
80
  )
81
  submit = gr.Button(value="Submit")
82
 
83
  with gr.Column():
84
  output = gr.Textbox(label="Predicted MOS", type="text")
85
+ submit.click(fn=predict_mos, inputs=[audio, domain, quick], outputs=[output])
86
 
87
  demo.queue().launch()