|
|
|
|
|
import argparse |
|
import json |
|
import platform |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from project_settings import project_path, temp_directory |
|
from toolbox.webrtcvad.vad import WebRTCVad |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--webrtcvad_examples_file", |
|
default=(project_path / "webrtcvad_examples.json").as_posix(), |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
webrtcvad: WebRTCVad = None |
|
|
|
|
|
def click_webrtcvad_button(audio: Tuple[int, np.ndarray], |
|
agg: int = 3, |
|
frame_duration_ms: int = 30, |
|
padding_duration_ms: int = 300, |
|
silence_duration_threshold: float = 0.3, |
|
): |
|
global webrtcvad |
|
|
|
sample_rate, signal = audio |
|
|
|
webrtcvad = WebRTCVad(agg=int(agg), |
|
frame_duration_ms=frame_duration_ms, |
|
padding_duration_ms=padding_duration_ms, |
|
silence_duration_threshold=silence_duration_threshold, |
|
sample_rate=sample_rate, |
|
) |
|
|
|
vad_segments = list() |
|
segments = webrtcvad.vad(signal) |
|
vad_segments += segments |
|
segments = webrtcvad.last_vad_segments() |
|
vad_segments += segments |
|
|
|
time = np.arange(0, len(signal)) / sample_rate |
|
plt.figure(figsize=(12, 5)) |
|
plt.plot(time, signal / 32768, color='b') |
|
for start, end in vad_segments: |
|
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') |
|
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') |
|
|
|
temp_image_file = temp_directory / "temp.jpg" |
|
plt.savefig(temp_image_file) |
|
image = Image.open(open(temp_image_file, "rb")) |
|
|
|
return image, vad_segments |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
brief_description = """ |
|
## Voice Activity Detection |
|
|
|
""" |
|
|
|
|
|
with open(args.webrtcvad_examples_file, "r", encoding="utf-8") as f: |
|
webrtcvad_examples = json.load(f) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(value=brief_description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Tabs(): |
|
with gr.TabItem("webrtcvad"): |
|
gr.Markdown(value="") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
webrtcvad_wav = gr.Audio(label="wav") |
|
|
|
with gr.Row(): |
|
webrtcvad_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg") |
|
webrtcvad_frame_duration_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_duration_ms") |
|
|
|
with gr.Row(): |
|
webrtcvad_padding_duration_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_duration_ms") |
|
webrtcvad_silence_duration_threshold = gr.Slider(minimum=0, maximum=1.0, value=0.3, step=0.1, label="silence_duration_threshold") |
|
|
|
webrtcvad_button = gr.Button("retrieval", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
webrtcvad_image = gr.Image(label="image", height=300, width=720, show_label=False) |
|
webrtcvad_end_points = gr.TextArea(label="end_points", max_lines=35) |
|
|
|
gr.Examples( |
|
examples=webrtcvad_examples, |
|
inputs=[ |
|
webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms, |
|
webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold |
|
], |
|
outputs=[webrtcvad_image, webrtcvad_end_points], |
|
fn=click_webrtcvad_button |
|
) |
|
|
|
|
|
webrtcvad_button.click( |
|
click_webrtcvad_button, |
|
inputs=[ |
|
webrtcvad_wav, webrtcvad_agg, webrtcvad_frame_duration_ms, |
|
webrtcvad_padding_duration_ms, webrtcvad_silence_duration_threshold |
|
], |
|
outputs=[webrtcvad_image, webrtcvad_end_points], |
|
) |
|
|
|
blocks.queue().launch( |
|
share=False if platform.system() == "Windows" else False |
|
) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|