speechbrain
English
Supradeepdan commited on
Commit
4279f1f
1 Parent(s): 89e3c59

SpeakerCounter

Browse files
Files changed (1) hide show
  1. SpeakerCounter.py +267 -0
SpeakerCounter.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.inference.interfaces import Pretrained
3
+ import torchaudio
4
+
5
+
6
+ def merge_overlapping_segments(segments):
7
+ """
8
+ Merges segments that overlap or are contiguous, ensuring each speaker segment is represented once.
9
+
10
+ Args:
11
+ segments (list of tuples): List of tuples representing (start, end, label) of segments.
12
+
13
+ Returns:
14
+ list of tuples: Merged list of segments.
15
+ """
16
+ if not segments:
17
+ return []
18
+ merged = [segments[0]]
19
+ for current in segments[1:]:
20
+ prev = merged[-1]
21
+ if current[0] <= prev[1]:
22
+ if current[2] == prev[2]:
23
+ merged[-1] = (prev[0], max(prev[1], current[1]), prev[2])
24
+ else:
25
+ merged.append(current)
26
+ else:
27
+ merged.append(current)
28
+ return merged
29
+
30
+
31
+ def refine_transitions(aggregated_predictions):
32
+ """
33
+ Refines transitions between speaker segments to enhance accuracy.
34
+
35
+ Args:
36
+ aggregated_predictions (list of tuples): The aggregated predictions with potential overlaps.
37
+
38
+ Returns:
39
+ list of tuples: Predictions with adjusted transitions.
40
+ """
41
+ refined_predictions = []
42
+ for i in range(len(aggregated_predictions)):
43
+ if i == 0:
44
+ refined_predictions.append(aggregated_predictions[i])
45
+ continue
46
+
47
+ current_start, current_end, current_label = aggregated_predictions[i]
48
+ prev_start, prev_end, prev_label = aggregated_predictions[i - 1]
49
+
50
+ if current_start - prev_end <= 1.0:
51
+ new_start = prev_end
52
+ else:
53
+ new_start = current_start
54
+
55
+ refined_predictions.append((new_start, current_end, current_label))
56
+
57
+ return refined_predictions
58
+
59
+
60
+ def refine_transitions_with_confidence(aggregated_predictions, segment_confidences):
61
+ """
62
+ Refines transitions between segments based on confidence levels.
63
+
64
+ Args:
65
+ aggregated_predictions (list of tuples): Initial aggregated predictions.
66
+ segment_confidences (list of float): Confidence scores corresponding to each segment.
67
+
68
+ Returns:
69
+ list of tuples: Refined segment predictions.
70
+ """
71
+ refined_predictions = []
72
+ for i in range(len(aggregated_predictions)):
73
+ if i == 0:
74
+ refined_predictions.append(aggregated_predictions[i])
75
+ continue
76
+
77
+ current_start, current_end, current_label = aggregated_predictions[i]
78
+ prev_start, prev_end, prev_label, prev_confidence = refined_predictions[-1] + (segment_confidences[i - 1],)
79
+
80
+ current_confidence = segment_confidences[i]
81
+
82
+ if current_label != prev_label:
83
+ if prev_confidence < current_confidence:
84
+ transition_point = current_start
85
+ else:
86
+ transition_point = prev_end
87
+ refined_predictions[-1] = (prev_start, transition_point, prev_label)
88
+ refined_predictions.append((transition_point, current_end, current_label))
89
+ else:
90
+ if prev_confidence < current_confidence:
91
+ refined_predictions[-1] = (prev_start, current_end, current_label)
92
+ else:
93
+ refined_predictions.append((current_start, current_end, current_label))
94
+
95
+ return refined_predictions
96
+
97
+
98
+ def aggregate_segments_with_overlap(segment_predictions):
99
+ """
100
+ Aggregates overlapping segments into single segments based on speaker labels.
101
+
102
+ Args:
103
+ segment_predictions (list of tuples): List of tuples representing (start, end, label) of segments.
104
+
105
+ Returns:
106
+ list of tuples: Aggregated segments.
107
+ """
108
+ aggregated_predictions = []
109
+ last_start, last_end, last_label = segment_predictions[0]
110
+
111
+ for start, end, label in segment_predictions[1:]:
112
+ if label == last_label and start <= last_end:
113
+ last_end = max(last_end, end)
114
+ else:
115
+ aggregated_predictions.append((last_start, last_end, last_label))
116
+ last_start, last_end, last_label = start, end, label
117
+
118
+ aggregated_predictions.append((last_start, last_end, last_label))
119
+
120
+ merged = merge_overlapping_segments(aggregated_predictions)
121
+ return merged
122
+
123
+
124
+ class SpeakerCounter(Pretrained):
125
+ """
126
+ A class for counting speakers in an audio file, built upon the SpeechBrain Pretrained class.
127
+ This class integrates several preprocessing and prediction modules to handle speaker diarization tasks.
128
+ """
129
+
130
+ def __init__(self, *args, **kwargs):
131
+ """
132
+ Initialize the SpeakerCounter with standard and custom parameters.
133
+ Args:
134
+ *args: Variable length argument list.
135
+ **kwargs: Arbitrary keyword arguments.
136
+ """
137
+ super().__init__(*args, **kwargs)
138
+ self.sample_rate = self.hparams.sample_rate
139
+
140
+ MODULES_NEEDED = [
141
+ "compute_features",
142
+ "mean_var_norm",
143
+ "embedding_model",
144
+ "classifier",
145
+ ]
146
+
147
+ def resample_waveform(self, waveform, orig_sample_rate):
148
+ """
149
+ Resamples the input waveform to the target sample rate specified in the object.
150
+
151
+ Args:
152
+ waveform (Tensor): The input waveform tensor.
153
+ orig_sample_rate (int): The original sample rate of the waveform.
154
+
155
+ Returns:
156
+ Tensor: The resampled waveform.
157
+ """
158
+ if orig_sample_rate != self.sample_rate:
159
+ resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=self.sample_rate)
160
+ waveform = resample_transform(waveform)
161
+ return waveform
162
+
163
+ def encode_batch(self, wavs, wav_lens=None):
164
+ """
165
+ Encodes a batch of waveforms into embeddings using the loaded models.
166
+
167
+ Args:
168
+ wavs (Tensor): Batch of waveforms.
169
+ wav_lens (Tensor, optional): Lengths of the waveforms for normalization.
170
+
171
+ Returns:
172
+ Tensor: Batch of embeddings.
173
+ """
174
+ if len(wavs.shape) == 1:
175
+ wavs = wavs.unsqueeze(0)
176
+
177
+ if wav_lens is None:
178
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
179
+
180
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
181
+ wavs = wavs.float()
182
+
183
+ # Computing features and embeddings
184
+ feats = self.mods.compute_features(wavs)
185
+ feats = self.mods.mean_var_norm(feats, wav_lens)
186
+ embeddings = self.mods.embedding_model(feats, wav_lens)
187
+ return embeddings
188
+
189
+ def create_segments(self, waveform, segment_length, overlap):
190
+ """
191
+ Creates segments from a single waveform for batch processing.
192
+
193
+ Args:
194
+ waveform (Tensor): Input waveform tensor.
195
+ segment_length (float): Length of each segment in seconds.
196
+ overlap (float): Overlap between segments in seconds.
197
+
198
+ Returns:
199
+ tuple: (segments, segment_times) where segments is a list of tensors, and segment_times
200
+ is a list of (start, end) times.
201
+ """
202
+ num_samples = waveform.shape[1]
203
+ segment_samples = int(segment_length * self.sample_rate)
204
+ overlap_samples = int(overlap * self.sample_rate)
205
+ step_samples = segment_samples - overlap_samples
206
+ segments = []
207
+ segment_times = []
208
+
209
+ for start in range(0, num_samples - segment_samples + 1, step_samples):
210
+ end = start + segment_samples
211
+ segments.append(waveform[:, start:end])
212
+ start_time = start / self.sample_rate
213
+ end_time = end / self.sample_rate
214
+ segment_times.append((start_time, end_time))
215
+
216
+ return segments, segment_times
217
+
218
+ def classify_file(self, path, segment_length=2.0, overlap=1.47):
219
+ """
220
+ Processes an audio file to classify and count speakers within segments.
221
+ Utilizes multiple stages of processing to handle overlapping speech and transitions.
222
+
223
+ Args:
224
+ path (str): Path to the audio file.
225
+ segment_length (float): Length of each segment in seconds.
226
+ overlap (float): Overlap between segments in seconds.
227
+
228
+ Outputs:
229
+ Writes the number of speakers in each segment to a text file.
230
+ """
231
+ waveform, osr = torchaudio.load(path)
232
+ waveform = self.resample_waveform(waveform, osr)
233
+
234
+ segments, segment_times = self.create_segments(waveform, segment_length, overlap)
235
+ segment_predictions = []
236
+
237
+ for segment, (start_time, end_time) in zip(segments, segment_times):
238
+ rel_length = torch.tensor([1.0])
239
+ emb = self.encode_batch(segment, rel_length)
240
+ out_prob = self.mods.classifier(emb).squeeze(1)
241
+ score, index = torch.max(out_prob, dim=-1)
242
+ text_lab = index.item()
243
+ segment_predictions.append((start_time, end_time, text_lab))
244
+
245
+ aggregated_predictions = aggregate_segments_with_overlap(segment_predictions)
246
+ refined_predictions = refine_transitions(aggregated_predictions)
247
+ preds = refine_transitions_with_confidence(aggregated_predictions, refined_predictions)
248
+
249
+ with open("sample_segment_predictions.txt", "w") as file:
250
+ for start_time, end_time, prediction in preds:
251
+ speaker_text = "no speech" if str(prediction) == "0" else (
252
+ "1 speaker" if str(prediction) == "1" else f"{prediction} speakers")
253
+ print(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}")
254
+ file.write(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}\n")
255
+
256
+ def forward(self, wavs, wav_lens=None):
257
+ """
258
+ Forward pass for classifying audio using preloaded modules.
259
+
260
+ Args:
261
+ wavs (Tensor): Input waveforms.
262
+ wav_lens (Tensor, optional): Lengths of the input waveforms.
263
+
264
+ Returns:
265
+ Output from classify_file method.
266
+ """
267
+ return self.classify_file(wavs, wav_lens)