voicemail / main.py
HoneyTian's picture
update
095d784
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from pathlib import Path
import platform
import shutil
import tempfile
import zipfile
import gradio as gr
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
def load_model(zip_file: Path):
model_name = zip_file.stem
with zipfile.ZipFile(zip_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "cnn_voicemail"
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_name
pth_path = tgt_path / "cnn_voicemail.pth"
vocab_path = tgt_path / "vocabulary"
with open(pth_path.as_posix(), "rb") as f:
model = torch.jit.load(f, map_location="cpu")
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# examples
examples = list()
for filename in examples_dir.glob("*/*/*.wav"):
language = filename.parts[-3]
label = filename.parts[-2]
examples.append([
filename.as_posix(),
language,
label
])
# models
language_to_model = dict()
for filename in list(sorted(trained_model_dir.glob("*.zip"))):
splits = filename.stem.split("_")
if len(splits) == 4:
language = splits[-2]
else:
language = "{}-{}".format(splits[-3], splits[-2].upper())
d = load_model(filename)
language_to_model[language] = d
# click event
def click_button(audio: np.ndarray,
language: str,
ground_true: str) -> str:
sample_rate, signal = audio
d = language_to_model[language]
model = d["model"]
vocabulary = d["vocabulary"]
inputs = signal / (1 << 15)
inputs = torch.tensor(inputs, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
outputs = model(inputs)
probs = outputs["probs"]
argmax = torch.argmax(probs, dim=-1)
probs = probs.tolist()[0]
argmax = argmax.tolist()[0]
label = vocabulary.get_token_from_index(argmax, namespace="labels")
prob = probs[argmax]
return label, round(prob, 4)
# ui
brief_description = """
## 语音信箱识别
基于 CNN 的语音信箱音频分类.
考虑到语音信箱的音频是比较固定的, 所以采用了基于 CNN 的方法, 以建模上下文依赖关系.
"""
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=3):
c_audio = gr.Audio(label="audio")
with gr.Row():
with gr.Column(scale=3):
c_language = gr.Dropdown(choices=language_to_model.keys(), label="language")
with gr.Column(scale=3):
c_ground_true = gr.Textbox(label="ground_true")
c_button = gr.Button("run", variant="primary")
with gr.Column(scale=3):
c_label = gr.Textbox(label="label")
c_probability = gr.Number(label="probability")
gr.Examples(
examples,
inputs=[c_audio, c_language, c_ground_true],
outputs=[c_label, c_probability],
fn=click_button,
examples_per_page=5,
)
c_button.click(
click_button,
inputs=[c_audio, c_language, c_ground_true],
outputs=[c_label, c_probability],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()