HoneyTian commited on
Commit
459dab4
·
1 Parent(s): 5dfbac5
examples/download_wav/step_1_download_wav.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+ import requests
9
+ from tqdm import tqdm
10
+
11
+ from project_settings import project_path
12
+
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser()
16
+
17
+ parser.add_argument(
18
+ "--excel_file_dir",
19
+ default=(project_path / "examples/download_wav").as_posix(),
20
+ type=str
21
+ )
22
+ parser.add_argument(
23
+ "--start_date",
24
+ default="2022-04-10 00:00:00",
25
+ type=str
26
+ )
27
+ parser.add_argument(
28
+ "--end_date",
29
+ default="2026-04-21 00:00:00",
30
+ type=str
31
+ )
32
+ parser.add_argument(
33
+ "--output_dir",
34
+ default=(project_path / "data/calling/358/wav_2ch").as_posix(),
35
+ type=str
36
+ )
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ excel_file_str = """
42
+ AIAgent-CallLog-20250929100824.xlsx
43
+ AIAgent-CallLog-20250929134959.xlsx
44
+ AIAgent-CallLog-20250929135030.xlsx
45
+ AIAgent-CallLog-20250929135052.xlsx
46
+ AIAgent-CallLog-20250929135122.xlsx
47
+ AIAgent-CallLog-20250929135134.xlsx
48
+ AIAgent-CallLog-20250929135209.xlsx
49
+ AIAgent-CallLog-20250929135219.xlsx
50
+ AIAgent-CallLog-20250929135247.xlsx
51
+ AIAgent-CallLog-20250929135300.xlsx
52
+ AIAgent-CallLog-20250929135311.xlsx
53
+ AIAgent-CallLog-20250929135335.xlsx
54
+ AIAgent-CallLog-20250929135344.xlsx
55
+ AIAgent-CallLog-20250929135355.xlsx
56
+ AIAgent-CallLog-20250929135443.xlsx
57
+ AIAgent-CallLog-20250929135452.xlsx
58
+ AIAgent-CallLog-20250929135501.xlsx
59
+ AIAgent-CallLog-20250929135537.xlsx
60
+ AIAgent-CallLog-20250929135544.xlsx
61
+ AIAgent-CallLog-20250929135554.xlsx
62
+ AIAgent-CallLog-20250929135630.xlsx
63
+ AIAgent-CallLog-20250929135701.xlsx
64
+ AIAgent-CallLog-20250929135710.xlsx
65
+ AIAgent-CallLog-20250929135716.xlsx
66
+ AIAgent-CallLog-20250929135755.xlsx
67
+ AIAgent-CallLog-20250929135800.xlsx
68
+ AIAgent-CallLog-20250929135809.xlsx
69
+ AIAgent-CallLog-20250929135842.xlsx
70
+ AIAgent-CallLog-20250929135849.xlsx
71
+ AIAgent-CallLog-20250929135858.xlsx
72
+ AIAgent-CallLog-20250929135909.xlsx
73
+ """
74
+
75
+
76
+ def main():
77
+ args = get_args()
78
+
79
+ format_str = "%Y-%m-%d %H:%M:%S"
80
+
81
+ start_date = datetime.strptime(args.start_date, format_str)
82
+ end_date = datetime.strptime(args.end_date, format_str)
83
+
84
+ excel_file_dir = Path(args.excel_file_dir)
85
+ output_dir = Path(args.output_dir)
86
+ output_dir.mkdir(parents=True, exist_ok=True)
87
+
88
+ print(f"start_date: {start_date}")
89
+ print(f"end_date: {end_date}")
90
+
91
+ # finished
92
+ finished = set()
93
+ for filename in output_dir.glob("*.wav"):
94
+ call_id = filename.stem
95
+ finished.add(call_id)
96
+
97
+ splits = excel_file_str.split("\n")
98
+ for row in splits:
99
+ name = str(row).strip()
100
+ if len(name) == 0:
101
+ continue
102
+ excel_file = excel_file_dir / name
103
+
104
+ df = pd.read_excel(excel_file.as_posix())
105
+ for i, row in tqdm(df.iterrows()):
106
+ call_date = row["Attempt time"]
107
+ call_id = row["Call ID"]
108
+ record_url = row["Recording file"]
109
+ if pd.isna(record_url):
110
+ continue
111
+
112
+ if call_id in finished:
113
+ continue
114
+ finished.add(call_id)
115
+
116
+ call_date = datetime.strptime(str(call_date), format_str)
117
+
118
+ if not start_date < call_date < end_date:
119
+ continue
120
+
121
+ call_date_str = call_date.strftime("%Y%m%d")
122
+ # record_url = f"https://phl-01.obs.ap-southeast-3.myhuaweicloud.com/{call_date_str}/21964/{call_id}.wav"
123
+ # record_url = f"https://nxai-hk-1259196162.cos.ap-hongkong.myqcloud.com/{call_date_str}/3101/{call_id}.wav"
124
+ # print(record_url)
125
+ try:
126
+ resp = requests.get(
127
+ url=record_url,
128
+ )
129
+ except (TimeoutError, requests.exceptions.ConnectionError):
130
+ continue
131
+ except Exception as e:
132
+ print(e)
133
+ continue
134
+
135
+ if resp.status_code == 404:
136
+ continue
137
+ if resp.status_code != 200:
138
+ raise AssertionError("status_code: {}; text: {}".format(resp.status_code, resp.text))
139
+
140
+ filename = output_dir / "{}.wav".format(call_id)
141
+ with open(filename.as_posix(), "wb") as f:
142
+ f.write(resp.content)
143
+
144
+ return
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
examples/download_wav/step_2_to_1ch.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import time
7
+
8
+ from scipy.io import wavfile
9
+ from tqdm import tqdm
10
+
11
+ from project_settings import project_path
12
+
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser()
16
+
17
+ parser.add_argument(
18
+ "--audio_dir",
19
+ default=(project_path / "data/calling/358/wav_2ch").as_posix(),
20
+ type=str
21
+ )
22
+ parser.add_argument(
23
+ "--output_dir",
24
+ default=(project_path / "data/calling/358/wav_1ch").as_posix(),
25
+ type=str
26
+ )
27
+ args = parser.parse_args()
28
+ return args
29
+
30
+
31
+ def main():
32
+ args = get_args()
33
+
34
+ audio_dir = Path(args.audio_dir)
35
+ output_dir = Path(args.output_dir)
36
+ output_dir.mkdir(parents=True, exist_ok=True)
37
+
38
+ finished = set()
39
+ for filename in tqdm(list(output_dir.glob("*.wav"))):
40
+ splits = filename.stem.split("_")
41
+ call_id = splits[3]
42
+ finished.add(call_id)
43
+ print(f"finished count: {len(finished)}")
44
+
45
+ for filename in tqdm(list(audio_dir.glob("*.wav"))):
46
+ call_id = filename.stem
47
+
48
+ if call_id in finished:
49
+ os.remove(filename.as_posix())
50
+ continue
51
+ finished.add(call_id)
52
+
53
+ try:
54
+ sample_rate, signal = wavfile.read(filename.as_posix())
55
+ except UnboundLocalError as error:
56
+ print(f"wavfile read failed. error type: {type(error)}, text: {str(error)}, filename: {filename.as_posix()}")
57
+ raise error
58
+ if sample_rate != 8000:
59
+ raise AssertionError
60
+
61
+ signal = signal[:, 0]
62
+
63
+ to_filename = output_dir / f"active_media_r_{call_id}_fi-FI_none.wav"
64
+ try:
65
+ wavfile.write(
66
+ to_filename.as_posix(),
67
+ sample_rate,
68
+ signal
69
+ )
70
+ os.remove(filename.as_posix())
71
+ except OSError as error:
72
+ print(f"wavfile write failed. error type: {type(error)}, text: {str(error)}, filename: {filename.as_posix()}")
73
+ raise error
74
+
75
+ return
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
examples/download_wav/step_3_split_two_second_wav.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import time
6
+
7
+ from scipy.io import wavfile
8
+ from tqdm import tqdm
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+
16
+ parser.add_argument(
17
+ "--audio_dir",
18
+ default=(project_path / "data/calling/358/wav_2ch").as_posix(),
19
+ type=str
20
+ )
21
+ parser.add_argument(
22
+ "--output_dir",
23
+ default=(project_path / "data/calling/358/wav_segmented").as_posix(),
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--first_n_seconds",
28
+ default=8,
29
+ type=int
30
+ )
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = get_args()
37
+
38
+ audio_dir = Path(args.audio_dir)
39
+ output_dir = Path(args.output_dir)
40
+ output_dir.mkdir(parents=True, exist_ok=True)
41
+
42
+ for filename in tqdm(list(audio_dir.glob("*.wav"))):
43
+ call_id = filename.stem
44
+ sample_rate, signal = wavfile.read(filename.as_posix())
45
+ if sample_rate != 8000:
46
+ raise AssertionError
47
+
48
+ signal = signal[:, 0]
49
+ signal_length = len(signal) - sample_rate * 2
50
+ if signal_length <= 0:
51
+ continue
52
+
53
+ for begin in range(0, signal_length, sample_rate * 2):
54
+ if begin >= sample_rate * args.first_n_seconds:
55
+ break
56
+ end = begin + sample_rate * 2
57
+ sub_signal = signal[begin: end]
58
+
59
+ ts = int(time.time() * 1000)
60
+ to_filename = output_dir / "{}_fi-FI_none_{}.wav".format(call_id, ts)
61
+ wavfile.write(
62
+ to_filename.as_posix(),
63
+ sample_rate,
64
+ sub_signal
65
+ )
66
+ return
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
examples/download_wav/step_3_split_two_second_wav_by_vad.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import time
6
+
7
+ import numpy as np
8
+ from scipy.io import wavfile
9
+ from tqdm import tqdm
10
+
11
+ from project_settings import project_path
12
+ from toolbox.webrtcvad.vad import WebRTCVad
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser()
17
+
18
+ parser.add_argument(
19
+ "--audio_dir",
20
+ default=(project_path / "data/calling/63/wav_2ch").as_posix(),
21
+ type=str
22
+ )
23
+ parser.add_argument(
24
+ "--output_dir",
25
+ default=(project_path / "data/calling/63/wav_segmented2").as_posix(),
26
+ type=str
27
+ )
28
+ parser.add_argument(
29
+ "--first_n_seconds",
30
+ default=10,
31
+ type=int
32
+ )
33
+ parser.add_argument(
34
+ "--sample_rate",
35
+ default=8000,
36
+ type=int
37
+ )
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def main():
43
+ args = get_args()
44
+
45
+ audio_dir = Path(args.audio_dir)
46
+ output_dir = Path(args.output_dir)
47
+ output_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ for filename in tqdm(list(audio_dir.glob("*.wav"))):
50
+ call_id = filename.stem
51
+ sample_rate, signal = wavfile.read(filename.as_posix())
52
+ if sample_rate != 8000:
53
+ raise AssertionError
54
+
55
+ signal = signal[:, 0]
56
+ signal = signal[:int(args.first_n_seconds * args.sample_rate)]
57
+
58
+ signal_length = len(signal) - sample_rate * 2
59
+ if signal_length <= 0:
60
+ continue
61
+
62
+ # vad
63
+ w_vad = WebRTCVad(sample_rate=args.sample_rate)
64
+ vad_segments = list()
65
+ segments = w_vad.vad(signal)
66
+ vad_segments += segments
67
+ segments = w_vad.last_vad_segments()
68
+ vad_segments += segments
69
+
70
+ for start, end in vad_segments:
71
+ if end - start < 0.01:
72
+ continue
73
+ start = max(0, start-0.4)
74
+ from_idx = int(start * sample_rate)
75
+ to_idx = int(end * sample_rate)
76
+ segment_signal = signal[from_idx: to_idx]
77
+ segment_signal_length = len(segment_signal)
78
+
79
+ min_inputs_length = 2 * sample_rate
80
+ for idx in range(0, segment_signal_length, min_inputs_length):
81
+ sub_signal = segment_signal[idx: idx + min_inputs_length]
82
+ sub_signal_length = len(sub_signal)
83
+ if sub_signal_length < min_inputs_length:
84
+ pad_length = min_inputs_length - sub_signal_length
85
+ # pad = np.zeros(shape=(pad_length,), dtype=np.int16)
86
+ pad = 0 + 25 * np.random.randn(pad_length)
87
+ pad = np.array(pad, dtype=np.int16)
88
+ sub_signal = np.concatenate([sub_signal, pad])
89
+
90
+ ts = int(time.time() * 1000)
91
+ to_filename = output_dir / f"{call_id}_en-PH_kxob7p6suuye_{ts}.wav"
92
+ wavfile.write(
93
+ filename=to_filename.as_posix(),
94
+ rate=sample_rate,
95
+ data=sub_signal
96
+ )
97
+ return
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
examples/online_model_test/step_3_make_test.py CHANGED
@@ -15,12 +15,12 @@ def get_args():
15
 
