#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from functools import lru_cache import json from pathlib import Path import platform import shutil import tempfile import zipfile import gradio as gr import numpy as np import torch from project_settings import environment, project_path from toolbox.torch.utils.data.vocabulary import Vocabulary def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--examples_dir", default=(project_path / "data/examples").as_posix(), type=str ) parser.add_argument( "--trained_model_dir", default=(project_path / "trained_models").as_posix(), type=str ) parser.add_argument( "--server_port", default=environment.get("server_port", 7860), type=int ) args = parser.parse_args() return args @lru_cache(maxsize=100) def load_model(model_file: Path): with zipfile.ZipFile(model_file, "r") as f_zip: out_root = Path(tempfile.gettempdir()) / "vm_sound_classification" if out_root.exists(): shutil.rmtree(out_root.as_posix()) out_root.mkdir(parents=True, exist_ok=True) f_zip.extractall(path=out_root) tgt_path = out_root / model_file.stem jit_model_file = tgt_path / "trace_model.zip" vocab_path = tgt_path / "vocabulary" vocabulary = Vocabulary.from_files(vocab_path.as_posix()) with open(jit_model_file.as_posix(), "rb") as f: model = torch.jit.load(f) model.eval() shutil.rmtree(tgt_path) d = { "model": model, "vocabulary": vocabulary } return d def click_button(audio: np.ndarray, model_name: str, ground_true: str) -> str: sample_rate, signal = audio model_file = "trained_models/{}.zip".format(model_name) model_file = Path(model_file) d = load_model(model_file) model = d["model"] vocabulary = d["vocabulary"] inputs = signal / (1 << 15) inputs = torch.tensor(inputs, dtype=torch.float32) inputs = torch.unsqueeze(inputs, dim=0) with torch.no_grad(): logits = model.forward(inputs) probs = torch.nn.functional.softmax(logits, dim=-1) label_idx = torch.argmax(probs, dim=-1) label_idx = label_idx.cpu() probs = probs.cpu() label_idx = label_idx.numpy()[0] prob = probs.numpy()[0][label_idx] label_str = vocabulary.get_token_from_index(label_idx, namespace="labels") return label_str, round(prob, 4) def main(): args = get_args() examples_dir = Path(args.examples_dir) trained_model_dir = Path(args.trained_model_dir) # models model_choices = list() for filename in trained_model_dir.glob("*.zip"): model_name = filename.stem model_choices.append(model_name) # examples examples = list() for filename in examples_dir.glob("*/*/*.wav"): label = filename.parts[-2] examples.append([ filename.as_posix(), model_choices[0], label ]) # ui brief_description = """ 国际语音智能外呼系统, 电话声音分类. """ # ui with gr.Blocks() as blocks: gr.Markdown(value=brief_description) with gr.Row(): with gr.Column(scale=3): c_audio = gr.Audio(label="audio") with gr.Row(): with gr.Column(scale=3): c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name") with gr.Column(scale=3): c_ground_true = gr.Textbox(label="ground_true") c_button = gr.Button("run", variant="primary") with gr.Column(scale=3): c_label = gr.Textbox(label="label") c_probability = gr.Number(label="probability") gr.Examples( examples, inputs=[c_audio, c_model_name, c_ground_true], outputs=[c_label, c_probability], fn=click_button, examples_per_page=5, ) c_button.click( click_button, inputs=[c_audio, c_model_name, c_ground_true], outputs=[c_label, c_probability], ) blocks.queue().launch( share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port ) return if __name__ == "__main__": main()