Commit
•
4279f1f
1
Parent(s):
89e3c59
SpeakerCounter
Browse files- SpeakerCounter.py +267 -0
SpeakerCounter.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from speechbrain.inference.interfaces import Pretrained
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
|
6 |
+
def merge_overlapping_segments(segments):
|
7 |
+
"""
|
8 |
+
Merges segments that overlap or are contiguous, ensuring each speaker segment is represented once.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
segments (list of tuples): List of tuples representing (start, end, label) of segments.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
list of tuples: Merged list of segments.
|
15 |
+
"""
|
16 |
+
if not segments:
|
17 |
+
return []
|
18 |
+
merged = [segments[0]]
|
19 |
+
for current in segments[1:]:
|
20 |
+
prev = merged[-1]
|
21 |
+
if current[0] <= prev[1]:
|
22 |
+
if current[2] == prev[2]:
|
23 |
+
merged[-1] = (prev[0], max(prev[1], current[1]), prev[2])
|
24 |
+
else:
|
25 |
+
merged.append(current)
|
26 |
+
else:
|
27 |
+
merged.append(current)
|
28 |
+
return merged
|
29 |
+
|
30 |
+
|
31 |
+
def refine_transitions(aggregated_predictions):
|
32 |
+
"""
|
33 |
+
Refines transitions between speaker segments to enhance accuracy.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
aggregated_predictions (list of tuples): The aggregated predictions with potential overlaps.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
list of tuples: Predictions with adjusted transitions.
|
40 |
+
"""
|
41 |
+
refined_predictions = []
|
42 |
+
for i in range(len(aggregated_predictions)):
|
43 |
+
if i == 0:
|
44 |
+
refined_predictions.append(aggregated_predictions[i])
|
45 |
+
continue
|
46 |
+
|
47 |
+
current_start, current_end, current_label = aggregated_predictions[i]
|
48 |
+
prev_start, prev_end, prev_label = aggregated_predictions[i - 1]
|
49 |
+
|
50 |
+
if current_start - prev_end <= 1.0:
|
51 |
+
new_start = prev_end
|
52 |
+
else:
|
53 |
+
new_start = current_start
|
54 |
+
|
55 |
+
refined_predictions.append((new_start, current_end, current_label))
|
56 |
+
|
57 |
+
return refined_predictions
|
58 |
+
|
59 |
+
|
60 |
+
def refine_transitions_with_confidence(aggregated_predictions, segment_confidences):
|
61 |
+
"""
|
62 |
+
Refines transitions between segments based on confidence levels.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
aggregated_predictions (list of tuples): Initial aggregated predictions.
|
66 |
+
segment_confidences (list of float): Confidence scores corresponding to each segment.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
list of tuples: Refined segment predictions.
|
70 |
+
"""
|
71 |
+
refined_predictions = []
|
72 |
+
for i in range(len(aggregated_predictions)):
|
73 |
+
if i == 0:
|
74 |
+
refined_predictions.append(aggregated_predictions[i])
|
75 |
+
continue
|
76 |
+
|
77 |
+
current_start, current_end, current_label = aggregated_predictions[i]
|
78 |
+
prev_start, prev_end, prev_label, prev_confidence = refined_predictions[-1] + (segment_confidences[i - 1],)
|
79 |
+
|
80 |
+
current_confidence = segment_confidences[i]
|
81 |
+
|
82 |
+
if current_label != prev_label:
|
83 |
+
if prev_confidence < current_confidence:
|
84 |
+
transition_point = current_start
|
85 |
+
else:
|
86 |
+
transition_point = prev_end
|
87 |
+
refined_predictions[-1] = (prev_start, transition_point, prev_label)
|
88 |
+
refined_predictions.append((transition_point, current_end, current_label))
|
89 |
+
else:
|
90 |
+
if prev_confidence < current_confidence:
|
91 |
+
refined_predictions[-1] = (prev_start, current_end, current_label)
|
92 |
+
else:
|
93 |
+
refined_predictions.append((current_start, current_end, current_label))
|
94 |
+
|
95 |
+
return refined_predictions
|
96 |
+
|
97 |
+
|
98 |
+
def aggregate_segments_with_overlap(segment_predictions):
|
99 |
+
"""
|
100 |
+
Aggregates overlapping segments into single segments based on speaker labels.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
segment_predictions (list of tuples): List of tuples representing (start, end, label) of segments.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
list of tuples: Aggregated segments.
|
107 |
+
"""
|
108 |
+
aggregated_predictions = []
|
109 |
+
last_start, last_end, last_label = segment_predictions[0]
|
110 |
+
|
111 |
+
for start, end, label in segment_predictions[1:]:
|
112 |
+
if label == last_label and start <= last_end:
|
113 |
+
last_end = max(last_end, end)
|
114 |
+
else:
|
115 |
+
aggregated_predictions.append((last_start, last_end, last_label))
|
116 |
+
last_start, last_end, last_label = start, end, label
|
117 |
+
|
118 |
+
aggregated_predictions.append((last_start, last_end, last_label))
|
119 |
+
|
120 |
+
merged = merge_overlapping_segments(aggregated_predictions)
|
121 |
+
return merged
|
122 |
+
|
123 |
+
|
124 |
+
class SpeakerCounter(Pretrained):
|
125 |
+
"""
|
126 |
+
A class for counting speakers in an audio file, built upon the SpeechBrain Pretrained class.
|
127 |
+
This class integrates several preprocessing and prediction modules to handle speaker diarization tasks.
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __init__(self, *args, **kwargs):
|
131 |
+
"""
|
132 |
+
Initialize the SpeakerCounter with standard and custom parameters.
|
133 |
+
Args:
|
134 |
+
*args: Variable length argument list.
|
135 |
+
**kwargs: Arbitrary keyword arguments.
|
136 |
+
"""
|
137 |
+
super().__init__(*args, **kwargs)
|
138 |
+
self.sample_rate = self.hparams.sample_rate
|
139 |
+
|
140 |
+
MODULES_NEEDED = [
|
141 |
+
"compute_features",
|
142 |
+
"mean_var_norm",
|
143 |
+
"embedding_model",
|
144 |
+
"classifier",
|
145 |
+
]
|
146 |
+
|
147 |
+
def resample_waveform(self, waveform, orig_sample_rate):
|
148 |
+
"""
|
149 |
+
Resamples the input waveform to the target sample rate specified in the object.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
waveform (Tensor): The input waveform tensor.
|
153 |
+
orig_sample_rate (int): The original sample rate of the waveform.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor: The resampled waveform.
|
157 |
+
"""
|
158 |
+
if orig_sample_rate != self.sample_rate:
|
159 |
+
resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=self.sample_rate)
|
160 |
+
waveform = resample_transform(waveform)
|
161 |
+
return waveform
|
162 |
+
|
163 |
+
def encode_batch(self, wavs, wav_lens=None):
|
164 |
+
"""
|
165 |
+
Encodes a batch of waveforms into embeddings using the loaded models.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
wavs (Tensor): Batch of waveforms.
|
169 |
+
wav_lens (Tensor, optional): Lengths of the waveforms for normalization.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
Tensor: Batch of embeddings.
|
173 |
+
"""
|
174 |
+
if len(wavs.shape) == 1:
|
175 |
+
wavs = wavs.unsqueeze(0)
|
176 |
+
|
177 |
+
if wav_lens is None:
|
178 |
+
wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
179 |
+
|
180 |
+
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
181 |
+
wavs = wavs.float()
|
182 |
+
|
183 |
+
# Computing features and embeddings
|
184 |
+
feats = self.mods.compute_features(wavs)
|
185 |
+
feats = self.mods.mean_var_norm(feats, wav_lens)
|
186 |
+
embeddings = self.mods.embedding_model(feats, wav_lens)
|
187 |
+
return embeddings
|
188 |
+
|
189 |
+
def create_segments(self, waveform, segment_length, overlap):
|
190 |
+
"""
|
191 |
+
Creates segments from a single waveform for batch processing.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
waveform (Tensor): Input waveform tensor.
|
195 |
+
segment_length (float): Length of each segment in seconds.
|
196 |
+
overlap (float): Overlap between segments in seconds.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
tuple: (segments, segment_times) where segments is a list of tensors, and segment_times
|
200 |
+
is a list of (start, end) times.
|
201 |
+
"""
|
202 |
+
num_samples = waveform.shape[1]
|
203 |
+
segment_samples = int(segment_length * self.sample_rate)
|
204 |
+
overlap_samples = int(overlap * self.sample_rate)
|
205 |
+
step_samples = segment_samples - overlap_samples
|
206 |
+
segments = []
|
207 |
+
segment_times = []
|
208 |
+
|
209 |
+
for start in range(0, num_samples - segment_samples + 1, step_samples):
|
210 |
+
end = start + segment_samples
|
211 |
+
segments.append(waveform[:, start:end])
|
212 |
+
start_time = start / self.sample_rate
|
213 |
+
end_time = end / self.sample_rate
|
214 |
+
segment_times.append((start_time, end_time))
|
215 |
+
|
216 |
+
return segments, segment_times
|
217 |
+
|
218 |
+
def classify_file(self, path, segment_length=2.0, overlap=1.47):
|
219 |
+
"""
|
220 |
+
Processes an audio file to classify and count speakers within segments.
|
221 |
+
Utilizes multiple stages of processing to handle overlapping speech and transitions.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
path (str): Path to the audio file.
|
225 |
+
segment_length (float): Length of each segment in seconds.
|
226 |
+
overlap (float): Overlap between segments in seconds.
|
227 |
+
|
228 |
+
Outputs:
|
229 |
+
Writes the number of speakers in each segment to a text file.
|
230 |
+
"""
|
231 |
+
waveform, osr = torchaudio.load(path)
|
232 |
+
waveform = self.resample_waveform(waveform, osr)
|
233 |
+
|
234 |
+
segments, segment_times = self.create_segments(waveform, segment_length, overlap)
|
235 |
+
segment_predictions = []
|
236 |
+
|
237 |
+
for segment, (start_time, end_time) in zip(segments, segment_times):
|
238 |
+
rel_length = torch.tensor([1.0])
|
239 |
+
emb = self.encode_batch(segment, rel_length)
|
240 |
+
out_prob = self.mods.classifier(emb).squeeze(1)
|
241 |
+
score, index = torch.max(out_prob, dim=-1)
|
242 |
+
text_lab = index.item()
|
243 |
+
segment_predictions.append((start_time, end_time, text_lab))
|
244 |
+
|
245 |
+
aggregated_predictions = aggregate_segments_with_overlap(segment_predictions)
|
246 |
+
refined_predictions = refine_transitions(aggregated_predictions)
|
247 |
+
preds = refine_transitions_with_confidence(aggregated_predictions, refined_predictions)
|
248 |
+
|
249 |
+
with open("sample_segment_predictions.txt", "w") as file:
|
250 |
+
for start_time, end_time, prediction in preds:
|
251 |
+
speaker_text = "no speech" if str(prediction) == "0" else (
|
252 |
+
"1 speaker" if str(prediction) == "1" else f"{prediction} speakers")
|
253 |
+
print(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}")
|
254 |
+
file.write(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}\n")
|
255 |
+
|
256 |
+
def forward(self, wavs, wav_lens=None):
|
257 |
+
"""
|
258 |
+
Forward pass for classifying audio using preloaded modules.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
wavs (Tensor): Input waveforms.
|
262 |
+
wav_lens (Tensor, optional): Lengths of the input waveforms.
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Output from classify_file method.
|
266 |
+
"""
|
267 |
+
return self.classify_file(wavs, wav_lens)
|