HoneyTian's picture
update
1204717
raw
history blame
No virus
4.55 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from functools import lru_cache
import json
from pathlib import Path
import platform
import shutil
import tempfile
import zipfile
import gradio as gr
import numpy as np
import torch
from project_settings import environment, project_path
from toolbox.torch.utils.data.vocabulary import Vocabulary
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--examples_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
parser.add_argument(
"--trained_model_dir",
default=(project_path / "trained_models").as_posix(),
type=str
)
parser.add_argument(
"--server_port",
default=environment.get("server_port", 7860),
type=int
)
args = parser.parse_args()
return args
@lru_cache(maxsize=100)
def load_model(model_file: Path):
with zipfile.ZipFile(model_file, "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
if out_root.exists():
shutil.rmtree(out_root.as_posix())
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
tgt_path = out_root / model_file.stem
jit_model_file = tgt_path / "trace_model.zip"
vocab_path = tgt_path / "vocabulary"
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
with open(jit_model_file.as_posix(), "rb") as f:
model = torch.jit.load(f)
model.eval()
shutil.rmtree(tgt_path)
d = {
"model": model,
"vocabulary": vocabulary
}
return d
def click_button(audio: np.ndarray,
model_name: str,
ground_true: str) -> str:
sample_rate, signal = audio
model_file = "trained_models/{}.zip".format(model_name)
model_file = Path(model_file)
d = load_model(model_file)
model = d["model"]
vocabulary = d["vocabulary"]
inputs = signal / (1 << 15)
inputs = torch.tensor(inputs, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
with torch.no_grad():
logits = model.forward(inputs)
probs = torch.nn.functional.softmax(logits, dim=-1)
label_idx = torch.argmax(probs, dim=-1)
label_idx = label_idx.cpu()
probs = probs.cpu()
label_idx = label_idx.numpy()[0]
prob = probs.numpy()[0][label_idx]
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
return label_str, round(prob, 4)
def main():
args = get_args()
examples_dir = Path(args.examples_dir)
trained_model_dir = Path(args.trained_model_dir)
# models
model_choices = list()
for filename in trained_model_dir.glob("*.zip"):
model_name = filename.stem
model_choices.append(model_name)
# examples
examples = list()
for filename in examples_dir.glob("*/*/*.wav"):
label = filename.parts[-2]
examples.append([
filename.as_posix(),
model_choices[0],
label
])
# ui
brief_description = """
国际语音智能外呼系统, 电话声音分类.
"""
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=3):
c_audio = gr.Audio(label="audio")
with gr.Row():
with gr.Column(scale=3):
c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
with gr.Column(scale=3):
c_ground_true = gr.Textbox(label="ground_true")
c_button = gr.Button("run", variant="primary")
with gr.Column(scale=3):
c_label = gr.Textbox(label="label")
c_probability = gr.Number(label="probability")
gr.Examples(
examples,
inputs=[c_audio, c_model_name, c_ground_true],
outputs=[c_label, c_probability],
fn=click_button,
examples_per_page=5,
)
c_button.click(
click_button,
inputs=[c_audio, c_model_name, c_ground_true],
outputs=[c_label, c_probability],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=args.server_port
)
return
if __name__ == "__main__":
main()