update
Browse files- Dockerfile +2 -1
- __pycache__/create_DER.cpython-311.pyc +0 -0
- __pycache__/load_dataset.cpython-311.pyc +0 -0
- __pycache__/mix_sae.cpython-311.pyc +0 -0
- __pycache__/train_mix_sae.cpython-311.pyc +0 -0
- app.py +2 -0
- app_test.py +191 -0
- create_DER.py +232 -0
- load_dataset.py +114 -0
- mix_sae.py +672 -0
- segment_process.py +156 -0
- train_mix_sae.py +343 -0
- whisper/__pycache__/__init__.cpython-311.pyc +0 -0
- whisper/__pycache__/audio.cpython-311.pyc +0 -0
- whisper/__pycache__/decoding.cpython-311.pyc +0 -0
- whisper/__pycache__/model.cpython-311.pyc +0 -0
- whisper/__pycache__/timing.cpython-311.pyc +0 -0
- whisper/__pycache__/tokenizer.cpython-311.pyc +0 -0
- whisper/__pycache__/transcribe.cpython-311.pyc +0 -0
- whisper/__pycache__/utils.cpython-311.pyc +0 -0
- whisper/__pycache__/version.cpython-311.pyc +0 -0
Dockerfile
CHANGED
@@ -36,7 +36,8 @@ RUN pip install streamlit --timeout 500
|
|
36 |
RUN pip install ffmpeg-python --timeout 1000
|
37 |
RUN pip install toml
|
38 |
RUN pip install librosa
|
39 |
-
|
|
|
40 |
# RUN pip uninstall ffmpeg --yes
|
41 |
# RUN pip uninstall ffmpeg-python --yes
|
42 |
|
|
|
36 |
RUN pip install ffmpeg-python --timeout 1000
|
37 |
RUN pip install toml
|
38 |
RUN pip install librosa
|
39 |
+
RUN pip install pandas
|
40 |
+
RUN pip install pyannote-audio
|
41 |
# RUN pip uninstall ffmpeg --yes
|
42 |
# RUN pip uninstall ffmpeg-python --yes
|
43 |
|
__pycache__/create_DER.cpython-311.pyc
ADDED
Binary file (8.37 kB). View file
|
|
__pycache__/load_dataset.cpython-311.pyc
ADDED
Binary file (7.01 kB). View file
|
|
__pycache__/mix_sae.cpython-311.pyc
ADDED
Binary file (33 kB). View file
|
|
__pycache__/train_mix_sae.cpython-311.pyc
ADDED
Binary file (8.73 kB). View file
|
|
app.py
CHANGED
@@ -7,6 +7,8 @@ import whisper
|
|
7 |
from sklearn.cluster import AgglomerativeClustering
|
8 |
import torch
|
9 |
import librosa
|
|
|
|
|
10 |
|
11 |
UPLOAD_FOLDER = "./uploads"
|
12 |
|
|
|
7 |
from sklearn.cluster import AgglomerativeClustering
|
8 |
import torch
|
9 |
import librosa
|
10 |
+
from mix_sae import *
|
11 |
+
|
12 |
|
13 |
UPLOAD_FOLDER = "./uploads"
|
14 |
|
app_test.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import whisper
|
5 |
+
from sklearn.cluster import AgglomerativeClustering
|
6 |
+
import torch
|
7 |
+
import librosa
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
from mix_sae import MoESparseAutoencodersCL
|
10 |
+
from load_dataset import AutoEncoderDataset
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
|
14 |
+
UPLOAD_FOLDER = "./uploads"
|
15 |
+
parser = argparse.ArgumentParser(description='Deep Clustering Network')
|
16 |
+
parser.add_argument('--input_dim', type=int, default=384,
|
17 |
+
help='input dimension')
|
18 |
+
# Model parameters
|
19 |
+
parser.add_argument('--lr', type=float, default=1e-3,
|
20 |
+
help='learning rate (default: 1e-4)')
|
21 |
+
parser.add_argument('--wd', type=float, default=1e-4,
|
22 |
+
help='weight decay (default: 5e-4)')
|
23 |
+
parser.add_argument('--batch-size', type=int, default=16,
|
24 |
+
help='input batch size for training')
|
25 |
+
parser.add_argument('--lamda', type=float, default=1,
|
26 |
+
help='coefficient of the reconstruction loss')
|
27 |
+
parser.add_argument('--beta', type=float, default=1,
|
28 |
+
help=('coefficient of the regularization term on '
|
29 |
+
'clustering'))
|
30 |
+
parser.add_argument('--hidden-dims', default=[256, 128, 64, 32],
|
31 |
+
help='learning rate (default: 1e-4)')
|
32 |
+
parser.add_argument('--latent_dim', type=int, default=2,
|
33 |
+
help='latent space dimension')
|
34 |
+
parser.add_argument('--n-clusters', type=int, default=2,
|
35 |
+
help='number of clusters in the latent space')
|
36 |
+
parser.add_argument('--input-dim', type=int, default=384,
|
37 |
+
help='input dimension')
|
38 |
+
parser.add_argument('--n-classes', type=int, default=2,
|
39 |
+
help='output dimension')
|
40 |
+
parser.add_argument('--pretrain_epochs', type=int, default=30,
|
41 |
+
help='pretraining step epochs')
|
42 |
+
parser.add_argument('--pretrain_epochs_main', type=int, default=30,
|
43 |
+
help='pretraining step epochs')
|
44 |
+
parser.add_argument('--pretrain', type=bool, default=True,
|
45 |
+
help='whether use pre-training')
|
46 |
+
parser.add_argument('--main_train_epochs', type=int, default=5,
|
47 |
+
help='main_train epochs')
|
48 |
+
parser.add_argument('--rho', type=float, default=0.2,
|
49 |
+
help='whether use pre-training')
|
50 |
+
parser.add_argument('--sparsity_param', type=float, default=0.1,
|
51 |
+
help='sparsity constract param')
|
52 |
+
parser.add_argument('--cl_loss_param', type=float, default=0.05,
|
53 |
+
help='clasification loss param')
|
54 |
+
args = parser.parse_args()
|
55 |
+
|
56 |
+
|
57 |
+
def allowed_file(filename):
|
58 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ['wav']
|
59 |
+
|
60 |
+
|
61 |
+
def process_wav(audio_file, speaker_number, model_type, run_device = 'cpu', sr = 16000):
|
62 |
+
embedding_dims = {"tiny": 384, 'small': 768, 'base': 512, 'medium':1024}
|
63 |
+
#---- get results from whisper model
|
64 |
+
whisper_model = whisper.load_model(model_type, run_device)
|
65 |
+
wp_results = whisper_model.transcribe(audio_file)
|
66 |
+
for ide in range(len(wp_results['segments'])):
|
67 |
+
del wp_results['segments'][ide]['seek']
|
68 |
+
del wp_results['segments'][ide]['tokens']
|
69 |
+
del wp_results['segments'][ide]['compression_ratio']
|
70 |
+
del wp_results['segments'][ide]['temperature']
|
71 |
+
del wp_results['segments'][ide]['avg_logprob']
|
72 |
+
del wp_results['segments'][ide]['no_speech_prob']
|
73 |
+
|
74 |
+
#---- solve each segment
|
75 |
+
segments = wp_results["segments"]
|
76 |
+
|
77 |
+
# >= 2 sentences
|
78 |
+
if len(segments) > 1:
|
79 |
+
embeddings = np.zeros(shape=(len(segments), embedding_dims[model_type]))
|
80 |
+
|
81 |
+
for i, segment in enumerate(segments):
|
82 |
+
start = int(segment["start"] * sr)
|
83 |
+
end = int(segment["end"] * sr)
|
84 |
+
|
85 |
+
# Extract a segment
|
86 |
+
audio = audio_file[start: end]
|
87 |
+
mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
|
88 |
+
|
89 |
+
#--- this code to create the correct shape of mel spectrogram
|
90 |
+
while True:
|
91 |
+
nF, nT = np.shape(mel)
|
92 |
+
if nT > 3000:
|
93 |
+
mel = mel[:,0:3000]
|
94 |
+
break
|
95 |
+
else:
|
96 |
+
mel = torch.cat((mel, mel), -1)
|
97 |
+
mel = torch.unsqueeze(mel, 0)
|
98 |
+
wp_emb = whisper_model.embed_audio(mel)
|
99 |
+
#print(np.shape(wp_emb))
|
100 |
+
|
101 |
+
emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0)
|
102 |
+
emb_1d = np.mean(emb_1d, axis=0)
|
103 |
+
#print(np.shape(emb_1d))
|
104 |
+
#exit()
|
105 |
+
embeddings[i] = emb_1d
|
106 |
+
embeddings= np.array(embeddings, dtype="f")
|
107 |
+
train_loader = AutoEncoderDataset(embeddings)
|
108 |
+
train_loader = DataLoader(train_loader, batch_size = args.batch_size, shuffle = False)
|
109 |
+
|
110 |
+
|
111 |
+
moe_cl = MoESparseAutoencodersCL(args=args)
|
112 |
+
mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader)
|
113 |
+
pre_label = moe_cl.psedo_label
|
114 |
+
mixture_moe_cl = moe_cl.main_training(train_loader)
|
115 |
+
moe_cl_pred = moe_cl.get_final_cluster(train_loader)
|
116 |
+
#--- clustering spk emb
|
117 |
+
# clustering = AgglomerativeClustering(speaker_number, compute_distances=True).fit(embeddings)
|
118 |
+
# labels = clustering.labels_
|
119 |
+
|
120 |
+
for i in range(len(segments)):
|
121 |
+
wp_results['segments'][i]["speaker"] = 'SPEAKER ' + str(pre_label[i] + 1)
|
122 |
+
|
123 |
+
# only one sentence
|
124 |
+
else:
|
125 |
+
wp_results['segments'][0]["speaker"] = 'SPEAKER 1'
|
126 |
+
|
127 |
+
return wp_results
|
128 |
+
|
129 |
+
|
130 |
+
def main():
|
131 |
+
|
132 |
+
title_style = """
|
133 |
+
<style>
|
134 |
+
.title {
|
135 |
+
text-align: center;
|
136 |
+
font-size: 40px;
|
137 |
+
}
|
138 |
+
</style>
|
139 |
+
"""
|
140 |
+
st.markdown(
|
141 |
+
title_style,
|
142 |
+
unsafe_allow_html=True
|
143 |
+
)
|
144 |
+
title = """
|
145 |
+
<h1 class = "title" >Telephone Calls Speaker Diarization</h1>
|
146 |
+
</div>
|
147 |
+
"""
|
148 |
+
st.markdown(title,
|
149 |
+
unsafe_allow_html=True)
|
150 |
+
# st.title("Speaker Diarization")
|
151 |
+
|
152 |
+
|
153 |
+
# Get user inputs
|
154 |
+
file = st.file_uploader("Upload a WAV file:", type=["wav"])
|
155 |
+
num_speakers = st.number_input("Number of speakers:", min_value=2, max_value=2)
|
156 |
+
|
157 |
+
model_list = ['tiny', 'small', 'base', 'medium']
|
158 |
+
model_type = st.selectbox("Select model type: ", model_list)
|
159 |
+
|
160 |
+
# Display the result
|
161 |
+
st.write("Your uploaded wav file: ")
|
162 |
+
st.audio(file, format = 'audio/wav')
|
163 |
+
if st.button("Submit"):
|
164 |
+
if file is not None:
|
165 |
+
|
166 |
+
# Read audio file using pydub
|
167 |
+
audio_file, _ = librosa.load(file, sr=16000)
|
168 |
+
|
169 |
+
# Process the uploaded file using the AI model
|
170 |
+
wp_results = process_wav(audio_file, num_speakers, model_type)
|
171 |
+
|
172 |
+
# Write result:
|
173 |
+
st.write("Segments:" )
|
174 |
+
for seg in wp_results['segments']:
|
175 |
+
seg['start'] = np.round(seg['start'], 1)
|
176 |
+
seg['end'] = np.round(seg['end'], 1)
|
177 |
+
st.write(seg)
|
178 |
+
st.write("Language: ", wp_results['language'])
|
179 |
+
st.write("Full text:")
|
180 |
+
st.write(wp_results['text'])
|
181 |
+
else:
|
182 |
+
print("Error")
|
183 |
+
st.write("\n\n---\n\n")
|
184 |
+
st.write("Built with Docker and Streamlit")
|
185 |
+
st.link_button("Paper link: https://arxiv.org/abs/2407.01963", "https://arxiv.org/abs/2407.01963")
|
186 |
+
return
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
main()
|
create_DER.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import simpleder
|
6 |
+
from pyannote.metrics.diarization import DiarizationErrorRate
|
7 |
+
from pyannote.core import Segment, Annotation
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("--label_dir", type = str, default="./datasets/spanish/human_label", help= "rttm label dir")
|
13 |
+
|
14 |
+
opt = parser.parse_args()
|
15 |
+
LABEL_DIR = opt.label_dir
|
16 |
+
column = ['speaker','file_name','number' ,'start', 'duration', 'na1', 'na2', 'label', 'na3', 'na4']
|
17 |
+
|
18 |
+
|
19 |
+
def createDER(label_path, sample_dir, prediction, window_length = 0.5, overlap = 0.0):
|
20 |
+
"""
|
21 |
+
Extract series from label and prediction for calculating DER
|
22 |
+
"""
|
23 |
+
|
24 |
+
# df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
|
25 |
+
# ref = []
|
26 |
+
|
27 |
+
# prev_end = 0
|
28 |
+
# # Assign label
|
29 |
+
# for row in df.iterrows():
|
30 |
+
# row_item = row[1]
|
31 |
+
# start = np.round(row_item['start'], 2)
|
32 |
+
# end = np.round(row_item['start'] + row_item['duration'], 2)
|
33 |
+
# # Avoid overlap
|
34 |
+
# if start < prev_end:
|
35 |
+
# start = prev_end
|
36 |
+
# # Avoid error label
|
37 |
+
# if start > end:
|
38 |
+
# continue
|
39 |
+
# ref.append((row_item['label'], start, end))
|
40 |
+
# prev_end = end
|
41 |
+
|
42 |
+
|
43 |
+
df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
|
44 |
+
refer = Annotation(uri='label')
|
45 |
+
|
46 |
+
# Assign label
|
47 |
+
prev_end = 0
|
48 |
+
for row in df.iterrows():
|
49 |
+
row_item = row[1]
|
50 |
+
start = np.round(row_item['start'], 2)
|
51 |
+
end = np.round(row_item['start'] + row_item['duration'], 2)
|
52 |
+
# Avoid overlap
|
53 |
+
if start < prev_end:
|
54 |
+
start = prev_end
|
55 |
+
# Avoid error label
|
56 |
+
if start > end:
|
57 |
+
continue
|
58 |
+
refer[Segment(start, end)] = row_item['label']
|
59 |
+
prev_end = end
|
60 |
+
print("******EXTRACT LABEL DONE***********")
|
61 |
+
|
62 |
+
# assert len(os.listdir(sample_dir)) == len(prediction)
|
63 |
+
segment_list = sorted(os.listdir(sample_dir), key= lambda x: float(x.split("_")[-2]))
|
64 |
+
|
65 |
+
# Create index mapping to store start-end index of consecutive segments
|
66 |
+
index_mapping = {}
|
67 |
+
start_index = 0
|
68 |
+
current_value = prediction[0]
|
69 |
+
|
70 |
+
for i in range(1, len(prediction)):
|
71 |
+
if prediction[i] != current_value:
|
72 |
+
index_mapping[(start_index, i - 1)] = current_value
|
73 |
+
start_index = i
|
74 |
+
current_value = prediction[i]
|
75 |
+
|
76 |
+
# Handle the last consecutive sequence
|
77 |
+
index_mapping[(start_index, len(prediction) - 1)] = current_value
|
78 |
+
|
79 |
+
|
80 |
+
# Assign label to consecutive segments
|
81 |
+
hyp = []
|
82 |
+
for key, value in index_mapping.items():
|
83 |
+
start_index = key[0]
|
84 |
+
end_index = key[1]
|
85 |
+
speaker_label = "spk0{}".format(value)
|
86 |
+
if overlap != 0:
|
87 |
+
start_time = np.round(overlap * start_index, 2)
|
88 |
+
if start_index == end_index:
|
89 |
+
end_time = np.round(start_time + window_length, 2)
|
90 |
+
else:
|
91 |
+
end_time = np.round(overlap * end_index + window_length, 2)
|
92 |
+
# Non-overlap
|
93 |
+
else:
|
94 |
+
start_time = np.round(window_length * start_index, 2)
|
95 |
+
if start_index == end_index:
|
96 |
+
end_time = np.round(start_time + window_length, 2)
|
97 |
+
else:
|
98 |
+
end_time = np.round((end_index + 1) * window_length, 2)
|
99 |
+
|
100 |
+
hyp.append((speaker_label, start_time, end_time))
|
101 |
+
|
102 |
+
|
103 |
+
hypo = Annotation(uri='hypo')
|
104 |
+
for item in hyp:
|
105 |
+
hypo[Segment(item[1], item[2])] = item[0]
|
106 |
+
|
107 |
+
print("******EXTRACT HYP DONE***********")
|
108 |
+
|
109 |
+
return refer, hypo
|
110 |
+
|
111 |
+
|
112 |
+
def create_DER_pyannote(label_path, pyannote_label_path):
|
113 |
+
# df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
|
114 |
+
|
115 |
+
# ref = []
|
116 |
+
# prev_end = 0
|
117 |
+
# # Assign label
|
118 |
+
# for row in df.iterrows():
|
119 |
+
# row_item = row[1]
|
120 |
+
# start = np.round(row_item['start'], 2)
|
121 |
+
# end = np.round(row_item['start'] + row_item['duration'], 2)
|
122 |
+
# # Avoid overlap
|
123 |
+
# if start < prev_end:
|
124 |
+
# start = prev_end
|
125 |
+
# # Avoid error label
|
126 |
+
# if start > end:
|
127 |
+
# continue
|
128 |
+
# ref.append((row_item['label'], start, end))
|
129 |
+
# prev_end = end
|
130 |
+
df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
|
131 |
+
refer = Annotation(uri='label')
|
132 |
+
|
133 |
+
# Assign label
|
134 |
+
for row in df.iterrows():
|
135 |
+
row_item = row[1]
|
136 |
+
start = np.round(row_item['start'], 2)
|
137 |
+
end = np.round(row_item['start'] + row_item['duration'], 2)
|
138 |
+
# # Avoid overlap
|
139 |
+
# if start < prev_end:
|
140 |
+
# start = prev_end
|
141 |
+
# # Avoid error label
|
142 |
+
# if start > end:
|
143 |
+
# continue
|
144 |
+
refer[Segment(start, end)] = row_item['label']
|
145 |
+
# ref.append((row_item['label'], start, end))
|
146 |
+
# prev_end = end
|
147 |
+
|
148 |
+
print("******EXTRACT LABEL DONE*****c******")
|
149 |
+
|
150 |
+
df = pd.read_csv(pyannote_label_path, delimiter=' ', header=None, usecols=column, names=column)
|
151 |
+
print(df)
|
152 |
+
pyannote_ref = []
|
153 |
+
prev_end = 0
|
154 |
+
# Assign label
|
155 |
+
for row in df.iterrows():
|
156 |
+
row_item = row[1]
|
157 |
+
start = np.round(row_item['start'], 2)
|
158 |
+
end = np.round(row_item['start'] + row_item['duration'], 2)
|
159 |
+
# Avoid overlap
|
160 |
+
if start < prev_end:
|
161 |
+
start = prev_end
|
162 |
+
# Avoid error label
|
163 |
+
if start > end:
|
164 |
+
continue
|
165 |
+
pyannote_ref.append((row_item['label'], start, end))
|
166 |
+
prev_end = end
|
167 |
+
print("******EXTRACT PYANNOTE LABEL DONE***********")
|
168 |
+
return refer, pyannote_ref
|
169 |
+
|
170 |
+
|
171 |
+
def create_pyannote_timeline(label_path, pyannote_label_path):
|
172 |
+
df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
|
173 |
+
refer = Annotation(uri='label')
|
174 |
+
# ref = []
|
175 |
+
# prev_end = 0
|
176 |
+
# Assign label
|
177 |
+
for row in df.iterrows():
|
178 |
+
row_item = row[1]
|
179 |
+
start = np.round(row_item['start'], 2)
|
180 |
+
end = np.round(row_item['start'] + row_item['duration'], 2)
|
181 |
+
# # Avoid overlap
|
182 |
+
# if start < prev_end:
|
183 |
+
# start = prev_end
|
184 |
+
# # Avoid error label
|
185 |
+
# if start > end:
|
186 |
+
# continue
|
187 |
+
refer[Segment(start, end)] = row_item['label']
|
188 |
+
# ref.append((row_item['label'], start, end))
|
189 |
+
# prev_end = end
|
190 |
+
|
191 |
+
print("******EXTRACT LABEL DONE***********")
|
192 |
+
|
193 |
+
df = pd.read_csv(pyannote_label_path, delimiter=' ', header=None, usecols=column, names=column)
|
194 |
+
py_refer = Annotation(uri='py_label')
|
195 |
+
ref = []
|
196 |
+
# prev_end = 0
|
197 |
+
# Assign label
|
198 |
+
for row in df.iterrows():
|
199 |
+
row_item = row[1]
|
200 |
+
start = np.round(row_item['start'], 2)
|
201 |
+
end = np.round(row_item['start'] + row_item['duration'], 2)
|
202 |
+
# # Avoid overlap
|
203 |
+
# if start < prev_end:
|
204 |
+
# start = prev_end
|
205 |
+
# # Avoid error label
|
206 |
+
# if start > end:
|
207 |
+
# continue
|
208 |
+
py_refer[Segment(start, end)] = row_item['label']
|
209 |
+
# ref.append((row_item['label'], start, end))
|
210 |
+
# prev_end = end
|
211 |
+
|
212 |
+
print("******EXTRACT PY LABEL DONE***********")
|
213 |
+
|
214 |
+
return refer, py_refer
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
label_dir = "datasets/spanish/human_label"
|
220 |
+
py_label_dir = "datasets/spanish/label"
|
221 |
+
label_list = sorted(os.listdir(label_dir))[:11]
|
222 |
+
py_label_list = sorted(os.listdir(py_label_dir))[:11]
|
223 |
+
with open("./compare_py_label_PY.txt", "w") as file:
|
224 |
+
for label, py_label in zip(label_list, py_label_list):
|
225 |
+
label_path = label_dir + "/" + label
|
226 |
+
py_label_path = py_label_dir + '/' + py_label
|
227 |
+
ref, py_ref = create_pyannote_timeline(label_path=label_path, pyannote_label_path=py_label_path)
|
228 |
+
der = DiarizationErrorRate(collar=0.0, skip_overlap=False)
|
229 |
+
error = der(ref, py_ref)
|
230 |
+
file.write(str(label) + "PYANNOTE err:" + str(error) + "\n")
|
231 |
+
file.close()
|
232 |
+
|
load_dataset.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
|
8 |
+
class CustomDataset(torch.utils.data.Dataset):
|
9 |
+
def __init__(self, sample_dir, embed_dim = 384, train = False, time = 2):
|
10 |
+
self.sample_dir = sample_dir
|
11 |
+
self.n_segments = len(os.listdir(self.sample_dir))
|
12 |
+
self.data = np.zeros((self.n_segments, embed_dim))
|
13 |
+
|
14 |
+
# Sorted segment based on start_time
|
15 |
+
self.sorted_segments = sorted(os.listdir(sample_dir), key= lambda x: float(x.split("_")[-2]))
|
16 |
+
|
17 |
+
# Assign segments
|
18 |
+
for idx, segment_npy in enumerate(self.sorted_segments):
|
19 |
+
segment_path = self.sample_dir + "/" + segment_npy
|
20 |
+
segment_embed = np.load(segment_path)
|
21 |
+
self.data[idx] = segment_embed
|
22 |
+
|
23 |
+
if train:
|
24 |
+
for time in range(time):
|
25 |
+
self.data = np.concatenate((self.data, self.data), axis = 0)
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.data)
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
sample = torch.from_numpy(self.data[idx]).float()
|
32 |
+
return sample
|
33 |
+
|
34 |
+
|
35 |
+
class AutoEncoderDataset(torch.utils.data.Dataset):
|
36 |
+
"""
|
37 |
+
Create dataset from predefined tensor for each autoencoder in MOE
|
38 |
+
"""
|
39 |
+
def __init__(self, data):
|
40 |
+
self.data = data
|
41 |
+
def __len__(self):
|
42 |
+
return len(self.data)
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
return self.data[idx]
|
45 |
+
|
46 |
+
|
47 |
+
import argparse
|
48 |
+
parser = argparse.ArgumentParser(description='Deep Clustering Network')
|
49 |
+
|
50 |
+
# Dataset parameters
|
51 |
+
parser.add_argument('--dir', default='./datasets/spanish/',
|
52 |
+
help='dataset directory')
|
53 |
+
parser.add_argument('--input_dim', type=int, default=384,
|
54 |
+
help='input dimension')
|
55 |
+
parser.add_argument('--n-classes', type=int, default=2,
|
56 |
+
help='output dimension')
|
57 |
+
|
58 |
+
# Training parameters
|
59 |
+
parser.add_argument('--lr', type=float, default=1e-3,
|
60 |
+
help='learning rate (default: 1e-4)')
|
61 |
+
parser.add_argument('--wd', type=float, default=1e-4,
|
62 |
+
help='weight decay (default: 5e-4)')
|
63 |
+
parser.add_argument('--batch-size', type=int, default=16,
|
64 |
+
help='input batch size for training')
|
65 |
+
parser.add_argument('--epoch', type=int, default=50,
|
66 |
+
help='number of epochs to train')
|
67 |
+
parser.add_argument('--pre-epoch', type=int, default=100,
|
68 |
+
help='number of pre-train epochs')
|
69 |
+
parser.add_argument('--pretrain', type=bool, default=True,
|
70 |
+
help='whether use pre-training')
|
71 |
+
|
72 |
+
# Model parameters
|
73 |
+
parser.add_argument('--lamda', type=float, default=1,
|
74 |
+
help='coefficient of the reconstruction loss')
|
75 |
+
parser.add_argument('--beta', type=float, default=1,
|
76 |
+
help=('coefficient of the regularization term on '
|
77 |
+
'clustering'))
|
78 |
+
parser.add_argument('--hidden-dims', default=[256, 128, 64, 32, 16],
|
79 |
+
help='learning rate (default: 1e-4)')
|
80 |
+
parser.add_argument('--latent_dim', type=int, default=2,
|
81 |
+
help='latent space dimension')
|
82 |
+
parser.add_argument('--n-clusters', type=int, default=2,
|
83 |
+
help='number of clusters in the latent space')
|
84 |
+
|
85 |
+
parser.add_argument('--n_1Dconv', type=int, default=4,
|
86 |
+
help='n_1dconv')
|
87 |
+
parser.add_argument('--kernel_size', default=[7, 5, 3, 3],
|
88 |
+
help='kernel_size')
|
89 |
+
parser.add_argument('--stride', type = int, default=1,
|
90 |
+
help='stride')
|
91 |
+
parser.add_argument('--num_blocks', type = int, default=4,
|
92 |
+
help='num_blocks')
|
93 |
+
parser.add_argument('--channels', type = int, default=[128, 64, 32, 16],
|
94 |
+
help='channels')
|
95 |
+
|
96 |
+
# Utility parameters
|
97 |
+
parser.add_argument('--n-jobs', type=int, default=1,
|
98 |
+
help='number of jobs to run in parallel')
|
99 |
+
parser.add_argument('--log-interval', type=int, default=20,
|
100 |
+
help=('how many batches to wait before logging the '
|
101 |
+
'training status'))
|
102 |
+
parser.add_argument("--window_length", type = float, default= 0.4, help="window length")
|
103 |
+
parser.add_argument("--overlap", type = float, default= 0, help="overlap")
|
104 |
+
|
105 |
+
|
106 |
+
args = parser.parse_args()
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
# Example usage:
|
110 |
+
sample_dir = "datasets/spanish/segments/0096_[cut_193sec].wav"
|
111 |
+
dataset = CustomDataset(sample_dir=sample_dir,train= False)
|
112 |
+
# dataset = CustomDataset(sample_dir=sample_dir,train= False)
|
113 |
+
|
114 |
+
|
mix_sae.py
ADDED
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from collections import OrderedDict
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import torch.nn.init as init
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.cluster import KMeans, SpectralClustering
|
8 |
+
from load_dataset import AutoEncoderDataset
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from load_dataset import *
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
#----------------------VERSION 3: SPARSE AUTOENDCODER KL PENALTY AND ENTROPY LOSS--------------------
|
14 |
+
import random
|
15 |
+
|
16 |
+
random.seed(10)
|
17 |
+
|
18 |
+
parser = argparse.ArgumentParser(description='Deep Clustering Network')
|
19 |
+
parser.add_argument('--input_dim', type=int, default=384,
|
20 |
+
help='input dimension')
|
21 |
+
# Model parameters
|
22 |
+
parser.add_argument('--lr', type=float, default=1e-3,
|
23 |
+
help='learning rate (default: 1e-4)')
|
24 |
+
parser.add_argument('--wd', type=float, default=1e-4,
|
25 |
+
help='weight decay (default: 5e-4)')
|
26 |
+
parser.add_argument('--batch-size', type=int, default=16,
|
27 |
+
help='input batch size for training')
|
28 |
+
parser.add_argument('--lamda', type=float, default=1,
|
29 |
+
help='coefficient of the reconstruction loss')
|
30 |
+
parser.add_argument('--beta', type=float, default=1,
|
31 |
+
help=('coefficient of the regularization term on '
|
32 |
+
'clustering'))
|
33 |
+
parser.add_argument('--hidden-dims', default=[256, 128, 64, 32],
|
34 |
+
help='learning rate (default: 1e-4)')
|
35 |
+
parser.add_argument('--latent_dim', type=int, default=2,
|
36 |
+
help='latent space dimension')
|
37 |
+
parser.add_argument('--n-clusters', type=int, default=2,
|
38 |
+
help='number of clusters in the latent space')
|
39 |
+
parser.add_argument('--input-dim', type=int, default=384,
|
40 |
+
help='input dimension')
|
41 |
+
parser.add_argument('--n-classes', type=int, default=2,
|
42 |
+
help='output dimension')
|
43 |
+
parser.add_argument('--pretrain_epochs', type=int, default=80,
|
44 |
+
help='pretraining step epochs')
|
45 |
+
parser.add_argument('--pretrain_epochs_main', type=int, default=80,
|
46 |
+
help='pretraining step epochs')
|
47 |
+
parser.add_argument('--pretrain', type=bool, default=True,
|
48 |
+
help='whether use pre-training')
|
49 |
+
parser.add_argument('--main_train_epochs', type=int, default=80,
|
50 |
+
help='main_train epochs')
|
51 |
+
parser.add_argument('--rho', type=float, default=0.2,
|
52 |
+
help='whether use pre-training')
|
53 |
+
parser.add_argument('--sparsity_param', type=float, default=0.1,
|
54 |
+
help='sparsity constract param')
|
55 |
+
parser.add_argument('--cl_loss_param', type=float, default=0.05,
|
56 |
+
help='clasification loss param')
|
57 |
+
args = parser.parse_args()
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
class AutoEncoder(nn.Module):
|
63 |
+
|
64 |
+
def __init__(self, args):
|
65 |
+
super(AutoEncoder, self).__init__()
|
66 |
+
self.args = args
|
67 |
+
self.input_dim = args.input_dim
|
68 |
+
self.output_dim = self.input_dim
|
69 |
+
self.hidden_dims = args.hidden_dims
|
70 |
+
self.hidden_dims.append(args.latent_dim)
|
71 |
+
self.dims_list = (args.hidden_dims +
|
72 |
+
args.hidden_dims[:-1][::-1]) # mirrored structure
|
73 |
+
self.n_layers = len(self.dims_list)
|
74 |
+
self.latent_dim = args.latent_dim
|
75 |
+
self.n_clusters = args.n_clusters
|
76 |
+
self.RHO = args.rho
|
77 |
+
|
78 |
+
# Validation check
|
79 |
+
assert self.n_layers % 2 > 0
|
80 |
+
assert self.dims_list[self.n_layers // 2] == self.latent_dim
|
81 |
+
|
82 |
+
# Encoder Network
|
83 |
+
layers = OrderedDict()
|
84 |
+
for idx, hidden_dim in enumerate(self.hidden_dims):
|
85 |
+
if idx == 0:
|
86 |
+
layers.update(
|
87 |
+
{
|
88 |
+
'linear0': nn.Linear(self.input_dim, hidden_dim),
|
89 |
+
# 'linear0': CustomDense(self.input_dim, hidden_dim),
|
90 |
+
# 'activation0': nn.LeakyReLU()
|
91 |
+
# 'activation0': nn.ReLU()
|
92 |
+
}
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
layers.update(
|
96 |
+
{
|
97 |
+
'linear{}'.format(idx): nn.Linear(
|
98 |
+
self.hidden_dims[idx-1], hidden_dim),
|
99 |
+
# 'linear{}'.format(idx): CustomDense(self.hidden_dims[idx-1], hidden_dim),
|
100 |
+
# 'activation{}'.format(idx): nn.LeakyReLU(),
|
101 |
+
# 'activation{}'.format(idx): nn.ELU(),
|
102 |
+
'activation{}'.format(idx): nn.LeakyReLU(),
|
103 |
+
# 'dropout{}'.format(idx): nn.Dropout(0.5),
|
104 |
+
'bn{}'.format(idx): nn.BatchNorm1d(
|
105 |
+
self.hidden_dims[idx]),
|
106 |
+
|
107 |
+
|
108 |
+
# 'bn{}'.format(idx): nn.BatchNorm1d(
|
109 |
+
# self.hidden_dims[idx])
|
110 |
+
}
|
111 |
+
)
|
112 |
+
self.encoder = nn.Sequential(layers)
|
113 |
+
|
114 |
+
# Decoder Network
|
115 |
+
layers = OrderedDict()
|
116 |
+
tmp_hidden_dims = self.hidden_dims[::-1]
|
117 |
+
for idx, hidden_dim in enumerate(tmp_hidden_dims):
|
118 |
+
if idx == len(tmp_hidden_dims) - 1:
|
119 |
+
layers.update(
|
120 |
+
{
|
121 |
+
'linear{}'.format(idx): nn.Linear(
|
122 |
+
hidden_dim, self.output_dim),
|
123 |
+
# 'activation{}'.format(idx):nn.ReLU()
|
124 |
+
# 'activation{}'.format(idx): nn.LeakyReLU(),
|
125 |
+
# 'activation{}'.format(idx): nn.ELU(),
|
126 |
+
# 'linear{}'.format(idx): CustomDense(hidden_dim, self.output_dim),
|
127 |
+
}
|
128 |
+
)
|
129 |
+
else:
|
130 |
+
layers.update(
|
131 |
+
{
|
132 |
+
'linear{}'.format(idx): nn.Linear(
|
133 |
+
hidden_dim, tmp_hidden_dims[idx+1]),
|
134 |
+
# 'linear{}'.format(idx): CustomDense(
|
135 |
+
# hidden_dim, tmp_hidden_dims[idx+1]),
|
136 |
+
# 'activation{}'.format(idx): nn.ELU(),
|
137 |
+
'activation{}'.format(idx): nn.LeakyReLU(),
|
138 |
+
# 'dropout{}'.format(idx): nn.Dropout(0.5),
|
139 |
+
'bn{}'.format(idx): nn.BatchNorm1d(
|
140 |
+
tmp_hidden_dims[idx+1]),
|
141 |
+
# 'activation{}'.format(idx): nn.ELU(),
|
142 |
+
|
143 |
+
# 'bn{}'.format(idx): nn.BatchNorm1d(
|
144 |
+
# tmp_hidden_dims[idx+1])
|
145 |
+
}
|
146 |
+
)
|
147 |
+
self.decoder = nn.Sequential(layers)
|
148 |
+
# Apply Xavier weight initialization to all linear layers
|
149 |
+
for m in self.modules():
|
150 |
+
if isinstance(m, nn.Linear):
|
151 |
+
init.xavier_normal_(m.weight)
|
152 |
+
init.constant_(m.bias, 0) # Initialize biases to 0
|
153 |
+
def __repr__(self):
|
154 |
+
repr_str = '[Structure]: {}-'.format(self.input_dim)
|
155 |
+
for idx, dim in enumerate(self.dims_list):
|
156 |
+
repr_str += '{}-'.format(dim)
|
157 |
+
repr_str += str(self.output_dim) + '\n'
|
158 |
+
repr_str += '[n_layers]: {}'.format(self.n_layers) + '\n'
|
159 |
+
repr_str += '[n_clusters]: {}'.format(self.n_clusters) + '\n'
|
160 |
+
repr_str += '[input_dims]: {}'.format(self.input_dim)
|
161 |
+
return repr_str
|
162 |
+
|
163 |
+
def __str__(self):
|
164 |
+
return self.__repr__()
|
165 |
+
|
166 |
+
def forward(self, X, latent=False):
|
167 |
+
output = self.encoder(X)
|
168 |
+
if latent:
|
169 |
+
return output
|
170 |
+
return self.decoder(output)
|
171 |
+
|
172 |
+
|
173 |
+
class VAE(nn.Module):
|
174 |
+
def __init__(self, args):
|
175 |
+
super(VAE, self).__init__()
|
176 |
+
self.args = args
|
177 |
+
self.input_dim = args.input_dim
|
178 |
+
self.output_dim = self.input_dim
|
179 |
+
self.hidden_dims = args.hidden_dims
|
180 |
+
self.latent_dim = args.latent_dim
|
181 |
+
self.n_clusters = args.n_clusters
|
182 |
+
|
183 |
+
# Encoder Network
|
184 |
+
layers = OrderedDict()
|
185 |
+
for idx, hidden_dim in enumerate(self.hidden_dims):
|
186 |
+
if idx == 0:
|
187 |
+
layers.update(
|
188 |
+
{
|
189 |
+
'linear0': nn.Linear(self.input_dim, hidden_dim),
|
190 |
+
|
191 |
+
}
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
layers.update(
|
195 |
+
{
|
196 |
+
'linear{}'.format(idx): nn.Linear(
|
197 |
+
self.hidden_dims[idx-1], hidden_dim),
|
198 |
+
|
199 |
+
'activation{}'.format(idx): nn.ReLU(),
|
200 |
+
|
201 |
+
'bn{}'.format(idx): nn.BatchNorm1d(
|
202 |
+
self.hidden_dims[idx])
|
203 |
+
}
|
204 |
+
)
|
205 |
+
self.encoder = nn.Sequential(layers)
|
206 |
+
|
207 |
+
# Decoder Network
|
208 |
+
layers = OrderedDict()
|
209 |
+
tmp_hidden_dims = self.hidden_dims[::-1]
|
210 |
+
for idx, hidden_dim in enumerate(tmp_hidden_dims):
|
211 |
+
if idx == len(tmp_hidden_dims) - 1:
|
212 |
+
layers.update(
|
213 |
+
{
|
214 |
+
'linear{}'.format(idx): nn.Linear(
|
215 |
+
hidden_dim, self.output_dim),
|
216 |
+
}
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
layers.update(
|
220 |
+
{
|
221 |
+
'linear{}'.format(idx): nn.Linear(
|
222 |
+
hidden_dim, tmp_hidden_dims[idx+1]),
|
223 |
+
'activation{}'.format(idx): nn.ReLU(),
|
224 |
+
'bn{}'.format(idx): nn.BatchNorm1d(
|
225 |
+
tmp_hidden_dims[idx+1])
|
226 |
+
}
|
227 |
+
)
|
228 |
+
self.decoder = nn.Sequential(layers)
|
229 |
+
self.fc_mu = nn.Linear(self.hidden_dims[-1], self.latent_dim)
|
230 |
+
self.fc_var = nn.Linear(self.hidden_dims[-1], self.latent_dim)
|
231 |
+
|
232 |
+
self.decode_input_linear = nn.Linear(self.latent_dim, self.hidden_dims[-1])
|
233 |
+
# Apply Xavier weight initialization to all linear layers
|
234 |
+
for m in self.modules():
|
235 |
+
if isinstance(m, nn.Linear):
|
236 |
+
init.xavier_normal_(m.weight)
|
237 |
+
init.constant_(m.bias, 0) # Initialize biases to 0
|
238 |
+
|
239 |
+
def encode(self, x):
|
240 |
+
x = self.encoder(x)
|
241 |
+
mu = self.fc_mu(x)
|
242 |
+
log_var = self.fc_var(x)
|
243 |
+
|
244 |
+
return [mu, log_var]
|
245 |
+
|
246 |
+
def decode(self, x):
|
247 |
+
x = self.decode_input_linear(x)
|
248 |
+
x = self.decoder(x)
|
249 |
+
return x
|
250 |
+
|
251 |
+
def reparameterize(self, mu, logvar):
|
252 |
+
"""
|
253 |
+
Reparameterization trick to sample from N(mu, var) from
|
254 |
+
N(0,1).
|
255 |
+
"""
|
256 |
+
std = torch.exp(0.5 * logvar)
|
257 |
+
eps = torch.randn_like(std)
|
258 |
+
return eps * std + mu
|
259 |
+
|
260 |
+
|
261 |
+
def loss_function(self, x_hat, x, mu, log_var, kld_weight = 1):
|
262 |
+
"""
|
263 |
+
Computes the VAE loss function.
|
264 |
+
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
|
265 |
+
"""
|
266 |
+
|
267 |
+
rec_loss = torch.nn.functional.mse_loss(x_hat, x)
|
268 |
+
|
269 |
+
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
|
270 |
+
|
271 |
+
loss = rec_loss + kld_weight * kld_loss
|
272 |
+
|
273 |
+
return [loss, rec_loss.detach(), -kld_loss.detach()]
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
"""
|
277 |
+
Forward VAE
|
278 |
+
Return: [output, input, mu, var]
|
279 |
+
"""
|
280 |
+
# Encoder
|
281 |
+
mu, log_var = self.encode(x)
|
282 |
+
# Sample
|
283 |
+
z = self.reparameterize(mu, log_var)
|
284 |
+
# Decoder
|
285 |
+
output = self.decode(z)
|
286 |
+
|
287 |
+
return [output, x, mu, log_var]
|
288 |
+
|
289 |
+
|
290 |
+
class ClusterNet(nn.Module):
|
291 |
+
|
292 |
+
def __init__(self, input_dim, hidden_dims = [128], n_clusters=2):
|
293 |
+
"""ClusterNet("""
|
294 |
+
super(ClusterNet, self).__init__()
|
295 |
+
layers = []
|
296 |
+
for i in range(len(hidden_dims)):
|
297 |
+
if i == 0:
|
298 |
+
layers.append(nn.Linear(input_dim, hidden_dims[i]))
|
299 |
+
layers.append(nn.LeakyReLU())
|
300 |
+
# layers.append(nn.Dropout(0.5))
|
301 |
+
# layers.append(nn.BatchNorm1d(hidden_dims[i])),
|
302 |
+
else:
|
303 |
+
layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])),
|
304 |
+
layers.append(nn.LeakyReLU())
|
305 |
+
# layers.append(nn.Dropout(0.5))
|
306 |
+
# layers.append(nn.BatchNorm1d(hidden_dims[i])),
|
307 |
+
# Last layer
|
308 |
+
layers.append(nn.Sequential(
|
309 |
+
nn.Flatten(),
|
310 |
+
nn.Linear(hidden_dims[-1], n_clusters),
|
311 |
+
nn.Softmax(dim = 1),
|
312 |
+
))
|
313 |
+
|
314 |
+
self.layers = nn.Sequential(*layers)
|
315 |
+
# Apply Xavier weight initialization to all linear layers
|
316 |
+
for m in self.modules():
|
317 |
+
if isinstance(m, nn.Linear):
|
318 |
+
init.xavier_normal_(m.weight)
|
319 |
+
init.constant_(m.bias, 0) # Initialize biases to 0
|
320 |
+
def forward(self, x):
|
321 |
+
"""Extract the feature vectors."""
|
322 |
+
features = x
|
323 |
+
for layer in self.layers:
|
324 |
+
features = layer(features)
|
325 |
+
return features
|
326 |
+
|
327 |
+
|
328 |
+
class MoESparseAutoencodersCL(nn.Module):
|
329 |
+
"""
|
330 |
+
Mixture of Expert DNN-Autoencoder
|
331 |
+
"""
|
332 |
+
def __init__(self, args):
|
333 |
+
super(MoESparseAutoencodersCL, self).__init__()
|
334 |
+
self.args = args
|
335 |
+
self.input_dim = args.input_dim
|
336 |
+
self.output_dim = self.input_dim
|
337 |
+
self.hidden_dims = args.hidden_dims
|
338 |
+
self.latent_dim = args.latent_dim
|
339 |
+
self.n_clusters = args.n_clusters
|
340 |
+
self.pretrain_epochs = args.pretrain_epochs
|
341 |
+
self.pretrain_epochs_main = args.pretrain_epochs_main
|
342 |
+
self.main_train_epochs = args.main_train_epochs
|
343 |
+
self.device = "cpu"
|
344 |
+
# Define main autoencoder at pretraining
|
345 |
+
self.main_autoencoder = AutoEncoder(args=args)
|
346 |
+
self.RHO = args.rho
|
347 |
+
self.BETA = args.sparsity_param
|
348 |
+
self.psedo_label = None
|
349 |
+
# Clustering algorithm for pre-training
|
350 |
+
self.cluster_algo = None
|
351 |
+
self.cl_loss_param = args.cl_loss_param
|
352 |
+
|
353 |
+
# Define autoencoder expert in mixture
|
354 |
+
self.moe = {}
|
355 |
+
for i in range(self.n_clusters):
|
356 |
+
self.moe[i] = AutoEncoder(args)
|
357 |
+
# Add cluster net (gating network) to moe
|
358 |
+
self.moe['cluster_net'] = ClusterNet(input_dim= self.input_dim, n_clusters=self.n_clusters)
|
359 |
+
|
360 |
+
|
361 |
+
def kl_divergence(self, rho, rho_hat):
|
362 |
+
rho_hat = torch.mean(F.sigmoid(rho_hat), 1) # sigmoid because we need the probability distributions
|
363 |
+
rho = torch.tensor([rho] * len(rho_hat)).to(self.device)
|
364 |
+
return torch.sum(rho * torch.log(rho/rho_hat) + (1 - rho) * torch.log((1 - rho)/(1 - rho_hat)))
|
365 |
+
|
366 |
+
# define the sparse loss function
|
367 |
+
def sparse_loss(self, rho, X, model):
|
368 |
+
values = X
|
369 |
+
loss = 0
|
370 |
+
model_children = list(model.children())
|
371 |
+
for i in range(len(model_children)):
|
372 |
+
values = model_children[i](values)
|
373 |
+
loss += self.kl_divergence(rho, values)
|
374 |
+
return loss / X.shape[0]
|
375 |
+
|
376 |
+
|
377 |
+
|
378 |
+
def batchwise_entropy_loss(self, cluster_outputs):
|
379 |
+
"""
|
380 |
+
Calculate batch wise entropy loss
|
381 |
+
"""
|
382 |
+
X = torch.mean(cluster_outputs, axis = 0)
|
383 |
+
return torch.special.entr(X).sum()
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
def loss_function(self, expert_outputs, cluster_net_outputs, X, psedo_label):
|
388 |
+
"""
|
389 |
+
Compute loss function in a batch
|
390 |
+
Loss = L - Beta * Entropy(cluster_net_outputs)
|
391 |
+
L = -log [p_i * exp (-(xhat_i - x_i) ** 2)]
|
392 |
+
"""
|
393 |
+
|
394 |
+
# Create one-hot psedo label
|
395 |
+
# print("Expert output" , expert_outputs)
|
396 |
+
encoded_arr = np.zeros((len(psedo_label), self.n_clusters), dtype=float)
|
397 |
+
for i in range(len(psedo_label)):
|
398 |
+
encoded_arr[i][psedo_label[i]] = 1
|
399 |
+
# print("Cluster network output", cluster_net_outputs)
|
400 |
+
# print("Encoded arr:", encoded_arr)
|
401 |
+
# Cross entropy loss
|
402 |
+
entropy_criterion = nn.CrossEntropyLoss()
|
403 |
+
entropy_loss = entropy_criterion(cluster_net_outputs, torch.tensor(psedo_label, dtype=torch.long))
|
404 |
+
# print("Entropy loss", entropy_loss)
|
405 |
+
|
406 |
+
|
407 |
+
# MOE reconstruction loss
|
408 |
+
loss = 0
|
409 |
+
for i in range(self.n_clusters):
|
410 |
+
mse = -((expert_outputs[i] - X)**2).mean(axis=1)
|
411 |
+
loss += cluster_net_outputs[:, i] * torch.exp(mse)
|
412 |
+
|
413 |
+
moe_loss = -torch.log(loss).sum()
|
414 |
+
# print('MOE loss', moe_loss)
|
415 |
+
return moe_loss - self.cl_loss_param * entropy_loss
|
416 |
+
# return moe_loss
|
417 |
+
|
418 |
+
|
419 |
+
def train_one_autoencoder(self, autoencoder, optimizer, criterion, data_loader, number_of_epochs, sparsity, rho, name='main', verbose=False):
|
420 |
+
"""
|
421 |
+
Training one autoencoder
|
422 |
+
"""
|
423 |
+
print('Training %s ...'%(name))
|
424 |
+
for epoch in range(number_of_epochs):
|
425 |
+
|
426 |
+
running_loss = 0.0
|
427 |
+
autoencoder.train()
|
428 |
+
for batch_index, (data) in enumerate(data_loader):
|
429 |
+
batch_size = data.size()[0]
|
430 |
+
# Duplicate if batch has one sample (handle one-sample err)
|
431 |
+
if batch_size == 1:
|
432 |
+
data = torch.cat([data, data], dim=0)
|
433 |
+
batch_size = 2
|
434 |
+
|
435 |
+
data = data.to(self.device).view(batch_size, -1)
|
436 |
+
# Get output decoder
|
437 |
+
rec_X = autoencoder(data)
|
438 |
+
|
439 |
+
if sparsity:
|
440 |
+
# Get latent
|
441 |
+
reg_loss = criterion(data, rec_X)
|
442 |
+
sparse_loss = self.sparse_loss(rho=rho, X = data, model=autoencoder)
|
443 |
+
loss = reg_loss + self.BETA * sparse_loss
|
444 |
+
# if batch_index & 100 == 0:
|
445 |
+
# print("Reg-loss: {} , Sparse-loss: {}".format(reg_loss, sparse_loss))
|
446 |
+
else:
|
447 |
+
loss = criterion(data, rec_X)
|
448 |
+
optimizer.zero_grad()
|
449 |
+
loss.backward()
|
450 |
+
optimizer.step()
|
451 |
+
running_loss += loss.data.numpy()
|
452 |
+
if batch_index % 200 ==0 and verbose:
|
453 |
+
print('epoch %d loss: %.5f batch: %d' % (epoch, running_loss/((batch_index + 1)), (batch_index + 1)*batch_size))
|
454 |
+
if batch_index != 0 and batch_index % 1000 == 0:
|
455 |
+
break
|
456 |
+
print('Done training %s'%(name))
|
457 |
+
|
458 |
+
|
459 |
+
def pretraining(self, dataloader):
|
460 |
+
"""
|
461 |
+
Pretraining step
|
462 |
+
1) Train a single main_autoencoder for the entire dataset
|
463 |
+
2) Apply k-means for the embedding space after training to get label for cluster net
|
464 |
+
3) Training i-th autoencoder using i-th assigned samples by K-means from the entire dataset
|
465 |
+
"""
|
466 |
+
#---------Training main_autoencoder---------------
|
467 |
+
criterion = nn.MSELoss()
|
468 |
+
optimizer = torch.optim.Adam(self.main_autoencoder.parameters(), lr=args.lr, weight_decay = args.wd)
|
469 |
+
|
470 |
+
self.train_one_autoencoder(autoencoder=self.main_autoencoder, optimizer=optimizer,
|
471 |
+
criterion=criterion, data_loader= dataloader,
|
472 |
+
number_of_epochs=self.pretrain_epochs_main, name= "main_autoencoder",
|
473 |
+
verbose=True,
|
474 |
+
sparsity=False,
|
475 |
+
rho= self.RHO
|
476 |
+
)
|
477 |
+
|
478 |
+
# ----------K-means clustering --------------------------
|
479 |
+
print("------Clustering---------")
|
480 |
+
|
481 |
+
# Get latent X
|
482 |
+
batch_X = []
|
483 |
+
for batch_idx, (data) in enumerate(dataloader):
|
484 |
+
batch_size = data.size()[0]
|
485 |
+
# Duplicate if batch has one sample
|
486 |
+
if batch_size == 1:
|
487 |
+
data = torch.cat([data, data], dim=0)
|
488 |
+
batch_size = 2
|
489 |
+
data = data.to(self.device).view(batch_size, -1)
|
490 |
+
latent_X = self.main_autoencoder(data, latent=True)
|
491 |
+
print("BATCH LATENT X", latent_X)
|
492 |
+
batch_X.append(latent_X.detach().cpu().numpy())
|
493 |
+
full_latent_X = np.vstack(batch_X)
|
494 |
+
|
495 |
+
# Clustering
|
496 |
+
# self.cluster_algo = AgglomerativeClustering(n_clusters=self.n_clusters).fit(full_latent_X)
|
497 |
+
# print("Cluster algo", self.cluster_algo)
|
498 |
+
# self.cluster_algo.fit(full_latent_X)
|
499 |
+
|
500 |
+
# self.cluster_algo = KMeans(n_clusters=self.n_clusters, n_init= self.n_clusters, init="k-means++", random_state=42).fit(full_latent_X)
|
501 |
+
self.cluster_algo = SpectralClustering(n_clusters=self.n_clusters, random_state=42).fit(full_latent_X)
|
502 |
+
self.psedo_label = self.cluster_algo.labels_
|
503 |
+
print("Done clustering!")
|
504 |
+
print("Original label:", self.psedo_label)
|
505 |
+
# tsne = TSNE(n_components=2, random_state=42)
|
506 |
+
# X_tsne = tsne.fit_transform(full_latent_X)
|
507 |
+
# colors = ['black', 'red']
|
508 |
+
# for i in np.unique(self.cluster_algo.labels_):
|
509 |
+
# plt.scatter(X_tsne[self.cluster_algo.labels_ == i, 0], X_tsne[self.cluster_algo.labels_ == i, 1], color=colors[i], label=str(i))
|
510 |
+
# plt.xlabel('t-SNE feature 1')
|
511 |
+
# plt.ylabel('t-SNE feature 2')
|
512 |
+
# plt.legend()
|
513 |
+
# plt.show()
|
514 |
+
|
515 |
+
# ---------Training each autoencoder expert with predefined label from K-means---------------
|
516 |
+
for i in range(self.n_clusters):
|
517 |
+
# Get full dataset through batch loop
|
518 |
+
dataset = []
|
519 |
+
for batch_idx, (data) in enumerate(dataloader):
|
520 |
+
batch_size = data.size()[0]
|
521 |
+
# # Duplicate if batch has one sample
|
522 |
+
if batch_size == 1:
|
523 |
+
data = torch.cat([data, data], dim=0)
|
524 |
+
batch_size = 2
|
525 |
+
dataset.append(data.detach().cpu().numpy())
|
526 |
+
dataset = np.vstack(dataset)
|
527 |
+
|
528 |
+
# Extract data for specific expert i
|
529 |
+
data_expert_i = dataset[self.cluster_algo.labels_ == i]
|
530 |
+
data_expert_i = AutoEncoderDataset(data = data_expert_i)
|
531 |
+
dataset_expert_i = DataLoader(data_expert_i, batch_size = args.batch_size, shuffle = False)
|
532 |
+
optimizer = torch.optim.Adam(self.moe[i].parameters(), lr=args.lr, weight_decay = args.wd)
|
533 |
+
criterion = nn.MSELoss()
|
534 |
+
# Train expert_i
|
535 |
+
self.train_one_autoencoder(autoencoder=self.moe[i], optimizer=optimizer,
|
536 |
+
criterion=criterion, data_loader=dataset_expert_i,
|
537 |
+
number_of_epochs=self.pretrain_epochs, name="Expert {}".format(i),
|
538 |
+
verbose=True, sparsity=True, rho = self.RHO)
|
539 |
+
|
540 |
+
print("Done Pretraining step !")
|
541 |
+
|
542 |
+
return self.moe, full_latent_X
|
543 |
+
|
544 |
+
|
545 |
+
def get_expert_outputs(self, X, latent = False):
|
546 |
+
"""
|
547 |
+
Get output of experts in a batch
|
548 |
+
Return: List of output of each expert
|
549 |
+
"""
|
550 |
+
output = []
|
551 |
+
for i in range(self.n_clusters):
|
552 |
+
if latent:
|
553 |
+
output_expert_i = self.moe[i](X, latent = True)
|
554 |
+
else:
|
555 |
+
output_expert_i = self.moe[i](X)
|
556 |
+
output.append(output_expert_i)
|
557 |
+
return output
|
558 |
+
|
559 |
+
|
560 |
+
def main_training(self, dataloader, name = "MOE", verbose = True):
|
561 |
+
"""
|
562 |
+
Main training to optimize loss function L = -log [p_i * exp (-(xhat_i - x_i) ** 2)]
|
563 |
+
"""
|
564 |
+
print('Training %s ...'%(name))
|
565 |
+
|
566 |
+
# Add parameters
|
567 |
+
params = list(self.moe['cluster_net'].parameters())
|
568 |
+
for i in range(self.n_clusters):
|
569 |
+
params += list(self.moe[i].parameters())
|
570 |
+
self.moe[i].train()
|
571 |
+
|
572 |
+
optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay = args.wd)
|
573 |
+
|
574 |
+
for epoch in range(self.main_train_epochs):
|
575 |
+
running_loss = 0.0
|
576 |
+
self.moe['cluster_net'].train()
|
577 |
+
|
578 |
+
for batch_index, (data) in enumerate(dataloader):
|
579 |
+
batch_size = data.size()[0]
|
580 |
+
# Duplicate if batch has one sample (the last batch)
|
581 |
+
if batch_size == 1:
|
582 |
+
data = torch.cat([data, data], dim=0)
|
583 |
+
batch_size = 2
|
584 |
+
# Get psedo-label
|
585 |
+
psedo_label = self.psedo_label[batch_index: batch_index + batch_size]
|
586 |
+
|
587 |
+
# Get decoder output
|
588 |
+
expert_outputs = self.get_expert_outputs(data)
|
589 |
+
# Get latent output
|
590 |
+
# latent_outputs = self.get_expert_outputs(data, latent=True)
|
591 |
+
|
592 |
+
# # Concate k-latent outputs
|
593 |
+
# latent_tensor = latent_outputs[0]
|
594 |
+
# for i in range(1, len(latent_outputs)):
|
595 |
+
# latent_tensor = torch.hstack((latent_tensor, latent_outputs[i]))
|
596 |
+
# # if batch_index % 100 == 0:
|
597 |
+
# # # print("Latent tensor", latent_tensor)
|
598 |
+
|
599 |
+
clustering_net_outputs = self.moe['cluster_net'](data)
|
600 |
+
|
601 |
+
# print("Cluster net output", clustering_net_outputs)
|
602 |
+
loss = self.loss_function(expert_outputs=expert_outputs,
|
603 |
+
cluster_net_outputs=clustering_net_outputs, X = data,
|
604 |
+
psedo_label=psedo_label)
|
605 |
+
|
606 |
+
optimizer.zero_grad()
|
607 |
+
loss.backward()
|
608 |
+
optimizer.step()
|
609 |
+
running_loss += loss.data.numpy()
|
610 |
+
if batch_index % 100 ==0 and verbose:
|
611 |
+
print('epoch %d loss: %.5f batch: %d' % (epoch, running_loss/((batch_index + 1)), (batch_index + 1)*batch_size))
|
612 |
+
if batch_index != 0 and batch_index % 1000 == 0:
|
613 |
+
break
|
614 |
+
|
615 |
+
# Update psedolabel
|
616 |
+
if epoch != 0 and epoch % 10 == 0:
|
617 |
+
self.psedo_label = self.get_final_cluster(dataloader)
|
618 |
+
print("Updated psedo label!")
|
619 |
+
print("######################################")
|
620 |
+
print("New psedo label: ", self.psedo_label)
|
621 |
+
# self.cl_loss_param = self.cl_loss_param * 2
|
622 |
+
|
623 |
+
print("Done main training!")
|
624 |
+
return self.moe
|
625 |
+
|
626 |
+
|
627 |
+
def get_final_cluster(self, test_loader):
|
628 |
+
"""
|
629 |
+
Assign final cluster for clustering based on cluster_net
|
630 |
+
"""
|
631 |
+
# Convert to eval mode
|
632 |
+
for i in range(self.n_clusters):
|
633 |
+
self.moe[i].eval()
|
634 |
+
self.moe['cluster_net'].eval()
|
635 |
+
total_pred = []
|
636 |
+
for batch_idx, (data) in enumerate(test_loader):
|
637 |
+
batch_size = data.size()[0]
|
638 |
+
data = data.view(batch_size, -1).to(self.device)
|
639 |
+
# Get the hard assignment label
|
640 |
+
with torch.no_grad():
|
641 |
+
# # Get latent output
|
642 |
+
# latent_outputs = self.get_expert_outputs(data, latent=True)
|
643 |
+
|
644 |
+
# # Concate k-latent outputs
|
645 |
+
# latent_tensor = latent_outputs[0]
|
646 |
+
# for i in range(1, len(latent_outputs)):
|
647 |
+
# latent_tensor = torch.hstack((latent_tensor, latent_outputs[i]))
|
648 |
+
|
649 |
+
cluster_pred = self.moe['cluster_net'](data)
|
650 |
+
cluster_pred = cluster_pred.cpu().numpy()
|
651 |
+
batch_pred = np.argmax(cluster_pred, axis = 1)
|
652 |
+
total_pred.append(batch_pred)
|
653 |
+
total_pred = np.concatenate(total_pred, axis=0)
|
654 |
+
return total_pred
|
655 |
+
|
656 |
+
|
657 |
+
if __name__ == "__main__":
|
658 |
+
# sample_dir = "da_datasets/da_spanish/segments_0.2/1569_[cut_127sec].wav"
|
659 |
+
# dataset = CustomDataset(sample_dir=sample_dir,train= False)
|
660 |
+
# train_loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
|
661 |
+
moe = MoESparseAutoencodersCL(args)
|
662 |
+
# mixture = moe.pretraining(train_loader)
|
663 |
+
# mixture = moe.main_training(train_loader)
|
664 |
+
# pred = moe.get_final_cluster(train_loader)
|
665 |
+
# print(pred)
|
666 |
+
total_params = sum(p.numel() for p in moe.parameters() if p.requires_grad)
|
667 |
+
print(total_params)
|
668 |
+
|
669 |
+
|
670 |
+
|
671 |
+
|
672 |
+
|
segment_process.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import numpy as np
|
4 |
+
import whisper
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--data_dir", type=str, default= "./samples", help="data wav dir")
|
11 |
+
parser.add_argument("--segment_dir", type=str, default= "./segments_0.4", help="segment dir")
|
12 |
+
parser.add_argument("--model_type", type = str, default="tiny", help="model type")
|
13 |
+
parser.add_argument("--run_device", type = str, default="cpu", help="run device")
|
14 |
+
parser.add_argument("--window_length", type = float, default= 0.4, help="window length")
|
15 |
+
parser.add_argument("--overlap", type = float, default= 0, help="overlap")
|
16 |
+
|
17 |
+
|
18 |
+
# Define
|
19 |
+
opt = parser.parse_args()
|
20 |
+
DATA_DIR = opt.data_dir
|
21 |
+
SEGMENT_DIR = opt.segment_dir
|
22 |
+
model_type = opt.model_type
|
23 |
+
run_device = opt.run_device
|
24 |
+
window_length = opt.window_length
|
25 |
+
overlap = opt.overlap
|
26 |
+
|
27 |
+
|
28 |
+
if not os.path.exists(SEGMENT_DIR):
|
29 |
+
os.makedirs(SEGMENT_DIR)
|
30 |
+
|
31 |
+
# Load model
|
32 |
+
whisper_model = whisper.load_model(model_type, run_device)
|
33 |
+
embedding_dims = {"tiny": 384, 'small': 384, 'base': 512, 'medium':1024}
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def extract_segment(input_file, output_file, start_time, end_time):
|
39 |
+
"""
|
40 |
+
Extract one segment given start_time and end_time
|
41 |
+
input_file: input .wav file
|
42 |
+
output_file: extracted .wav segment
|
43 |
+
start_time, end_time: start-end time of the segment
|
44 |
+
"""
|
45 |
+
# split_file_name = f'./{input_file}_segment_{start_time}_{end_time}.wav'
|
46 |
+
cmd= 'ffmpeg -i '+input_file+' -acodec copy -ss '+str(start_time)+' -to '+str(end_time)+' '+ output_file
|
47 |
+
os.system(cmd)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
def split_audio_with_ffmpeg(input_file, output_dir, segment_length=window_length, overlap=overlap):
|
52 |
+
"""
|
53 |
+
Extract all segments from original audio
|
54 |
+
"""
|
55 |
+
input_filename = input_file.split("/")[-1]
|
56 |
+
if not os.path.exists(output_dir):
|
57 |
+
os.makedirs(output_dir)
|
58 |
+
|
59 |
+
# duration = float(subprocess.check_output(['ffprobe', '-i', input_file, '-show_entries', 'format=duration', '-v', 'quiet', '-of', 'csv=%s' % ("p=0")]))
|
60 |
+
duration = float(subprocess.check_output(['ffprobe', '-i', input_file, '-show_entries', 'format=duration', '-v', 'quiet', '-of', 'csv=%s' % ("p=0")]))
|
61 |
+
|
62 |
+
start_time = 0
|
63 |
+
last_flag = False
|
64 |
+
while start_time < duration and last_flag == False:
|
65 |
+
end_time = np.round(min(start_time + segment_length, duration), 2)
|
66 |
+
# Cover the last segment
|
67 |
+
if end_time + segment_length > duration:
|
68 |
+
end_time = duration
|
69 |
+
last_flag = True
|
70 |
+
output_file = os.path.join(output_dir, f"{input_filename}_segment_{start_time}_{end_time}.wav")
|
71 |
+
extract_segment(input_file, output_file, start_time, end_time)
|
72 |
+
start_time += segment_length - overlap
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def extract_segment_embedding(segment_dir, save_segment_dir, window_length):
|
77 |
+
"""
|
78 |
+
Extract embedding for each segment
|
79 |
+
"""
|
80 |
+
|
81 |
+
audio = whisper.load_audio(segment_dir)
|
82 |
+
print("AUDIO SHAPE:", audio.shape)
|
83 |
+
|
84 |
+
# #Duplicate the array to get 30s chunk
|
85 |
+
# audio = np.tile(audio, int(30/window_length))
|
86 |
+
|
87 |
+
# print("AUDIO SHAPE:", audio.shape)
|
88 |
+
mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
|
89 |
+
|
90 |
+
|
91 |
+
# print("MEL SHAPE", mel.shape)
|
92 |
+
#--- this code to create the correct shape of mel spectrogram
|
93 |
+
while True:
|
94 |
+
nF, nT = np.shape(mel)
|
95 |
+
# print(nF, nT)
|
96 |
+
if nT > 3000:
|
97 |
+
mel = mel[:,0:3000]
|
98 |
+
break
|
99 |
+
else:
|
100 |
+
mel = torch.cat((mel, mel), -1)
|
101 |
+
mel = torch.unsqueeze(mel, 0)
|
102 |
+
|
103 |
+
wp_emb = whisper_model.embed_audio(mel)
|
104 |
+
print("Wb_emb shape:", wp_emb.shape)
|
105 |
+
# print("WB embedding:", wp_emb)
|
106 |
+
|
107 |
+
emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0)
|
108 |
+
|
109 |
+
emb_1d = np.mean(emb_1d, axis=0)
|
110 |
+
|
111 |
+
emb_1d = np.expand_dims(emb_1d, axis = 0)
|
112 |
+
print("Speaker embedding shape", emb_1d.shape)
|
113 |
+
|
114 |
+
np.save(save_segment_dir + '/{}.npy'.format(segment_dir.split("/")[-1]), emb_1d, allow_pickle=True)
|
115 |
+
|
116 |
+
return emb_1d
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
def delete_segment_after_done(segments_dir):
|
121 |
+
"""
|
122 |
+
Delete segment after extracting embedding
|
123 |
+
"""
|
124 |
+
for segment in os.listdir(segments_dir):
|
125 |
+
if segment.endswith('.wav'):
|
126 |
+
segment_path = segments_dir + "/" + segment
|
127 |
+
cmd = 'rm '+ segment_path
|
128 |
+
os.system(cmd)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
data_list = sorted(os.listdir(DATA_DIR))
|
135 |
+
|
136 |
+
for sample in data_list:
|
137 |
+
sample_path = DATA_DIR + "/" + sample
|
138 |
+
segment_save_dir = SEGMENT_DIR + "/" + sample
|
139 |
+
if not os.path.exists(segment_save_dir):
|
140 |
+
os.makedirs(segment_save_dir)
|
141 |
+
# # Extract segments
|
142 |
+
split_audio_with_ffmpeg(input_file=sample_path, output_dir=segment_save_dir)
|
143 |
+
|
144 |
+
# Extract embedding
|
145 |
+
segment_list = sorted(os.listdir(segment_save_dir), key= lambda x: x.split("_")[-2])
|
146 |
+
for segment in segment_list:
|
147 |
+
segment_path = segment_save_dir + "/" + segment
|
148 |
+
# Extract embedding each segment
|
149 |
+
embed_1d = extract_segment_embedding(segment_dir=segment_path, save_segment_dir= segment_save_dir,window_length=window_length)
|
150 |
+
|
151 |
+
# Delele segment wav after embeddings are extracted
|
152 |
+
delete_segment_after_done(segments_dir=segment_save_dir)
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
train_mix_sae.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
from load_dataset import CustomDataset
|
7 |
+
from create_DER import createDER
|
8 |
+
from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering, DBSCAN, AffinityPropagation
|
9 |
+
from pyannote.metrics.diarization import DiarizationErrorRate
|
10 |
+
from pyannote.core import Segment, Timeline, Annotation
|
11 |
+
from sklearn.manifold import TSNE
|
12 |
+
from mix_sae import MoESparseAutoencodersCL
|
13 |
+
import os
|
14 |
+
import pandas as pd
|
15 |
+
|
16 |
+
|
17 |
+
######### NOTE: WRITE FILE FOR EACH METHOD ########################
|
18 |
+
parser = argparse.ArgumentParser(description='Deep Clustering Network')
|
19 |
+
|
20 |
+
# Dataset parameters
|
21 |
+
parser.add_argument('--dir', default='./a_dataset/model_size_english/',
|
22 |
+
help='dataset directory')
|
23 |
+
parser.add_argument('--input_dim', type=int, default=1280,
|
24 |
+
help='input dimension')
|
25 |
+
parser.add_argument('--n-classes', type=int, default=2,
|
26 |
+
help='output dimension')
|
27 |
+
|
28 |
+
# Training parameters
|
29 |
+
parser.add_argument('--lr', type=float, default=1e-3,
|
30 |
+
help='learning rate (default: 1e-4)')
|
31 |
+
parser.add_argument('--wd', type=float, default=1e-4,
|
32 |
+
help='weight decay (default: 5e-4)')
|
33 |
+
parser.add_argument('--batch-size', type=int, default=16,
|
34 |
+
help='input batch size for training')
|
35 |
+
parser.add_argument('--batch_size_moe', type=int, default=16,
|
36 |
+
help='input batch size for training')
|
37 |
+
|
38 |
+
parser.add_argument('--epoch', type=int, default=50,
|
39 |
+
help='number of epochs to train')
|
40 |
+
parser.add_argument('--pre-epoch', type=int, default=200,
|
41 |
+
help='number of pre-train epochs')
|
42 |
+
# parser.add_argument('--pretrain_epochs', type=int, default=80,
|
43 |
+
# help='pretraining step epochs')
|
44 |
+
# parser.add_argument('--pretrain', type=bool, default=True,
|
45 |
+
# help='whether use pre-training')
|
46 |
+
# parser.add_argument('--main_train_epochs', type=int, default=150,
|
47 |
+
# help='main_train epochs')
|
48 |
+
# Model parameters
|
49 |
+
parser.add_argument('--lamda', type=float, default=1,
|
50 |
+
help='coefficient of the reconstruction loss')
|
51 |
+
parser.add_argument('--beta', type=float, default=0.001,
|
52 |
+
help=('coefficient of the regularization term on '
|
53 |
+
'clustering'))
|
54 |
+
parser.add_argument('--hidden-dims', default=[256, 64],
|
55 |
+
help='learning rate (default: 1e-4)')
|
56 |
+
parser.add_argument('--latent_dim', type=int, default=2,
|
57 |
+
help='latent space dimension')
|
58 |
+
parser.add_argument('--n-clusters', type=int, default=2,
|
59 |
+
help='number of clusters in the latent space')
|
60 |
+
|
61 |
+
|
62 |
+
# Utility parameters
|
63 |
+
parser.add_argument('--n-jobs', type=int, default=1,
|
64 |
+
help='number of jobs to run in parallel')
|
65 |
+
parser.add_argument('--log-interval', type=int, default=20,
|
66 |
+
help=('how many batches to wait before logging the '
|
67 |
+
'training status'))
|
68 |
+
parser.add_argument("--window_length", type = float, default= 0.2, help="window length")
|
69 |
+
parser.add_argument("--overlap", type = float, default= 0, help="overlap")
|
70 |
+
parser.add_argument('--rho', type=float, default=0.2,
|
71 |
+
help='whether use pre-training')
|
72 |
+
parser.add_argument('--pretrain_epochs', type=int, default=10,
|
73 |
+
help='pretraining step epochs')
|
74 |
+
parser.add_argument('--pretrain_epochs_main', type=int, default= 20,
|
75 |
+
help='pretraining step epochs')
|
76 |
+
parser.add_argument('--pretrain', type=bool, default=True,
|
77 |
+
help='whether use pre-training')
|
78 |
+
parser.add_argument('--main_train_epochs', type=int, default =5,
|
79 |
+
help='main_train epochs')
|
80 |
+
parser.add_argument('--sparsity_param', type=float, default=0.01,
|
81 |
+
help='sparsity constract param')
|
82 |
+
parser.add_argument('--cl_loss_param', type=float, default= 1,
|
83 |
+
help='clasification loss param')
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
label_dir = args.dir + "/label"
|
87 |
+
segment_dir = args.dir + "/large_segments_{}".format(args.window_length)
|
88 |
+
pyannote_label_dir = args.dir + "/label"
|
89 |
+
window_length = args.window_length
|
90 |
+
overlap = args.overlap
|
91 |
+
|
92 |
+
sample_list = sorted(os.listdir(segment_dir))
|
93 |
+
label_list = sorted(os.listdir(label_dir))
|
94 |
+
pyannote_label_list = sorted(os.listdir(pyannote_label_dir))
|
95 |
+
|
96 |
+
|
97 |
+
# # Create dataframe to store result
|
98 |
+
columns = ["Language", "Filename", "K-means_DER", "K-medoids_DER", "ONLY_PRE_DER" "MOE_CL_DER"]
|
99 |
+
# Create dataframe
|
100 |
+
df = pd.DataFrame(columns=columns)
|
101 |
+
|
102 |
+
for sample, label, py_label in zip(sample_list, label_list, pyannote_label_list):
|
103 |
+
print("Processing segments in folder {}".format(sample))
|
104 |
+
print("Label: ", label)
|
105 |
+
|
106 |
+
sample_path = segment_dir + "/" + sample
|
107 |
+
label_path = label_dir + '/' + label
|
108 |
+
pyannote_label_path = pyannote_label_dir + "/" + py_label
|
109 |
+
# ## -------------------------- BASELINE ML ----------------------------------
|
110 |
+
# agglo = AgglomerativeClustering(n_clusters=args.n_clusters)
|
111 |
+
# kmeans = KMeans(n_clusters=2)
|
112 |
+
|
113 |
+
|
114 |
+
# # Dataset for ML baseline
|
115 |
+
# ml_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
|
116 |
+
# data = ml_dataset.data
|
117 |
+
# # reduce_embeddings = dimension_reduce(embeddings=data, reduced_dims=2, reduce_method="pca")
|
118 |
+
# reduce_embeddings = data
|
119 |
+
|
120 |
+
# agglo_res = agglo.fit(reduce_embeddings)
|
121 |
+
# kmeans_res = kmeans.fit(reduce_embeddings)
|
122 |
+
|
123 |
+
|
124 |
+
# agglo_ref, agglo_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=agglo_res.labels_, window_length=window_length, overlap=overlap)
|
125 |
+
# kmean_ref, kmean_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=kmeans_res.labels_, window_length=window_length, overlap=overlap)
|
126 |
+
|
127 |
+
# # print("KMEANS HYP", kmean_hyp)
|
128 |
+
# # print("KMEDOIDS HYP", k_medoids_hyp)
|
129 |
+
# # print("AGGLO HYP", agglo_hyp)
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
# print("K-means label:", kmeans_res.labels_)
|
134 |
+
# print("Agglo label:", agglo_res.labels_)
|
135 |
+
|
136 |
+
|
137 |
+
der = DiarizationErrorRate(collar=0.25, skip_overlap=False)
|
138 |
+
# agglo_error = der(agglo_ref, agglo_hyp)
|
139 |
+
# kmeans_error = der(kmean_ref, kmean_hyp)
|
140 |
+
|
141 |
+
|
142 |
+
# print("Agglo DER: ", agglo_error)
|
143 |
+
# print("K-means DER: ", kmeans_error)
|
144 |
+
|
145 |
+
|
146 |
+
# Check tsne
|
147 |
+
# plt.style.use('grayscale')
|
148 |
+
# check_list = {"k-Means":kmeans_res,"k-Medoids": k_medoids_res, "Agglomerative":agglo_res}
|
149 |
+
# for algo in check_list:
|
150 |
+
# y_pred = check_list[algo].labels_
|
151 |
+
# tsne = TSNE(n_components=2, random_state=42)
|
152 |
+
# X_tsne = tsne.fit_transform(reduce_embeddings)
|
153 |
+
# colors = ['black', 'aqua']
|
154 |
+
# for i in np.unique(y_pred):
|
155 |
+
# plt.scatter(X_tsne[y_pred == i, 0], X_tsne[y_pred == i, 1], color=colors[i], label=str(i))
|
156 |
+
# plt.xlabel('t-SNE feature 1')
|
157 |
+
# plt.ylabel('t-SNE feature 2')
|
158 |
+
# plt.title('t-SNE visualization with cluster labels for {}'.format(algo))
|
159 |
+
# plt.legend()
|
160 |
+
# plt.show()
|
161 |
+
|
162 |
+
# # --------------------------- DEEP CLUSTERING ------------------------------------
|
163 |
+
|
164 |
+
train_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
|
165 |
+
# train_dataset = CustomDataset(sample_dir=sample_path)
|
166 |
+
|
167 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
|
168 |
+
|
169 |
+
test_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
|
170 |
+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle= False)
|
171 |
+
|
172 |
+
# #Pretrain
|
173 |
+
# rec_loss_list = model.pretrain(train_loader, args.pre_epoch)
|
174 |
+
|
175 |
+
# for e in range(args.epoch):
|
176 |
+
# model.train()
|
177 |
+
# model.fit(e, train_loader)
|
178 |
+
|
179 |
+
# y_pred = []
|
180 |
+
# latent_X_list = []
|
181 |
+
# model.eval() # Set the model to evaluation mode
|
182 |
+
|
183 |
+
|
184 |
+
# for data in test_loader:
|
185 |
+
# batch_size = data.size()[0]
|
186 |
+
# data = data.view(batch_size, -1).to(model.device)
|
187 |
+
# latent_X = model.autoencoder(data, latent=True)
|
188 |
+
# print('Eval latent x', latent_X)
|
189 |
+
# latent_X = latent_X.detach().cpu().numpy()
|
190 |
+
# y_pred.append(model.kmeans.update_assign(latent_X).reshape(-1, 1))
|
191 |
+
# latent_X_list.append(latent_X)
|
192 |
+
|
193 |
+
|
194 |
+
# y_pred = np.concatenate(y_pred, axis=0)
|
195 |
+
# y_pred = list(np.squeeze(y_pred))
|
196 |
+
# print("Y_pred", y_pred)
|
197 |
+
# latent_X_list = np.concatenate(latent_X_list, axis = 0)
|
198 |
+
# algomerative = KMeans(n_clusters=args.n_clusters)
|
199 |
+
# cluster= algomerative.fit(latent_X_list)
|
200 |
+
# y_pred = cluster.labels_
|
201 |
+
# print("Y_pred", y_pred)
|
202 |
+
# print(latent_X_list)
|
203 |
+
# print(y_pred)
|
204 |
+
# # Perform t-SNE
|
205 |
+
# tsne = TSNE(n_components=2, random_state=42)
|
206 |
+
# X_tsne = tsne.fit_transform(latent_X_list)
|
207 |
+
# colors = ['black', 'aqua']
|
208 |
+
# for i in np.unique(y_pred):
|
209 |
+
# plt.scatter(X_tsne[y_pred == i, 0], X_tsne[y_pred == i, 1], color=colors[i], label=str(i))
|
210 |
+
# plt.xlabel('t-SNE feature 1')
|
211 |
+
# plt.ylabel('t-SNE feature 2')
|
212 |
+
# plt.title('t-SNE visualization with cluster labels for DCN-SD')
|
213 |
+
# plt.legend()
|
214 |
+
# plt.show()
|
215 |
+
|
216 |
+
# # -------------------SPECTRAL NET-----------------------------
|
217 |
+
# data = torch.tensor(train_dataset.data, dtype=torch.float)
|
218 |
+
# test_data = torch.tensor(test_dataset.data, dtype= torch.float)
|
219 |
+
# spectralnet = SpectralNet(n_clusters=2,
|
220 |
+
# should_use_ae = True,
|
221 |
+
# should_use_siamese= False,
|
222 |
+
# ae_hiddens = [256, 256, 512, 2],
|
223 |
+
# ae_epochs = 150,
|
224 |
+
# ae_batch_size = 128,
|
225 |
+
# ae_patience = 30,
|
226 |
+
# # siamese_hiddens = [384, 384, 128, 2],
|
227 |
+
# # siamese_epochs = 150,
|
228 |
+
# # siamese_batch_size = 128,
|
229 |
+
# # siamese_patience = 30,
|
230 |
+
# spectral_hiddens = [384, 384, 512, 2],
|
231 |
+
# spectral_epochs = 300,
|
232 |
+
# spectral_batch_size = 128,
|
233 |
+
# spectral_patience = 60)
|
234 |
+
# spectralnet.fit(data) # X is the dataset and it should be a torch.Tensor
|
235 |
+
# cluster_assignments = spectralnet.predict(data) # Get the final assignments to cluster
|
236 |
+
# print("Spectral pred", cluster_assignments)
|
237 |
+
|
238 |
+
|
239 |
+
# #------------------MOE ------------------------------------
|
240 |
+
# moe = MoEAutoencoders(args=args)
|
241 |
+
# mixture_moe = moe.pretraining(train_loader)
|
242 |
+
# mixture_moe = moe.main_training(train_loader)
|
243 |
+
# moe_pred = moe.get_final_cluster(train_loader)
|
244 |
+
# print("MOE pred", moe_pred)
|
245 |
+
|
246 |
+
|
247 |
+
# #------------------MOE SPARITY------------------------------------
|
248 |
+
# moe_spa = MoESparseAutoencoders(args=args)
|
249 |
+
# mixture_moe_spa = moe_spa.pretraining(train_loader)
|
250 |
+
# mixture_moe_spa = moe_spa.main_training(train_loader)
|
251 |
+
# moe_spa_pred = moe_spa.get_final_cluster(train_loader)
|
252 |
+
# print("MOE SPARITY pred", moe_spa_pred)
|
253 |
+
# plt.style.use("seaborn-v0_8-deep")
|
254 |
+
#------------------MOE SPARITY CL------------------------------------
|
255 |
+
moe_cl = MoESparseAutoencodersCL(args=args)
|
256 |
+
mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader)
|
257 |
+
pre_label = moe_cl.psedo_label
|
258 |
+
# latent_X_list = np.concatenate(full_latent_X, axis = 0)
|
259 |
+
# Perform t-SNE
|
260 |
+
# tsne = TSNE(n_components=2, random_state=42)
|
261 |
+
# X_tsne = tsne.fit_transform(full_latent_X)
|
262 |
+
# colors = ['red', 'green']
|
263 |
+
# for i in np.unique(pre_label):
|
264 |
+
# plt.scatter(X_tsne[pre_label == i, 0], X_tsne[pre_label == i, 1], color=colors[i], label="Spk {}".format(str(i)))
|
265 |
+
# plt.legend()
|
266 |
+
# # plt.axis("off")
|
267 |
+
# plt.xticks([])
|
268 |
+
# plt.yticks([])
|
269 |
+
|
270 |
+
# plt.show()
|
271 |
+
mixture_moe_cl = moe_cl.main_training(train_loader)
|
272 |
+
moe_cl_pred = moe_cl.get_final_cluster(train_loader)
|
273 |
+
# for data in test_loader:
|
274 |
+
# batch_size = data.size()[0]
|
275 |
+
# data = data.view(batch_size, -1).to(model.device)
|
276 |
+
# latent_X = model.autoencoder(data, latent=True)
|
277 |
+
# print('Eval latent x', latent_X)
|
278 |
+
# latent_X = latent_X.detach().cpu().numpy()
|
279 |
+
# y_pred.append(model.kmeans.update_assign(latent_X).reshape(-1, 1))
|
280 |
+
# latent_X_list.append(latent_X)
|
281 |
+
|
282 |
+
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
# spt_ref, spt_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=cluster_assignments, window_length=window_length, overlap=overlap)
|
288 |
+
# spt_error = der(spt_ref, spt_hyp)
|
289 |
+
|
290 |
+
# print("Spectral Net DER: ", spt_error)
|
291 |
+
|
292 |
+
# moe_ref, moe_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_pred, window_length=window_length, overlap=overlap)
|
293 |
+
# moe_error = der(moe_ref, moe_hyp)
|
294 |
+
# print("MOE DER", moe_error)
|
295 |
+
|
296 |
+
# moe_spa_ref, moe_spa_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_spa_pred, window_length=window_length, overlap=overlap)
|
297 |
+
# moe_spa_error = der(moe_spa_ref, moe_spa_hyp)
|
298 |
+
# print("MOE SPA DER", moe_spa_error)
|
299 |
+
|
300 |
+
moe_cl_ref, moe_cl_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=pre_label, window_length=window_length, overlap=overlap)
|
301 |
+
moe_cl_pre_error = der(moe_cl_ref, moe_cl_hyp)
|
302 |
+
|
303 |
+
print("MOE CL PRE DER", moe_cl_pre_error)
|
304 |
+
|
305 |
+
moe_cl_ref, moe_cl_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_cl_pred, window_length=window_length, overlap=overlap)
|
306 |
+
moe_cl_error = der(moe_cl_ref, moe_cl_hyp)
|
307 |
+
print("MOE CL pred", moe_cl_pred)
|
308 |
+
print("MOE CL DER", moe_cl_error)
|
309 |
+
|
310 |
+
|
311 |
+
# # Calculate PYANNOTE DER
|
312 |
+
# ref, py_ref = create_pyannote_timeline(label_path=label_path, pyannote_label_path=pyannote_label_path)
|
313 |
+
# py_error = der(ref, py_ref)
|
314 |
+
# print("Pyannote DER: ", py_error)
|
315 |
+
# # Write result each algo
|
316 |
+
# with open(save_dir + "/{}.txt".format(file_name[:-4]), 'a') as file:
|
317 |
+
# file.write("Algo: {}\n".format(algo))
|
318 |
+
# file.write("Label \n")
|
319 |
+
# file.write(" ".join(label_series) + "\n")
|
320 |
+
# file.write("Pred \n")
|
321 |
+
# file.write(" ".join(pred_series) + "\n")
|
322 |
+
# file.write("DER: {}\n".format(error))
|
323 |
+
# file.write("\n")
|
324 |
+
# file.close()
|
325 |
+
|
326 |
+
|
327 |
+
# Update result to csv
|
328 |
+
language = args.dir.split("/")[-2]
|
329 |
+
file_name = sample
|
330 |
+
new_row = {"Language":language, "Filename": file_name,"MOE_CL":moe_cl_error}
|
331 |
+
|
332 |
+
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
|
333 |
+
df.to_csv("1111_co_large_MOE_CL_{}_{}.csv".format(language, window_length), index=False)
|
334 |
+
# Write result each algo
|
335 |
+
with open("./1111_co_large_report_MOE_CL_{}_{}.txt".format(language, window_length), 'a') as file:
|
336 |
+
file.write("Filename: {}\n".format(file_name))
|
337 |
+
file.write("Pred \n")
|
338 |
+
for digit in moe_cl_pred:
|
339 |
+
file.write("{} ".format(digit))
|
340 |
+
file.write("\n")
|
341 |
+
file.write("DER: {}\n".format(moe_cl_error))
|
342 |
+
file.write("\n")
|
343 |
+
file.close()
|
whisper/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (9.65 kB). View file
|
|
whisper/__pycache__/audio.cpython-311.pyc
ADDED
Binary file (7.55 kB). View file
|
|
whisper/__pycache__/decoding.cpython-311.pyc
ADDED
Binary file (45.9 kB). View file
|
|
whisper/__pycache__/model.cpython-311.pyc
ADDED
Binary file (21.6 kB). View file
|
|
whisper/__pycache__/timing.cpython-311.pyc
ADDED
Binary file (17.9 kB). View file
|
|
whisper/__pycache__/tokenizer.cpython-311.pyc
ADDED
Binary file (19.4 kB). View file
|
|
whisper/__pycache__/transcribe.cpython-311.pyc
ADDED
Binary file (24.2 kB). View file
|
|
whisper/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (12.9 kB). View file
|
|
whisper/__pycache__/version.cpython-311.pyc
ADDED
Binary file (207 Bytes). View file
|
|