16
  parser.add_argument(
17
  "--src_dir",
18
- default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\calling\886",
19
  type=str,
20
  )
21
  parser.add_argument(
22
  "--tgt_dir",
23
- default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\voice_test_examples\886\96",
24
  type=str,
25
  )
26
  parser.add_argument(
@@ -37,6 +37,7 @@ def main():
37
 
38
  src_dir = Path(args.src_dir)
39
  tgt_dir = Path(args.tgt_dir)
 
40
 
41
  client = Client("http://10.75.27.247:7861/")
42
 
 
15
 
16
  parser.add_argument(
17
  "--src_dir",
18
+ default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\calling\65\voicemail",
19
  type=str,
20
  )
21
  parser.add_argument(
22
  "--tgt_dir",
23
+ default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\voice_test_examples\65\95",
24
  type=str,
25
  )
26
  parser.add_argument(
 
37
 
38
  src_dir = Path(args.src_dir)
39
  tgt_dir = Path(args.tgt_dir)
40
+ tgt_dir.mkdir(parents=True, exist_ok=True)
41
 
42
  client = Client("http://10.75.27.247:7861/")
43
 
main.py CHANGED
@@ -38,7 +38,7 @@ from project_settings import environment, project_path
38
  from toolbox.torch.utils.data.vocabulary import Vocabulary
39
  from tabs.cls_tab import get_cls_tab
40
  from tabs.split_tab import get_split_tab
41
- from tabs.voicemail_tab import get_voicemail_tab
42
  from tabs.shell_tab import get_shell_tab
43
 
44
 
@@ -135,7 +135,7 @@ def main():
135
  examples_dir=args.examples_dir,
136
  trained_model_dir=args.trained_model_dir,
137
  )
138
- _ = get_voicemail_tab(
139
  examples_dir=args.examples_dir,
140
  trained_model_dir=args.trained_model_dir,
141
  )
 
38
  from toolbox.torch.utils.data.vocabulary import Vocabulary
39
  from tabs.cls_tab import get_cls_tab
40
  from tabs.split_tab import get_split_tab
41
+ from tabs.event_tab import get_event_tab
42
  from tabs.shell_tab import get_shell_tab
43
 
44
 
 
135
  examples_dir=args.examples_dir,
136
  trained_model_dir=args.trained_model_dir,
137
  )
