qgyd2021's picture
update
fdbda89
raw
history blame
4.89 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import platform
from typing import Tuple
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from project_settings import project_path, temp_directory
from toolbox.webrtcvad.vad import WebRTCVad
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--webrtcvad_examples_file",
default=(project_path / "webrtcvad_examples.json").as_posix(),
type=str
)
args = parser.parse_args()
return args
webrtcvad: WebRTCVad = None
def click_webrtcvad_button(audio: Tuple[int, np.ndarray],
agg: int = 3,
frame_duration_ms: int = 30,
padding_duration_ms: int = 300,
silence_duration_threshold: float = 0.3,
):
global webrtcvad
sample_rate, signal = audio
webrtcvad = WebRTCVad(agg=int(agg),
frame_duration_ms=frame_duration_ms,
padding_duration_ms=padding_duration_ms,
silence_duration_threshold=silence_duration_threshold,
sample_rate=sample_rate,
)
vad_segments = list()
segments = webrtcvad.vad(signal)
vad_segments += segments
segments = webrtcvad.last_vad_segments()
vad_segments += segments
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color='b')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
temp_image_file = temp_directory / "temp.jpg"
plt.savefig(temp_image_file)
image = Image.open(open(temp_image_file, "rb"))
return image, vad_segments
def main():
args = get_args()
brief_description = """
## Voice Activity Detection
"""
# examples
with open(args.webrtcvad_examples_file, "r", encoding="utf-8") as f:
webrtcvad_examples = json.load(f)
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("webrtcvad"):
gr.Markdown(value="")
with gr.Row():
with gr.Column(scale=1):
webrtcvad_wav = gr.Audio(label="wav")
with gr.Row():
webrtcvad_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
webrtcvad_frame_duration_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_duration_ms")
with gr.Row():
webrtcvad_padding_duration_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_duration_ms")
webrtcvad_silence_duration_threshold = gr.Slider(minimum=0, maximum=1.0, value=0.3, step=0.1, label="silence_duration_threshold")
webrtcvad_button = gr.Button("retrieval", variant="primary")
with gr.Column(scale=1):
webrtcvad_image = gr.Image(label="image", height=300, width=720, show_label=False)
webrtcvad_end_points = gr.TextArea(label="end_points", max_lines=35)
gr.Examples(
examples=webrtcvad_examples,
inputs=[
webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
],
outputs=[webrtcvad_image, webrtcvad_end_points],
fn=click_webrtcvad_button
)
# click event
webrtcvad_button.click(
click_webrtcvad_button,
inputs=[
webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms,
webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold
],
outputs=[webrtcvad_image, webrtcvad_end_points],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False
)
return
if __name__ == "__main__":
main()