KevinGeng commited on
Commit
5f636d7
1 Parent(s): bf74f19

change Arthur.config to lower standarad

Browse files
Files changed (4) hide show
  1. app.py +68 -51
  2. config/Arthur.yaml +4 -4
  3. local/check_data.py +8 -7
  4. local/convert_metrics.py +73 -0
app.py CHANGED
@@ -19,7 +19,7 @@ from transformers import pipeline
19
  import librosa
20
  import librosa.display
21
  import matplotlib.pyplot as plt
22
-
23
 
24
  # Google cloud service
25
  from googleapiclient.discovery import build
@@ -229,8 +229,14 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
229
  "GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
230
  )
231
 
232
- # Google Drive saving # TODO
233
- click_google_saving(audio_path)
 
 
 
 
 
 
234
 
235
  return (
236
  fig_h,
@@ -240,6 +246,7 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
240
  phone_transcription,
241
  ppm,
242
  error_msg,
 
243
  )
244
 
245
  with open("src/description.html", "r", encoding="utf-8") as f:
@@ -345,7 +352,6 @@ def generate_now_time_wav():
345
 
346
  # Add google drive cloud saving
347
  def click_google_saving(audio_file,
348
-
349
  ):
350
  # reference_id,
351
  # reference_textbox,
@@ -361,8 +367,7 @@ def click_google_saving(audio_file,
361
 
362
  request = service.files().create(
363
  media_body=media,
364
- body={'name': name,
365
- }
366
  )
367
  # 'reference_id': reference_id,
368
  # "reference_textbox": reference_textbox,
@@ -373,6 +378,8 @@ def click_google_saving(audio_file,
373
  # "ppm": ppm,
374
  # "msg": msg
375
  response = request.execute()
 
 
376
  # return response.get('id')
377
 
378
 
@@ -404,18 +411,19 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
404
  interactive=True,
405
  elem_classes="ref_text",
406
  )
407
- with gr.Accordion("Input for Development", open=False):
408
- reference_id = gr.Textbox(
409
- value="ID",
410
- placeholder="Utter ID",
411
- label="Reference_ID",
412
- visible=True,
413
- )
414
- reference_PPM = gr.Textbox(
415
- placeholder="Pneumatic Voice's PPM",
416
- label="Ref PPM",
417
- visible=True,
418
- )
 
419
  with gr.Row():
420
  b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit")
421
 
@@ -440,42 +448,48 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
440
  interactive=False,
441
  elem_classes="message",
442
  )
443
- with gr.Accordion("Output for Development", open=False):
444
- wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True)
445
 
446
- predict_mos = gr.Textbox(
447
- placeholder="Predicted MOS",
448
- label="Predicted MOS",
449
- visible=True,
450
- )
451
 
452
- hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True)
453
 
454
- wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True)
455
 
456
- predict_pho = gr.Textbox(
457
- placeholder="Predicted Phonemes",
458
- label="Predicted Phonemes",
459
- visible=True,
460
- )
461
 
462
- ppm = gr.Textbox(
463
- placeholder="Phonemes per minutes",
464
- label="PPM",
465
- visible=True,
466
- )
467
- outputs = [
468
- wav_plot,
469
- predict_mos,
470
- hyp,
471
- wer,
472
- predict_pho,
473
- ppm,
474
- msg,
475
- ]
 
 
 
 
 
 
476
 
477
- # b = gr.Button("Submit")
478
- b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit")
479
 
480
  # Logger
481
  callback.setup(
@@ -488,10 +502,11 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
488
  hyp,
489
  wer,
490
  ppm,
491
- msg],
 
492
  flagging_dir="./exp/%s" % config["exp_id"],
493
  )
494
-
495
  with gr.Row():
496
  b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save")
497
  js_confirmed_saving = "(x) => confirm('Recording Saved!')"
@@ -508,6 +523,7 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
508
  wer,
509
  ppm,
510
  msg,
 
511
  ],
512
  outputs=None,
513
  preprocess=False,
@@ -527,9 +543,10 @@ with gr.Blocks(css=css, theme=my_theme) as demo:
527
  wer,
528
  ppm,
529
  msg,
 
530
  ],
531
  value="3.Clear All",
532
  elem_id="clear",
533
  )
534
 
535
- demo.launch(share=True)
 
19
  import librosa
20
  import librosa.display
21
  import matplotlib.pyplot as plt
22
+ from local.convert_metrics import nat2avaMOS, WER2INTELI
23
 
24
  # Google cloud service
25
  from googleapiclient.discovery import build
 
229
  "GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."
230
  )
231
 
232
+ # Google Drive saving
233
+ saved_google_id = None
234
+ if error_msg == ("GOOD JOB! Please 【Save the Recording】.\nYou can start recording the next sample."):
235
+ saved_google_id = click_google_saving(audio_path)
236
+ # TODO: add saved_google_id to the csv file
237
+ ## else:
238
+ ## TODO: clear all output as start recording again
239
+ ## print("Saving Failed")
240
 