138
+ _ = get_event_tab(
139
  examples_dir=args.examples_dir,
140
  trained_model_dir=args.trained_model_dir,
141
  )
tabs/{voicemail_tab.py → event_tab.py} RENAMED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import json
4
- from functools import lru_cache
5
  from pathlib import Path
6
  import shutil
7
  import tempfile
@@ -43,9 +43,11 @@ def load_model(model_file: Path):
43
  return d
44
 
45
 
46
- def when_click_voicemail_button(audio_t,
47
- model_name: str,
48
- ground_true: str) -> Tuple[str, float]:
 
 
49
 
50
  sample_rate, signal = audio_t
51
 
@@ -58,16 +60,19 @@ def when_click_voicemail_button(audio_t,
58
  inputs = signal / (1 << 15)
59
  inputs = torch.tensor(inputs, dtype=torch.float32)
60
  inputs = torch.unsqueeze(inputs, dim=0)
 
 
 
 
 
61
 
62
  outputs = list()
63
  with torch.no_grad():
64
- for idx in range(0, 5):
65
- begin = idx * int(sample_rate*2)
66
- end = begin + int(sample_rate*2)
67
  sub_inputs = inputs[:, begin:end]
68
- if sub_inputs.shape[-1] < sample_rate:
69
- # raise AssertionError(f"audio duration less than: {sample_rate}")
70
- continue
71
 
72
  logits = model.forward(sub_inputs)
73
  probs = torch.nn.functional.softmax(logits, dim=-1)
@@ -90,56 +95,82 @@ def when_click_voicemail_button(audio_t,
90
  return outputs
91
 
92
 
93
- def get_voicemail_tab(examples_dir: str, trained_model_dir: str):
94
- voicemail_examples_dir = Path(examples_dir)
95
- voicemail_trained_model_dir = Path(trained_model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # models
98
- voicemail_model_choices = list()
99
- for filename in voicemail_trained_model_dir.glob("*.zip"):
100
  model_name = filename.stem
101
  if model_name == "examples":
102
  continue
103
- voicemail_model_choices.append(model_name)
104
- model_choices = list(sorted(voicemail_model_choices))
 
 
 
 
 
 
 
105
 
106
  # examples zip
107
- voicemail_example_zip_file = voicemail_trained_model_dir / "examples.zip"
108
- with zipfile.ZipFile(voicemail_example_zip_file.as_posix(), "r") as f_zip:
109
- out_root = voicemail_examples_dir
110
  if out_root.exists():
111
  shutil.rmtree(out_root.as_posix())
112
  out_root.mkdir(parents=True, exist_ok=True)
113
  f_zip.extractall(path=out_root)
114
 
115
  # examples
116
- voicemail_examples = list()
117
- for filename in voicemail_examples_dir.glob("**/*/*.wav"):
118
  label = filename.parts[-2]
119
- voicemail_examples.append([
120
  filename.as_posix(),
121
  model_choices[0],
122
  label
123
  ])
124
 
125
- with gr.TabItem("voicemail"):
126
  with gr.Row():
127
  with gr.Column(scale=3):
128
- voicemail_audio = gr.Audio(label="audio")
129
  with gr.Row():
130
- with gr.Column(scale=3):
131
- voicemail_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
132
- with gr.Column(scale=3):
133
- voicemail_ground_true = gr.Textbox(label="ground_true")
134
-
135
- voicemail_button = gr.Button("run", variant="primary")
 
136
  with gr.Column(scale=3):
137
- voicemail_outputs = gr.Textbox(label="outputs")
 
 
 
 
 
 
138
 
139
- voicemail_button.click(
140
- when_click_voicemail_button,
141
- inputs=[voicemail_audio, voicemail_model_name, voicemail_ground_true],
142
- outputs=[voicemail_outputs],
143
  )
144
 
145
  return locals()
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import json
4
+ from functools import lru_cache, partial
5
  from pathlib import Path
6
  import shutil
7
  import tempfile
 
43
  return d
44
 
45
 
46
+ def when_click_event_button(audio_t,
47
+ model_name: str, target_label: str,
48
+ win_size: float, win_step: float,
49
+ max_duration: float
50
+ ) -> Tuple[str, float]:
51
 
52
  sample_rate, signal = audio_t
53
 
 
60
  inputs = signal / (1 << 15)
61
  inputs = torch.tensor(inputs, dtype=torch.float32)
62
  inputs = torch.unsqueeze(inputs, dim=0)
63
+ # inputs shape: (1, num_samples)
64
+
65
+ win_size = int(win_size * sample_rate)
66
+ win_step = int(win_step * sample_rate)
67
+ max_duration = int(max_duration * sample_rate)
68
 
69
  outputs = list()
70
  with torch.no_grad():
71
+ for begin in range(0, (max_duration-win_size+1), win_step):
72
+ end = begin + win_size
 
73
  sub_inputs = inputs[:, begin:end]
74
+ if sub_inputs.shape[-1] < win_size:
75
+ break
 
76
 
77
  logits = model.forward(sub_inputs)
78
  probs = torch.nn.functional.softmax(logits, dim=-1)
 
95
  return outputs
96
 
97
 
98
+ def when_model_name_change(model_name: str, event_trained_model_dir: Path):
99
+ m = load_model(
100
+ model_file=(event_trained_model_dir / f"{model_name}.zip")
101
+ )
102
+ token_to_index: dict = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
103
+ label_choices = list(token_to_index.keys())
104
+
105
+ split_label = gr.Dropdown(choices=label_choices, value=label_choices[0], label="label")
106
+
107
+ return split_label
108
+
109
+
110
+ def get_event_tab(examples_dir: str, trained_model_dir: str):
111
+ event_examples_dir = Path(examples_dir)
112
+ event_trained_model_dir = Path(trained_model_dir)
113
 
114
  # models
115
+ event_model_choices = list()
116
+ for filename in event_trained_model_dir.glob("*.zip"):
117
  model_name = filename.stem
118
  if model_name == "examples":
119
  continue
120
+ event_model_choices.append(model_name)
121
+ model_choices = list(sorted(event_model_choices))
122
+
123
+ # model_labels_choices
124
+ m = load_model(
125
+ model_file=(event_trained_model_dir / f"{model_choices[0]}.zip")
126
+ )
127
+ token_to_index = m["vocabulary"].get_token_to_index_vocabulary(namespace="labels")
128
+ model_labels_choices = list(token_to_index.keys())
129
 
130
  # examples zip
131
+ event_example_zip_file = event_trained_model_dir / "examples.zip"
132
+ with zipfile.ZipFile(event_example_zip_file.as_posix(), "r") as f_zip:
133
+ out_root = event_examples_dir
134
  if out_root.exists():
135
  shutil.rmtree(out_root.as_posix())
136
  out_root.mkdir(parents=True, exist_ok=True)
137
  f_zip.extractall(path=out_root)
138
 
139
  # examples
140
+ event_examples = list()
141
+ for filename in event_examples_dir.glob("**/*/*.wav"):
142
  label = filename.parts[-2]
143
+ event_examples.append([
144
  filename.as_posix(),
145
  model_choices[0],
146
  label
147
  ])
148
 
149
+ with gr.TabItem("event"):
150
  with gr.Row():
151
  with gr.Column(scale=3):
152
+ event_audio = gr.Audio(label="audio")
153
  with gr.Row():
154
+ event_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
155
+ event_label = gr.Dropdown(choices=model_labels_choices, value=model_labels_choices[0], label="label")
156
+ with gr.Row():
157
+ event_win_size = gr.Number(value=2.0, minimum=0, maximum=5, step=0.05, label="win_size")
158
+ event_win_step = gr.Number(value=2.0, minimum=0, maximum=5, step=0.05, label="win_step")
159
+ event_max_duration = gr.Number(value=8, minimum=0, maximum=15, step=1, label="max_duration")
160
+ event_button = gr.Button("run", variant="primary")
161
  with gr.Column(scale=3):
162
+ event_outputs = gr.Textbox(label="outputs")
163
+
164
+ event_model_name.change(
165
+ partial(when_model_name_change, event_trained_model_dir=event_trained_model_dir),
166
+ inputs=[event_model_name],
167
+ outputs=[event_label],
168
+ )
169
 
170
+ event_button.click(
171
+ when_click_event_button,
172
+ inputs=[event_audio, event_model_name, event_label, event_win_size, event_win_step, event_max_duration],
173
+ outputs=[event_outputs],
174
  )
175
 
176
  return locals()