qgyd2021's picture
update
fdbda89
raw
history blame
5.12 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import collections
import contextlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import wave
import webrtcvad
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
default=(project_path / "data/3300999628164249998.wav").as_posix(),
type=str,
)
parser.add_argument(
"--agg",
default=3,
type=int,
help="The level of aggressiveness of the VAD: [0-3]'"
)
parser.add_argument(
"--frame_duration_ms",
default=30,
type=int,
)
parser.add_argument(
"--silence_duration_threshold",
default=0.3,
type=float,
help="minimum silence duration, in seconds."
)
args = parser.parse_args()
return args
def read_wave(path):
with contextlib.closing(wave.open(path, 'rb')) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
class Frame(object):
def __init__(self, audio_bytes, timestamp, duration):
self.audio_bytes = audio_bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
ring_buffer = collections.deque(maxlen=num_padding_frames)
triggered = False
voiced_frames = []
for frame in frames:
is_speech = vad.is_speech(frame.audio_bytes, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
for f, _ in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
if num_unvoiced > 0.9 * ring_buffer.maxlen:
triggered = False
yield [b''.join([f.audio_bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
ring_buffer.clear()
voiced_frames = []
if voiced_frames:
yield [b''.join([f.audio_bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
def main():
args = get_args()
vad = webrtcvad.Vad(mode=args.agg)
audio_pcm_data, sample_rate = read_wave(args.wav_file)
_, audio_data = wavfile.read(args.wav_file)
# audio_data_ = bytes(audio_data)
frames = frame_generator(
frame_duration_ms=args.frame_duration_ms,
audio=audio_pcm_data, sample_rate=sample_rate
)
frames = list(frames)
segments = vad_collector(sample_rate, args.frame_duration_ms, 300, vad, frames)
segments = list(segments)
vad_segments = list()
timestamp_start = 0.0
timestamp_end = 0.0
last_i = len(segments) - 1
for i, segment in enumerate(segments):
start = round(segment[1], 4)
end = round(segment[2], 4)
flag_first = i == 0
flag_last = i == last_i
if flag_first:
timestamp_start = start
timestamp_end = end
continue
if timestamp_start:
sil_duration = start - timestamp_end
if sil_duration > args.silence_duration_threshold:
vad_segments.append([timestamp_start, timestamp_end])
timestamp_start = start
timestamp_end = end
if flag_last:
vad_segments.append([timestamp_start, timestamp_end])
else:
timestamp_end = end
print(vad_segments)
time = np.arange(0, len(audio_data)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, audio_data / 32768, color='b')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
plt.show()
return
if __name__ == '__main__':
main()