Spaces:
Sleeping
Sleeping
update
Browse files- .gitignore +1 -0
- examples/online_model_test/step_1_predict.py +216 -0
- examples/online_model_test/step_2_audio_filter.py +43 -0
- examples/online_model_test/step_3_make_test.py +74 -0
- main.py +26 -35
- requirements.txt +1 -0
- tabs/{split_tabs.py → split_tab.py} +0 -0
- tabs/voicemail_tab.py +149 -0
.gitignore
CHANGED
@@ -17,3 +17,4 @@
|
|
17 |
|
18 |
#**/*.wav
|
19 |
**/*.xlsx
|
|
|
|
17 |
|
18 |
#**/*.wav
|
19 |
**/*.xlsx
|
20 |
+
**/*.onnx
|
examples/online_model_test/step_1_predict.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import librosa
|
11 |
+
import numpy as np
|
12 |
+
import onnxruntime as ort
|
13 |
+
import pandas as pd
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
|
17 |
+
|
18 |
+
def get_args():
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument(
|
21 |
+
"--audio_dir",
|
22 |
+
default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\sea-idn\audio_lib_hkg_1\audio_lib_hkg_1\zh-TW",
|
23 |
+
type=str,
|
24 |
+
)
|
25 |
+
parser.add_argument("--onnx_model_file", default="zh-TW.onnx", type=str)
|
26 |
+
parser.add_argument("--target_duration", default=8.0, type=float)
|
27 |
+
|
28 |
+
parser.add_argument("--output_file", default="zh_tw_predict.xlsx", type=str)
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
return args
|
32 |
+
|
33 |
+
|
34 |
+
class OnlineModelConfig(object):
|
35 |
+
def __init__(self,
|
36 |
+
sample_rate: int = 8000,
|
37 |
+
n_fft: int = 1024,
|
38 |
+
hop_size: int = 512,
|
39 |
+
n_mels: int = 80,
|
40 |
+
f_min: float = 10.0,
|
41 |
+
f_max: float = 3800.0,
|
42 |
+
):
|
43 |
+
self.sample_rate = sample_rate
|
44 |
+
self.n_fft = n_fft
|
45 |
+
self.hop_size = hop_size
|
46 |
+
self.n_mels = n_mels
|
47 |
+
self.f_min = f_min
|
48 |
+
self.f_max = f_max
|
49 |
+
|
50 |
+
|
51 |
+
class OnlineModelInference(object):
|
52 |
+
def __init__(self,
|
53 |
+
model_path: str,
|
54 |
+
):
|
55 |
+
self.model_path = model_path
|
56 |
+
|
57 |
+
providers = [
|
58 |
+
"CUDAExecutionProvider", "CPUExecutionProvider"
|
59 |
+
] if torch.cuda.is_available() else [
|
60 |
+
"CPUExecutionProvider"
|
61 |
+
]
|
62 |
+
self.session = ort.InferenceSession(self.model_path, providers=providers)
|
63 |
+
|
64 |
+
self.config = OnlineModelConfig()
|
65 |
+
|
66 |
+
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
67 |
+
sample_rate=self.config.sample_rate,
|
68 |
+
n_fft=self.config.n_fft,
|
69 |
+
hop_length=self.config.hop_size,
|
70 |
+
n_mels=self.config.n_mels,
|
71 |
+
f_min=self.config.f_min,
|
72 |
+
f_max=self.config.f_max,
|
73 |
+
window_fn=torch.hamming_window
|
74 |
+
)
|
75 |
+
|
76 |
+
def predict_by_ndarray(self,
|
77 |
+
sub_signal: np.ndarray,
|
78 |
+
h: np.ndarray = None,
|
79 |
+
c: np.ndarray = None,
|
80 |
+
):
|
81 |
+
# sub_signal, shape: [num_samples,]
|
82 |
+
sub_signal = torch.tensor(sub_signal, dtype=torch.float32)
|
83 |
+
|
84 |
+
sub_signal = sub_signal.unsqueeze(0)
|
85 |
+
# sub_signal, shape: [1, num_samples]
|
86 |
+
mel_spec = self.mel_transform.forward(sub_signal)
|
87 |
+
# mel_spec, shape: [1, n_mels, n_frames]
|
88 |
+
mel_spec = torch.transpose(mel_spec, dim0=1, dim1=2)
|
89 |
+
# mel_spec, shape: [1, n_frames, n_mels]
|
90 |
+
|
91 |
+
h = torch.tensor(h) if h is not None else None
|
92 |
+
c = torch.tensor(c) if h is not None else None
|
93 |
+
label, prob, h, c = self.predict_by_mel_spec(mel_spec, h=h, c=c)
|
94 |
+
# h, c: torch.Tensor
|
95 |
+
h = h.numpy()
|
96 |
+
c = c.numpy()
|
97 |
+
return label, prob, h, c
|
98 |
+
|
99 |
+
def predict_by_mel_spec(self,
|
100 |
+
mel_spec: torch.Tensor,
|
101 |
+
h: torch.Tensor = None,
|
102 |
+
c: torch.Tensor = None,
|
103 |
+
):
|
104 |
+
# mel_spec, shape: [1, n_frames, n_mels]
|
105 |
+
|
106 |
+
if h is None:
|
107 |
+
h = np.zeros((3, 1, 64), dtype=np.float32) # 3层LSTM,批次大小1,隐藏大小64
|
108 |
+
else:
|
109 |
+
h = h.numpy()
|
110 |
+
if c is None:
|
111 |
+
c = np.zeros((3, 1, 64), dtype=np.float32) # 3层LSTM,批次大小1,隐藏大小64
|
112 |
+
else:
|
113 |
+
c = c.numpy()
|
114 |
+
|
115 |
+
mel_spec_np = mel_spec.numpy()
|
116 |
+
outputs = self.session.run(
|
117 |
+
input_feed={
|
118 |
+
"input": mel_spec_np,
|
119 |
+
"h": h,
|
120 |
+
"c": c
|
121 |
+
},
|
122 |
+
output_names=[
|
123 |
+
"output", "h_out", "c_out"
|
124 |
+
],
|
125 |
+
)
|
126 |
+
logits, h, c = outputs
|
127 |
+
# logits, np.ndarray, shape: [b, num_labels]
|
128 |
+
# h, c: np.ndarray
|
129 |
+
h = torch.tensor(h)
|
130 |
+
c = torch.tensor(c)
|
131 |
+
|
132 |
+
probs = torch.softmax(torch.tensor(logits), dim=1)
|
133 |
+
max_prob, predicted_label_index = torch.max(probs, dim=1)
|
134 |
+
|
135 |
+
label = self.get_label_by_index(predicted_label_index.item())
|
136 |
+
prob = max_prob.item()
|
137 |
+
return label, prob, h, c
|
138 |
+
|
139 |
+
@staticmethod
|
140 |
+
def get_label_by_index(index: int):
|
141 |
+
label_map = {
|
142 |
+
0: "voice",
|
143 |
+
1: "voicemail",
|
144 |
+
2: "mute",
|
145 |
+
3: "noise"
|
146 |
+
}
|
147 |
+
result = label_map[index]
|
148 |
+
return result
|
149 |
+
|
150 |
+
|
151 |
+
def main():
|
152 |
+
args = get_args()
|
153 |
+
|
154 |
+
audio_dir = Path(args.audio_dir)
|
155 |
+
|
156 |
+
model = OnlineModelInference(model_path=args.onnx_model_file)
|
157 |
+
|
158 |
+
result = list()
|
159 |
+
for filename in tqdm(audio_dir.glob("**/active_media_r_*.wav")):
|
160 |
+
splits = filename.stem.split("_")
|
161 |
+
call_id = splits[3]
|
162 |
+
language = splits[4]
|
163 |
+
scene_id = splits[5]
|
164 |
+
|
165 |
+
signal, sample_rate = librosa.load(filename.as_posix(), sr=8000)
|
166 |
+
duration = librosa.get_duration(y=signal, sr=sample_rate)
|
167 |
+
signal_length = len(signal)
|
168 |
+
if signal_length == 0:
|
169 |
+
continue
|
170 |
+
|
171 |
+
target_duration = args.target_duration * sample_rate
|
172 |
+
target_duration = int(target_duration)
|
173 |
+
|
174 |
+
predict_result = list()
|
175 |
+
h = None
|
176 |
+
c = None
|
177 |
+
for begin in range(0, target_duration, sample_rate*2):
|
178 |
+
end = begin + sample_rate*2
|
179 |
+
sub_signal = signal[begin: end]
|
180 |
+
if len(sub_signal) == 0:
|
181 |
+
break
|
182 |
+
label, prob, h, c = model.predict_by_ndarray(sub_signal, h=h, c=c)
|
183 |
+
predict_result.append({
|
184 |
+
"label": label,
|
185 |
+
"prob": prob,
|
186 |
+
})
|
187 |
+
label_list = [p["label"] for p in predict_result]
|
188 |
+
predict_result_ = json.dumps(predict_result, ensure_ascii=False, indent=4)
|
189 |
+
label2 = predict_result[0]["label"]
|
190 |
+
prob2 = predict_result[0]["prob"]
|
191 |
+
|
192 |
+
ground_truth_ = "voicemail" if any([l == "voicemail" for l in label_list]) else "else"
|
193 |
+
flag = 1 if label2 == "voicemail" else 0
|
194 |
+
|
195 |
+
row = {
|
196 |
+
"call_id": call_id,
|
197 |
+
"language": language,
|
198 |
+
"scene_id": scene_id,
|
199 |
+
"filename": filename.as_posix(),
|
200 |
+
"duration": duration,
|
201 |
+
"predict_result": predict_result_,
|
202 |
+
"label2": label2,
|
203 |
+
"prob2": prob2,
|
204 |
+
"ground_truth_": ground_truth_,
|
205 |
+
"flag": flag,
|
206 |
+
}
|
207 |
+
result.append(row)
|
208 |
+
|
209 |
+
result = pd.DataFrame(result)
|
210 |
+
result.to_excel(args.output_file, index=False)
|
211 |
+
|
212 |
+
return
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
main()
|
examples/online_model_test/step_2_audio_filter.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
def get_args():
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
|
13 |
+
parser.add_argument("--predict_file", default="zh_tw_predict.xlsx", type=str)
|
14 |
+
parser.add_argument(
|
15 |
+
"--output_dir",
|
16 |
+
default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\calling\886",
|
17 |
+
type=str,
|
18 |
+
)
|
19 |
+
args = parser.parse_args()
|
20 |
+
return args
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
args = get_args()
|
25 |
+
|
26 |
+
output_dir = Path(args.output_dir)
|
27 |
+
|
28 |
+
df = pd.read_excel(args.predict_file)
|
29 |
+
for i, row in df.iterrows():
|
30 |
+
filename = row["filename"]
|
31 |
+
ground_truth_ = row["ground_truth_"]
|
32 |
+
|
33 |
+
if ground_truth_ == "voicemail":
|
34 |
+
shutil.copy(
|
35 |
+
filename,
|
36 |
+
output_dir.as_posix()
|
37 |
+
)
|
38 |
+
|
39 |
+
return
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
main()
|
examples/online_model_test/step_3_make_test.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from gradio_client import Client, handle_file
|
8 |
+
import librosa
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
|
16 |
+
parser.add_argument(
|
17 |
+
"--src_dir",
|
18 |
+
default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\calling\886",
|
19 |
+
type=str,
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--tgt_dir",
|
23 |
+
default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\voice_test_examples\886\96",
|
24 |
+
type=str,
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--early_media_file",
|
28 |
+
default=r"D:\Users\tianx\HuggingDatasets\international_voice\data\voice_test_examples\886\97\early_media_ba95fafd-8e2f-488f-8e5a-9bada95e24fb.wav",
|
29 |
+
type=str,
|
30 |
+
)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
src_dir = Path(args.src_dir)
|
39 |
+
tgt_dir = Path(args.tgt_dir)
|
40 |
+
|
41 |
+
client = Client("http://10.75.27.247:7861/")
|
42 |
+
|
43 |
+
for filename in tqdm(src_dir.glob("*.wav")):
|
44 |
+
splits = filename.stem.split("_")
|
45 |
+
call_id = splits[3]
|
46 |
+
|
47 |
+
filename_ = filename.as_posix()
|
48 |
+
y, sr = librosa.load(filename_)
|
49 |
+
duration = librosa.get_duration(y=y, sr=sr)
|
50 |
+
if duration < 20:
|
51 |
+
filename_, _ = client.predict(
|
52 |
+
audio_t=handle_file(filename_),
|
53 |
+
pad_seconds=20,
|
54 |
+
pad_mode="repeat",
|
55 |
+
api_name="/when_click_pad_audio"
|
56 |
+
)
|
57 |
+
|
58 |
+
active_media_file = tgt_dir / f"active_media_{call_id}.wav"
|
59 |
+
early_media_file = tgt_dir / f"early_media_{call_id}.wav"
|
60 |
+
|
61 |
+
shutil.copy(
|
62 |
+
filename_,
|
63 |
+
active_media_file.as_posix(),
|
64 |
+
)
|
65 |
+
shutil.copy(
|
66 |
+
args.early_media_file,
|
67 |
+
early_media_file.as_posix(),
|
68 |
+
)
|
69 |
+
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|
main.py
CHANGED
@@ -1,5 +1,25 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import argparse
|
4 |
from functools import lru_cache
|
5 |
from pathlib import Path
|
@@ -17,11 +37,11 @@ import torch
|
|
17 |
from project_settings import environment, project_path
|
18 |
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
from tabs.cls_tab import get_cls_tab
|
20 |
-
from tabs.
|
|
|
21 |
from tabs.shell_tab import get_shell_tab
|
22 |
|
23 |
|
24 |
-
|
25 |
def get_args():
|
26 |
parser = argparse.ArgumentParser()
|
27 |
parser.add_argument(
|
@@ -83,39 +103,6 @@ def load_model(model_file: Path):
|
|
83 |
return d
|
84 |
|
85 |
|
86 |
-
def click_button(audio: np.ndarray,
|
87 |
-
model_name: str,
|
88 |
-
ground_true: str) -> Tuple[str, float]:
|
89 |
-
|
90 |
-
sample_rate, signal = audio
|
91 |
-
|
92 |
-
model_file = "trained_models/{}.zip".format(model_name)
|
93 |
-
model_file = Path(model_file)
|
94 |
-
d = load_model(model_file)
|
95 |
-
|
96 |
-
model = d["model"]
|
97 |
-
vocabulary = d["vocabulary"]
|
98 |
-
|
99 |
-
inputs = signal / (1 << 15)
|
100 |
-
inputs = torch.tensor(inputs, dtype=torch.float32)
|
101 |
-
inputs = torch.unsqueeze(inputs, dim=0)
|
102 |
-
|
103 |
-
with torch.no_grad():
|
104 |
-
logits = model.forward(inputs)
|
105 |
-
probs = torch.nn.functional.softmax(logits, dim=-1)
|
106 |
-
label_idx = torch.argmax(probs, dim=-1)
|
107 |
-
|
108 |
-
label_idx = label_idx.cpu()
|
109 |
-
probs = probs.cpu()
|
110 |
-
|
111 |
-
label_idx = label_idx.numpy()[0]
|
112 |
-
prob = probs.numpy()[0][label_idx]
|
113 |
-
|
114 |
-
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
115 |
-
|
116 |
-
return label_str, round(prob, 4)
|
117 |
-
|
118 |
-
|
119 |
def main():
|
120 |
args = get_args()
|
121 |
|
@@ -148,6 +135,10 @@ def main():
|
|
148 |
examples_dir=args.examples_dir,
|
149 |
trained_model_dir=args.trained_model_dir,
|
150 |
)
|
|
|
|
|
|
|
|
|
151 |
_ = get_split_tab(
|
152 |
examples_dir=args.examples_dir,
|
153 |
trained_model_dir=args.trained_model_dir,
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
docker build -t cc_audio_8:v20250828_1343 .
|
5 |
+
docker stop cc_audio_8_7864 && docker rm cc_audio_8_7864
|
6 |
+
docker run -itd \
|
7 |
+
--name cc_audio_8_7864 \
|
8 |
+
--restart=always \
|
9 |
+
--network host \
|
10 |
+
-e server_port=7865 \
|
11 |
+
cc_audio_8:v20250828_1343 /bin/bash
|
12 |
+
|
13 |
+
docker run -itd \
|
14 |
+
--name cc_audio_8_7864 \
|
15 |
+
--network host \
|
16 |
+
--gpus all \
|
17 |
+
--privileged \
|
18 |
+
--ipc=host \
|
19 |
+
python:3.12 /bin/bash
|
20 |
+
|
21 |
+
nohup python3 main.py --server_port 7864 --hf_token hf_coRVvzwA****jLmZHwJobEX &
|
22 |
+
"""
|
23 |
import argparse
|
24 |
from functools import lru_cache
|
25 |
from pathlib import Path
|
|
|
37 |
from project_settings import environment, project_path
|
38 |
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
39 |
from tabs.cls_tab import get_cls_tab
|
40 |
+
from tabs.split_tab import get_split_tab
|
41 |
+
from tabs.voicemail_tab import get_voicemail_tab
|
42 |
from tabs.shell_tab import get_shell_tab
|
43 |
|
44 |
|
|
|
45 |
def get_args():
|
46 |
parser = argparse.ArgumentParser()
|
47 |
parser.add_argument(
|
|
|
103 |
return d
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def main():
|
107 |
args = get_args()
|
108 |
|
|
|
135 |
examples_dir=args.examples_dir,
|
136 |
trained_model_dir=args.trained_model_dir,
|
137 |
)
|
138 |
+
_ = get_voicemail_tab(
|
139 |
+
examples_dir=args.examples_dir,
|
140 |
+
trained_model_dir=args.trained_model_dir,
|
141 |
+
)
|
142 |
_ = get_split_tab(
|
143 |
examples_dir=args.examples_dir,
|
144 |
trained_model_dir=args.trained_model_dir,
|
requirements.txt
CHANGED
@@ -12,3 +12,4 @@ evaluate
|
|
12 |
gradio
|
13 |
python-dotenv
|
14 |
numpy
|
|
|
|
12 |
gradio
|
13 |
python-dotenv
|
14 |
numpy
|
15 |
+
onnxruntime
|
tabs/{split_tabs.py → split_tab.py}
RENAMED
File without changes
|
tabs/voicemail_tab.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import json
|
4 |
+
from functools import lru_cache
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import tempfile
|
8 |
+
import zipfile
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from project_settings import project_path
|
15 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
16 |
+
|
17 |
+
|
18 |
+
@lru_cache(maxsize=100)
|
19 |
+
def load_model(model_file: Path):
|
20 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
21 |
+
out_root = Path(tempfile.gettempdir()) / "cc_audio_8"
|
22 |
+
if out_root.exists():
|
23 |
+
shutil.rmtree(out_root.as_posix())
|
24 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
25 |
+
f_zip.extractall(path=out_root)
|
26 |
+
|
27 |
+
tgt_path = out_root / model_file.stem
|
28 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
29 |
+
vocab_path = tgt_path / "vocabulary"
|
30 |
+
|
31 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
32 |
+
|
33 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
34 |
+
model = torch.jit.load(f)
|
35 |
+
model.eval()
|
36 |
+
|
37 |
+
shutil.rmtree(tgt_path)
|
38 |
+
|
39 |
+
d = {
|
40 |
+
"model": model,
|
41 |
+
"vocabulary": vocabulary
|
42 |
+
}
|
43 |
+
return d
|
44 |
+
|
45 |
+
|
46 |
+
def when_click_voicemail_button(audio_t,
|
47 |
+
model_name: str,
|
48 |
+
ground_true: str) -> Tuple[str, float]:
|
49 |
+
|
50 |
+
sample_rate, signal = audio_t
|
51 |
+
|
52 |
+
model_file = project_path / f"trained_models/{model_name}.zip"
|
53 |
+
d = load_model(model_file)
|
54 |
+
|
55 |
+
model = d["model"]
|
56 |
+
vocabulary = d["vocabulary"]
|
57 |
+
|
58 |
+
inputs = signal / (1 << 15)
|
59 |
+
inputs = torch.tensor(inputs, dtype=torch.float32)
|
60 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
61 |
+
|
62 |
+
num_samples = inputs.shape[-1]
|
63 |
+
|
64 |
+
outputs = list()
|
65 |
+
with torch.no_grad():
|
66 |
+
for begin in range(0, num_samples, sample_rate*2):
|
67 |
+
end = begin + int(sample_rate*2)
|
68 |
+
sub_inputs = inputs[:, begin:end]
|
69 |
+
if sub_inputs.shape[-1] < sample_rate:
|
70 |
+
raise AssertionError(f"audio duration less than: {sample_rate}")
|
71 |
+
|
72 |
+
logits = model.forward(sub_inputs)
|
73 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
74 |
+
label_idx = torch.argmax(probs, dim=-1)
|
75 |
+
|
76 |
+
label_idx = label_idx.cpu()
|
77 |
+
probs = probs.cpu()
|
78 |
+
|
79 |
+
label_idx = label_idx.numpy()[0]
|
80 |
+
prob = probs.numpy()[0][label_idx]
|
81 |
+
prob: float = round(float(prob), 4)
|
82 |
+
|
83 |
+
label_str: str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
84 |
+
|
85 |
+
outputs.append({
|
86 |
+
"label": label_str,
|
87 |
+
"prob": prob,
|
88 |
+
})
|
89 |
+
outputs = json.dumps(outputs, ensure_ascii=False, indent=4)
|
90 |
+
return outputs
|
91 |
+
|
92 |
+
|
93 |
+
def get_voicemail_tab(examples_dir: str, trained_model_dir: str):
|
94 |
+
voicemail_examples_dir = Path(examples_dir)
|
95 |
+
voicemail_trained_model_dir = Path(trained_model_dir)
|
96 |
+
|
97 |
+
# models
|
98 |
+
voicemail_model_choices = list()
|
99 |
+
for filename in voicemail_trained_model_dir.glob("*.zip"):
|
100 |
+
model_name = filename.stem
|
101 |
+
if model_name == "examples":
|
102 |
+
continue
|
103 |
+
voicemail_model_choices.append(model_name)
|
104 |
+
model_choices = list(sorted(voicemail_model_choices))
|
105 |
+
|
106 |
+
# examples zip
|
107 |
+
voicemail_example_zip_file = voicemail_trained_model_dir / "examples.zip"
|
108 |
+
with zipfile.ZipFile(voicemail_example_zip_file.as_posix(), "r") as f_zip:
|
109 |
+
out_root = voicemail_examples_dir
|
110 |
+
if out_root.exists():
|
111 |
+
shutil.rmtree(out_root.as_posix())
|
112 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
113 |
+
f_zip.extractall(path=out_root)
|
114 |
+
|
115 |
+
# examples
|
116 |
+
voicemail_examples = list()
|
117 |
+
for filename in voicemail_examples_dir.glob("**/*/*.wav"):
|
118 |
+
label = filename.parts[-2]
|
119 |
+
voicemail_examples.append([
|
120 |
+
filename.as_posix(),
|
121 |
+
model_choices[0],
|
122 |
+
label
|
123 |
+
])
|
124 |
+
|
125 |
+
with gr.TabItem("voicemail"):
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=3):
|
128 |
+
voicemail_audio = gr.Audio(label="audio")
|
129 |
+
with gr.Row():
|
130 |
+
with gr.Column(scale=3):
|
131 |
+
voicemail_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
|
132 |
+
with gr.Column(scale=3):
|
133 |
+
voicemail_ground_true = gr.Textbox(label="ground_true")
|
134 |
+
|
135 |
+
voicemail_button = gr.Button("run", variant="primary")
|
136 |
+
with gr.Column(scale=3):
|
137 |
+
voicemail_outputs = gr.Textbox(label="outputs")
|
138 |
+
|
139 |
+
voicemail_button.click(
|
140 |
+
when_click_voicemail_button,
|
141 |
+
inputs=[voicemail_audio, voicemail_model_name, voicemail_ground_true],
|
142 |
+
outputs=[voicemail_outputs],
|
143 |
+
)
|
144 |
+
|
145 |
+
return locals()
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
pass
|