qgyd2021's picture
update
8a5e901
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import collections
import os
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import wavfile
import torch
import webrtcvad
from project_settings import project_path
class FrameVoiceClassifier(object):
def predict(self, chunk: np.ndarray) -> float:
raise NotImplementedError
class WebRTCVoiceClassifier(FrameVoiceClassifier):
def __init__(self,
agg: int = 3,
sample_rate: int = 8000
):
self.agg = agg
self.sample_rate = sample_rate
self.model = webrtcvad.Vad(mode=agg)
def predict(self, chunk: np.ndarray) -> float:
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
audio_bytes = bytes(chunk)
is_speech = self.model.is_speech(audio_bytes, self.sample_rate)
return 1.0 if is_speech else 0.0
class SileroVoiceClassifier(FrameVoiceClassifier):
def __init__(self,
model_path: str,
sample_rate: int = 8000):
self.model_path = model_path
self.sample_rate = sample_rate
with open(self.model_path, "rb") as f:
model = torch.jit.load(f, map_location="cpu")
self.model = model
self.model.reset_states()
def predict(self, chunk: np.ndarray) -> float:
if self.sample_rate / len(chunk) > 31.25:
raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25))
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
chunk = chunk / 32768
chunk = torch.tensor(chunk, dtype=torch.float32)
speech_prob = self.model(chunk, self.sample_rate).item()
return float(speech_prob)
class CallVoiceClassifier(FrameVoiceClassifier):
def __init__(self,
model_path: str,
sample_rate: int = 8000):
self.model_path = model_path
self.sample_rate = sample_rate
self.model = torch.jit.load(os.path.join(model_path, "cnn_voicemail.pth"))
def predict(self, chunk: np.ndarray) -> float:
if chunk.dtype != np.int16:
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype))
chunk = chunk / 32768
inputs = torch.tensor(chunk, dtype=torch.float32)
inputs = torch.unsqueeze(inputs, dim=0)
try:
outputs = self.model(inputs)
except RuntimeError as e:
print(inputs.shape)
raise e
probs = outputs["probs"]
voice_prob = probs[0][2]
return float(voice_prob)
class Frame(object):
def __init__(self, signal: np.ndarray, timestamp_s: float):
self.signal = signal
self.timestamp_s = timestamp_s
class Vad(object):
def __init__(self,
model: FrameVoiceClassifier,
start_ring_rate: float = 0.5,
end_ring_rate: float = 0.5,
frame_length_ms: int = 30,
frame_step_ms: int = 30,
padding_length_ms: int = 300,
max_silence_length_ms: int = 300,
max_speech_length_s: float = 2.0,
min_speech_length_s: float = 0.3,
sample_rate: int = 8000
):
self.model = model
self.start_ring_rate = start_ring_rate
self.end_ring_rate = end_ring_rate
self.frame_length_ms = frame_length_ms
self.padding_length_ms = padding_length_ms
self.max_silence_length_ms = max_silence_length_ms
self.max_speech_length_s = max_speech_length_s
self.min_speech_length_s = min_speech_length_s
self.sample_rate = sample_rate
# frames
self.frame_length = int(sample_rate * (frame_length_ms / 1000.0))
self.frame_step = int(sample_rate * (frame_step_ms / 1000.0))
self.frame_timestamp_s = 0.0
self.signal_cache = np.zeros(shape=(self.frame_length,), dtype=np.int16)
# self.signal_cache = None
# segments
self.num_padding_frames = int(padding_length_ms / frame_step_ms)
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
self.triggered = False
self.voiced_frames: List[Frame] = list()
self.segments = list()
# vad segments
self.is_first_segment = True
self.timestamp_start_s = 0.0
self.timestamp_end_s = 0.0
# speech probs
self.speech_probs: List[float] = list()
def signal_to_frames(self, signal: np.ndarray):
frames = list()
l = len(signal)
duration_s = float(self.frame_step) / self.sample_rate
for offset in range(0, l - self.frame_length + 1, self.frame_step):
sub_signal = signal[offset:offset+self.frame_length]
frame = Frame(sub_signal, self.frame_timestamp_s)
self.frame_timestamp_s += duration_s
frames.append(frame)
return frames
def segments_generator(self, signal: np.ndarray):
# signal rounding
if self.signal_cache is not None:
signal = np.concatenate([self.signal_cache, signal])
# rest
rest = (len(signal) - self.frame_length) % self.frame_step
if rest == 0:
self.signal_cache = None
signal_ = signal
else:
self.signal_cache = signal[-rest:]
signal_ = signal[:-rest]
# frames
frames = self.signal_to_frames(signal_)
for frame in frames:
speech_prob = self.model.predict(frame.signal)
self.speech_probs.append(speech_prob)
if not self.triggered:
self.ring_buffer.append((frame, speech_prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen:
self.triggered = True
for f, _ in self.ring_buffer:
self.voiced_frames.append(f)
continue
self.voiced_frames.append(frame)
self.ring_buffer.append((frame, speech_prob))
num_voiced = sum([p for _, p in self.ring_buffer])
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen:
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp_s,
self.voiced_frames[-1].timestamp_s,
]
yield segment
self.triggered = False
self.ring_buffer.clear()
self.voiced_frames = []
continue
def vad_segments_generator(self, segments_generator):
segments = list(segments_generator)
for i, segment in enumerate(segments):
start = round(segment[1], 4)
end = round(segment[2], 4)
if self.timestamp_start_s is None and self.timestamp_end_s is None:
self.timestamp_start_s = start
self.timestamp_end_s = end
continue
if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s:
end_ = self.timestamp_start_s + self.max_speech_length_s
vad_segment = [self.timestamp_start_s, end_]
yield vad_segment
self.timestamp_start_s = end_
silence_length_ms = (start - self.timestamp_end_s) * 1000
if silence_length_ms < self.max_silence_length_ms:
self.timestamp_end_s = end
continue
if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s:
self.timestamp_start_s = start
self.timestamp_end_s = end
continue
vad_segment = [self.timestamp_start_s, self.timestamp_end_s]
yield vad_segment
self.timestamp_start_s = start
self.timestamp_end_s = end
def vad(self, signal: np.ndarray) -> List[list]:
segments = self.segments_generator(signal)
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
return vad_segments
def last_vad_segments(self) -> List[list]:
# last segments
if len(self.voiced_frames) == 0:
segments = []
else:
segment = [
np.concatenate([f.signal for f in self.voiced_frames]),
self.voiced_frames[0].timestamp_s,
self.voiced_frames[-1].timestamp_s
]
segments = [segment]
# last vad segments
vad_segments = self.vad_segments_generator(segments)
vad_segments = list(vad_segments)
if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5:
vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]]
return vad_segments
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray:
speech_probs_ = list()
for p in speech_probs[1:]:
speech_probs_.extend([p] * frame_step)
pad = (signal.shape[0] - len(speech_probs_))
speech_probs_ = speech_probs_ + [0.0] * pad
speech_probs_ = np.array(speech_probs_, dtype=np.float32)
if len(speech_probs_) != len(signal):
raise AssertionError
return speech_probs_
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list):
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color='b')
plt.plot(time, speech_probs, color='gray')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点")
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点")
plt.show()
return
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--wav_file",
default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(),
type=str,
)
parser.add_argument(
"--model_path",
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(),
type=str,
)
args = parser.parse_args()
return args
SAMPLE_RATE = 8000
def main():
args = get_args()
sample_rate, signal = wavfile.read(args.wav_file)
if SAMPLE_RATE != sample_rate:
raise AssertionError
# model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE)
# model = WebRTCVoiceClassifier(agg=1, sample_rate=SAMPLE_RATE)
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
vad = Vad(model=model,
start_ring_rate=0.2,
end_ring_rate=0.1,
frame_length_ms=300,
frame_step_ms=30,
padding_length_ms=300,
max_silence_length_ms=300,
sample_rate=SAMPLE_RATE,
)
print(vad)
vad_segments = list()
segments = vad.vad(signal)
vad_segments += segments
for segment in segments:
print(segment)
# last vad segment
segments = vad.last_vad_segments()
vad_segments += segments
for segment in segments:
print(segment)
print(vad.speech_probs)
print(len(vad.speech_probs))
# speech_probs
speech_probs = process_speech_probs(
signal=signal,
speech_probs=vad.speech_probs,
frame_step=vad.frame_step,
)
# plot
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments)
return
if __name__ == '__main__':
main()