241
  return (
242
  fig_h,
 
246
  phone_transcription,
247
  ppm,
248
  error_msg,
249
+ saved_google_id,
250
  )
251
 
252
  with open("src/description.html", "r", encoding="utf-8") as f:
 
352
 
353
  # Add google drive cloud saving
354
  def click_google_saving(audio_file,
 
355
  ):
356
  # reference_id,
357
  # reference_textbox,
 
367
 
368
  request = service.files().create(
369
  media_body=media,
370
+ body={'name': name, }
 
371
  )
372
  # 'reference_id': reference_id,
373
  # "reference_textbox": reference_textbox,
 
378
  # "ppm": ppm,
379
  # "msg": msg
380
  response = request.execute()
381
+ # get saved file id
382
+ return response.get('id')
383
  # return response.get('id')
384
 
385
 
 
411
  interactive=True,
412
  elem_classes="ref_text",
413
  )
414
+ with gr.Row():
415
+ with gr.Accordion("Input for Development", open=False):
416
+ reference_id = gr.Textbox(
417
+ value="ID",
418
+ placeholder="Utter ID",
419
+ label="Reference_ID",
420
+ visible=True,
421
+ )
422
+ reference_PPM = gr.Textbox(
423
+ placeholder="Pneumatic Voice's PPM",
424
+ label="Ref PPM",
425
+ visible=True,
426
+ )
427
  with gr.Row():
428
  b = gr.Button(value="1.Submit", variant="primary", elem_classes="submit")
429
 
 
448
  interactive=False,
449
  elem_classes="message",
450
  )
451
+ with gr.Accordion("Output for Development", open=False):
452
+ wav_plot = gr.Plot(PlaceHolder="Wav/Pause Plot", label="wav_pause_plot", visible=True)
453
 
454
+ predict_mos = gr.Textbox(
455
+ placeholder="Predicted MOS",
456
+ label="Predicted MOS",
457
+ visible=True,
458
+ )
459
 
460
+ hyp = gr.Textbox(placeholder="Hypothesis", label="Hypothesis", visible=True)
461
 
462
+ wer = gr.Textbox(placeholder="Word Error Rate", label="WER", visible=True)
463
 
464
+ predict_pho = gr.Textbox(
465
+ placeholder="Predicted Phonemes",
466
+ label="Predicted Phonemes",
467
+ visible=True,
468
+ )
469
 
470
+ ppm = gr.Textbox(
471
+ placeholder="Phonemes per minutes",
472
+ label="PPM",
473
+ visible=True,
474
+ )
475
+ saved_google_drive_id = gr.Textbox(
476
+ placeholder="Saved Google Drive ID",
477
+ label="Saved Google Drive ID",
478
+ visible=True,
479
+ )
480
+ outputs = [
481
+ wav_plot,
482
+ predict_mos,
483
+ hyp,
484
+ wer,
485
+ predict_pho,
486
+ ppm,
487
+ msg,
488
+ saved_google_drive_id
489
+ ]
490
 
491
+ # b = gr.Button("Submit")
492
+ b.click(fn=calc_mos, inputs=inputs, outputs=outputs, api_name="Submit")
493
 
494
  # Logger
495
  callback.setup(
 
502
  hyp,
503
  wer,
504
  ppm,
505
+ msg,
506
+ saved_google_drive_id],
507
  flagging_dir="./exp/%s" % config["exp_id"],
508
  )
509
+ # Saving the Recording to CSV Logger (TO BE DELETED)
510
  with gr.Row():
511
  b2 = gr.Button("2. Save the Recording", variant="primary", elem_id="save")
512
  js_confirmed_saving = "(x) => confirm('Recording Saved!')"
 
523
  wer,
524
  ppm,
525
  msg,
526
+ saved_google_drive_id
527
  ],
528
  outputs=None,
529
  preprocess=False,
 
543
  wer,
544
  ppm,
545
  msg,
546
+ saved_google_drive_id
547
  ],
548
  value="3.Clear All",
549
  elem_id="clear",
550
  )
551
 
552
+ demo.launch(share=False)
config/Arthur.yaml CHANGED
@@ -3,10 +3,10 @@ ref_txt: data/Arthur_the_rat.txt
3
  ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
4
  ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
5
  thre:
6
- minppm: 100
7
- maxppm: 100
8
- WER: 0.1
9
- AUTOMOS: 4.0
10
  auth:
11
  username: Kath
12
  password: Kath
 
3
  ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
4
  ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
5
  thre:
6
+ minppm: 300
7
+ maxppm: 300
8
+ WER: 0.5
9
+ AUTOMOS: 2.0
10
  auth:
11
  username: Kath
12
  password: Kath
local/check_data.py CHANGED
@@ -2,9 +2,11 @@ from googleapiclient.discovery import build
2
  from google.oauth2 import service_account
