#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import collections import contextlib import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import wave import webrtcvad from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", default=(project_path / "data/3300999628164249998.wav").as_posix(), type=str, ) parser.add_argument( "--agg", default=3, type=int, help="The level of aggressiveness of the VAD: [0-3]'" ) parser.add_argument( "--frame_duration_ms", default=30, type=int, ) parser.add_argument( "--silence_duration_threshold", default=0.3, type=float, help="minimum silence duration, in seconds." ) args = parser.parse_args() return args def read_wave(path): with contextlib.closing(wave.open(path, 'rb')) as wf: num_channels = wf.getnchannels() assert num_channels == 1 sample_width = wf.getsampwidth() assert sample_width == 2 sample_rate = wf.getframerate() assert sample_rate in (8000, 16000, 32000, 48000) pcm_data = wf.readframes(wf.getnframes()) return pcm_data, sample_rate class Frame(object): def __init__(self, audio_bytes, timestamp, duration): self.audio_bytes = audio_bytes self.timestamp = timestamp self.duration = duration def frame_generator(frame_duration_ms, audio, sample_rate): n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) offset = 0 timestamp = 0.0 duration = (float(n) / sample_rate) / 2.0 while offset + n < len(audio): yield Frame(audio[offset:offset + n], timestamp, duration) timestamp += duration offset += n def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): num_padding_frames = int(padding_duration_ms / frame_duration_ms) ring_buffer = collections.deque(maxlen=num_padding_frames) triggered = False voiced_frames = [] for frame in frames: is_speech = vad.is_speech(frame.audio_bytes, sample_rate) if not triggered: ring_buffer.append((frame, is_speech)) num_voiced = len([f for f, speech in ring_buffer if speech]) if num_voiced > 0.9 * ring_buffer.maxlen: triggered = True for f, _ in ring_buffer: voiced_frames.append(f) ring_buffer.clear() else: voiced_frames.append(frame) ring_buffer.append((frame, is_speech)) num_unvoiced = len([f for f, speech in ring_buffer if not speech]) if num_unvoiced > 0.9 * ring_buffer.maxlen: triggered = False yield [b''.join([f.audio_bytes for f in voiced_frames]), voiced_frames[0].timestamp, voiced_frames[-1].timestamp] ring_buffer.clear() voiced_frames = [] if voiced_frames: yield [b''.join([f.audio_bytes for f in voiced_frames]), voiced_frames[0].timestamp, voiced_frames[-1].timestamp] def main(): args = get_args() vad = webrtcvad.Vad(mode=args.agg) audio_pcm_data, sample_rate = read_wave(args.wav_file) _, audio_data = wavfile.read(args.wav_file) # audio_data_ = bytes(audio_data) frames = frame_generator( frame_duration_ms=args.frame_duration_ms, audio=audio_pcm_data, sample_rate=sample_rate ) frames = list(frames) segments = vad_collector(sample_rate, args.frame_duration_ms, 300, vad, frames) segments = list(segments) vad_segments = list() timestamp_start = 0.0 timestamp_end = 0.0 last_i = len(segments) - 1 for i, segment in enumerate(segments): start = round(segment[1], 4) end = round(segment[2], 4) flag_first = i == 0 flag_last = i == last_i if flag_first: timestamp_start = start timestamp_end = end continue if timestamp_start: sil_duration = start - timestamp_end if sil_duration > args.silence_duration_threshold: vad_segments.append([timestamp_start, timestamp_end]) timestamp_start = start timestamp_end = end if flag_last: vad_segments.append([timestamp_start, timestamp_end]) else: timestamp_end = end print(vad_segments) time = np.arange(0, len(audio_data)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(time, audio_data / 32768, color='b') for start, end in vad_segments: plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点 plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点 plt.show() return if __name__ == '__main__': main()