|
|
|
|
|
import argparse |
|
import collections |
|
from typing import List |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import webrtcvad |
|
|
|
from project_settings import project_path |
|
|
|
|
|
class Frame(object): |
|
def __init__(self, signal: np.ndarray, timestamp, duration): |
|
self.signal = signal |
|
self.timestamp = timestamp |
|
self.duration = duration |
|
|
|
|
|
class WebRTCVad(object): |
|
def __init__(self, |
|
agg: int = 3, |
|
frame_duration_ms: int = 30, |
|
padding_duration_ms: int = 300, |
|
silence_duration_threshold: float = 0.3, |
|
sample_rate: int = 8000 |
|
): |
|
self.agg = agg |
|
self.frame_duration_ms = frame_duration_ms |
|
self.padding_duration_ms = padding_duration_ms |
|
self.silence_duration_threshold = silence_duration_threshold |
|
self.sample_rate = sample_rate |
|
|
|
self._vad = webrtcvad.Vad(mode=agg) |
|
|
|
|
|
self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0)) |
|
self.frame_timestamp = 0.0 |
|
self.signal_cache = None |
|
|
|
|
|
self.num_padding_frames = int(padding_duration_ms / frame_duration_ms) |
|
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) |
|
self.triggered = False |
|
self.voiced_frames: List[Frame] = list() |
|
self.segments = list() |
|
|
|
|
|
self.is_first_segment = True |
|
self.timestamp_start = 0.0 |
|
self.timestamp_end = 0.0 |
|
|
|
def signal_to_frames(self, signal: np.ndarray): |
|
frames = list() |
|
|
|
l = len(signal) |
|
|
|
duration = (float(self.frame_length) / self.sample_rate) |
|
|
|
for offset in range(0, l, self.frame_length): |
|
sub_signal = signal[offset:offset+self.frame_length] |
|
|
|
frame = Frame(sub_signal, self.frame_timestamp, duration) |
|
self.frame_timestamp += duration |
|
|
|
frames.append(frame) |
|
return frames |
|
|
|
def segments_generator(self, signal: np.ndarray): |
|
|
|
if self.signal_cache is not None: |
|
signal = np.concatenate([self.signal_cache, signal]) |
|
|
|
rest = len(signal) % self.frame_length |
|
|
|
if rest == 0: |
|
self.signal_cache = None |
|
signal_ = signal |
|
else: |
|
self.signal_cache = signal[-rest:] |
|
signal_ = signal[:-rest] |
|
|
|
|
|
frames = self.signal_to_frames(signal_) |
|
|
|
for frame in frames: |
|
audio_bytes = bytes(frame.signal) |
|
is_speech = self._vad.is_speech(audio_bytes, self.sample_rate) |
|
|
|
if not self.triggered: |
|
self.ring_buffer.append((frame, is_speech)) |
|
num_voiced = len([f for f, speech in self.ring_buffer if speech]) |
|
|
|
if num_voiced > 0.9 * self.ring_buffer.maxlen: |
|
self.triggered = True |
|
|
|
for f, _ in self.ring_buffer: |
|
self.voiced_frames.append(f) |
|
self.ring_buffer.clear() |
|
else: |
|
self.voiced_frames.append(frame) |
|
self.ring_buffer.append((frame, is_speech)) |
|
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech]) |
|
if num_unvoiced > 0.9 * self.ring_buffer.maxlen: |
|
self.triggered = False |
|
segment = [ |
|
np.concatenate([f.signal for f in self.voiced_frames]), |
|
self.voiced_frames[0].timestamp, |
|
self.voiced_frames[-1].timestamp |
|
] |
|
yield segment |
|
self.ring_buffer.clear() |
|
self.voiced_frames = [] |
|
|
|
def vad_segments_generator(self, segments_generator): |
|
segments = list(segments_generator) |
|
|
|
for i, segment in enumerate(segments): |
|
start = round(segment[1], 4) |
|
end = round(segment[2], 4) |
|
|
|
if self.is_first_segment: |
|
self.timestamp_start = start |
|
self.timestamp_end = end |
|
self.is_first_segment = False |
|
continue |
|
|
|
if self.timestamp_start: |
|
sil_duration = start - self.timestamp_end |
|
if sil_duration > self.silence_duration_threshold: |
|
vad_segment = [self.timestamp_start, self.timestamp_end] |
|
yield vad_segment |
|
|
|
self.timestamp_start = start |
|
self.timestamp_end = end |
|
else: |
|
self.timestamp_end = end |
|
|
|
def vad(self, signal: np.ndarray) -> List[list]: |
|
segments = self.segments_generator(signal) |
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
return vad_segments |
|
|
|
def last_vad_segments(self) -> List[list]: |
|
|
|
if len(self.voiced_frames) == 0: |
|
segments = [] |
|
else: |
|
segment = [ |
|
np.concatenate([f.signal for f in self.voiced_frames]), |
|
self.voiced_frames[0].timestamp, |
|
self.voiced_frames[-1].timestamp |
|
] |
|
segments = [segment] |
|
|
|
|
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
|
|
vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]] |
|
return vad_segments |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--wav_file", |
|
default=(project_path / "data/3300999628164249998.wav").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--agg", |
|
default=3, |
|
type=int, |
|
help="The level of aggressiveness of the VAD: [0-3]'" |
|
) |
|
parser.add_argument( |
|
"--frame_duration_ms", |
|
default=30, |
|
type=int, |
|
) |
|
parser.add_argument( |
|
"--silence_duration_threshold", |
|
default=0.3, |
|
type=float, |
|
help="minimum silence duration, in seconds." |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
SAMPLE_RATE = 8000 |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
w_vad = WebRTCVad(sample_rate=SAMPLE_RATE) |
|
|
|
sample_rate, signal = wavfile.read(args.wav_file) |
|
if SAMPLE_RATE != sample_rate: |
|
raise AssertionError |
|
|
|
vad_segments = list() |
|
|
|
segments = w_vad.vad(signal) |
|
vad_segments += segments |
|
for segment in segments: |
|
print(segment) |
|
|
|
|
|
segments = w_vad.last_vad_segments() |
|
vad_segments += segments |
|
for segment in segments: |
|
print(segment) |
|
|
|
|
|
time = np.arange(0, len(signal)) / sample_rate |
|
plt.figure(figsize=(12, 5)) |
|
plt.plot(time, signal / 32768, color='b') |
|
for start, end in vad_segments: |
|
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') |
|
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') |
|
|
|
plt.show() |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|