qgyd2021's picture
update
fdbda89
raw
history blame
7.21 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import collections
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import webrtcvad
from project_settings import project_path
class Frame(object):
def __init__(self, signal: np.ndarray, timestamp, duration):
self.signal = signal
self.timestamp = timestamp
self.duration = duration
class WebRTCVad(object):
def __init__(self,
agg: int = 3,
frame_duration_ms: int = 30,
padding_duration_ms: int = 300,
silence_duration_threshold: float = 0.3,
sample_rate: int = 8000
):
self.agg = agg
self.frame_duration_ms = frame_duration_ms
self.padding_duration_ms = padding_duration_ms
self.silence_duration_threshold = silence_duration_threshold
self.sample_rate = sample_rate
self._vad = webrtcvad.Vad(mode=agg)
# frames
self.frame_length = int(sample_rate * (frame_duration_ms / 1000.0))
self.frame_timestamp = 0.0
self.signal_cache = None
# segments
self.num_padding_frames = int(padding_duration_ms / frame_duration_ms)
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames: List[Frame] = list()
self.segments = list()
# vad segments
self.is_first_segment = True
self.timestamp_start = 0.0
self.timestamp_end = 0.0
def signal_to_frames(self, signal: np.ndarray):
frames = list()
l = len(signal)
duration = (float(self.frame_length) / self.sample_rate)
for offset in range(0, l, self.frame_length):
sub_signal = signal[offset:offset+self.frame_length]
frame = Frame(sub_signal, self.frame_timestamp, duration)
self.frame_timestamp += duration
frames.append(frame)
return frames
def segments_generator(self, signal: np.ndarray):
# signal rounding
if self.signal_cache is not None:
signal = np.concatenate([self.signal_cache, signal])
rest = len(signal) % self.frame_length
if rest == 0:
self.signal_cache = None
signal_ = signal
else:
self.signal_cache = signal[-rest:]
signal_ = signal[:-rest]
# frames
frames = self.signal_to_frames(signal_)
for frame in frames:
audio_bytes = bytes(frame.signal)
is_speech = self._vad.is_speech(audio_bytes, self.sample_rate)
if not self.triggered:
self.ring_buffer.append((frame, is_speech))
num_voiced = len([f for f, speech in self.ring_buffer if speech])
if num_voiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = True
for f, _ in self.ring_buffer:
self.voiced_frames.append(f)
self.ring_buffer.clear()
else:
self.voiced_frames.append(frame)
self.ring_buffer.append((frame, is_speech))
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
self.triggered = False
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp,
self.voiced_frames[-1].timestamp
]
yield segment
self.ring_buffer.clear()
self.voiced_frames = []
def vad_segments_generator(self, segments_generator):
segments = list(segments_generator)
for i, segment in enumerate(segments):
start = round(segment[1], 4)
end = round(segment[2], 4)
if self.is_first_segment:
self.timestamp_start = start
self.timestamp_end = end
self.is_first_segment = False
continue
if self.timestamp_start:
sil_duration = start - self.timestamp_end
if sil_duration > self.silence_duration_threshold:
vad_segment = [self.timestamp_start, self.timestamp_end]
yield vad_segment
self.timestamp_start = start
self.timestamp_end = end
else:
self.timestamp_end = end
def vad(self, signal: np.ndarray) -> List[list]:
segments = self.segments_generator(signal)
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
return vad_segments
def last_vad_segments(self) -> List[list]:
# last segments
if len(self.voiced_frames) == 0:
segments = []
else:
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp,
self.voiced_frames[-1].timestamp
]
segments = [segment]
# last vad segments
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
vad_segments = vad_segments + [[self.timestamp_start, self.timestamp_end]]
return vad_segments
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
default=(project_path / "data/3300999628164249998.wav").as_posix(),
type=str,
)
parser.add_argument(
"--agg",
default=3,
type=int,
help="The level of aggressiveness of the VAD: [0-3]'"
)
parser.add_argument(
"--frame_duration_ms",
default=30,
type=int,
)
parser.add_argument(
"--silence_duration_threshold",
default=0.3,
type=float,
help="minimum silence duration, in seconds."
)
args = parser.parse_args()
return args
SAMPLE_RATE = 8000
def main():
args = get_args()
w_vad = WebRTCVad(sample_rate=SAMPLE_RATE)
sample_rate, signal = wavfile.read(args.wav_file)
if SAMPLE_RATE != sample_rate:
raise AssertionError
vad_segments = list()
segments = w_vad.vad(signal)
vad_segments += segments
for segment in segments:
print(segment)
# last vad segment
segments = w_vad.last_vad_segments()
vad_segments += segments
for segment in segments:
print(segment)
# plot
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color='b')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
plt.show()
return
if __name__ == '__main__':
main()