agorlanov commited on
Commit
3ff6c9f
1 Parent(s): 9bec5d4
Files changed (5) hide show
  1. README.md +3 -3
  2. app.py +25 -16
  3. main_pipeline.py +28 -37
  4. utils/denoise_pipeline.py +0 -1
  5. utils/diarization_pipeline.py +651 -9
README.md CHANGED
@@ -10,9 +10,8 @@ pinned: false
10
  ---
11
 
12
  # How inference:
13
- 1) [huggingface](https://huggingface.co/spaces/deepkotix/denoise_and_diarization)
14
- 2) [telegram bot](https://t.me/diarizarion_bot)
15
- 3) run local inference:
16
  1) GUI:
17
  `python app.py`
18
  2) Inference local:
@@ -67,4 +66,5 @@ How i can improve (i have experience in it):
67
 
68
  How to improve besides what's on top:
69
  + delete overlap speech using asr
 
70
 
 
10
  ---
11
 
12
  # How inference:
13
+ 1) [huggingface](https://huggingface.co/spaces/speechmaster/denoise_and_diarization)
14
+ 2) run local inference:
 
15
  1) GUI:
16
  `python app.py`
17
  2) Inference local:
 
66
 
67
  How to improve besides what's on top:
68
  + delete overlap speech using asr
69
+ + delete overlap speech using overlap detection
70
 
app.py CHANGED
@@ -1,11 +1,13 @@
1
- import os
2
 
3
- import gradio as gr
4
-
5
- from main_pipeline import main_pipeline
6
  from scipy.io.wavfile import write
7
 
8
- title = "Audio_denoise and speaker diarization"
 
 
 
 
9
 
