|
import torch
|
|
from speechbrain.inference.interfaces import Pretrained
|
|
import torchaudio
|
|
|
|
|
|
def merge_overlapping_segments(segments):
|
|
"""
|
|
Merges segments that overlap or are contiguous, ensuring each speaker segment is represented once.
|
|
|
|
Args:
|
|
segments (list of tuples): List of tuples representing (start, end, label) of segments.
|
|
|
|
Returns:
|
|
list of tuples: Merged list of segments.
|
|
"""
|
|
if not segments:
|
|
return []
|
|
merged = [segments[0]]
|
|
for current in segments[1:]:
|
|
prev = merged[-1]
|
|
if current[0] <= prev[1]:
|
|
if current[2] == prev[2]:
|
|
merged[-1] = (prev[0], max(prev[1], current[1]), prev[2])
|
|
else:
|
|
merged.append(current)
|
|
else:
|
|
merged.append(current)
|
|
return merged
|
|
|
|
|
|
def refine_transitions(aggregated_predictions):
|
|
"""
|
|
Refines transitions between speaker segments to enhance accuracy.
|
|
|
|
Args:
|
|
aggregated_predictions (list of tuples): The aggregated predictions with potential overlaps.
|
|
|
|
Returns:
|
|
list of tuples: Predictions with adjusted transitions.
|
|
"""
|
|
refined_predictions = []
|
|
for i in range(len(aggregated_predictions)):
|
|
if i == 0:
|
|
refined_predictions.append(aggregated_predictions[i])
|
|
continue
|
|
|
|
current_start, current_end, current_label = aggregated_predictions[i]
|
|
prev_start, prev_end, prev_label = aggregated_predictions[i - 1]
|
|
|
|
if current_start - prev_end <= 1.0:
|
|
new_start = prev_end
|
|
else:
|
|
new_start = current_start
|
|
|
|
refined_predictions.append((new_start, current_end, current_label))
|
|
|
|
return refined_predictions
|
|
|
|
|
|
def refine_transitions_with_confidence(aggregated_predictions, segment_confidences):
|
|
"""
|
|
Refines transitions between segments based on confidence levels.
|
|
|
|
Args:
|
|
aggregated_predictions (list of tuples): Initial aggregated predictions.
|
|
segment_confidences (list of float): Confidence scores corresponding to each segment.
|
|
|
|
Returns:
|
|
list of tuples: Refined segment predictions.
|
|
"""
|
|
refined_predictions = []
|
|
for i in range(len(aggregated_predictions)):
|
|
if i == 0:
|
|
refined_predictions.append(aggregated_predictions[i])
|
|
continue
|
|
|
|
current_start, current_end, current_label = aggregated_predictions[i]
|
|
prev_start, prev_end, prev_label, prev_confidence = refined_predictions[-1] + (segment_confidences[i - 1],)
|
|
|
|
current_confidence = segment_confidences[i]
|
|
|
|
if current_label != prev_label:
|
|
if prev_confidence < current_confidence:
|
|
transition_point = current_start
|
|
else:
|
|
transition_point = prev_end
|
|
refined_predictions[-1] = (prev_start, transition_point, prev_label)
|
|
refined_predictions.append((transition_point, current_end, current_label))
|
|
else:
|
|
if prev_confidence < current_confidence:
|
|
refined_predictions[-1] = (prev_start, current_end, current_label)
|
|
else:
|
|
refined_predictions.append((current_start, current_end, current_label))
|
|
|
|
return refined_predictions
|
|
|
|
|
|
def aggregate_segments_with_overlap(segment_predictions):
|
|
"""
|
|
Aggregates overlapping segments into single segments based on speaker labels.
|
|
|
|
Args:
|
|
segment_predictions (list of tuples): List of tuples representing (start, end, label) of segments.
|
|
|
|
Returns:
|
|
list of tuples: Aggregated segments.
|
|
"""
|
|
aggregated_predictions = []
|
|
last_start, last_end, last_label = segment_predictions[0]
|
|
|
|
for start, end, label in segment_predictions[1:]:
|
|
if label == last_label and start <= last_end:
|
|
last_end = max(last_end, end)
|
|
else:
|
|
aggregated_predictions.append((last_start, last_end, last_label))
|
|
last_start, last_end, last_label = start, end, label
|
|
|
|
aggregated_predictions.append((last_start, last_end, last_label))
|
|
|
|
merged = merge_overlapping_segments(aggregated_predictions)
|
|
return merged
|
|
|
|
|
|
class SpeakerCounter(Pretrained):
|
|
"""
|
|
A class for counting speakers in an audio file, built upon the SpeechBrain Pretrained class.
|
|
This class integrates several preprocessing and prediction modules to handle speaker diarization tasks.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Initialize the SpeakerCounter with standard and custom parameters.
|
|
Args:
|
|
*args: Variable length argument list.
|
|
**kwargs: Arbitrary keyword arguments.
|
|
"""
|
|
super().__init__(*args, **kwargs)
|
|
self.sample_rate = self.hparams.sample_rate
|
|
|
|
MODULES_NEEDED = [
|
|
"compute_features",
|
|
"mean_var_norm",
|
|
"embedding_model",
|
|
"classifier",
|
|
]
|
|
|
|
def resample_waveform(self, waveform, orig_sample_rate):
|
|
"""
|
|
Resamples the input waveform to the target sample rate specified in the object.
|
|
|
|
Args:
|
|
waveform (Tensor): The input waveform tensor.
|
|
orig_sample_rate (int): The original sample rate of the waveform.
|
|
|
|
Returns:
|
|
Tensor: The resampled waveform.
|
|
"""
|
|
if orig_sample_rate != self.sample_rate:
|
|
resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=self.sample_rate)
|
|
waveform = resample_transform(waveform)
|
|
return waveform
|
|
|
|
def encode_batch(self, wavs, wav_lens=None):
|
|
"""
|
|
Encodes a batch of waveforms into embeddings using the loaded models.
|
|
|
|
Args:
|
|
wavs (Tensor): Batch of waveforms.
|
|
wav_lens (Tensor, optional): Lengths of the waveforms for normalization.
|
|
|
|
Returns:
|
|
Tensor: Batch of embeddings.
|
|
"""
|
|
if len(wavs.shape) == 1:
|
|
wavs = wavs.unsqueeze(0)
|
|
|
|
if wav_lens is None:
|
|
wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
|
|
|
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
|
wavs = wavs.float()
|
|
|
|
|
|
feats = self.mods.compute_features(wavs)
|
|
feats = self.mods.mean_var_norm(feats, wav_lens)
|
|
embeddings = self.mods.embedding_model(feats, wav_lens)
|
|
return embeddings
|
|
|
|
def create_segments(self, waveform, segment_length, overlap):
|
|
"""
|
|
Creates segments from a single waveform for batch processing.
|
|
|
|
Args:
|
|
waveform (Tensor): Input waveform tensor.
|
|
segment_length (float): Length of each segment in seconds.
|
|
overlap (float): Overlap between segments in seconds.
|
|
|
|
Returns:
|
|
tuple: (segments, segment_times) where segments is a list of tensors, and segment_times
|
|
is a list of (start, end) times.
|
|
"""
|
|
num_samples = waveform.shape[1]
|
|
segment_samples = int(segment_length * self.sample_rate)
|
|
overlap_samples = int(overlap * self.sample_rate)
|
|
step_samples = segment_samples - overlap_samples
|
|
segments = []
|
|
segment_times = []
|
|
|
|
for start in range(0, num_samples - segment_samples + 1, step_samples):
|
|
end = start + segment_samples
|
|
segments.append(waveform[:, start:end])
|
|
start_time = start / self.sample_rate
|
|
end_time = end / self.sample_rate
|
|
segment_times.append((start_time, end_time))
|
|
|
|
return segments, segment_times
|
|
|
|
def classify_file(self, path, segment_length=2.0, overlap=1.47):
|
|
"""
|
|
Processes an audio file to classify and count speakers within segments.
|
|
Utilizes multiple stages of processing to handle overlapping speech and transitions.
|
|
|
|
Args:
|
|
path (str): Path to the audio file.
|
|
segment_length (float): Length of each segment in seconds.
|
|
overlap (float): Overlap between segments in seconds.
|
|
|
|
Outputs:
|
|
Writes the number of speakers in each segment to a text file.
|
|
"""
|
|
waveform, osr = torchaudio.load(path)
|
|
waveform = self.resample_waveform(waveform, osr)
|
|
|
|
segments, segment_times = self.create_segments(waveform, segment_length, overlap)
|
|
segment_predictions = []
|
|
|
|
for segment, (start_time, end_time) in zip(segments, segment_times):
|
|
rel_length = torch.tensor([1.0])
|
|
emb = self.encode_batch(segment, rel_length)
|
|
out_prob = self.mods.classifier(emb).squeeze(1)
|
|
score, index = torch.max(out_prob, dim=-1)
|
|
text_lab = index.item()
|
|
segment_predictions.append((start_time, end_time, text_lab))
|
|
|
|
aggregated_predictions = aggregate_segments_with_overlap(segment_predictions)
|
|
refined_predictions = refine_transitions(aggregated_predictions)
|
|
preds = refine_transitions_with_confidence(aggregated_predictions, refined_predictions)
|
|
|
|
with open("sample_segment_predictions.txt", "w") as file:
|
|
for start_time, end_time, prediction in preds:
|
|
speaker_text = "no speech" if str(prediction) == "0" else (
|
|
"1 speaker" if str(prediction) == "1" else f"{prediction} speakers")
|
|
print(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}")
|
|
file.write(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}\n")
|
|
|
|
def forward(self, wavs, wav_lens=None):
|
|
"""
|
|
Forward pass for classifying audio using preloaded modules.
|
|
|
|
Args:
|
|
wavs (Tensor): Input waveforms.
|
|
wav_lens (Tensor, optional): Lengths of the input waveforms.
|
|
|
|
Returns:
|
|
Output from classify_file method.
|
|
"""
|
|
return self.classify_file(wavs, wav_lens)
|
|
|