|
|
|
|
|
import argparse |
|
import collections |
|
import os |
|
from typing import List |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import torch |
|
import webrtcvad |
|
|
|
from project_settings import project_path |
|
|
|
|
|
class FrameVoiceClassifier(object): |
|
def predict(self, chunk: np.ndarray) -> float: |
|
raise NotImplementedError |
|
|
|
|
|
class WebRTCVoiceClassifier(FrameVoiceClassifier): |
|
def __init__(self, |
|
agg: int = 3, |
|
sample_rate: int = 8000 |
|
): |
|
self.agg = agg |
|
self.sample_rate = sample_rate |
|
|
|
self.model = webrtcvad.Vad(mode=agg) |
|
|
|
def predict(self, chunk: np.ndarray) -> float: |
|
if chunk.dtype != np.int16: |
|
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) |
|
|
|
audio_bytes = bytes(chunk) |
|
is_speech = self.model.is_speech(audio_bytes, self.sample_rate) |
|
return 1.0 if is_speech else 0.0 |
|
|
|
|
|
class SileroVoiceClassifier(FrameVoiceClassifier): |
|
def __init__(self, |
|
model_path: str, |
|
sample_rate: int = 8000): |
|
self.model_path = model_path |
|
self.sample_rate = sample_rate |
|
|
|
with open(self.model_path, "rb") as f: |
|
model = torch.jit.load(f, map_location="cpu") |
|
self.model = model |
|
self.model.reset_states() |
|
|
|
def predict(self, chunk: np.ndarray) -> float: |
|
if self.sample_rate / len(chunk) > 31.25: |
|
raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25)) |
|
if chunk.dtype != np.int16: |
|
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) |
|
|
|
chunk = chunk / 32768 |
|
chunk = torch.tensor(chunk, dtype=torch.float32) |
|
speech_prob = self.model(chunk, self.sample_rate).item() |
|
return float(speech_prob) |
|
|
|
|
|
class CallVoiceClassifier(FrameVoiceClassifier): |
|
def __init__(self, |
|
model_path: str, |
|
sample_rate: int = 8000): |
|
self.model_path = model_path |
|
self.sample_rate = sample_rate |
|
|
|
self.model = torch.jit.load(os.path.join(model_path, "cnn_voicemail.pth")) |
|
|
|
def predict(self, chunk: np.ndarray) -> float: |
|
if chunk.dtype != np.int16: |
|
raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) |
|
|
|
chunk = chunk / 32768 |
|
|
|
inputs = torch.tensor(chunk, dtype=torch.float32) |
|
inputs = torch.unsqueeze(inputs, dim=0) |
|
|
|
try: |
|
outputs = self.model(inputs) |
|
except RuntimeError as e: |
|
print(inputs.shape) |
|
raise e |
|
|
|
probs = outputs["probs"] |
|
voice_prob = probs[0][2] |
|
return float(voice_prob) |
|
|
|
|
|
class Frame(object): |
|
def __init__(self, signal: np.ndarray, timestamp_s: float): |
|
self.signal = signal |
|
self.timestamp_s = timestamp_s |
|
|
|
|
|
class Vad(object): |
|
def __init__(self, |
|
model: FrameVoiceClassifier, |
|
start_ring_rate: float = 0.5, |
|
end_ring_rate: float = 0.5, |
|
frame_length_ms: int = 30, |
|
frame_step_ms: int = 30, |
|
padding_length_ms: int = 300, |
|
max_silence_length_ms: int = 300, |
|
max_speech_length_s: float = 2.0, |
|
min_speech_length_s: float = 0.3, |
|
sample_rate: int = 8000 |
|
): |
|
self.model = model |
|
self.start_ring_rate = start_ring_rate |
|
self.end_ring_rate = end_ring_rate |
|
self.frame_length_ms = frame_length_ms |
|
self.padding_length_ms = padding_length_ms |
|
self.max_silence_length_ms = max_silence_length_ms |
|
self.max_speech_length_s = max_speech_length_s |
|
self.min_speech_length_s = min_speech_length_s |
|
self.sample_rate = sample_rate |
|
|
|
|
|
self.frame_length = int(sample_rate * (frame_length_ms / 1000.0)) |
|
self.frame_step = int(sample_rate * (frame_step_ms / 1000.0)) |
|
self.frame_timestamp_s = 0.0 |
|
self.signal_cache = np.zeros(shape=(self.frame_length,), dtype=np.int16) |
|
|
|
|
|
|
|
self.num_padding_frames = int(padding_length_ms / frame_step_ms) |
|
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) |
|
self.triggered = False |
|
self.voiced_frames: List[Frame] = list() |
|
self.segments = list() |
|
|
|
|
|
self.is_first_segment = True |
|
self.timestamp_start_s = 0.0 |
|
self.timestamp_end_s = 0.0 |
|
|
|
|
|
self.speech_probs: List[float] = list() |
|
|
|
def signal_to_frames(self, signal: np.ndarray): |
|
frames = list() |
|
|
|
l = len(signal) |
|
|
|
duration_s = float(self.frame_step) / self.sample_rate |
|
|
|
for offset in range(0, l - self.frame_length + 1, self.frame_step): |
|
sub_signal = signal[offset:offset+self.frame_length] |
|
frame = Frame(sub_signal, self.frame_timestamp_s) |
|
self.frame_timestamp_s += duration_s |
|
|
|
frames.append(frame) |
|
return frames |
|
|
|
def segments_generator(self, signal: np.ndarray): |
|
|
|
if self.signal_cache is not None: |
|
signal = np.concatenate([self.signal_cache, signal]) |
|
|
|
|
|
rest = (len(signal) - self.frame_length) % self.frame_step |
|
|
|
if rest == 0: |
|
self.signal_cache = None |
|
signal_ = signal |
|
else: |
|
self.signal_cache = signal[-rest:] |
|
signal_ = signal[:-rest] |
|
|
|
|
|
frames = self.signal_to_frames(signal_) |
|
|
|
for frame in frames: |
|
speech_prob = self.model.predict(frame.signal) |
|
self.speech_probs.append(speech_prob) |
|
|
|
if not self.triggered: |
|
self.ring_buffer.append((frame, speech_prob)) |
|
num_voiced = sum([p for _, p in self.ring_buffer]) |
|
|
|
if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen: |
|
self.triggered = True |
|
|
|
for f, _ in self.ring_buffer: |
|
self.voiced_frames.append(f) |
|
continue |
|
|
|
self.voiced_frames.append(frame) |
|
self.ring_buffer.append((frame, speech_prob)) |
|
num_voiced = sum([p for _, p in self.ring_buffer]) |
|
|
|
if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen: |
|
segment = [ |
|
np.concatenate([f.signal for f in self.voiced_frames]), |
|
self.voiced_frames[0].timestamp_s, |
|
self.voiced_frames[-1].timestamp_s, |
|
] |
|
yield segment |
|
self.triggered = False |
|
self.ring_buffer.clear() |
|
self.voiced_frames = [] |
|
continue |
|
|
|
def vad_segments_generator(self, segments_generator): |
|
segments = list(segments_generator) |
|
|
|
for i, segment in enumerate(segments): |
|
start = round(segment[1], 4) |
|
end = round(segment[2], 4) |
|
|
|
if self.timestamp_start_s is None and self.timestamp_end_s is None: |
|
self.timestamp_start_s = start |
|
self.timestamp_end_s = end |
|
continue |
|
|
|
if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s: |
|
end_ = self.timestamp_start_s + self.max_speech_length_s |
|
vad_segment = [self.timestamp_start_s, end_] |
|
yield vad_segment |
|
self.timestamp_start_s = end_ |
|
|
|
silence_length_ms = (start - self.timestamp_end_s) * 1000 |
|
if silence_length_ms < self.max_silence_length_ms: |
|
self.timestamp_end_s = end |
|
continue |
|
|
|
if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s: |
|
self.timestamp_start_s = start |
|
self.timestamp_end_s = end |
|
continue |
|
|
|
vad_segment = [self.timestamp_start_s, self.timestamp_end_s] |
|
yield vad_segment |
|
self.timestamp_start_s = start |
|
self.timestamp_end_s = end |
|
|
|
def vad(self, signal: np.ndarray) -> List[list]: |
|
segments = self.segments_generator(signal) |
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
return vad_segments |
|
|
|
def last_vad_segments(self) -> List[list]: |
|
|
|
if len(self.voiced_frames) == 0: |
|
segments = [] |
|
else: |
|
segment = [ |
|
np.concatenate([f.signal for f in self.voiced_frames]), |
|
self.voiced_frames[0].timestamp_s, |
|
self.voiced_frames[-1].timestamp_s |
|
] |
|
segments = [segment] |
|
|
|
|
|
vad_segments = self.vad_segments_generator(segments) |
|
vad_segments = list(vad_segments) |
|
|
|
if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5: |
|
vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]] |
|
return vad_segments |
|
|
|
|
|
def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray: |
|
speech_probs_ = list() |
|
for p in speech_probs[1:]: |
|
speech_probs_.extend([p] * frame_step) |
|
|
|
pad = (signal.shape[0] - len(speech_probs_)) |
|
speech_probs_ = speech_probs_ + [0.0] * pad |
|
speech_probs_ = np.array(speech_probs_, dtype=np.float32) |
|
|
|
if len(speech_probs_) != len(signal): |
|
raise AssertionError |
|
return speech_probs_ |
|
|
|
|
|
def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list): |
|
time = np.arange(0, len(signal)) / sample_rate |
|
plt.figure(figsize=(12, 5)) |
|
plt.plot(time, signal / 32768, color='b') |
|
plt.plot(time, speech_probs, color='gray') |
|
for start, end in vad_segments: |
|
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点") |
|
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点") |
|
|
|
plt.show() |
|
return |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--wav_file", |
|
default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
"--model_path", |
|
default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(), |
|
type=str, |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
SAMPLE_RATE = 8000 |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
sample_rate, signal = wavfile.read(args.wav_file) |
|
if SAMPLE_RATE != sample_rate: |
|
raise AssertionError |
|
|
|
|
|
|
|
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix()) |
|
|
|
vad = Vad(model=model, |
|
start_ring_rate=0.2, |
|
end_ring_rate=0.1, |
|
frame_length_ms=300, |
|
frame_step_ms=30, |
|
padding_length_ms=300, |
|
max_silence_length_ms=300, |
|
sample_rate=SAMPLE_RATE, |
|
) |
|
print(vad) |
|
|
|
vad_segments = list() |
|
|
|
segments = vad.vad(signal) |
|
vad_segments += segments |
|
for segment in segments: |
|
print(segment) |
|
|
|
|
|
segments = vad.last_vad_segments() |
|
vad_segments += segments |
|
for segment in segments: |
|
print(segment) |
|
|
|
print(vad.speech_probs) |
|
print(len(vad.speech_probs)) |
|
|
|
|
|
speech_probs = process_speech_probs( |
|
signal=signal, |
|
speech_probs=vad.speech_probs, |
|
frame_step=vad.frame_step, |
|
) |
|
|
|
|
|
make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments) |
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|