10
  example_list = [
11
  ["dialog.mp3"]
@@ -13,25 +15,32 @@ example_list = [
13
 
14
 
15
  def app_pipeline(audio):
 
 
 
16
  audio_path = 'test.wav'
17
  write(audio_path, audio[0], audio[1])
18
- denoised_audio_path, result_diarization = main_pipeline(audio_path)
 
 
19
 
20
- return [denoised_audio_path]+result_diarization + [None] * (9 - len(result_diarization))
21
 
22
-
23
- iface = gr.Interface(
24
  app_pipeline,
25
- gr.Audio(type="numpy", label="Input"),
26
- [gr.Audio(visible=True, label='denoised_audio' if i == 0 else f'speaker{i}') for i in range(10)],
27
  title=title,
28
  examples=example_list,
29
  cache_examples=False,
30
 
31
  )
32
 
33
- if 'PORT' in os.environ.keys():
34
- iface.launch(enable_queue=True, auth=(os.environ["login"], os.environ["password"]), server_name="0.0.0.0",
35
- server_port=int(os.environ["PORT"]))
36
- else:
37
- iface.launch(enable_queue=True)
 
 
 
 
 
1
+ import argparse
2
 
3
+ import torch
 
 
4
  from scipy.io.wavfile import write
5
 
6
+ from main_pipeline import CleaningPipeline
7
+
8
+ import gradio as gr
9
+
10
+ title = "Audio denoising and speaker diarization "
11
 
12
  example_list = [
13
  ["dialog.mp3"]
 
15
 
16
 
17
  def app_pipeline(audio):
18
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ cleaning_pipeline = CleaningPipeline(device)
20
+
21
  audio_path = 'test.wav'
22
  write(audio_path, audio[0], audio[1])
23
+ result = cleaning_pipeline(audio_path)
24
+ if result != []:
25
+ return result
26
 
 
27
 
28
+ app = gr.Interface(
 
29
  app_pipeline,
30
+ gr.Audio(type="numpy", label="Input_audio"),
31
+ [gr.Audio(visible=True, label='denoised_audio' if i == 0 else f'speaker{i}') for i in range(20)],
32
  title=title,
33
  examples=example_list,
34
  cache_examples=False,
35
 
36
  )
37
 
38
+ # if 'PORT' in os.environ.keys():
39
+ # app.launch(enable_queue=True, auth=(os.environ["login"], os.environ["password"]), server_name="0.0.0.0",
40
+ # server_port=int(os.environ["PORT"]))
41
+ # else:
42
+ # app.launch(enable_queue=True)
43
+
44
+ app.launch(debug=True, share=True, inline=False, enable_queue=True, max_threads=1,
45
+ server_name="0.0.0.0",
46
+ server_port=1234)
main_pipeline.py CHANGED
@@ -2,60 +2,51 @@ import argparse
2
  import librosa
3
  import torch
4
  import os
 
 
5
  from tqdm import tqdm
6
 
7
  from utils.denoise_pipeline import denoise
8
- from utils.diarization_pipeline import diarization
 
9
  import numpy as np
10
 
11
  import pandas as pd
12
  import soundfile as sf
13
 
14
 
15
- def filter_small_speech(segments):
16
- segments['duration'] = segments.end - segments.start
17
- durs = segments.groupby('label').sum()
18
- labels = durs[durs['duration'] / durs.sum()['duration'] > 0.015].index
19
- return segments[segments.label.isin(labels)]
20
-
21
-
22
- def save_speaker_audios(segments, denoised_audio_path, out_folder='out', out_f=48000):
23
- signal, sr = librosa.load(denoised_audio_path, sr=out_f, mono=True)
24
- os.makedirs(out_folder, exist_ok=True)
25
- out_wav_paths = []
26
 
27
- segments = pd.DataFrame(segments)
28
- segments = filter_small_speech(segments)
 
29
 
30
- for label in set(segments.label):
31
- temp_df = segments[segments.label == label]
32
- output_signal = []
33
- for _, r in temp_df.iterrows():
34
- start = int(r["start"] * out_f)
35
- end = int(r["end"] * out_f)
36
- output_signal.append(signal[start:end])
37
 
38
- out_wav_path = f'{out_folder}/{label}.wav'
39
- sf.write(out_wav_path, np.concatenate(output_signal), out_f, 'PCM_24')
40
- out_wav_paths.append(out_wav_path)
 
41
 
42
- return out_wav_paths[:10]
43
-
44
-
45
- def main_pipeline(audio_path, out_folder='out'):
46
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
-
48
- denoised_audio_path = denoise(audio_path, device)
49
- segments = diarization(denoised_audio_path)
50
- denoised_audio_paths = save_speaker_audios(segments, denoised_audio_path, out_folder)
51
- return denoised_audio_path, denoised_audio_paths
52
 
53
 
54
  if __name__ == '__main__':
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument('--audio-path', default='dialog.mp3', help='Path to audio')
 
57
  parser.add_argument('--out-folder-path', default='out', help='Path to result folder')
58
  opt = parser.parse_args()
59
-
60
- for _ in tqdm(range(10)):
61
- main_pipeline(audio_path=opt.audio_path, out_folder=opt.out_folder_path)
 
 
2
  import librosa
3
  import torch
4
  import os
5
+
6
+ from scipy.io.wavfile import write
7
  from tqdm import tqdm
8
 
9
  from utils.denoise_pipeline import denoise
10
+ from utils.diarization_pipeline import DiarizationPipeline
11
+
12
  import numpy as np
13
 
14
  import pandas as pd
15
  import soundfile as sf
16
 
17
 
18
+ class CleaningPipeline:
19
+ def __init__(self, device):
20
+ """
21
+ Cleaning audio pipeline. Contains:
22
+ - denoising
23
+ - diarization
24
+ """
 
 
 
 
25
 
26
+ self.device = device
27
+ self.denoiser = denoise
28
+ self.diarization = DiarizationPipeline(device)
29
 
30
+ def __call__(self, input_audio_path: str):
31
+ denoised_audio_path = self.denoiser(input_audio_path, self.device)
32
+ result_diarization = self.diarization(denoised_audio_path)
 
 
 
 
33
 
34
+ if result_diarization != {}:
35
+ output_diar_audio_paths = result_diarization['output_diar_audio_paths']
36
+ count_speakers = result_diarization['count_speakers']
37
+ return [denoised_audio_path] + output_diar_audio_paths + [None] * (19 - count_speakers)
38
 
39
+ else:
40
+ return []
 
 
 
 
 
 
 
 
41
 
42
 
43
  if __name__ == '__main__':
44
  parser = argparse.ArgumentParser()
45
  parser.add_argument('--audio-path', default='dialog.mp3', help='Path to audio')
46
+ parser.add_argument('--device', default='cpu', help='Path to audio')
47
  parser.add_argument('--out-folder-path', default='out', help='Path to result folder')
48
  opt = parser.parse_args()
49
+ cleaning_pipeline = CleaningPipeline('cuda:0')
50
+ cleaning_pipeline(input_audio_path=opt.audio_path)
51
+ # for _ in tqdm(range(10)):
52
+ # main_pipeline(audio_path=opt.audio_path, device=opt.device)
utils/denoise_pipeline.py CHANGED
@@ -5,7 +5,6 @@ from demucs.pretrained import get_model
5
  from scipy.io.wavfile import write
6
 
7
  demucs_model = get_model('cfa93e08')
8
- # demucs_model = get_model('htdemucs')
9
 
10
 
11
  def denoise(filename: str, device: str, out_filename='denoise.wav') -> str:
 
5
  from scipy.io.wavfile import write
6
 
7
  demucs_model = get_model('cfa93e08')
 
8
 
9
 
10
  def denoise(filename: str, device: str, out_filename='denoise.wav') -> str:
utils/diarization_pipeline.py CHANGED
@@ -1,25 +1,667 @@
1
- from simple_diarizer.diarizer import Diarizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class DiarizationPipeline:
5
- def __init__(self, ):
6
  super(DiarizationPipeline, self).__init__()
7
  self.diar = Diarizer(
 
8
  embed_model='ecapa', # supported types: ['xvec', 'ecapa']
9
  cluster_method='ahc', # supported types: ['ahc', 'sc']
10
  window=1, # size of window to extract embeddings (in seconds)
11
  period=0.1 # hop of window (in seconds)
12
  )
13
 
14
- def __call__(self, wav_file):
15
- segments = self.diar.diarize(wav_file,
16
- num_speakers=None,
17
- threshold=9e-1, )
 
 
 
 
 
 
 
 
18
 
19
- return segments
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- diarization = DiarizationPipeline()
23
 
24
  if __name__ == '__main__':
25
- diarization('../converted.wav')
 
 
 
1
+ from copy import deepcopy
2
+ from os.path import basename, splitext
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import pandas as pd
7
+ import soundfile as sf
8
+ import torch
9
+ import torchaudio
10
+ from scipy.ndimage import gaussian_filter
11
+ from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
12
+ from sklearn.metrics import pairwise_distances
13
+ from speechbrain.pretrained import EncoderClassifier
14
+
15
+
16
+ def similarity_matrix(embeds, metric="cosine"):
17
+ return pairwise_distances(embeds, metric=metric)
18
+
19
+
20
+ def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwargs):
21
+ """
22
+ Cluster embeds using Agglomerative Hierarchical Clustering
23
+ """
24
+ if n_clusters is None:
25
+ assert threshold, "If num_clusters is not defined, threshold must be defined"
26
+
27
+ S = similarity_matrix(embeds, metric=metric)
28
+
29
+ if n_clusters is None:
30
+ cluster_model = AgglomerativeClustering(
31
+ n_clusters=None,
32
+ affinity="precomputed",
33
+ linkage="average",
34
+ compute_full_tree=True,
35
+ distance_threshold=threshold,
36
+ )
37
+
38
+ return cluster_model.fit_predict(S)
39
+ else:
40
+ cluster_model = AgglomerativeClustering(
41
+ n_clusters=n_clusters, affinity="precomputed", linkage="average"
42
+ )
43
+
44
+ return cluster_model.fit_predict(S)
45
+
46
+
47
+ ##########################################
48
+ # Spectral clustering
49
+ # A lot of these methods are lifted from
50
+ # https://github.com/wq2012/SpectralCluster
51
+ ##########################################
52
+
53
+
54
+ def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs):
55
+ """
56
+ Cluster embeds using Spectral Clustering
57
+ """
58
+ if n_clusters is None:
59
+ assert threshold, "If num_clusters is not defined, threshold must be defined"
60
+
61
+ S = compute_affinity_matrix(embeds)
62
+ if enhance_sim:
63
+ S = sim_enhancement(S)
64
+
65
+ if n_clusters is None:
66
+ (eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S)
67
+ # Get number of clusters.
68
+ k = compute_number_of_clusters(eigenvalues, 100, threshold)
69
+
70
+ # Get spectral embeddings.
71
+ spectral_embeddings = eigenvectors[:, :k]
72
+
73
+ # Run K-Means++ on spectral embeddings.
74
+ # Note: The correct way should be using a K-Means implementation
75
+ # that supports customized distance measure such as cosine distance.
76
+ # This implemention from scikit-learn does NOT, which is inconsistent
77
+ # with the paper.
78
+ kmeans_clusterer = KMeans(
79
+ n_clusters=k, init="k-means++", max_iter=300, random_state=0
80
+ )
81
+ labels = kmeans_clusterer.fit_predict(spectral_embeddings)
82
+ return labels
83
+ else:
84
+ cluster_model = SpectralClustering(
85
+ n_clusters=n_clusters, affinity="precomputed"
86
+ )
87
+
88
+ return cluster_model.fit_predict(S)
89
+
90
+
91
+ def diagonal_fill(A):
92
+ """
93
+ Sets the diagonal elemnts of the matrix to the max of each row
94
+ """
95
+ np.fill_diagonal(A, 0.0)
96
+ A[np.diag_indices(A.shape[0])] = np.max(A, axis=1)
97
+ return A
98
+
99
+
100
+ def gaussian_blur(A, sigma=1.0):
101
+ """
102
+ Does a gaussian blur on the affinity matrix
103
+ """
104
+ return gaussian_filter(A, sigma=sigma)
105
+
106
+
107
+ def row_threshold_mult(A, p=0.95, mult=0.01):
108
+ """
109
+ For each row multiply elements smaller than the row's p'th percentile by mult
110
+ """
111
+ percentiles = np.percentile(A, p * 100, axis=1)
112
+ mask = A < percentiles[:, np.newaxis]
113
+
114
+ A = (mask * mult * A) + (~mask * A)
115
+ return A
116
+
117
+
118
+ def symmetrization(A):
119
+ """
120
+ Symmeterization: Y_{i,j} = max(S_{ij}, S_{ji})
121
+ """
122
+ return np.maximum(A, A.T)
123
+
124
+
125
+ def diffusion(A):
126
+ """
127
+ Diffusion: Y <- YY^T
128
+ """
129
+ return np.dot(A, A.T)
130
+
131
+
132
+ def row_max_norm(A):
133
+ """
134
+ Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik})
135
+ """
136
+ maxes = np.amax(A, axis=1)
137
+ return A / maxes
138
+
139
+
140
+ def sim_enhancement(A):
141
+ func_order = [
142
+ diagonal_fill,
143
+ gaussian_blur,
144
+ row_threshold_mult,
145
+ symmetrization,
146
+ diffusion,
147
+ row_max_norm,
148
+ ]
149
+ for f in func_order:
150
+ A = f(A)
151
+ return A
152
+
153
+
154
+ def compute_affinity_matrix(X):
155
+ """Compute the affinity matrix from data.
156
+ Note that the range of affinity is [0,1].
157
+ Args:
158
+ X: numpy array of shape (n_samples, n_features)
159
+ Returns:
160
+ affinity: numpy array of shape (n_samples, n_samples)
161
+ """
162
+ # Normalize the data.
163
+ l2_norms = np.linalg.norm(X, axis=1)
164
+ X_normalized = X / l2_norms[:, None]
165
+ # Compute cosine similarities. Range is [-1,1].
166
+ cosine_similarities = np.matmul(X_normalized, np.transpose(X_normalized))
167
+ # Compute the affinity. Range is [0,1].
168
+ # Note that this step is not mentioned in the paper!
169
+ affinity = (cosine_similarities + 1.0) / 2.0
170
+ return affinity
171
+
172
+
173
+ def compute_sorted_eigenvectors(A):
174
+ """Sort eigenvectors by the real part of eigenvalues.
175
+ Args:
176
+ A: the matrix to perform eigen analysis with shape (M, M)
177
+ Returns:
178
+ w: sorted eigenvalues of shape (M,)
179
+ v: sorted eigenvectors, where v[;, i] corresponds to ith largest
180
+ eigenvalue
181
+ """
182
+ # Eigen decomposition.
183
+ eigenvalues, eigenvectors = np.linalg.eig(A)
184
+ eigenvalues = eigenvalues.real
185
+ eigenvectors = eigenvectors.real
186
+ # Sort from largest to smallest.
187
+ index_array = np.argsort(-eigenvalues)
188
+ # Re-order.
189
+ w = eigenvalues[index_array]
190
+ v = eigenvectors[:, index_array]
191
+ return w, v
192
+
193
+
194
+ def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2):
195
+ """Compute number of clusters using EigenGap principle.
196
+ Args:
197
+ eigenvalues: sorted eigenvalues of the affinity matrix
198
+ max_clusters: max number of clusters allowed
199
+ stop_eigenvalue: we do not look at eigen values smaller than this
200
+ Returns:
201
+ number of clusters as an integer
202
+ """
203
+ max_delta = 0
204
+ max_delta_index = 0
205
+ range_end = len(eigenvalues)
206
+ if max_clusters and max_clusters + 1 < range_end:
207
+ range_end = max_clusters + 1
208
+ for i in range(1, range_end):
209
+ if eigenvalues[i - 1] < stop_eigenvalue:
210
+ break
211
+ delta = eigenvalues[i - 1] / eigenvalues[i]
212
+ if delta > max_delta:
213
+ max_delta = delta
214
+ max_delta_index = i
215
+ return max_delta_index
216
+
217
+
218
+ class Diarizer:
219
+ def __init__(
220
+ self, device='cuda:0', embed_model="xvec", cluster_method="sc", window=1.5, period=0.75
221
+ ):
222
+ self.device = device
223
+ assert embed_model in [
224
+ "xvec",
225
+ "ecapa",
226
+ ], "Only xvec and ecapa are supported options"
227
+ assert cluster_method in [
228
+ "ahc",
229
+ "sc",
230
+ ], "Only ahc and sc in the supported clustering options"
231
+
232
+ if cluster_method == "ahc":
233
+ self.cluster = cluster_AHC
234
+ if cluster_method == "sc":
235
+ self.cluster = cluster_SC
236
+
237
+ self.vad_model, self.get_speech_ts = self.setup_VAD()
238
+
239
+ self.run_opts = ({"device": self.device})
240
+
241
+ if embed_model == "ecapa":
242
+ self.embed_model = EncoderClassifier.from_hparams(
243
+ source="speechbrain/spkrec-ecapa-voxceleb",
244
+ savedir="pretrained_models/spkrec-ecapa-voxceleb",
245
+ run_opts=self.run_opts,
246
+ )
247
+
248
+ self.window = window
249
+ self.period = period
250
+
251
+ def setup_VAD(self):
252
+ model, utils = torch.hub.load(
253
+ repo_or_dir="snakers4/silero-vad", model="silero_vad"
254
+ )
255
+ # force_reload=True)
256
+
257
+ get_speech_ts = utils[0]
258
+ return model, get_speech_ts
259
+
260
+ def vad(self, signal):
261
+ """
262
+ Runs the VAD model on the signal
263
+ """
264
+ return self.get_speech_ts(signal.to(self.device), self.vad_model.to(self.device))
265
+
266
+ def windowed_embeds(self, signal, fs, window=1.5, period=0.75):
267
+ """
268
+ Calculates embeddings for windows across the signal
269
+
270
+ window: length of the window, in seconds
271
+ period: jump of the window, in seconds
272
+
273
+ returns: embeddings, segment info
274
+ """
275
+ len_window = int(window * fs)
276
+ len_period = int(period * fs)
277
+ len_signal = signal.shape[1]
278
+
279
+ # Get the windowed segments
280
+ segments = []
281
+ start = 0
282
+ while start + len_window < len_signal:
283
+ segments.append([start, start + len_window])
284
+ start += len_period
285
+
286
+ segments.append([start, len_signal - 1])
287
+ embeds = []
288
+
289
+ with torch.no_grad():
290
+ for i, j in segments:
291
+ signal_seg = signal[:, i:j]
292
+ seg_embed = self.embed_model.encode_batch(signal_seg)
293
+ embeds.append(seg_embed.squeeze(0).squeeze(0).cpu().numpy())
294
+
295
+ embeds = np.array(embeds)
296
+ return embeds, np.array(segments)
297
+
298
+ def recording_embeds(self, signal, fs, speech_ts):
299
+ """
300
+ Takes signal and VAD output (speech_ts) and produces windowed embeddings
301
+
302
+ returns: embeddings, segment info
303
+ """
304
+ all_embeds = []
305
+ all_segments = []
306
+
307
+ for utt in speech_ts:
308
+ start = utt["start"]
309
+ end = utt["end"]
310
+
311
+ utt_signal = signal[:, start:end]
312
+ utt_embeds, utt_segments = self.windowed_embeds(
313
+ utt_signal, fs, self.window, self.period
314
+ )
315
+ all_embeds.append(utt_embeds)
316
+ all_segments.append(utt_segments + start)
317
+
318
+ all_embeds = np.concatenate(all_embeds, axis=0)
319
+ all_segments = np.concatenate(all_segments, axis=0)
320
+ return all_embeds, all_segments
321
+
322
+ @staticmethod
323
+ def join_segments(cluster_labels, segments, tolerance=5):
324
+ """
325
+ Joins up same speaker segments, resolves overlap conflicts
326
+
327
+ Uses the midpoint for overlap conflicts
328
+ tolerance allows for very minimally separated segments to be combined
329
+ (in samples)
330
+ """
331
+ assert len(cluster_labels) == len(segments)
332
+
333
+ new_segments = [
334
+ {"start": segments[0][0], "end": segments[0][1], "label": cluster_labels[0]}
335
+ ]
336
+
337
+ for l, seg in zip(cluster_labels[1:], segments[1:]):
338
+ start = seg[0]
339
+ end = seg[1]
340
+
341
+ protoseg = {"start": seg[0], "end": seg[1], "label": l}
342
+
343
+ if start <= new_segments[-1]["end"]:
344
+ # If segments overlap
345
+ if l == new_segments[-1]["label"]:
346
+ # If overlapping segment has same label
347
+ new_segments[-1]["end"] = end
348
+ else:
349
+ # If overlapping segment has diff label
350
+ # Resolve by setting new start to midpoint
351
+ # And setting last segment end to midpoint
352
+ overlap = new_segments[-1]["end"] - start
353
+ midpoint = start + overlap // 2
354
+ new_segments[-1]["end"] = midpoint
355
+ protoseg["start"] = midpoint
356
+ new_segments.append(protoseg)
357
+ else:
358
+ # If there's no overlap just append
359
+ new_segments.append(protoseg)
360
+
361
+ return new_segments
362
+
363
+ @staticmethod
364
+ def make_output_seconds(cleaned_segments, fs):
365
+ """
366
+ Convert cleaned segments to readable format in seconds
367
+ """
368
+ for seg in cleaned_segments:
369
+ seg["start_sample"] = seg["start"]
370
+ seg["end_sample"] = seg["end"]
371
+ seg["start"] = seg["start"] / fs
372
+ seg["end"] = seg["end"] / fs
373
+ return cleaned_segments
374
+
375
+ def diarize(
376
+ self,
377
+ wav_file,
378
+ num_speakers=2,
379
+ threshold=None,
380
+ silence_tolerance=0.2,
381
+ enhance_sim=True,
382
+ extra_info=False,
383
+ outfile=None,
384
+ ):
385
+ """
386
+ Diarize a 16khz mono wav file, produces list of segments
387
+
388
+ Inputs:
389
+ wav_file (path): Path to input audio file
390
+ num_speakers (int) or NoneType: Number of speakers to cluster to
391
+ threshold (float) or NoneType: Threshold to cluster to if
392
+ num_speakers is not defined
393
+ silence_tolerance (float): Same speaker segments which are close enough together
394
+ by silence_tolerance will be joined into a single segment
395
+ enhance_sim (bool): Whether or not to perform affinity matrix enhancement
396
+ during spectral clustering
397
+ If self.cluster_method is 'ahc' this option does nothing.
398
+ extra_info (bool): Whether or not to return the embeddings and raw segments
399
+ in addition to segments
400
+ outfile (path): If specified will output an RTTM file
401
+
402
+ Outputs:
403
+ If extra_info is False:
404
+ segments (list): List of dicts with segment information
405
+ {
406
+ 'start': Start time of segment in seconds,
407
+ 'start_sample': Starting index of segment,
408
+ 'end': End time of segment in seconds,
409
+ 'end_sample' Ending index of segment,
410
+ 'label': Cluster label of segment
411
+ }
412
+ If extra_info is True:
413
+ dict: { 'segments': segments (list): List of dicts with segment information
414
+ {
415
+ 'start': Start time of segment in seconds,
416
+ 'start_sample': Starting index of segment,
417
+ 'end': End time of segment in seconds,
418
+ 'end_sample' Ending index of segment,
419
+ 'label': Cluster label of segment
420
+ },
421
+ 'embeds': embeddings (np.array): Array of embeddings, each row corresponds to a segment,
422
+ 'segments': segments (list): indexes for start and end frame for each embed in embeds,
423
+ 'cluster_labels': cluster_labels (list): cluster label for each embed in embeds
424
+ }
425
+
426
+ Uses AHC/SC to cluster
427
+ """
428
+
429
+ signal, fs = torchaudio.load(wav_file)
430
+ if len(signal) == 2:
431
+ signal = signal[:1, :]
432
+ if fs != 16000:
433
+ signal = torchaudio.functional.resample(signal, fs, 16000)
434
+ fs = 16000
435
+
436
+ speech_ts = self.vad(signal[0])
437
+ if len(speech_ts) >= 1:
438
+
439
+ embeds, segments = self.recording_embeds(signal, fs, speech_ts)
440
+ if len(embeds) > 1:
441
+ cluster_labels = self.cluster(
442
+ embeds,
443
+ n_clusters=num_speakers,
444
+ threshold=threshold,
445
+ enhance_sim=enhance_sim,
446
+ )
447
+ else:
448
+ cluster_labels = np.zeros(len(embeds), dtype=int)
449
+ cleaned_segments = self.join_segments(cluster_labels, segments)
450
+ cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
451
+ cleaned_segments = self.join_samespeaker_segments(
452
+ cleaned_segments, silence_tolerance=silence_tolerance
453
+ )
454
+ if outfile:
455
+ self.rttm_output(cleaned_segments, splitext(basename(wav_file))[0], outfile=outfile)
456
+
457
+ if not extra_info:
458
+ return cleaned_segments
459
+ else:
460
+ return {"clean_segments": cleaned_segments,
461
+ "embeds": embeds,
462
+ "segments": segments,
463
+ "cluster_labels": cluster_labels}
464
+ else:
465
+ print("Couldn't find any speech during VAD")
466
+ return {}
467
+
468
+ @staticmethod
469
+ def rttm_output(segments, recname, outfile=None):
470
+ assert outfile, "Please specify an outfile"
471
+ rttm_line = "SPEAKER {} 0 {} {} <NA> <NA> {} <NA> <NA>\n"
472
+ with open(outfile, "w") as fp:
473
+ for seg in segments:
474
+ start = seg["start"]
475
+ offset = seg["end"] - seg["start"]
476
+ label = seg["label"]
477
+ line = rttm_line.format(recname, start, offset, label)
478
+ fp.write(line)
479
+
480
+ @staticmethod
481
+ def join_samespeaker_segments(segments, silence_tolerance=0.5):
482
+ """
483
+ Join up segments that belong to the same speaker,
484
+ even if there is a duration of silence in between them.
485
+
486
+ If the silence is greater than silence_tolerance, does not join up
487
+ """
488
+ new_segments = [segments[0]]
489
+
490
+ for seg in segments[1:]:
491
+ if seg["label"] == new_segments[-1]["label"]:
492
+ if new_segments[-1]["end"] + silence_tolerance >= seg["start"]:
493
+ new_segments[-1]["end"] = seg["end"]
494
+ new_segments[-1]["end_sample"] = seg["end_sample"]
495
+ else:
496
+ new_segments.append(seg)
497
+ else:
498
+ new_segments.append(seg)
499
+ return new_segments
500
+
501
+ @staticmethod
502
+ def match_diarization_to_transcript(segments, text_segments):
503
+ """
504
+ Match the output of .diarize to word segments
505
+ """
506
+
507
+ text_starts, text_ends, text_segs = [], [], []
508
+ for s in text_segments:
509
+ text_starts.append(s["start"])
510
+ text_ends.append(s["end"])
511
+ text_segs.append(s["text"])
512
+
513
+ text_starts = np.array(text_starts)
514
+ text_ends = np.array(text_ends)
515
+ text_segs = np.array(text_segs)
516
+
517
+ # Get the earliest start from either diar output or asr output
518
+ earliest_start = np.min([text_starts[0], segments[0]["start"]])
519
+
520
+ worded_segments = segments.copy()
521
+ worded_segments[0]["start"] = earliest_start
522
+ cutoffs = []
523
+
524
+ for seg in worded_segments:
525
+ end_idx = np.searchsorted(text_ends, seg["end"], side="left") - 1
526
+ cutoffs.append(end_idx)
527
+
528
+ indexes = [[0, cutoffs[0]]]
529
+ for c in cutoffs[1:]:
530
+ indexes.append([indexes[-1][-1], c])
531
+
532
+ indexes[-1][-1] = len(text_segs)
533
+
534
+ final_segments = []
535
+
536
+ for i, seg in enumerate(worded_segments):
537
+ s_idx, e_idx = indexes[i]
538
+ words = text_segs[s_idx:e_idx]
539
+ newseg = deepcopy(seg)
540
+ newseg["words"] = " ".join(words)
541
+ final_segments.append(newseg)
542
+
543
+ return final_segments
544
+
545
+ def match_diarization_to_transcript_ctm(self, segments, ctm_file):
546
+ """
547
+ Match the output of .diarize to a ctm file produced by asr
548
+ """
549
+ ctm_df = pd.read_csv(
550
+ ctm_file,
551
+ delimiter=" ",
552
+ names=["utt", "channel", "start", "offset", "word", "confidence"],
553
+ )
554
+ ctm_df["end"] = ctm_df["start"] + ctm_df["offset"]
555
+
556
+ starts = ctm_df["start"].values
557
+ ends = ctm_df["end"].values
558
+ words = ctm_df["word"].values
559
+
560
+ # Get the earliest start from either diar output or asr output
561
+ earliest_start = np.min([ctm_df["start"].values[0], segments[0]["start"]])
562
+
563
+ worded_segments = self.join_samespeaker_segments(segments)
564
+ worded_segments[0]["start"] = earliest_start
565
+ cutoffs = []
566
+
567
+ for seg in worded_segments:
568
+ end_idx = np.searchsorted(ctm_df["end"].values, seg["end"], side="left") - 1
569
+ cutoffs.append(end_idx)
570
+
571
+ indexes = [[0, cutoffs[0]]]
572
+ for c in cutoffs[1:]:
573
+ indexes.append([indexes[-1][-1], c])
574
+
575
+ indexes[-1][-1] = len(words)
576
+
577
+ final_segments = []
578
+
579
+ for i, seg in enumerate(worded_segments):
580
+ s_idx, e_idx = indexes[i]
581
+ words = ctm_df["word"].values[s_idx:e_idx]
582
+ seg["words"] = " ".join(words)
583
+ if len(words) >= 1:
584
+ final_segments.append(seg)
585
+ else:
586
+ print(
587
+ "Removed segment between {} and {} as no words were matched".format(
588
+ seg["start"], seg["end"]
589
+ )
590
+ )
591
+
592
+ return final_segments
593
+
594
+ @staticmethod
595
+ def nice_text_output(worded_segments, outfile):
596
+ with open(outfile, "w") as fp:
597
+ for seg in worded_segments:
598
+ fp.write(
599
+ "[{} to {}] Speaker {}: \n".format(
600
+ round(seg["start"], 2), round(seg["end"], 2), seg["label"]
601
+ )
602
+ )
603
+ fp.write("{}\n\n".format(seg["words"]))
604
 
