#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import collections import os from typing import List import matplotlib.pyplot as plt import numpy as np from scipy.io import wavfile import torch import webrtcvad from project_settings import project_path class FrameVoiceClassifier(object): def predict(self, chunk: np.ndarray) -> float: raise NotImplementedError class WebRTCVoiceClassifier(FrameVoiceClassifier): def __init__(self, agg: int = 3, sample_rate: int = 8000 ): self.agg = agg self.sample_rate = sample_rate self.model = webrtcvad.Vad(mode=agg) def predict(self, chunk: np.ndarray) -> float: if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) audio_bytes = bytes(chunk) is_speech = self.model.is_speech(audio_bytes, self.sample_rate) return 1.0 if is_speech else 0.0 class SileroVoiceClassifier(FrameVoiceClassifier): def __init__(self, model_path: str, sample_rate: int = 8000): self.model_path = model_path self.sample_rate = sample_rate with open(self.model_path, "rb") as f: model = torch.jit.load(f, map_location="cpu") self.model = model self.model.reset_states() def predict(self, chunk: np.ndarray) -> float: if self.sample_rate / len(chunk) > 31.25: raise AssertionError("chunk samples number {} is less than {}".format(len(chunk), self.sample_rate / 31.25)) if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) chunk = chunk / 32768 chunk = torch.tensor(chunk, dtype=torch.float32) speech_prob = self.model(chunk, self.sample_rate).item() return float(speech_prob) class CallVoiceClassifier(FrameVoiceClassifier): def __init__(self, model_path: str, sample_rate: int = 8000): self.model_path = model_path self.sample_rate = sample_rate self.model = torch.jit.load(os.path.join(model_path, "cnn_voicemail.pth")) def predict(self, chunk: np.ndarray) -> float: if chunk.dtype != np.int16: raise AssertionError("signal dtype should be np.int16, instead of {}".format(chunk.dtype)) chunk = chunk / 32768 inputs = torch.tensor(chunk, dtype=torch.float32) inputs = torch.unsqueeze(inputs, dim=0) try: outputs = self.model(inputs) except RuntimeError as e: print(inputs.shape) raise e probs = outputs["probs"] voice_prob = probs[0][2] return float(voice_prob) class Frame(object): def __init__(self, signal: np.ndarray, timestamp_s: float): self.signal = signal self.timestamp_s = timestamp_s class Vad(object): def __init__(self, model: FrameVoiceClassifier, start_ring_rate: float = 0.5, end_ring_rate: float = 0.5, frame_length_ms: int = 30, frame_step_ms: int = 30, padding_length_ms: int = 300, max_silence_length_ms: int = 300, max_speech_length_s: float = 2.0, min_speech_length_s: float = 0.3, sample_rate: int = 8000 ): self.model = model self.start_ring_rate = start_ring_rate self.end_ring_rate = end_ring_rate self.frame_length_ms = frame_length_ms self.padding_length_ms = padding_length_ms self.max_silence_length_ms = max_silence_length_ms self.max_speech_length_s = max_speech_length_s self.min_speech_length_s = min_speech_length_s self.sample_rate = sample_rate # frames self.frame_length = int(sample_rate * (frame_length_ms / 1000.0)) self.frame_step = int(sample_rate * (frame_step_ms / 1000.0)) self.frame_timestamp_s = 0.0 self.signal_cache = np.zeros(shape=(self.frame_length,), dtype=np.int16) # self.signal_cache = None # segments self.num_padding_frames = int(padding_length_ms / frame_step_ms) self.ring_buffer = collections.deque(maxlen=self.num_padding_frames) self.triggered = False self.voiced_frames: List[Frame] = list() self.segments = list() # vad segments self.is_first_segment = True self.timestamp_start_s = 0.0 self.timestamp_end_s = 0.0 # speech probs self.speech_probs: List[float] = list() def signal_to_frames(self, signal: np.ndarray): frames = list() l = len(signal) duration_s = float(self.frame_step) / self.sample_rate for offset in range(0, l - self.frame_length + 1, self.frame_step): sub_signal = signal[offset:offset+self.frame_length] frame = Frame(sub_signal, self.frame_timestamp_s) self.frame_timestamp_s += duration_s frames.append(frame) return frames def segments_generator(self, signal: np.ndarray): # signal rounding if self.signal_cache is not None: signal = np.concatenate([self.signal_cache, signal]) # rest rest = (len(signal) - self.frame_length) % self.frame_step if rest == 0: self.signal_cache = None signal_ = signal else: self.signal_cache = signal[-rest:] signal_ = signal[:-rest] # frames frames = self.signal_to_frames(signal_) for frame in frames: speech_prob = self.model.predict(frame.signal) self.speech_probs.append(speech_prob) if not self.triggered: self.ring_buffer.append((frame, speech_prob)) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced > self.start_ring_rate * self.ring_buffer.maxlen: self.triggered = True for f, _ in self.ring_buffer: self.voiced_frames.append(f) continue self.voiced_frames.append(frame) self.ring_buffer.append((frame, speech_prob)) num_voiced = sum([p for _, p in self.ring_buffer]) if num_voiced < self.end_ring_rate * self.ring_buffer.maxlen: segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp_s, self.voiced_frames[-1].timestamp_s, ] yield segment self.triggered = False self.ring_buffer.clear() self.voiced_frames = [] continue def vad_segments_generator(self, segments_generator): segments = list(segments_generator) for i, segment in enumerate(segments): start = round(segment[1], 4) end = round(segment[2], 4) if self.timestamp_start_s is None and self.timestamp_end_s is None: self.timestamp_start_s = start self.timestamp_end_s = end continue if self.timestamp_end_s - self.timestamp_start_s > self.max_speech_length_s: end_ = self.timestamp_start_s + self.max_speech_length_s vad_segment = [self.timestamp_start_s, end_] yield vad_segment self.timestamp_start_s = end_ silence_length_ms = (start - self.timestamp_end_s) * 1000 if silence_length_ms < self.max_silence_length_ms: self.timestamp_end_s = end continue if self.timestamp_end_s - self.timestamp_start_s < self.min_speech_length_s: self.timestamp_start_s = start self.timestamp_end_s = end continue vad_segment = [self.timestamp_start_s, self.timestamp_end_s] yield vad_segment self.timestamp_start_s = start self.timestamp_end_s = end def vad(self, signal: np.ndarray) -> List[list]: segments = self.segments_generator(signal) vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) return vad_segments def last_vad_segments(self) -> List[list]: # last segments if len(self.voiced_frames) == 0: segments = [] else: segment = [ np.concatenate([f.signal for f in self.voiced_frames]), self.voiced_frames[0].timestamp_s, self.voiced_frames[-1].timestamp_s ] segments = [segment] # last vad segments vad_segments = self.vad_segments_generator(segments) vad_segments = list(vad_segments) if self.timestamp_end_s > 1e-5 and self.timestamp_end_s > 1e-5: vad_segments = vad_segments + [[self.timestamp_start_s, self.timestamp_end_s]] return vad_segments def process_speech_probs(signal: np.ndarray, speech_probs: List[float], frame_step: int) -> np.ndarray: speech_probs_ = list() for p in speech_probs[1:]: speech_probs_.extend([p] * frame_step) pad = (signal.shape[0] - len(speech_probs_)) speech_probs_ = speech_probs_ + [0.0] * pad speech_probs_ = np.array(speech_probs_, dtype=np.float32) if len(speech_probs_) != len(signal): raise AssertionError return speech_probs_ def make_visualization(signal: np.ndarray, speech_probs, sample_rate: int, vad_segments: list): time = np.arange(0, len(signal)) / sample_rate plt.figure(figsize=(12, 5)) plt.plot(time, signal / 32768, color='b') plt.plot(time, speech_probs, color='gray') for start, end in vad_segments: plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--", label="开始端点") plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--", label="结束端点") plt.show() return def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--wav_file", default=(project_path / "data/early_media/62/3300999628999191096.wav").as_posix(), type=str, ) parser.add_argument( "--model_path", default=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix(), type=str, ) args = parser.parse_args() return args SAMPLE_RATE = 8000 def main(): args = get_args() sample_rate, signal = wavfile.read(args.wav_file) if SAMPLE_RATE != sample_rate: raise AssertionError # model = SileroVoiceClassifier(model_path=args.model_path, sample_rate=SAMPLE_RATE) # model = WebRTCVoiceClassifier(agg=1, sample_rate=SAMPLE_RATE) model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix()) vad = Vad(model=model, start_ring_rate=0.2, end_ring_rate=0.1, frame_length_ms=300, frame_step_ms=30, padding_length_ms=300, max_silence_length_ms=300, sample_rate=SAMPLE_RATE, ) print(vad) vad_segments = list() segments = vad.vad(signal) vad_segments += segments for segment in segments: print(segment) # last vad segment segments = vad.last_vad_segments() vad_segments += segments for segment in segments: print(segment) print(vad.speech_probs) print(len(vad.speech_probs)) # speech_probs speech_probs = process_speech_probs( signal=signal, speech_probs=vad.speech_probs, frame_step=vad.frame_step, ) # plot make_visualization(signal, speech_probs, SAMPLE_RATE, vad_segments) return if __name__ == '__main__': main()