|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import librosa |
|
import torch |
|
import numpy as np |
|
|
|
VAD_THRESHOLD = 20 |
|
SAMPLING_RATE = 16000 |
|
|
|
|
|
class SileroVAD: |
|
""" |
|
Voice Activity Detection (VAD) using Silero-VAD. |
|
""" |
|
|
|
def __init__(self, local=False, model="silero_vad", device=torch.device("cpu")): |
|
""" |
|
Initialize the VAD object. |
|
|
|
Args: |
|
local (bool, optional): Whether to load the model locally. Defaults to False. |
|
model (str, optional): The VAD model name to load. Defaults to "silero_vad". |
|
device (torch.device, optional): The device to run the model on. Defaults to 'cpu'. |
|
|
|
Returns: |
|
None |
|
|
|
Raises: |
|
RuntimeError: If loading the model fails. |
|
""" |
|
try: |
|
vad_model, utils = torch.hub.load( |
|
repo_or_dir="snakers4/silero-vad" if not local else "vad/silero-vad", |
|
model=model, |
|
force_reload=False, |
|
onnx=True, |
|
source="github" if not local else "local", |
|
) |
|
self.vad_model = vad_model |
|
(get_speech_timestamps, _, _, _, _) = utils |
|
self.get_speech_timestamps = get_speech_timestamps |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load VAD model: {e}") |
|
|
|
def segment_speech(self, audio_segment, start_time, end_time, sampling_rate): |
|
""" |
|
Segment speech from an audio segment and return a list of timestamps. |
|
|
|
Args: |
|
audio_segment (np.ndarray): The audio segment to be segmented. |
|
start_time (int): The start time of the audio segment in frames. |
|
end_time (int): The end time of the audio segment in frames. |
|
sampling_rate (int): The sampling rate of the audio segment. |
|
|
|
Returns: |
|
list: A list of timestamps, each containing the start and end times of speech segments in frames. |
|
|
|
Raises: |
|
ValueError: If the audio segment is invalid. |
|
""" |
|
if audio_segment is None or not isinstance(audio_segment, (np.ndarray, list)): |
|
raise ValueError("Invalid audio segment") |
|
|
|
speech_timestamps = self.get_speech_timestamps( |
|
audio_segment, self.vad_model, sampling_rate=sampling_rate |
|
) |
|
|
|
adjusted_timestamps = [ |
|
(ts["start"] + start_time, ts["end"] + start_time) |
|
for ts in speech_timestamps |
|
] |
|
if not adjusted_timestamps: |
|
return [] |
|
|
|
intervals = [ |
|
end[0] - start[1] |
|
for start, end in zip(adjusted_timestamps[:-1], adjusted_timestamps[1:]) |
|
] |
|
|
|
segments = [] |
|
|
|
def split_timestamps(start_index, end_index): |
|
if ( |
|
start_index == end_index |
|
or adjusted_timestamps[end_index][1] |
|
- adjusted_timestamps[start_index][0] |
|
< 20 * sampling_rate |
|
): |
|
segments.append([start_index, end_index]) |
|
else: |
|
if not intervals[start_index:end_index]: |
|
return |
|
max_interval_index = intervals[start_index:end_index].index( |
|
max(intervals[start_index:end_index]) |
|
) |
|
split_index = start_index + max_interval_index |
|
split_timestamps(start_index, split_index) |
|
split_timestamps(split_index + 1, end_index) |
|
|
|
split_timestamps(0, len(adjusted_timestamps) - 1) |
|
|
|
merged_timestamps = [ |
|
[adjusted_timestamps[start][0], adjusted_timestamps[end][1]] |
|
for start, end in segments |
|
] |
|
return merged_timestamps |
|
|
|
def vad(self, speakerdia, audio): |
|
""" |
|
Process the audio based on the given speaker diarization dataframe. |
|
|
|
Args: |
|
speakerdia (pd.DataFrame): The diarization dataframe containing start, end, and speaker info. |
|
audio (dict): A dictionary containing the audio waveform and sample rate. |
|
|
|
Returns: |
|
list: A list of dictionaries containing processed audio segments with start, end, and speaker. |
|
""" |
|
sampling_rate = audio["sample_rate"] |
|
audio_data = audio["waveform"] |
|
|
|
out = [] |
|
last_end = 0 |
|
speakers_seen = set() |
|
count_id = 0 |
|
|
|
for index, row in speakerdia.iterrows(): |
|
start = float(row["start"]) |
|
end = float(row["end"]) |
|
|
|
if end <= last_end: |
|
continue |
|
last_end = end |
|
|
|
start_frame = int(start * sampling_rate) |
|
end_frame = int(end * sampling_rate) |
|
if row["speaker"] not in speakers_seen: |
|
speakers_seen.add(row["speaker"]) |
|
|
|
if end - start <= VAD_THRESHOLD: |
|
out.append( |
|
{ |
|
"index": str(count_id).zfill(5), |
|
"start": start, |
|
"end": end, |
|
"speaker": row["speaker"], |
|
} |
|
) |
|
count_id += 1 |
|
continue |
|
|
|
temp_audio = audio_data[start_frame:end_frame] |
|
|
|
|
|
temp_audio_resampled = librosa.resample( |
|
temp_audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE |
|
) |
|
|
|
for start_frame_sub, end_frame_sub in self.segment_speech( |
|
temp_audio_resampled, |
|
int(start * SAMPLING_RATE), |
|
int(end * SAMPLING_RATE), |
|
SAMPLING_RATE, |
|
): |
|
out.append( |
|
{ |
|
"index": str(count_id).zfill(5), |
|
"start": start_frame_sub / SAMPLING_RATE, |
|
"end": end_frame_sub / SAMPLING_RATE, |
|
"speaker": row["speaker"], |
|
} |
|
) |
|
count_id += 1 |
|
|
|
return out |
|
|