605
 
606
  class DiarizationPipeline:
607
+ def __init__(self, device=None):
608
  super(DiarizationPipeline, self).__init__()
609
  self.diar = Diarizer(
610
+ device=device,
611
  embed_model='ecapa', # supported types: ['xvec', 'ecapa']
612
  cluster_method='ahc', # supported types: ['ahc', 'sc']
613
  window=1, # size of window to extract embeddings (in seconds)
614
  period=0.1 # hop of window (in seconds)
615
  )
616
 
617
+ def save_speaker_audios(self, segments: list, audio_path: str):
618
+ """
619
+
620
+ :param segments: result diarization timestamps
621
+ :param audio_path:
622
+ :return: out_wav_paths: list of audio paths
623
+ """
624
+ signal, sr = librosa.load(audio_path, sr=None, mono=True)
625
+ out_wav_paths = []
626
+
627
+ segments = pd.DataFrame(segments)
628
+ segments = self.filter_small_speech(segments)
629
 
630
+ sort_labels = segments.groupby(['label'])['duration'].sum().nlargest(len(set(segments.label))).index
631
 
632
+ for indx, label in enumerate(sort_labels):
633
+ temp_df = segments[segments.label == label]
634
+ output_signal = []
635
+ for _, r in temp_df.iterrows():
636
+ start = int(r["start"] * sr)
637
+ end = int(r["end"] * sr)
638
+ output_signal.append(signal[start:end])
639
+ out_wav_path = audio_path.replace('.wav', f'_{indx}.wav')
640
+ sf.write(out_wav_path, np.concatenate(output_signal), sr)
641
+ out_wav_paths.append(out_wav_path)
642
+
643
+ return out_wav_paths
644
+
645
+ def filter_small_speech(self, segments):
646
+ segments['duration'] = segments.end - segments.start
647
+ durs = segments.groupby('label').sum()
648
+ labels = durs[durs['duration'] / durs.sum()['duration'] > 0.015].index
649
+ return segments[segments.label.isin(labels)]
650
+
651
+ def __call__(self, input_wav_path: str)-> dict:
652
+
653
+ segments = self.diar.diarize(input_wav_path,
654
+ num_speakers=None,
655
+ threshold=9e-1, )
656
+ if segments != {}:
657
+ output_wav_paths = self.save_speaker_audios(segments, input_wav_path)
658
+ return {'count_speakers': max([i['label'] for i in segments]) + 1, 'diarization_segments': segments,
659
+ 'output_diar_audio_paths': output_wav_paths}
660
+ else:
661
+ return {}
662
 
 
663
 
664
  if __name__ == '__main__':
665
+ diarization = DiarizationPipeline(device='cuda:0')
666
+
667
+ diarization('../dialog.mp3')