|
|
|
import torch |
|
import torchaudio |
|
from typing import Callable, List |
|
import torch.nn.functional as F |
|
import warnings |
|
import pandas as pd |
|
from matplotlib import pyplot as plt |
|
|
|
def get_speech_probs(audio: torch.Tensor, |
|
|
|
sampling_rate: int = 16000, |
|
window_size_samples: int = 512, |
|
progress_tracking_callback: Callable[[float], None] = None): |
|
if not torch.is_tensor(audio): |
|
try: |
|
audio = torch.Tensor(audio) |
|
except: |
|
raise TypeError("Audio cannot be casted to tensor. Cast it manually") |
|
|
|
if len(audio.shape) > 1: |
|
for i in range(len(audio.shape)): |
|
audio = audio.squeeze(0) |
|
if len(audio.shape) > 1: |
|
raise ValueError("More than one dimension in audio. Are you trying to process audio with 2 channels?") |
|
|
|
if sampling_rate > 16000 and (sampling_rate % 16000 == 0): |
|
step = sampling_rate // 16000 |
|
sampling_rate = 16000 |
|
audio = audio[::step] |
|
warnings.warn('Sampling rate is a multiply of 16000, casting to 16000 manually!') |
|
else: |
|
step = 1 |
|
|
|
if sampling_rate == 8000 and window_size_samples > 768: |
|
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!') |
|
if window_size_samples not in [256, 512, 768, 1024, 1536]: |
|
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate') |
|
|
|
model.reset_states() |
|
|
|
audio_length_samples = len(audio) |
|
|
|
speech_probs = [] |
|
for current_start_sample in range(0, audio_length_samples, window_size_samples): |
|
chunk = audio[current_start_sample: current_start_sample + window_size_samples] |
|
if len(chunk) < window_size_samples: |
|
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk)))) |
|
speech_prob = model(chunk, sampling_rate).item() |
|
speech_probs.append(speech_prob) |
|
|
|
progress = current_start_sample + window_size_samples |
|
if progress > audio_length_samples: |
|
progress = audio_length_samples |
|
progress_percent = (progress / audio_length_samples) * 100 |
|
if progress_tracking_callback: |
|
progress_tracking_callback(progress_percent) |
|
return speech_probs |
|
|
|
def probs2speech_timestamps(speech_probs, audio_length_samples, |
|
threshold: float = 0.5, |
|
sampling_rate: int = 16000, |
|
min_speech_duration_ms: int = 250, |
|
max_speech_duration_s: float = float('inf'), |
|
min_silence_duration_ms: int = 100, |
|
window_size_samples: int = 512, |
|
speech_pad_ms: int = 30, |
|
return_seconds: bool = True, |
|
rounding: int = 1,): |
|
|
|
step = sampling_rate // 16000 |
|
|
|
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000 |
|
speech_pad_samples = sampling_rate * speech_pad_ms / 1000 |
|
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples |
|
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 |
|
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 |
|
|
|
triggered = False |
|
speeches = [] |
|
current_speech = {} |
|
neg_threshold = threshold - 0.15 |
|
temp_end = 0 |
|
prev_end = next_start = 0 |
|
|
|
for i, speech_prob in enumerate(speech_probs): |
|
if (speech_prob >= threshold) and temp_end: |
|
temp_end = 0 |
|
if next_start < prev_end: |
|
next_start = window_size_samples * i |
|
|
|
if (speech_prob >= threshold) and not triggered: |
|
triggered = True |
|
current_speech['start'] = window_size_samples * i |
|
continue |
|
|
|
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples: |
|
if prev_end: |
|
current_speech['end'] = prev_end |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
if next_start < prev_end: |
|
triggered = False |
|
else: |
|
current_speech['start'] = next_start |
|
prev_end = next_start = temp_end = 0 |
|
else: |
|
current_speech['end'] = window_size_samples * i |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if (speech_prob < neg_threshold) and triggered: |
|
if not temp_end: |
|
temp_end = window_size_samples * i |
|
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech : |
|
prev_end = temp_end |
|
if (window_size_samples * i) - temp_end < min_silence_samples: |
|
continue |
|
else: |
|
current_speech['end'] = temp_end |
|
if (current_speech['end'] - current_speech['start']) > min_speech_samples: |
|
speeches.append(current_speech) |
|
current_speech = {} |
|
prev_end = next_start = temp_end = 0 |
|
triggered = False |
|
continue |
|
|
|
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples: |
|
current_speech['end'] = audio_length_samples |
|
speeches.append(current_speech) |
|
|
|
for i, speech in enumerate(speeches): |
|
if i == 0: |
|
speech['start'] = int(max(0, speech['start'] - speech_pad_samples)) |
|
if i != len(speeches) - 1: |
|
silence_duration = speeches[i+1]['start'] - speech['end'] |
|
if silence_duration < 2 * speech_pad_samples: |
|
speech['end'] += int(silence_duration // 2) |
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - silence_duration // 2)) |
|
else: |
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) |
|
speeches[i+1]['start'] = int(max(0, speeches[i+1]['start'] - speech_pad_samples)) |
|
else: |
|
speech['end'] = int(min(audio_length_samples, speech['end'] + speech_pad_samples)) |
|
|
|
if return_seconds: |
|
for speech_dict in speeches: |
|
speech_dict['start'] = round(speech_dict['start'] / sampling_rate, rounding) |
|
speech_dict['end'] = round(speech_dict['end'] / sampling_rate, rounding) |
|
elif step > 1: |
|
for speech_dict in speeches: |
|
speech_dict['start'] *= step |
|
speech_dict['end'] *= step |
|
return speeches |
|
|
|
def make_visualization(probs, step): |
|
fig, ax = plt.subplots(figsize=(16, 8),) |
|
|
|
pd.DataFrame({'probs': probs}, |
|
index=[x * step for x in range(len(probs))]).plot(ax = ax, |
|
kind='area', ylim=[0, 1.05], xlim=[0, len(probs) * step], |
|
xlabel='seconds', |
|
ylabel='speech probability', |
|
colormap='tab20') |
|
return fig |
|
|
|
torch.set_num_threads(1) |
|
|
|
|
|
USE_ONNX = True |
|
|
|
|
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', |
|
model='silero_vad', |
|
|
|
onnx=USE_ONNX) |
|
(_, |
|
_, read_audio, |
|
*_) = utils |