File size: 5,121 Bytes
fdbda89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import collections
import contextlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import wave
import webrtcvad
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
default=(project_path / "data/3300999628164249998.wav").as_posix(),
type=str,
)
parser.add_argument(
"--agg",
default=3,
type=int,
help="The level of aggressiveness of the VAD: [0-3]'"
)
parser.add_argument(
"--frame_duration_ms",
default=30,
type=int,
)
parser.add_argument(
"--silence_duration_threshold",
default=0.3,
type=float,
help="minimum silence duration, in seconds."
)
args = parser.parse_args()
return args
def read_wave(path):
with contextlib.closing(wave.open(path, 'rb')) as wf:
num_channels = wf.getnchannels()
assert num_channels == 1
sample_width = wf.getsampwidth()
assert sample_width == 2
sample_rate = wf.getframerate()
assert sample_rate in (8000, 16000, 32000, 48000)
pcm_data = wf.readframes(wf.getnframes())
return pcm_data, sample_rate
class Frame(object):
def __init__(self, audio_bytes, timestamp, duration):
self.audio_bytes = audio_bytes
self.timestamp = timestamp
self.duration = duration
def frame_generator(frame_duration_ms, audio, sample_rate):
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
offset = 0
timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration)
timestamp += duration
offset += n
def vad_collector(sample_rate, frame_duration_ms,
padding_duration_ms, vad, frames):
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
ring_buffer = collections.deque(maxlen=num_padding_frames)
triggered = False
voiced_frames = []
for frame in frames:
is_speech = vad.is_speech(frame.audio_bytes, sample_rate)
if not triggered:
ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in ring_buffer if speech])
if num_voiced > 0.9 * ring_buffer.maxlen:
triggered = True
for f, _ in ring_buffer:
voiced_frames.append(f)
ring_buffer.clear()
else:
voiced_frames.append(frame)
ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
if num_unvoiced > 0.9 * ring_buffer.maxlen:
triggered = False
yield [b''.join([f.audio_bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
ring_buffer.clear()
voiced_frames = []
if voiced_frames:
yield [b''.join([f.audio_bytes for f in voiced_frames]),
voiced_frames[0].timestamp, voiced_frames[-1].timestamp]
def main():
args = get_args()
vad = webrtcvad.Vad(mode=args.agg)
audio_pcm_data, sample_rate = read_wave(args.wav_file)
_, audio_data = wavfile.read(args.wav_file)
# audio_data_ = bytes(audio_data)
frames = frame_generator(
frame_duration_ms=args.frame_duration_ms,
audio=audio_pcm_data, sample_rate=sample_rate
)
frames = list(frames)
segments = vad_collector(sample_rate, args.frame_duration_ms, 300, vad, frames)
segments = list(segments)
vad_segments = list()
timestamp_start = 0.0
timestamp_end = 0.0
last_i = len(segments) - 1
for i, segment in enumerate(segments):
start = round(segment[1], 4)
end = round(segment[2], 4)
flag_first = i == 0
flag_last = i == last_i
if flag_first:
timestamp_start = start
timestamp_end = end
continue
if timestamp_start:
sil_duration = start - timestamp_end
if sil_duration > args.silence_duration_threshold:
vad_segments.append([timestamp_start, timestamp_end])
timestamp_start = start
timestamp_end = end
if flag_last:
vad_segments.append([timestamp_start, timestamp_end])
else:
timestamp_end = end
print(vad_segments)
time = np.arange(0, len(audio_data)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, audio_data / 32768, color='b')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
plt.show()
return
if __name__ == '__main__':
main()
|