vad_go / main.py
HoneyTian's picture
update
5a74d6f
raw
history blame
6.54 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import logging
import json
from pathlib import Path
import platform
import re
from typing import Tuple
from project_settings import project_path, log_directory
import log
log.setup(log_directory=log_directory)
import gradio as gr
from toolbox.os.command import Command
main_logger = logging.getLogger("main")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--example_wav_dir",
default=(project_path / "data/examples").as_posix(),
type=str
)
args = parser.parse_args()
return args
def process_uploaded_file(
vad_engine: str,
filename: str,
silence_time: float = 0.3,
longest_activate: float = 3.0,
speech_pad_time: float = 0.03,
threshold: float = 0.5,
) -> Tuple[str, str]:
if vad_engine == "nx_vad":
return run_nx_vad(filename, silence_time, longest_activate)
elif vad_engine == "silero_vad":
return run_silero_vad(filename, silence_time, speech_pad_time, threshold)
else:
return f"vad engine invalid: {vad_engine}", ""
def run_nx_vad(filename: str, silence_time: float = 0.3, longest_activate: float = 3.0) -> Tuple[str, str]:
filename = Path(filename).as_posix()
main_logger.info("do nx vad: {}".format(filename))
cmd = "vad_bins/nx_vad --filename {} --silence_time {} --longest_activate {}".format(
filename, silence_time, longest_activate
)
raw_vad_result = Command.popen(cmd)
pattern = "(\\d+)[\r\n]VadFlagPrepare[\r\n](?:\\d+)[\r\n]VadFlagSpeaking(?:[\r\n](?:\\d+)[\r\n]VadFlagPause[\r\n](?:\\d+)[\r\n]VadFlagSpeaking)?[\r\n](\\d+)[\r\n]VadFlagNoSpeech"
vad_timestamps = re.findall(pattern, raw_vad_result, flags=re.DOTALL)
vad_timestamps = [(int(start), int(end)) for start, end in vad_timestamps]
vad_timestamps: str = json.dumps(vad_timestamps, ensure_ascii=False, indent=2)
return vad_timestamps, raw_vad_result
def run_silero_vad(filename: str,
silence_time: float = 0.3,
speech_pad_time: float = 0.03,
threshold: float = 0.5
) -> Tuple[str, str]:
filename = Path(filename).as_posix()
main_logger.info("do silero vad: {}".format(filename))
cmd = "vad_bins/silero_vad --filename {} --silence_time {} --speech_pad_time {} --threshold {}".format(
filename, silence_time, speech_pad_time, threshold
)
raw_vad_result = Command.popen(cmd)
pattern = "speech starts at (.+?)s[\r\n].*?speech ends at (.+?)s"
vad_timestamps = re.findall(pattern, raw_vad_result, flags=re.DOTALL)
vad_timestamps = [(int(float(start) * 1000), int(float(end) * 1000)) for start, end in vad_timestamps]
vad_timestamps: str = json.dumps(vad_timestamps, ensure_ascii=False, indent=2)
return vad_timestamps, raw_vad_result
def shell(cmd: str):
return Command.popen(cmd)
def main():
args = get_args()
title = "## GO语言实现的VAD."
# examples
example_wav_dir = Path(args.example_wav_dir)
examples = list()
for filename in example_wav_dir.glob("*.wav"):
examples.append(
[
"nx_vad",
filename.as_posix(),
0.3, 3.0,
]
)
# blocks
with gr.Blocks() as blocks:
gr.Markdown(value=title)
with gr.Tabs():
with gr.TabItem("Upload from disk"):
uploaded_file = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload from disk",
)
with gr.Row():
uploaded_vad_engine = gr.Dropdown(choices=["nx_vad", "silero_vad"], value="nx_vad", label="vad_engine")
uploaded_silence_time = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="silence_time")
uploaded_longest_activate = gr.Slider(minimum=0.0, maximum=20.0, value=3.0, step=0.1, label="longest_activate")
uploaded_speech_pad_time = gr.Slider(minimum=0.00, maximum=0.50, value=0.03, step=0.01, label="speech_pad_time")
uploaded_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.1, label="threshold")
upload_button = gr.Button("Run VAD", variant="primary")
with gr.Row():
uploaded_vad_timestamps = gr.Textbox(label="vad_timestamps")
uploaded_raw_vad_result = gr.Textbox(label="raw_vad_result")
gr.Examples(
examples=examples,
inputs=[
uploaded_vad_engine,
uploaded_file,
uploaded_silence_time,
uploaded_longest_activate,
uploaded_speech_pad_time,
uploaded_threshold,
],
outputs=[
uploaded_vad_timestamps,
uploaded_raw_vad_result,
],
fn=process_uploaded_file
)
upload_button.click(
process_uploaded_file,
inputs=[
uploaded_vad_engine,
uploaded_file,
uploaded_silence_time,
uploaded_longest_activate,
uploaded_speech_pad_time,
uploaded_threshold,
],
outputs=[
uploaded_vad_timestamps,
uploaded_raw_vad_result,
],
)
with gr.TabItem("shell"):
shell_text = gr.Textbox(label="cmd")
shell_button = gr.Button("run", variant="primary")
shell_output = gr.Textbox(label="output")
shell_button.click(
shell,
inputs=[
shell_text,
],
outputs=[
shell_output
],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860,
show_error=True
)
return
if __name__ == "__main__":
main()