Spaces:
Sleeping
Sleeping
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from functools import lru_cache | |
import json | |
from pathlib import Path | |
import platform | |
import shutil | |
import tempfile | |
import zipfile | |
import gradio as gr | |
import numpy as np | |
import torch | |
from project_settings import environment, project_path | |
from toolbox.torch.utils.data.vocabulary import Vocabulary | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--examples_dir", | |
default=(project_path / "data/examples").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--trained_model_dir", | |
default=(project_path / "trained_models").as_posix(), | |
type=str | |
) | |
parser.add_argument( | |
"--server_port", | |
default=environment.get("server_port", 7860), | |
type=int | |
) | |
args = parser.parse_args() | |
return args | |
def load_model(model_file: Path): | |
with zipfile.ZipFile(model_file, "r") as f_zip: | |
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" | |
if out_root.exists(): | |
shutil.rmtree(out_root.as_posix()) | |
out_root.mkdir(parents=True, exist_ok=True) | |
f_zip.extractall(path=out_root) | |
tgt_path = out_root / model_file.stem | |
jit_model_file = tgt_path / "trace_model.zip" | |
vocab_path = tgt_path / "vocabulary" | |
vocabulary = Vocabulary.from_files(vocab_path.as_posix()) | |
with open(jit_model_file.as_posix(), "rb") as f: | |
model = torch.jit.load(f) | |
model.eval() | |
shutil.rmtree(tgt_path) | |
d = { | |
"model": model, | |
"vocabulary": vocabulary | |
} | |
return d | |
def click_button(audio: np.ndarray, | |
model_name: str, | |
ground_true: str) -> str: | |
sample_rate, signal = audio | |
model_file = "trained_models/{}.zip".format(model_name) | |
model_file = Path(model_file) | |
d = load_model(model_file) | |
model = d["model"] | |
vocabulary = d["vocabulary"] | |
inputs = signal / (1 << 15) | |
inputs = torch.tensor(inputs, dtype=torch.float32) | |
inputs = torch.unsqueeze(inputs, dim=0) | |
with torch.no_grad(): | |
logits = model.forward(inputs) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
label_idx = torch.argmax(probs, dim=-1) | |
label_idx = label_idx.cpu() | |
probs = probs.cpu() | |
label_idx = label_idx.numpy()[0] | |
prob = probs.numpy()[0][label_idx] | |
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") | |
return label_str, round(prob, 4) | |
def main(): | |
args = get_args() | |
examples_dir = Path(args.examples_dir) | |
trained_model_dir = Path(args.trained_model_dir) | |
# models | |
model_choices = list() | |
for filename in trained_model_dir.glob("*.zip"): | |
model_name = filename.stem | |
model_choices.append(model_name) | |
# examples | |
examples = list() | |
for filename in examples_dir.glob("*/*/*.wav"): | |
label = filename.parts[-2] | |
examples.append([ | |
filename.as_posix(), | |
model_choices[0], | |
label | |
]) | |
# ui | |
brief_description = """ | |
国际语音智能外呼系统, 电话声音分类. | |
""" | |
# ui | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=brief_description) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
c_audio = gr.Audio(label="audio") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name") | |
with gr.Column(scale=3): | |
c_ground_true = gr.Textbox(label="ground_true") | |
c_button = gr.Button("run", variant="primary") | |
with gr.Column(scale=3): | |
c_label = gr.Textbox(label="label") | |
c_probability = gr.Number(label="probability") | |
gr.Examples( | |
examples, | |
inputs=[c_audio, c_model_name, c_ground_true], | |
outputs=[c_label, c_probability], | |
fn=click_button, | |
examples_per_page=5, | |
) | |
c_button.click( | |
click_button, | |
inputs=[c_audio, c_model_name, c_ground_true], | |
outputs=[c_label, c_probability], | |
) | |
blocks.queue().launch( | |
share=False if platform.system() == "Windows" else False, | |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", | |
server_port=args.server_port | |
) | |
return | |
if __name__ == "__main__": | |
main() | |