|
|
|
|
|
import argparse |
|
import collections |
|
import contextlib |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import wave |
|
import webrtcvad |
|
|
|
from project_settings import project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--wav_file", |
|
default=(project_path / "data/3300999628164249998.wav").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--agg", |
|
default=3, |
|
type=int, |
|
help="The level of aggressiveness of the VAD: [0-3]'" |
|
) |
|
parser.add_argument( |
|
"--frame_duration_ms", |
|
default=30, |
|
type=int, |
|
) |
|
parser.add_argument( |
|
"--silence_duration_threshold", |
|
default=0.3, |
|
type=float, |
|
help="minimum silence duration, in seconds." |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def read_wave(path): |
|
with contextlib.closing(wave.open(path, 'rb')) as wf: |
|
num_channels = wf.getnchannels() |
|
assert num_channels == 1 |
|
sample_width = wf.getsampwidth() |
|
assert sample_width == 2 |
|
sample_rate = wf.getframerate() |
|
assert sample_rate in (8000, 16000, 32000, 48000) |
|
pcm_data = wf.readframes(wf.getnframes()) |
|
return pcm_data, sample_rate |
|
|
|
|
|
class Frame(object): |
|
def __init__(self, audio_bytes, timestamp, duration): |
|
self.audio_bytes = audio_bytes |
|
self.timestamp = timestamp |
|
self.duration = duration |
|
|
|
|
|
def frame_generator(frame_duration_ms, audio, sample_rate): |
|
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) |
|
offset = 0 |
|
timestamp = 0.0 |
|
duration = (float(n) / sample_rate) / 2.0 |
|
while offset + n < len(audio): |
|
yield Frame(audio[offset:offset + n], timestamp, duration) |
|
timestamp += duration |
|
offset += n |
|
|
|
|
|
def vad_collector(sample_rate, frame_duration_ms, |
|
padding_duration_ms, vad, frames): |
|
|
|
num_padding_frames = int(padding_duration_ms / frame_duration_ms) |
|
ring_buffer = collections.deque(maxlen=num_padding_frames) |
|
triggered = False |
|
|
|
voiced_frames = [] |
|
for frame in frames: |
|
is_speech = vad.is_speech(frame.audio_bytes, sample_rate) |
|
|
|
if not triggered: |
|
ring_buffer.append((frame, is_speech)) |
|
num_voiced = len([f for f, speech in ring_buffer if speech]) |
|
|
|
if num_voiced > 0.9 * ring_buffer.maxlen: |
|
triggered = True |
|
|
|
for f, _ in ring_buffer: |
|
voiced_frames.append(f) |
|
ring_buffer.clear() |
|
else: |
|
voiced_frames.append(frame) |
|
ring_buffer.append((frame, is_speech)) |
|
num_unvoiced = len([f for f, speech in ring_buffer if not speech]) |
|
if num_unvoiced > 0.9 * ring_buffer.maxlen: |
|
triggered = False |
|
yield [b''.join([f.audio_bytes for f in voiced_frames]), |
|
voiced_frames[0].timestamp, voiced_frames[-1].timestamp] |
|
ring_buffer.clear() |
|
voiced_frames = [] |
|
|
|
if voiced_frames: |
|
yield [b''.join([f.audio_bytes for f in voiced_frames]), |
|
voiced_frames[0].timestamp, voiced_frames[-1].timestamp] |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
vad = webrtcvad.Vad(mode=args.agg) |
|
|
|
audio_pcm_data, sample_rate = read_wave(args.wav_file) |
|
_, audio_data = wavfile.read(args.wav_file) |
|
|
|
|
|
frames = frame_generator( |
|
frame_duration_ms=args.frame_duration_ms, |
|
audio=audio_pcm_data, sample_rate=sample_rate |
|
) |
|
frames = list(frames) |
|
|
|
segments = vad_collector(sample_rate, args.frame_duration_ms, 300, vad, frames) |
|
segments = list(segments) |
|
|
|
vad_segments = list() |
|
timestamp_start = 0.0 |
|
timestamp_end = 0.0 |
|
|
|
last_i = len(segments) - 1 |
|
for i, segment in enumerate(segments): |
|
start = round(segment[1], 4) |
|
end = round(segment[2], 4) |
|
|
|
flag_first = i == 0 |
|
flag_last = i == last_i |
|
if flag_first: |
|
timestamp_start = start |
|
timestamp_end = end |
|
continue |
|
|
|
if timestamp_start: |
|
sil_duration = start - timestamp_end |
|
if sil_duration > args.silence_duration_threshold: |
|
vad_segments.append([timestamp_start, timestamp_end]) |
|
timestamp_start = start |
|
timestamp_end = end |
|
if flag_last: |
|
vad_segments.append([timestamp_start, timestamp_end]) |
|
else: |
|
timestamp_end = end |
|
|
|
print(vad_segments) |
|
|
|
time = np.arange(0, len(audio_data)) / sample_rate |
|
|
|
plt.figure(figsize=(12, 5)) |
|
|
|
plt.plot(time, audio_data / 32768, color='b') |
|
|
|
for start, end in vad_segments: |
|
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') |
|
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') |
|
|
|
plt.show() |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|