3
  from googleapiclient.http import MediaFileUpload
4
  import pdb
5
- pdb.set_trace()
6
 
7
  import gradio as gr
 
 
 
8
 
9
  # 来自Google Cloud控制台的JSON凭据文件
10
  credentials_file = "./src/peerless-window-254907-b386b71c0d99.json"
@@ -22,13 +24,14 @@ files = results.get('files', [])
22
  print(files)
23
  from googleapiclient.http import MediaIoBaseDownload
24
  import io
 
25
 
26
- file_id = "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
 
27
  # Get the file's metadata
28
  file = service.files().get(fileId=file_id).execute()
29
 
30
- pdb.set_trace()
31
- request = service.files().get_media(fileId="1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg")
32
  with open(file['name'], 'wb') as file_obj:
33
  downloader = MediaIoBaseDownload(file_obj, request)
34
  done = False
@@ -38,6 +41,4 @@ with open(file['name'], 'wb') as file_obj:
38
 
39
  print(f"Downloaded: {file['name']}")
40
 
41
- pdb.set_trace()
42
-
43
- # print('文件ID:%s' % response.get('id'))
 
2
  from google.oauth2 import service_account
3
  from googleapiclient.http import MediaFileUpload
4
  import pdb
 
5
 
6
  import gradio as gr
7
+ '''
8
+ Usage: python loacl/checkdata.py <google_file_id>
9
+ '''
10
 
11
  # 来自Google Cloud控制台的JSON凭据文件
12
  credentials_file = "./src/peerless-window-254907-b386b71c0d99.json"
 
24
  print(files)
25
  from googleapiclient.http import MediaIoBaseDownload
26
  import io
27
+ import sys
28
 
29
+ file_id = sys.argv[1]
30
+ # "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
31
  # Get the file's metadata
32
  file = service.files().get(fileId=file_id).execute()
33
 
34
+ request = service.files().get_media(fileId=file_id)
 
35
  with open(file['name'], 'wb') as file_obj:
36
  downloader = MediaIoBaseDownload(file_obj, request)
37
  done = False
 
41
 
42
  print(f"Downloaded: {file['name']}")
43
 
44
+ pdb.set_trace()
 
 
local/convert_metrics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ # Natural MOS to AVA MOS
5
+
6
+ def linear_function(x):
7
+ m = (4 - 1) / (1.5 - 1)
8
+ b = 1 - m * 1
9
+ return m * x + b
10
+
11
+ def quadratic_function(x):
12
+ return -0.0816 * (x - 5) ** 2 + 5
13
+
14
+ # Natural MOS to AVA MOS
15
+ def nat2avaMOS(x):
16
+ if x <= 1.5:
17
+ return linear_function(x)
18
+ elif x >1.5 and x <= 5:
19
+ return quadratic_function(x)
20
+
21
+ # Word error rate to Intellibility Score (X is percentage)
22
+ def WER2INTELI(x):
23
+ if x <= 10:
24
+ return 100
25
+ elif x <= 100:
26
+ slope = (30 - 100) / (100 - 10)
27
+ intercept = 100 - slope * 10
28
+ return slope * x + intercept
29
+ else:
30
+ return 100 * np.exp(-0.01 * (x - 100))
31
+
32
+ # 生成 x 值
33
+ # x = np.linspace(0, 200, 400) # 从0到200生成400个点
34
+
35
+ # 计算对应的 y 值
36
+ # y = [WER2INTELI(xi) for xi in x]
37
+
38
+
39
+ # plt.plot(x, y)
40
+ # plt.xlabel('x')
41
+ # plt.ylabel('f(x)')
42
+ # plt.title('Custom Function')
43
+ # plt.grid(True)
44
+ # plt.show()
45
+
46
+ # 生成 x 值的范围
47
+ x1 = np.linspace(1, 1.5, 100)
48
+ x2 = np.linspace(1.5, 5, 100)
49
+
50
+ # 计算对应的 y 值
51
+ y1 = linear_function(x1)
52
+ y2 = quadratic_function(x2)
53
+
54
+ # 绘制线性部分
55
+ plt.plot(x1, y1, label='Linear Function (1 <= x <= 1.5)')
56
+
57
+ # 绘制二次部分
58
+ plt.plot(x2, y2, label='Quadratic Function (1.5 <= x <= 5)')
59
+
60
+ # 添加标签和标题
61
+ plt.xlabel('Natural Mean Opinion Score')
62
+ plt.ylabel('AVA Mean Opinion Score')
63
+ plt.title('nat2avaMOS')
64
+
65
+ # 添加图例
66
+ plt.legend()
67
+
68
+ # 显示图形
69
+ plt.grid(True)
70
+
71
+ # 显示图像
72
+ # plt.savefig("./local/nat2avaMOS.png")
73
+ # plt.savefig("./local/WER2INT.png")