File size: 6,287 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# Source: https://github.com/snakers4/silero-vad
#
# Copyright (c) 2024 snakers4
#
# This code is from a MIT-licensed repository. The full license text is available at the root of the source repository.
#
# Note: This code has been modified to fit the context of this repository.
import librosa
import torch
import numpy as np
VAD_THRESHOLD = 20
SAMPLING_RATE = 16000
class SileroVAD:
"""
Voice Activity Detection (VAD) using Silero-VAD.
"""
def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")):
"""
Initialize the VAD object.
Args:
local (bool, optional): Whether to load the model locally. Defaults to False.
model (str, optional): The VAD model name to load. Defaults to "silero_vad".
device (torch.device, optional): The device to run the model on. Defaults to 'cpu'.
Returns:
None
Raises:
RuntimeError: If loading the model fails.
"""
try:
vad_model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad",
model=model,
force_reload=False,
onnx=True,
source="github" if not local else "local",
)
self.vad_model = vad_model
(get_speech_timestamps, _, _, _, _) = utils
self.get_speech_timestamps = get_speech_timestamps
except Exception as e:
raise RuntimeError(f"Failed to load VAD model: {e}")
def segment_speech(self, audio_segment, start_time, end_time, sampling_rate):
"""
Segment speech from an audio segment and return a list of timestamps.
Args:
audio_segment (np.ndarray): The audio segment to be segmented.
start_time (int): The start time of the audio segment in frames.
end_time (int): The end time of the audio segment in frames.
sampling_rate (int): The sampling rate of the audio segment.
Returns:
list: A list of timestamps, each containing the start and end times of speech segments in frames.
Raises:
ValueError: If the audio segment is invalid.
"""
if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)):
raise ValueError("Invalid audio segment")
speech_timestamps = self.get_speech_timestamps(
audio_segment, self.vad_model, sampling_rate=sampling_rate
)
adjusted_timestamps = [
(ts["start"] + start_time, ts["end"] + start_time)
for ts in speech_timestamps
]
if not adjusted_timestamps:
return []
intervals = [
end[0] - start[1]
for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:])
]
segments = []
def split_timestamps(start_index, end_index):
if (
start_index == end_index
or adjusted_timestamps[end_index][1]
- adjusted_timestamps[start_index][0]
< 20 * sampling_rate
):
segments.append([start_index, end_index])
else:
if not intervals[start_index:end_index]:
return
max_interval_index = intervals[start_index:end_index].index(
max(intervals[start_index:end_index])
)
split_index = start_index + max_interval_index
split_timestamps(start_index, split_index)
split_timestamps(split_index + 1, end_index)
split_timestamps(0, len(adjusted_timestamps) - 1)
merged_timestamps = [
[adjusted_timestamps[start][0], adjusted_timestamps[end][1]]
for start, end in segments
]
return merged_timestamps
def vad(self, speakerdia, audio):
"""
Process the audio based on the given speaker diarization dataframe.
Args:
speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info.
audio (dict): A dictionary containing the audio waveform and sample rate.
Returns:
list: A list of dictionaries containing processed audio segments with start, end, and speaker.
"""
sampling_rate = audio["sample_rate"]
audio_data = audio["waveform"]
out = []
last_end = 0
speakers_seen = set()
count_id = 0
for index, row in speakerdia.iterrows():
start = float(row["start"])
end = float(row["end"])
if end <= last_end:
continue
last_end = end
start_frame = int(start * sampling_rate)
end_frame = int(end * sampling_rate)
if row["speaker"] not in speakers_seen:
speakers_seen.add(row["speaker"])
if end - start <= VAD_THRESHOLD:
out.append(
{
"index": str(count_id).zfill(5),
"start": start, # in seconds
"end": end,
"speaker": row["speaker"], # same for all
}
)
count_id += 1
continue
temp_audio = audio_data[start_frame:end_frame]
# resample from 24k to 16k
temp_audio_resampled = librosa.resample(
temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE
)
for start_frame_sub, end_frame_sub in self.segment_speech(
temp_audio_resampled,
int(start * SAMPLING_RATE),
int(end * SAMPLING_RATE),
SAMPLING_RATE,
):
out.append(
{
"index": str(count_id).zfill(5),
"start": start_frame_sub / SAMPLING_RATE, # in seconds
"end": end_frame_sub / SAMPLING_RATE,
"speaker": row["speaker"], # same for all
}
)
count_id += 1
return out
|