Spaces:
Runtime error
Runtime error
agorlanov
commited on
Commit
•
3ff6c9f
1
Parent(s):
9bec5d4
train_fix
Browse files- README.md +3 -3
- app.py +25 -16
- main_pipeline.py +28 -37
- utils/denoise_pipeline.py +0 -1
- utils/diarization_pipeline.py +651 -9
README.md
CHANGED
@@ -10,9 +10,8 @@ pinned: false
|
|
10 |
---
|
11 |
|
12 |
# How inference:
|
13 |
-
1) [huggingface](https://huggingface.co/spaces/
|
14 |
-
2)
|
15 |
-
3) run local inference:
|
16 |
1) GUI:
|
17 |
`python app.py`
|
18 |
2) Inference local:
|
@@ -67,4 +66,5 @@ How i can improve (i have experience in it):
|
|
67 |
|
68 |
How to improve besides what's on top:
|
69 |
+ delete overlap speech using asr
|
|
|
70 |
|
|
|
10 |
---
|
11 |
|
12 |
# How inference:
|
13 |
+
1) [huggingface](https://huggingface.co/spaces/speechmaster/denoise_and_diarization)
|
14 |
+
2) run local inference:
|
|
|
15 |
1) GUI:
|
16 |
`python app.py`
|
17 |
2) Inference local:
|
|
|
66 |
|
67 |
How to improve besides what's on top:
|
68 |
+ delete overlap speech using asr
|
69 |
+
+ delete overlap speech using overlap detection
|
70 |
|
app.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
-
import
|
2 |
|
3 |
-
import
|
4 |
-
|
5 |
-
from main_pipeline import main_pipeline
|
6 |
from scipy.io.wavfile import write
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
example_list = [
|
11 |
["dialog.mp3"]
|
@@ -13,25 +15,32 @@ example_list = [
|
|
13 |
|
14 |
|
15 |
def app_pipeline(audio):
|
|
|
|
|
|
|
16 |
audio_path = 'test.wav'
|
17 |
write(audio_path, audio[0], audio[1])
|
18 |
-
|
|
|
|
|
19 |
|
20 |
-
return [denoised_audio_path]+result_diarization + [None] * (9 - len(result_diarization))
|
21 |
|
22 |
-
|
23 |
-
iface = gr.Interface(
|
24 |
app_pipeline,
|
25 |
-
gr.Audio(type="numpy", label="
|
26 |
-
[gr.Audio(visible=True, label='denoised_audio' if i == 0 else f'speaker{i}') for i in range(
|
27 |
title=title,
|
28 |
examples=example_list,
|
29 |
cache_examples=False,
|
30 |
|
31 |
)
|
32 |
|
33 |
-
if 'PORT' in os.environ.keys():
|
34 |
-
|
35 |
-
|
36 |
-
else:
|
37 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
|
3 |
+
import torch
|
|
|
|
|
4 |
from scipy.io.wavfile import write
|
5 |
|
6 |
+
from main_pipeline import CleaningPipeline
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
title = "Audio denoising and speaker diarization "
|
11 |
|
12 |
example_list = [
|
13 |
["dialog.mp3"]
|
|
|
15 |
|
16 |
|
17 |
def app_pipeline(audio):
|
18 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
+
cleaning_pipeline = CleaningPipeline(device)
|
20 |
+
|
21 |
audio_path = 'test.wav'
|
22 |
write(audio_path, audio[0], audio[1])
|
23 |
+
result = cleaning_pipeline(audio_path)
|
24 |
+
if result != []:
|
25 |
+
return result
|
26 |
|
|
|
27 |
|
28 |
+
app = gr.Interface(
|
|
|
29 |
app_pipeline,
|
30 |
+
gr.Audio(type="numpy", label="Input_audio"),
|
31 |
+
[gr.Audio(visible=True, label='denoised_audio' if i == 0 else f'speaker{i}') for i in range(20)],
|
32 |
title=title,
|
33 |
examples=example_list,
|
34 |
cache_examples=False,
|
35 |
|
36 |
)
|
37 |
|
38 |
+
# if 'PORT' in os.environ.keys():
|
39 |
+
# app.launch(enable_queue=True, auth=(os.environ["login"], os.environ["password"]), server_name="0.0.0.0",
|
40 |
+
# server_port=int(os.environ["PORT"]))
|
41 |
+
# else:
|
42 |
+
# app.launch(enable_queue=True)
|
43 |
+
|
44 |
+
app.launch(debug=True, share=True, inline=False, enable_queue=True, max_threads=1,
|
45 |
+
server_name="0.0.0.0",
|
46 |
+
server_port=1234)
|
main_pipeline.py
CHANGED
@@ -2,60 +2,51 @@ import argparse
|
|
2 |
import librosa
|
3 |
import torch
|
4 |
import os
|
|
|
|
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
from utils.denoise_pipeline import denoise
|
8 |
-
from utils.diarization_pipeline import
|
|
|
9 |
import numpy as np
|
10 |
|
11 |
import pandas as pd
|
12 |
import soundfile as sf
|
13 |
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
def save_speaker_audios(segments, denoised_audio_path, out_folder='out', out_f=48000):
|
23 |
-
signal, sr = librosa.load(denoised_audio_path, sr=out_f, mono=True)
|
24 |
-
os.makedirs(out_folder, exist_ok=True)
|
25 |
-
out_wav_paths = []
|
26 |
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
for _, r in temp_df.iterrows():
|
34 |
-
start = int(r["start"] * out_f)
|
35 |
-
end = int(r["end"] * out_f)
|
36 |
-
output_signal.append(signal[start:end])
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
def main_pipeline(audio_path, out_folder='out'):
|
46 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
47 |
-
|
48 |
-
denoised_audio_path = denoise(audio_path, device)
|
49 |
-
segments = diarization(denoised_audio_path)
|
50 |
-
denoised_audio_paths = save_speaker_audios(segments, denoised_audio_path, out_folder)
|
51 |
-
return denoised_audio_path, denoised_audio_paths
|
52 |
|
53 |
|
54 |
if __name__ == '__main__':
|
55 |
parser = argparse.ArgumentParser()
|
56 |
parser.add_argument('--audio-path', default='dialog.mp3', help='Path to audio')
|
|
|
57 |
parser.add_argument('--out-folder-path', default='out', help='Path to result folder')
|
58 |
opt = parser.parse_args()
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
2 |
import librosa
|
3 |
import torch
|
4 |
import os
|
5 |
+
|
6 |
+
from scipy.io.wavfile import write
|
7 |
from tqdm import tqdm
|
8 |
|
9 |
from utils.denoise_pipeline import denoise
|
10 |
+
from utils.diarization_pipeline import DiarizationPipeline
|
11 |
+
|
12 |
import numpy as np
|
13 |
|
14 |
import pandas as pd
|
15 |
import soundfile as sf
|
16 |
|
17 |
|
18 |
+
class CleaningPipeline:
|
19 |
+
def __init__(self, device):
|
20 |
+
"""
|
21 |
+
Cleaning audio pipeline. Contains:
|
22 |
+
- denoising
|
23 |
+
- diarization
|
24 |
+
"""
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
self.device = device
|
27 |
+
self.denoiser = denoise
|
28 |
+
self.diarization = DiarizationPipeline(device)
|
29 |
|
30 |
+
def __call__(self, input_audio_path: str):
|
31 |
+
denoised_audio_path = self.denoiser(input_audio_path, self.device)
|
32 |
+
result_diarization = self.diarization(denoised_audio_path)
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
if result_diarization != {}:
|
35 |
+
output_diar_audio_paths = result_diarization['output_diar_audio_paths']
|
36 |
+
count_speakers = result_diarization['count_speakers']
|
37 |
+
return [denoised_audio_path] + output_diar_audio_paths + [None] * (19 - count_speakers)
|
38 |
|
39 |
+
else:
|
40 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
if __name__ == '__main__':
|
44 |
parser = argparse.ArgumentParser()
|
45 |
parser.add_argument('--audio-path', default='dialog.mp3', help='Path to audio')
|
46 |
+
parser.add_argument('--device', default='cpu', help='Path to audio')
|
47 |
parser.add_argument('--out-folder-path', default='out', help='Path to result folder')
|
48 |
opt = parser.parse_args()
|
49 |
+
cleaning_pipeline = CleaningPipeline('cuda:0')
|
50 |
+
cleaning_pipeline(input_audio_path=opt.audio_path)
|
51 |
+
# for _ in tqdm(range(10)):
|
52 |
+
# main_pipeline(audio_path=opt.audio_path, device=opt.device)
|
utils/denoise_pipeline.py
CHANGED
@@ -5,7 +5,6 @@ from demucs.pretrained import get_model
|
|
5 |
from scipy.io.wavfile import write
|
6 |
|
7 |
demucs_model = get_model('cfa93e08')
|
8 |
-
# demucs_model = get_model('htdemucs')
|
9 |
|
10 |
|
11 |
def denoise(filename: str, device: str, out_filename='denoise.wav') -> str:
|
|
|
5 |
from scipy.io.wavfile import write
|
6 |
|
7 |
demucs_model = get_model('cfa93e08')
|
|
|
8 |
|
9 |
|
10 |
def denoise(filename: str, device: str, out_filename='denoise.wav') -> str:
|
utils/diarization_pipeline.py
CHANGED
@@ -1,25 +1,667 @@
|
|
1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
class DiarizationPipeline:
|
5 |
-
def __init__(self, ):
|
6 |
super(DiarizationPipeline, self).__init__()
|
7 |
self.diar = Diarizer(
|
|
|
8 |
embed_model='ecapa', # supported types: ['xvec', 'ecapa']
|
9 |
cluster_method='ahc', # supported types: ['ahc', 'sc']
|
10 |
window=1, # size of window to extract embeddings (in seconds)
|
11 |
period=0.1 # hop of window (in seconds)
|
12 |
)
|
13 |
|
14 |
-
def
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
diarization = DiarizationPipeline()
|
23 |
|
24 |
if __name__ == '__main__':
|
25 |
-
diarization('
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from os.path import basename, splitext
|
3 |
+
|
4 |
+
import librosa
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
import soundfile as sf
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from scipy.ndimage import gaussian_filter
|
11 |
+
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
|
12 |
+
from sklearn.metrics import pairwise_distances
|
13 |
+
from speechbrain.pretrained import EncoderClassifier
|
14 |
+
|
15 |
+
|
16 |
+
def similarity_matrix(embeds, metric="cosine"):
|
17 |
+
return pairwise_distances(embeds, metric=metric)
|
18 |
+
|
19 |
+
|
20 |
+
def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwargs):
|
21 |
+
"""
|
22 |
+
Cluster embeds using Agglomerative Hierarchical Clustering
|
23 |
+
"""
|
24 |
+
if n_clusters is None:
|
25 |
+
assert threshold, "If num_clusters is not defined, threshold must be defined"
|
26 |
+
|
27 |
+
S = similarity_matrix(embeds, metric=metric)
|
28 |
+
|
29 |
+
if n_clusters is None:
|
30 |
+
cluster_model = AgglomerativeClustering(
|
31 |
+
n_clusters=None,
|
32 |
+
affinity="precomputed",
|
33 |
+
linkage="average",
|
34 |
+
compute_full_tree=True,
|
35 |
+
distance_threshold=threshold,
|
36 |
+
)
|
37 |
+
|
38 |
+
return cluster_model.fit_predict(S)
|
39 |
+
else:
|
40 |
+
cluster_model = AgglomerativeClustering(
|
41 |
+
n_clusters=n_clusters, affinity="precomputed", linkage="average"
|
42 |
+
)
|
43 |
+
|
44 |
+
return cluster_model.fit_predict(S)
|
45 |
+
|
46 |
+
|
47 |
+
##########################################
|
48 |
+
# Spectral clustering
|
49 |
+
# A lot of these methods are lifted from
|
50 |
+
# https://github.com/wq2012/SpectralCluster
|
51 |
+
##########################################
|
52 |
+
|
53 |
+
|
54 |
+
def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs):
|
55 |
+
"""
|
56 |
+
Cluster embeds using Spectral Clustering
|
57 |
+
"""
|
58 |
+
if n_clusters is None:
|
59 |
+
assert threshold, "If num_clusters is not defined, threshold must be defined"
|
60 |
+
|
61 |
+
S = compute_affinity_matrix(embeds)
|
62 |
+
if enhance_sim:
|
63 |
+
S = sim_enhancement(S)
|
64 |
+
|
65 |
+
if n_clusters is None:
|
66 |
+
(eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S)
|
67 |
+
# Get number of clusters.
|
68 |
+
k = compute_number_of_clusters(eigenvalues, 100, threshold)
|
69 |
+
|
70 |
+
# Get spectral embeddings.
|
71 |
+
spectral_embeddings = eigenvectors[:, :k]
|
72 |
+
|
73 |
+
# Run K-Means++ on spectral embeddings.
|
74 |
+
# Note: The correct way should be using a K-Means implementation
|
75 |
+
# that supports customized distance measure such as cosine distance.
|
76 |
+
# This implemention from scikit-learn does NOT, which is inconsistent
|
77 |
+
# with the paper.
|
78 |
+
kmeans_clusterer = KMeans(
|
79 |
+
n_clusters=k, init="k-means++", max_iter=300, random_state=0
|
80 |
+
)
|
81 |
+
labels = kmeans_clusterer.fit_predict(spectral_embeddings)
|
82 |
+
return labels
|
83 |
+
else:
|
84 |
+
cluster_model = SpectralClustering(
|
85 |
+
n_clusters=n_clusters, affinity="precomputed"
|
86 |
+
)
|
87 |
+
|
88 |
+
return cluster_model.fit_predict(S)
|
89 |
+
|
90 |
+
|
91 |
+
def diagonal_fill(A):
|
92 |
+
"""
|
93 |
+
Sets the diagonal elemnts of the matrix to the max of each row
|
94 |
+
"""
|
95 |
+
np.fill_diagonal(A, 0.0)
|
96 |
+
A[np.diag_indices(A.shape[0])] = np.max(A, axis=1)
|
97 |
+
return A
|
98 |
+
|
99 |
+
|
100 |
+
def gaussian_blur(A, sigma=1.0):
|
101 |
+
"""
|
102 |
+
Does a gaussian blur on the affinity matrix
|
103 |
+
"""
|
104 |
+
return gaussian_filter(A, sigma=sigma)
|
105 |
+
|
106 |
+
|
107 |
+
def row_threshold_mult(A, p=0.95, mult=0.01):
|
108 |
+
"""
|
109 |
+
For each row multiply elements smaller than the row's p'th percentile by mult
|
110 |
+
"""
|
111 |
+
percentiles = np.percentile(A, p * 100, axis=1)
|
112 |
+
mask = A < percentiles[:, np.newaxis]
|
113 |
+
|
114 |
+
A = (mask * mult * A) + (~mask * A)
|
115 |
+
return A
|
116 |
+
|
117 |
+
|
118 |
+
def symmetrization(A):
|
119 |
+
"""
|
120 |
+
Symmeterization: Y_{i,j} = max(S_{ij}, S_{ji})
|
121 |
+
"""
|
122 |
+
return np.maximum(A, A.T)
|
123 |
+
|
124 |
+
|
125 |
+
def diffusion(A):
|
126 |
+
"""
|
127 |
+
Diffusion: Y <- YY^T
|
128 |
+
"""
|
129 |
+
return np.dot(A, A.T)
|
130 |
+
|
131 |
+
|
132 |
+
def row_max_norm(A):
|
133 |
+
"""
|
134 |
+
Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik})
|
135 |
+
"""
|
136 |
+
maxes = np.amax(A, axis=1)
|
137 |
+
return A / maxes
|
138 |
+
|
139 |
+
|
140 |
+
def sim_enhancement(A):
|
141 |
+
func_order = [
|
142 |
+
diagonal_fill,
|
143 |
+
gaussian_blur,
|
144 |
+
row_threshold_mult,
|
145 |
+
symmetrization,
|
146 |
+
diffusion,
|
147 |
+
row_max_norm,
|
148 |
+
]
|
149 |
+
for f in func_order:
|
150 |
+
A = f(A)
|
151 |
+
return A
|
152 |
+
|
153 |
+
|
154 |
+
def compute_affinity_matrix(X):
|
155 |
+
"""Compute the affinity matrix from data.
|
156 |
+
Note that the range of affinity is [0,1].
|
157 |
+
Args:
|
158 |
+
X: numpy array of shape (n_samples, n_features)
|
159 |
+
Returns:
|
160 |
+
affinity: numpy array of shape (n_samples, n_samples)
|
161 |
+
"""
|
162 |
+
# Normalize the data.
|
163 |
+
l2_norms = np.linalg.norm(X, axis=1)
|
164 |
+
X_normalized = X / l2_norms[:, None]
|
165 |
+
# Compute cosine similarities. Range is [-1,1].
|
166 |
+
cosine_similarities = np.matmul(X_normalized, np.transpose(X_normalized))
|
167 |
+
# Compute the affinity. Range is [0,1].
|
168 |
+
# Note that this step is not mentioned in the paper!
|
169 |
+
affinity = (cosine_similarities + 1.0) / 2.0
|
170 |
+
return affinity
|
171 |
+
|
172 |
+
|
173 |
+
def compute_sorted_eigenvectors(A):
|
174 |
+
"""Sort eigenvectors by the real part of eigenvalues.
|
175 |
+
Args:
|
176 |
+
A: the matrix to perform eigen analysis with shape (M, M)
|
177 |
+
Returns:
|
178 |
+
w: sorted eigenvalues of shape (M,)
|
179 |
+
v: sorted eigenvectors, where v[;, i] corresponds to ith largest
|
180 |
+
eigenvalue
|
181 |
+
"""
|
182 |
+
# Eigen decomposition.
|
183 |
+
eigenvalues, eigenvectors = np.linalg.eig(A)
|
184 |
+
eigenvalues = eigenvalues.real
|
185 |
+
eigenvectors = eigenvectors.real
|
186 |
+
# Sort from largest to smallest.
|
187 |
+
index_array = np.argsort(-eigenvalues)
|
188 |
+
# Re-order.
|
189 |
+
w = eigenvalues[index_array]
|
190 |
+
v = eigenvectors[:, index_array]
|
191 |
+
return w, v
|
192 |
+
|
193 |
+
|
194 |
+
def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2):
|
195 |
+
"""Compute number of clusters using EigenGap principle.
|
196 |
+
Args:
|
197 |
+
eigenvalues: sorted eigenvalues of the affinity matrix
|
198 |
+
max_clusters: max number of clusters allowed
|
199 |
+
stop_eigenvalue: we do not look at eigen values smaller than this
|
200 |
+
Returns:
|
201 |
+
number of clusters as an integer
|
202 |
+
"""
|
203 |
+
max_delta = 0
|
204 |
+
max_delta_index = 0
|
205 |
+
range_end = len(eigenvalues)
|
206 |
+
if max_clusters and max_clusters + 1 < range_end:
|
207 |
+
range_end = max_clusters + 1
|
208 |
+
for i in range(1, range_end):
|
209 |
+
if eigenvalues[i - 1] < stop_eigenvalue:
|
210 |
+
break
|
211 |
+
delta = eigenvalues[i - 1] / eigenvalues[i]
|
212 |
+
if delta > max_delta:
|
213 |
+
max_delta = delta
|
214 |
+
max_delta_index = i
|
215 |
+
return max_delta_index
|
216 |
+
|
217 |
+
|
218 |
+
class Diarizer:
|
219 |
+
def __init__(
|
220 |
+
self, device='cuda:0', embed_model="xvec", cluster_method="sc", window=1.5, period=0.75
|
221 |
+
):
|
222 |
+
self.device = device
|
223 |
+
assert embed_model in [
|
224 |
+
"xvec",
|
225 |
+
"ecapa",
|
226 |
+
], "Only xvec and ecapa are supported options"
|
227 |
+
assert cluster_method in [
|
228 |
+
"ahc",
|
229 |
+
"sc",
|
230 |
+
], "Only ahc and sc in the supported clustering options"
|
231 |
+
|
232 |
+
if cluster_method == "ahc":
|
233 |
+
self.cluster = cluster_AHC
|
234 |
+
if cluster_method == "sc":
|
235 |
+
self.cluster = cluster_SC
|
236 |
+
|
237 |
+
self.vad_model, self.get_speech_ts = self.setup_VAD()
|
238 |
+
|
239 |
+
self.run_opts = ({"device": self.device})
|
240 |
+
|
241 |
+
if embed_model == "ecapa":
|
242 |
+
self.embed_model = EncoderClassifier.from_hparams(
|
243 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
244 |
+
savedir="pretrained_models/spkrec-ecapa-voxceleb",
|
245 |
+
run_opts=self.run_opts,
|
246 |
+
)
|
247 |
+
|
248 |
+
self.window = window
|
249 |
+
self.period = period
|
250 |
+
|
251 |
+
def setup_VAD(self):
|
252 |
+
model, utils = torch.hub.load(
|
253 |
+
repo_or_dir="snakers4/silero-vad", model="silero_vad"
|
254 |
+
)
|
255 |
+
# force_reload=True)
|
256 |
+
|
257 |
+
get_speech_ts = utils[0]
|
258 |
+
return model, get_speech_ts
|
259 |
+
|
260 |
+
def vad(self, signal):
|
261 |
+
"""
|
262 |
+
Runs the VAD model on the signal
|
263 |
+
"""
|
264 |
+
return self.get_speech_ts(signal.to(self.device), self.vad_model.to(self.device))
|
265 |
+
|
266 |
+
def windowed_embeds(self, signal, fs, window=1.5, period=0.75):
|
267 |
+
"""
|
268 |
+
Calculates embeddings for windows across the signal
|
269 |
+
|
270 |
+
window: length of the window, in seconds
|
271 |
+
period: jump of the window, in seconds
|
272 |
+
|
273 |
+
returns: embeddings, segment info
|
274 |
+
"""
|
275 |
+
len_window = int(window * fs)
|
276 |
+
len_period = int(period * fs)
|
277 |
+
len_signal = signal.shape[1]
|
278 |
+
|
279 |
+
# Get the windowed segments
|
280 |
+
segments = []
|
281 |
+
start = 0
|
282 |
+
while start + len_window < len_signal:
|
283 |
+
segments.append([start, start + len_window])
|
284 |
+
start += len_period
|
285 |
+
|
286 |
+
segments.append([start, len_signal - 1])
|
287 |
+
embeds = []
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
for i, j in segments:
|
291 |
+
signal_seg = signal[:, i:j]
|
292 |
+
seg_embed = self.embed_model.encode_batch(signal_seg)
|
293 |
+
embeds.append(seg_embed.squeeze(0).squeeze(0).cpu().numpy())
|
294 |
+
|
295 |
+
embeds = np.array(embeds)
|
296 |
+
return embeds, np.array(segments)
|
297 |
+
|
298 |
+
def recording_embeds(self, signal, fs, speech_ts):
|
299 |
+
"""
|
300 |
+
Takes signal and VAD output (speech_ts) and produces windowed embeddings
|
301 |
+
|
302 |
+
returns: embeddings, segment info
|
303 |
+
"""
|
304 |
+
all_embeds = []
|
305 |
+
all_segments = []
|
306 |
+
|
307 |
+
for utt in speech_ts:
|
308 |
+
start = utt["start"]
|
309 |
+
end = utt["end"]
|
310 |
+
|
311 |
+
utt_signal = signal[:, start:end]
|
312 |
+
utt_embeds, utt_segments = self.windowed_embeds(
|
313 |
+
utt_signal, fs, self.window, self.period
|
314 |
+
)
|
315 |
+
all_embeds.append(utt_embeds)
|
316 |
+
all_segments.append(utt_segments + start)
|
317 |
+
|
318 |
+
all_embeds = np.concatenate(all_embeds, axis=0)
|
319 |
+
all_segments = np.concatenate(all_segments, axis=0)
|
320 |
+
return all_embeds, all_segments
|
321 |
+
|
322 |
+
@staticmethod
|
323 |
+
def join_segments(cluster_labels, segments, tolerance=5):
|
324 |
+
"""
|
325 |
+
Joins up same speaker segments, resolves overlap conflicts
|
326 |
+
|
327 |
+
Uses the midpoint for overlap conflicts
|
328 |
+
tolerance allows for very minimally separated segments to be combined
|
329 |
+
(in samples)
|
330 |
+
"""
|
331 |
+
assert len(cluster_labels) == len(segments)
|
332 |
+
|
333 |
+
new_segments = [
|
334 |
+
{"start": segments[0][0], "end": segments[0][1], "label": cluster_labels[0]}
|
335 |
+
]
|
336 |
+
|
337 |
+
for l, seg in zip(cluster_labels[1:], segments[1:]):
|
338 |
+
start = seg[0]
|
339 |
+
end = seg[1]
|
340 |
+
|
341 |
+
protoseg = {"start": seg[0], "end": seg[1], "label": l}
|
342 |
+
|
343 |
+
if start <= new_segments[-1]["end"]:
|
344 |
+
# If segments overlap
|
345 |
+
if l == new_segments[-1]["label"]:
|
346 |
+
# If overlapping segment has same label
|
347 |
+
new_segments[-1]["end"] = end
|
348 |
+
else:
|
349 |
+
# If overlapping segment has diff label
|
350 |
+
# Resolve by setting new start to midpoint
|
351 |
+
# And setting last segment end to midpoint
|
352 |
+
overlap = new_segments[-1]["end"] - start
|
353 |
+
midpoint = start + overlap // 2
|
354 |
+
new_segments[-1]["end"] = midpoint
|
355 |
+
protoseg["start"] = midpoint
|
356 |
+
new_segments.append(protoseg)
|
357 |
+
else:
|
358 |
+
# If there's no overlap just append
|
359 |
+
new_segments.append(protoseg)
|
360 |
+
|
361 |
+
return new_segments
|
362 |
+
|
363 |
+
@staticmethod
|
364 |
+
def make_output_seconds(cleaned_segments, fs):
|
365 |
+
"""
|
366 |
+
Convert cleaned segments to readable format in seconds
|
367 |
+
"""
|
368 |
+
for seg in cleaned_segments:
|
369 |
+
seg["start_sample"] = seg["start"]
|
370 |
+
seg["end_sample"] = seg["end"]
|
371 |
+
seg["start"] = seg["start"] / fs
|
372 |
+
seg["end"] = seg["end"] / fs
|
373 |
+
return cleaned_segments
|
374 |
+
|
375 |
+
def diarize(
|
376 |
+
self,
|
377 |
+
wav_file,
|
378 |
+
num_speakers=2,
|
379 |
+
threshold=None,
|
380 |
+
silence_tolerance=0.2,
|
381 |
+
enhance_sim=True,
|
382 |
+
extra_info=False,
|
383 |
+
outfile=None,
|
384 |
+
):
|
385 |
+
"""
|
386 |
+
Diarize a 16khz mono wav file, produces list of segments
|
387 |
+
|
388 |
+
Inputs:
|
389 |
+
wav_file (path): Path to input audio file
|
390 |
+
num_speakers (int) or NoneType: Number of speakers to cluster to
|
391 |
+
threshold (float) or NoneType: Threshold to cluster to if
|
392 |
+
num_speakers is not defined
|
393 |
+
silence_tolerance (float): Same speaker segments which are close enough together
|
394 |
+
by silence_tolerance will be joined into a single segment
|
395 |
+
enhance_sim (bool): Whether or not to perform affinity matrix enhancement
|
396 |
+
during spectral clustering
|
397 |
+
If self.cluster_method is 'ahc' this option does nothing.
|
398 |
+
extra_info (bool): Whether or not to return the embeddings and raw segments
|
399 |
+
in addition to segments
|
400 |
+
outfile (path): If specified will output an RTTM file
|
401 |
+
|
402 |
+
Outputs:
|
403 |
+
If extra_info is False:
|
404 |
+
segments (list): List of dicts with segment information
|
405 |
+
{
|
406 |
+
'start': Start time of segment in seconds,
|
407 |
+
'start_sample': Starting index of segment,
|
408 |
+
'end': End time of segment in seconds,
|
409 |
+
'end_sample' Ending index of segment,
|
410 |
+
'label': Cluster label of segment
|
411 |
+
}
|
412 |
+
If extra_info is True:
|
413 |
+
dict: { 'segments': segments (list): List of dicts with segment information
|
414 |
+
{
|
415 |
+
'start': Start time of segment in seconds,
|
416 |
+
'start_sample': Starting index of segment,
|
417 |
+
'end': End time of segment in seconds,
|
418 |
+
'end_sample' Ending index of segment,
|
419 |
+
'label': Cluster label of segment
|
420 |
+
},
|
421 |
+
'embeds': embeddings (np.array): Array of embeddings, each row corresponds to a segment,
|
422 |
+
'segments': segments (list): indexes for start and end frame for each embed in embeds,
|
423 |
+
'cluster_labels': cluster_labels (list): cluster label for each embed in embeds
|
424 |
+
}
|
425 |
+
|
426 |
+
Uses AHC/SC to cluster
|
427 |
+
"""
|
428 |
+
|
429 |
+
signal, fs = torchaudio.load(wav_file)
|
430 |
+
if len(signal) == 2:
|
431 |
+
signal = signal[:1, :]
|
432 |
+
if fs != 16000:
|
433 |
+
signal = torchaudio.functional.resample(signal, fs, 16000)
|
434 |
+
fs = 16000
|
435 |
+
|
436 |
+
speech_ts = self.vad(signal[0])
|
437 |
+
if len(speech_ts) >= 1:
|
438 |
+
|
439 |
+
embeds, segments = self.recording_embeds(signal, fs, speech_ts)
|
440 |
+
if len(embeds) > 1:
|
441 |
+
cluster_labels = self.cluster(
|
442 |
+
embeds,
|
443 |
+
n_clusters=num_speakers,
|
444 |
+
threshold=threshold,
|
445 |
+
enhance_sim=enhance_sim,
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
cluster_labels = np.zeros(len(embeds), dtype=int)
|
449 |
+
cleaned_segments = self.join_segments(cluster_labels, segments)
|
450 |
+
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
|
451 |
+
cleaned_segments = self.join_samespeaker_segments(
|
452 |
+
cleaned_segments, silence_tolerance=silence_tolerance
|
453 |
+
)
|
454 |
+
if outfile:
|
455 |
+
self.rttm_output(cleaned_segments, splitext(basename(wav_file))[0], outfile=outfile)
|
456 |
+
|
457 |
+
if not extra_info:
|
458 |
+
return cleaned_segments
|
459 |
+
else:
|
460 |
+
return {"clean_segments": cleaned_segments,
|
461 |
+
"embeds": embeds,
|
462 |
+
"segments": segments,
|
463 |
+
"cluster_labels": cluster_labels}
|
464 |
+
else:
|
465 |
+
print("Couldn't find any speech during VAD")
|
466 |
+
return {}
|
467 |
+
|
468 |
+
@staticmethod
|
469 |
+
def rttm_output(segments, recname, outfile=None):
|
470 |
+
assert outfile, "Please specify an outfile"
|
471 |
+
rttm_line = "SPEAKER {} 0 {} {} <NA> <NA> {} <NA> <NA>\n"
|
472 |
+
with open(outfile, "w") as fp:
|
473 |
+
for seg in segments:
|
474 |
+
start = seg["start"]
|
475 |
+
offset = seg["end"] - seg["start"]
|
476 |
+
label = seg["label"]
|
477 |
+
line = rttm_line.format(recname, start, offset, label)
|
478 |
+
fp.write(line)
|
479 |
+
|
480 |
+
@staticmethod
|
481 |
+
def join_samespeaker_segments(segments, silence_tolerance=0.5):
|
482 |
+
"""
|
483 |
+
Join up segments that belong to the same speaker,
|
484 |
+
even if there is a duration of silence in between them.
|
485 |
+
|
486 |
+
If the silence is greater than silence_tolerance, does not join up
|
487 |
+
"""
|
488 |
+
new_segments = [segments[0]]
|
489 |
+
|
490 |
+
for seg in segments[1:]:
|
491 |
+
if seg["label"] == new_segments[-1]["label"]:
|
492 |
+
if new_segments[-1]["end"] + silence_tolerance >= seg["start"]:
|
493 |
+
new_segments[-1]["end"] = seg["end"]
|
494 |
+
new_segments[-1]["end_sample"] = seg["end_sample"]
|
495 |
+
else:
|
496 |
+
new_segments.append(seg)
|
497 |
+
else:
|
498 |
+
new_segments.append(seg)
|
499 |
+
return new_segments
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
def match_diarization_to_transcript(segments, text_segments):
|
503 |
+
"""
|
504 |
+
Match the output of .diarize to word segments
|
505 |
+
"""
|
506 |
+
|
507 |
+
text_starts, text_ends, text_segs = [], [], []
|
508 |
+
for s in text_segments:
|
509 |
+
text_starts.append(s["start"])
|
510 |
+
text_ends.append(s["end"])
|
511 |
+
text_segs.append(s["text"])
|
512 |
+
|
513 |
+
text_starts = np.array(text_starts)
|
514 |
+
text_ends = np.array(text_ends)
|
515 |
+
text_segs = np.array(text_segs)
|
516 |
+
|
517 |
+
# Get the earliest start from either diar output or asr output
|
518 |
+
earliest_start = np.min([text_starts[0], segments[0]["start"]])
|
519 |
+
|
520 |
+
worded_segments = segments.copy()
|
521 |
+
worded_segments[0]["start"] = earliest_start
|
522 |
+
cutoffs = []
|
523 |
+
|
524 |
+
for seg in worded_segments:
|
525 |
+
end_idx = np.searchsorted(text_ends, seg["end"], side="left") - 1
|
526 |
+
cutoffs.append(end_idx)
|
527 |
+
|
528 |
+
indexes = [[0, cutoffs[0]]]
|
529 |
+
for c in cutoffs[1:]:
|
530 |
+
indexes.append([indexes[-1][-1], c])
|
531 |
+
|
532 |
+
indexes[-1][-1] = len(text_segs)
|
533 |
+
|
534 |
+
final_segments = []
|
535 |
+
|
536 |
+
for i, seg in enumerate(worded_segments):
|
537 |
+
s_idx, e_idx = indexes[i]
|
538 |
+
words = text_segs[s_idx:e_idx]
|
539 |
+
newseg = deepcopy(seg)
|
540 |
+
newseg["words"] = " ".join(words)
|
541 |
+
final_segments.append(newseg)
|
542 |
+
|
543 |
+
return final_segments
|
544 |
+
|
545 |
+
def match_diarization_to_transcript_ctm(self, segments, ctm_file):
|
546 |
+
"""
|
547 |
+
Match the output of .diarize to a ctm file produced by asr
|
548 |
+
"""
|
549 |
+
ctm_df = pd.read_csv(
|
550 |
+
ctm_file,
|
551 |
+
delimiter=" ",
|
552 |
+
names=["utt", "channel", "start", "offset", "word", "confidence"],
|
553 |
+
)
|
554 |
+
ctm_df["end"] = ctm_df["start"] + ctm_df["offset"]
|
555 |
+
|
556 |
+
starts = ctm_df["start"].values
|
557 |
+
ends = ctm_df["end"].values
|
558 |
+
words = ctm_df["word"].values
|
559 |
+
|
560 |
+
# Get the earliest start from either diar output or asr output
|
561 |
+
earliest_start = np.min([ctm_df["start"].values[0], segments[0]["start"]])
|
562 |
+
|
563 |
+
worded_segments = self.join_samespeaker_segments(segments)
|
564 |
+
worded_segments[0]["start"] = earliest_start
|
565 |
+
cutoffs = []
|
566 |
+
|
567 |
+
for seg in worded_segments:
|
568 |
+
end_idx = np.searchsorted(ctm_df["end"].values, seg["end"], side="left") - 1
|
569 |
+
cutoffs.append(end_idx)
|
570 |
+
|
571 |
+
indexes = [[0, cutoffs[0]]]
|
572 |
+
for c in cutoffs[1:]:
|
573 |
+
indexes.append([indexes[-1][-1], c])
|
574 |
+
|
575 |
+
indexes[-1][-1] = len(words)
|
576 |
+
|
577 |
+
final_segments = []
|
578 |
+
|
579 |
+
for i, seg in enumerate(worded_segments):
|
580 |
+
s_idx, e_idx = indexes[i]
|
581 |
+
words = ctm_df["word"].values[s_idx:e_idx]
|
582 |
+
seg["words"] = " ".join(words)
|
583 |
+
if len(words) >= 1:
|
584 |
+
final_segments.append(seg)
|
585 |
+
else:
|
586 |
+
print(
|
587 |
+
"Removed segment between {} and {} as no words were matched".format(
|
588 |
+
seg["start"], seg["end"]
|
589 |
+
)
|
590 |
+
)
|
591 |
+
|
592 |
+
return final_segments
|
593 |
+
|
594 |
+
@staticmethod
|
595 |
+
def nice_text_output(worded_segments, outfile):
|
596 |
+
with open(outfile, "w") as fp:
|
597 |
+
for seg in worded_segments:
|
598 |
+
fp.write(
|
599 |
+
"[{} to {}] Speaker {}: \n".format(
|
600 |
+
round(seg["start"], 2), round(seg["end"], 2), seg["label"]
|
601 |
+
)
|
602 |
+
)
|
603 |
+
fp.write("{}\n\n".format(seg["words"]))
|
604 |
|
605 |
|
606 |
class DiarizationPipeline:
|
607 |
+
def __init__(self, device=None):
|
608 |
super(DiarizationPipeline, self).__init__()
|
609 |
self.diar = Diarizer(
|
610 |
+
device=device,
|
611 |
embed_model='ecapa', # supported types: ['xvec', 'ecapa']
|
612 |
cluster_method='ahc', # supported types: ['ahc', 'sc']
|
613 |
window=1, # size of window to extract embeddings (in seconds)
|
614 |
period=0.1 # hop of window (in seconds)
|
615 |
)
|
616 |
|
617 |
+
def save_speaker_audios(self, segments: list, audio_path: str):
|
618 |
+
"""
|
619 |
+
|
620 |
+
:param segments: result diarization timestamps
|
621 |
+
:param audio_path:
|
622 |
+
:return: out_wav_paths: list of audio paths
|
623 |
+
"""
|
624 |
+
signal, sr = librosa.load(audio_path, sr=None, mono=True)
|
625 |
+
out_wav_paths = []
|
626 |
+
|
627 |
+
segments = pd.DataFrame(segments)
|
628 |
+
segments = self.filter_small_speech(segments)
|
629 |
|
630 |
+
sort_labels = segments.groupby(['label'])['duration'].sum().nlargest(len(set(segments.label))).index
|
631 |
|
632 |
+
for indx, label in enumerate(sort_labels):
|
633 |
+
temp_df = segments[segments.label == label]
|
634 |
+
output_signal = []
|
635 |
+
for _, r in temp_df.iterrows():
|
636 |
+
start = int(r["start"] * sr)
|
637 |
+
end = int(r["end"] * sr)
|
638 |
+
output_signal.append(signal[start:end])
|
639 |
+
out_wav_path = audio_path.replace('.wav', f'_{indx}.wav')
|
640 |
+
sf.write(out_wav_path, np.concatenate(output_signal), sr)
|
641 |
+
out_wav_paths.append(out_wav_path)
|
642 |
+
|
643 |
+
return out_wav_paths
|
644 |
+
|
645 |
+
def filter_small_speech(self, segments):
|
646 |
+
segments['duration'] = segments.end - segments.start
|
647 |
+
durs = segments.groupby('label').sum()
|
648 |
+
labels = durs[durs['duration'] / durs.sum()['duration'] > 0.015].index
|
649 |
+
return segments[segments.label.isin(labels)]
|
650 |
+
|
651 |
+
def __call__(self, input_wav_path: str)-> dict:
|
652 |
+
|
653 |
+
segments = self.diar.diarize(input_wav_path,
|
654 |
+
num_speakers=None,
|
655 |
+
threshold=9e-1, )
|
656 |
+
if segments != {}:
|
657 |
+
output_wav_paths = self.save_speaker_audios(segments, input_wav_path)
|
658 |
+
return {'count_speakers': max([i['label'] for i in segments]) + 1, 'diarization_segments': segments,
|
659 |
+
'output_diar_audio_paths': output_wav_paths}
|
660 |
+
else:
|
661 |
+
return {}
|
662 |
|
|
|
663 |
|
664 |
if __name__ == '__main__':
|
665 |
+
diarization = DiarizationPipeline(device='cuda:0')
|
666 |
+
|
667 |
+
diarization('../dialog.mp3')
|