LTPhat commited on
Commit
0f44ac1
·
1 Parent(s): 7c59ca2
Dockerfile CHANGED
@@ -36,7 +36,8 @@ RUN pip install streamlit --timeout 500
36
  RUN pip install ffmpeg-python --timeout 1000
37
  RUN pip install toml
38
  RUN pip install librosa
39
-
 
40
  # RUN pip uninstall ffmpeg --yes
41
  # RUN pip uninstall ffmpeg-python --yes
42
 
 
36
  RUN pip install ffmpeg-python --timeout 1000
37
  RUN pip install toml
38
  RUN pip install librosa
39
+ RUN pip install pandas
40
+ RUN pip install pyannote-audio
41
  # RUN pip uninstall ffmpeg --yes
42
  # RUN pip uninstall ffmpeg-python --yes
43
 
__pycache__/create_DER.cpython-311.pyc ADDED
Binary file (8.37 kB). View file
 
__pycache__/load_dataset.cpython-311.pyc ADDED
Binary file (7.01 kB). View file
 
__pycache__/mix_sae.cpython-311.pyc ADDED
Binary file (33 kB). View file
 
__pycache__/train_mix_sae.cpython-311.pyc ADDED
Binary file (8.73 kB). View file
 
app.py CHANGED
@@ -7,6 +7,8 @@ import whisper
7
  from sklearn.cluster import AgglomerativeClustering
8
  import torch
9
  import librosa
 
 
10
 
11
  UPLOAD_FOLDER = "./uploads"
12
 
 
7
  from sklearn.cluster import AgglomerativeClustering
8
  import torch
9
  import librosa
10
+ from mix_sae import *
11
+
12
 
13
  UPLOAD_FOLDER = "./uploads"
14
 
app_test.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import os
4
+ import whisper
5
+ from sklearn.cluster import AgglomerativeClustering
6
+ import torch
7
+ import librosa
8
+ from torch.utils.data import DataLoader
9
+ from mix_sae import MoESparseAutoencodersCL
10
+ from load_dataset import AutoEncoderDataset
11
+ import argparse
12
+
13
+
14
+ UPLOAD_FOLDER = "./uploads"
15
+ parser = argparse.ArgumentParser(description='Deep Clustering Network')
16
+ parser.add_argument('--input_dim', type=int, default=384,
17
+ help='input dimension')
18
+ # Model parameters
19
+ parser.add_argument('--lr', type=float, default=1e-3,
20
+ help='learning rate (default: 1e-4)')
21
+ parser.add_argument('--wd', type=float, default=1e-4,
22
+ help='weight decay (default: 5e-4)')
23
+ parser.add_argument('--batch-size', type=int, default=16,
24
+ help='input batch size for training')
25
+ parser.add_argument('--lamda', type=float, default=1,
26
+ help='coefficient of the reconstruction loss')
27
+ parser.add_argument('--beta', type=float, default=1,
28
+ help=('coefficient of the regularization term on '
29
+ 'clustering'))
30
+ parser.add_argument('--hidden-dims', default=[256, 128, 64, 32],
31
+ help='learning rate (default: 1e-4)')
32
+ parser.add_argument('--latent_dim', type=int, default=2,
33
+ help='latent space dimension')
34
+ parser.add_argument('--n-clusters', type=int, default=2,
35
+ help='number of clusters in the latent space')
36
+ parser.add_argument('--input-dim', type=int, default=384,
37
+ help='input dimension')
38
+ parser.add_argument('--n-classes', type=int, default=2,
39
+ help='output dimension')
40
+ parser.add_argument('--pretrain_epochs', type=int, default=30,
41
+ help='pretraining step epochs')
42
+ parser.add_argument('--pretrain_epochs_main', type=int, default=30,
43
+ help='pretraining step epochs')
44
+ parser.add_argument('--pretrain', type=bool, default=True,
45
+ help='whether use pre-training')
46
+ parser.add_argument('--main_train_epochs', type=int, default=5,
47
+ help='main_train epochs')
48
+ parser.add_argument('--rho', type=float, default=0.2,
49
+ help='whether use pre-training')
50
+ parser.add_argument('--sparsity_param', type=float, default=0.1,
51
+ help='sparsity constract param')
52
+ parser.add_argument('--cl_loss_param', type=float, default=0.05,
53
+ help='clasification loss param')
54
+ args = parser.parse_args()
55
+
56
+
57
+ def allowed_file(filename):
58
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in ['wav']
59
+
60
+
61
+ def process_wav(audio_file, speaker_number, model_type, run_device = 'cpu', sr = 16000):
62
+ embedding_dims = {"tiny": 384, 'small': 768, 'base': 512, 'medium':1024}
63
+ #---- get results from whisper model
64
+ whisper_model = whisper.load_model(model_type, run_device)
65
+ wp_results = whisper_model.transcribe(audio_file)
66
+ for ide in range(len(wp_results['segments'])):
67
+ del wp_results['segments'][ide]['seek']
68
+ del wp_results['segments'][ide]['tokens']
69
+ del wp_results['segments'][ide]['compression_ratio']
70
+ del wp_results['segments'][ide]['temperature']
71
+ del wp_results['segments'][ide]['avg_logprob']
72
+ del wp_results['segments'][ide]['no_speech_prob']
73
+
74
+ #---- solve each segment
75
+ segments = wp_results["segments"]
76
+
77
+ # >= 2 sentences
78
+ if len(segments) > 1:
79
+ embeddings = np.zeros(shape=(len(segments), embedding_dims[model_type]))
80
+
81
+ for i, segment in enumerate(segments):
82
+ start = int(segment["start"] * sr)
83
+ end = int(segment["end"] * sr)
84
+
85
+ # Extract a segment
86
+ audio = audio_file[start: end]
87
+ mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
88
+
89
+ #--- this code to create the correct shape of mel spectrogram
90
+ while True:
91
+ nF, nT = np.shape(mel)
92
+ if nT > 3000:
93
+ mel = mel[:,0:3000]
94
+ break
95
+ else:
96
+ mel = torch.cat((mel, mel), -1)
97
+ mel = torch.unsqueeze(mel, 0)
98
+ wp_emb = whisper_model.embed_audio(mel)
99
+ #print(np.shape(wp_emb))
100
+
101
+ emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0)
102
+ emb_1d = np.mean(emb_1d, axis=0)
103
+ #print(np.shape(emb_1d))
104
+ #exit()
105
+ embeddings[i] = emb_1d
106
+ embeddings= np.array(embeddings, dtype="f")
107
+ train_loader = AutoEncoderDataset(embeddings)
108
+ train_loader = DataLoader(train_loader, batch_size = args.batch_size, shuffle = False)
109
+
110
+
111
+ moe_cl = MoESparseAutoencodersCL(args=args)
112
+ mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader)
113
+ pre_label = moe_cl.psedo_label
114
+ mixture_moe_cl = moe_cl.main_training(train_loader)
115
+ moe_cl_pred = moe_cl.get_final_cluster(train_loader)
116
+ #--- clustering spk emb
117
+ # clustering = AgglomerativeClustering(speaker_number, compute_distances=True).fit(embeddings)
118
+ # labels = clustering.labels_
119
+
120
+ for i in range(len(segments)):
121
+ wp_results['segments'][i]["speaker"] = 'SPEAKER ' + str(pre_label[i] + 1)
122
+
123
+ # only one sentence
124
+ else:
125
+ wp_results['segments'][0]["speaker"] = 'SPEAKER 1'
126
+
127
+ return wp_results
128
+
129
+
130
+ def main():
131
+
132
+ title_style = """
133
+ <style>
134
+ .title {
135
+ text-align: center;
136
+ font-size: 40px;
137
+ }
138
+ </style>
139
+ """
140
+ st.markdown(
141
+ title_style,
142
+ unsafe_allow_html=True
143
+ )
144
+ title = """
145
+ <h1 class = "title" >Telephone Calls Speaker Diarization</h1>
146
+ </div>
147
+ """
148
+ st.markdown(title,
149
+ unsafe_allow_html=True)
150
+ # st.title("Speaker Diarization")
151
+
152
+
153
+ # Get user inputs
154
+ file = st.file_uploader("Upload a WAV file:", type=["wav"])
155
+ num_speakers = st.number_input("Number of speakers:", min_value=2, max_value=2)
156
+
157
+ model_list = ['tiny', 'small', 'base', 'medium']
158
+ model_type = st.selectbox("Select model type: ", model_list)
159
+
160
+ # Display the result
161
+ st.write("Your uploaded wav file: ")
162
+ st.audio(file, format = 'audio/wav')
163
+ if st.button("Submit"):
164
+ if file is not None:
165
+
166
+ # Read audio file using pydub
167
+ audio_file, _ = librosa.load(file, sr=16000)
168
+
169
+ # Process the uploaded file using the AI model
170
+ wp_results = process_wav(audio_file, num_speakers, model_type)
171
+
172
+ # Write result:
173
+ st.write("Segments:" )
174
+ for seg in wp_results['segments']:
175
+ seg['start'] = np.round(seg['start'], 1)
176
+ seg['end'] = np.round(seg['end'], 1)
177
+ st.write(seg)
178
+ st.write("Language: ", wp_results['language'])
179
+ st.write("Full text:")
180
+ st.write(wp_results['text'])
181
+ else:
182
+ print("Error")
183
+ st.write("\n\n---\n\n")
184
+ st.write("Built with Docker and Streamlit")
185
+ st.link_button("Paper link: https://arxiv.org/abs/2407.01963", "https://arxiv.org/abs/2407.01963")
186
+ return
187
+
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
create_DER.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import argparse
4
+ import os
5
+ import simpleder
6
+ from pyannote.metrics.diarization import DiarizationErrorRate
7
+ from pyannote.core import Segment, Annotation
8
+
9
+
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--label_dir", type = str, default="./datasets/spanish/human_label", help= "rttm label dir")
13
+
14
+ opt = parser.parse_args()
15
+ LABEL_DIR = opt.label_dir
16
+ column = ['speaker','file_name','number' ,'start', 'duration', 'na1', 'na2', 'label', 'na3', 'na4']
17
+
18
+
19
+ def createDER(label_path, sample_dir, prediction, window_length = 0.5, overlap = 0.0):
20
+ """
21
+ Extract series from label and prediction for calculating DER
22
+ """
23
+
24
+ # df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
25
+ # ref = []
26
+
27
+ # prev_end = 0
28
+ # # Assign label
29
+ # for row in df.iterrows():
30
+ # row_item = row[1]
31
+ # start = np.round(row_item['start'], 2)
32
+ # end = np.round(row_item['start'] + row_item['duration'], 2)
33
+ # # Avoid overlap
34
+ # if start < prev_end:
35
+ # start = prev_end
36
+ # # Avoid error label
37
+ # if start > end:
38
+ # continue
39
+ # ref.append((row_item['label'], start, end))
40
+ # prev_end = end
41
+
42
+
43
+ df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
44
+ refer = Annotation(uri='label')
45
+
46
+ # Assign label
47
+ prev_end = 0
48
+ for row in df.iterrows():
49
+ row_item = row[1]
50
+ start = np.round(row_item['start'], 2)
51
+ end = np.round(row_item['start'] + row_item['duration'], 2)
52
+ # Avoid overlap
53
+ if start < prev_end:
54
+ start = prev_end
55
+ # Avoid error label
56
+ if start > end:
57
+ continue
58
+ refer[Segment(start, end)] = row_item['label']
59
+ prev_end = end
60
+ print("******EXTRACT LABEL DONE***********")
61
+
62
+ # assert len(os.listdir(sample_dir)) == len(prediction)
63
+ segment_list = sorted(os.listdir(sample_dir), key= lambda x: float(x.split("_")[-2]))
64
+
65
+ # Create index mapping to store start-end index of consecutive segments
66
+ index_mapping = {}
67
+ start_index = 0
68
+ current_value = prediction[0]
69
+
70
+ for i in range(1, len(prediction)):
71
+ if prediction[i] != current_value:
72
+ index_mapping[(start_index, i - 1)] = current_value
73
+ start_index = i
74
+ current_value = prediction[i]
75
+
76
+ # Handle the last consecutive sequence
77
+ index_mapping[(start_index, len(prediction) - 1)] = current_value
78
+
79
+
80
+ # Assign label to consecutive segments
81
+ hyp = []
82
+ for key, value in index_mapping.items():
83
+ start_index = key[0]
84
+ end_index = key[1]
85
+ speaker_label = "spk0{}".format(value)
86
+ if overlap != 0:
87
+ start_time = np.round(overlap * start_index, 2)
88
+ if start_index == end_index:
89
+ end_time = np.round(start_time + window_length, 2)
90
+ else:
91
+ end_time = np.round(overlap * end_index + window_length, 2)
92
+ # Non-overlap
93
+ else:
94
+ start_time = np.round(window_length * start_index, 2)
95
+ if start_index == end_index:
96
+ end_time = np.round(start_time + window_length, 2)
97
+ else:
98
+ end_time = np.round((end_index + 1) * window_length, 2)
99
+
100
+ hyp.append((speaker_label, start_time, end_time))
101
+
102
+
103
+ hypo = Annotation(uri='hypo')
104
+ for item in hyp:
105
+ hypo[Segment(item[1], item[2])] = item[0]
106
+
107
+ print("******EXTRACT HYP DONE***********")
108
+
109
+ return refer, hypo
110
+
111
+
112
+ def create_DER_pyannote(label_path, pyannote_label_path):
113
+ # df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
114
+
115
+ # ref = []
116
+ # prev_end = 0
117
+ # # Assign label
118
+ # for row in df.iterrows():
119
+ # row_item = row[1]
120
+ # start = np.round(row_item['start'], 2)
121
+ # end = np.round(row_item['start'] + row_item['duration'], 2)
122
+ # # Avoid overlap
123
+ # if start < prev_end:
124
+ # start = prev_end
125
+ # # Avoid error label
126
+ # if start > end:
127
+ # continue
128
+ # ref.append((row_item['label'], start, end))
129
+ # prev_end = end
130
+ df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
131
+ refer = Annotation(uri='label')
132
+
133
+ # Assign label
134
+ for row in df.iterrows():
135
+ row_item = row[1]
136
+ start = np.round(row_item['start'], 2)
137
+ end = np.round(row_item['start'] + row_item['duration'], 2)
138
+ # # Avoid overlap
139
+ # if start < prev_end:
140
+ # start = prev_end
141
+ # # Avoid error label
142
+ # if start > end:
143
+ # continue
144
+ refer[Segment(start, end)] = row_item['label']
145
+ # ref.append((row_item['label'], start, end))
146
+ # prev_end = end
147
+
148
+ print("******EXTRACT LABEL DONE*****c******")
149
+
150
+ df = pd.read_csv(pyannote_label_path, delimiter=' ', header=None, usecols=column, names=column)
151
+ print(df)
152
+ pyannote_ref = []
153
+ prev_end = 0
154
+ # Assign label
155
+ for row in df.iterrows():
156
+ row_item = row[1]
157
+ start = np.round(row_item['start'], 2)
158
+ end = np.round(row_item['start'] + row_item['duration'], 2)
159
+ # Avoid overlap
160
+ if start < prev_end:
161
+ start = prev_end
162
+ # Avoid error label
163
+ if start > end:
164
+ continue
165
+ pyannote_ref.append((row_item['label'], start, end))
166
+ prev_end = end
167
+ print("******EXTRACT PYANNOTE LABEL DONE***********")
168
+ return refer, pyannote_ref
169
+
170
+
171
+ def create_pyannote_timeline(label_path, pyannote_label_path):
172
+ df = pd.read_csv(label_path, delimiter=' ', header=None, usecols=column, names=column)
173
+ refer = Annotation(uri='label')
174
+ # ref = []
175
+ # prev_end = 0
176
+ # Assign label
177
+ for row in df.iterrows():
178
+ row_item = row[1]
179
+ start = np.round(row_item['start'], 2)
180
+ end = np.round(row_item['start'] + row_item['duration'], 2)
181
+ # # Avoid overlap
182
+ # if start < prev_end:
183
+ # start = prev_end
184
+ # # Avoid error label
185
+ # if start > end:
186
+ # continue
187
+ refer[Segment(start, end)] = row_item['label']
188
+ # ref.append((row_item['label'], start, end))
189
+ # prev_end = end
190
+
191
+ print("******EXTRACT LABEL DONE***********")
192
+
193
+ df = pd.read_csv(pyannote_label_path, delimiter=' ', header=None, usecols=column, names=column)
194
+ py_refer = Annotation(uri='py_label')
195
+ ref = []
196
+ # prev_end = 0
197
+ # Assign label
198
+ for row in df.iterrows():
199
+ row_item = row[1]
200
+ start = np.round(row_item['start'], 2)
201
+ end = np.round(row_item['start'] + row_item['duration'], 2)
202
+ # # Avoid overlap
203
+ # if start < prev_end:
204
+ # start = prev_end
205
+ # # Avoid error label
206
+ # if start > end:
207
+ # continue
208
+ py_refer[Segment(start, end)] = row_item['label']
209
+ # ref.append((row_item['label'], start, end))
210
+ # prev_end = end
211
+
212
+ print("******EXTRACT PY LABEL DONE***********")
213
+
214
+ return refer, py_refer
215
+
216
+
217
+
218
+ if __name__ == "__main__":
219
+ label_dir = "datasets/spanish/human_label"
220
+ py_label_dir = "datasets/spanish/label"
221
+ label_list = sorted(os.listdir(label_dir))[:11]
222
+ py_label_list = sorted(os.listdir(py_label_dir))[:11]
223
+ with open("./compare_py_label_PY.txt", "w") as file:
224
+ for label, py_label in zip(label_list, py_label_list):
225
+ label_path = label_dir + "/" + label
226
+ py_label_path = py_label_dir + '/' + py_label
227
+ ref, py_ref = create_pyannote_timeline(label_path=label_path, pyannote_label_path=py_label_path)
228
+ der = DiarizationErrorRate(collar=0.0, skip_overlap=False)
229
+ error = der(ref, py_ref)
230
+ file.write(str(label) + "PYANNOTE err:" + str(error) + "\n")
231
+ file.close()
232
+
load_dataset.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import datasets, transforms
3
+ from torch.utils.data import DataLoader
4
+ import numpy as np
5
+ import os
6
+
7
+
8
+ class CustomDataset(torch.utils.data.Dataset):
9
+ def __init__(self, sample_dir, embed_dim = 384, train = False, time = 2):
10
+ self.sample_dir = sample_dir
11
+ self.n_segments = len(os.listdir(self.sample_dir))
12
+ self.data = np.zeros((self.n_segments, embed_dim))
13
+
14
+ # Sorted segment based on start_time
15
+ self.sorted_segments = sorted(os.listdir(sample_dir), key= lambda x: float(x.split("_")[-2]))
16
+
17
+ # Assign segments
18
+ for idx, segment_npy in enumerate(self.sorted_segments):
19
+ segment_path = self.sample_dir + "/" + segment_npy
20
+ segment_embed = np.load(segment_path)
21
+ self.data[idx] = segment_embed
22
+
23
+ if train:
24
+ for time in range(time):
25
+ self.data = np.concatenate((self.data, self.data), axis = 0)
26
+
27
+ def __len__(self):
28
+ return len(self.data)
29
+
30
+ def __getitem__(self, idx):
31
+ sample = torch.from_numpy(self.data[idx]).float()
32
+ return sample
33
+
34
+
35
+ class AutoEncoderDataset(torch.utils.data.Dataset):
36
+ """
37
+ Create dataset from predefined tensor for each autoencoder in MOE
38
+ """
39
+ def __init__(self, data):
40
+ self.data = data
41
+ def __len__(self):
42
+ return len(self.data)
43
+ def __getitem__(self, idx):
44
+ return self.data[idx]
45
+
46
+
47
+ import argparse
48
+ parser = argparse.ArgumentParser(description='Deep Clustering Network')
49
+
50
+ # Dataset parameters
51
+ parser.add_argument('--dir', default='./datasets/spanish/',
52
+ help='dataset directory')
53
+ parser.add_argument('--input_dim', type=int, default=384,
54
+ help='input dimension')
55
+ parser.add_argument('--n-classes', type=int, default=2,
56
+ help='output dimension')
57
+
58
+ # Training parameters
59
+ parser.add_argument('--lr', type=float, default=1e-3,
60
+ help='learning rate (default: 1e-4)')
61
+ parser.add_argument('--wd', type=float, default=1e-4,
62
+ help='weight decay (default: 5e-4)')
63
+ parser.add_argument('--batch-size', type=int, default=16,
64
+ help='input batch size for training')
65
+ parser.add_argument('--epoch', type=int, default=50,
66
+ help='number of epochs to train')
67
+ parser.add_argument('--pre-epoch', type=int, default=100,
68
+ help='number of pre-train epochs')
69
+ parser.add_argument('--pretrain', type=bool, default=True,
70
+ help='whether use pre-training')
71
+
72
+ # Model parameters
73
+ parser.add_argument('--lamda', type=float, default=1,
74
+ help='coefficient of the reconstruction loss')
75
+ parser.add_argument('--beta', type=float, default=1,
76
+ help=('coefficient of the regularization term on '
77
+ 'clustering'))
78
+ parser.add_argument('--hidden-dims', default=[256, 128, 64, 32, 16],
79
+ help='learning rate (default: 1e-4)')
80
+ parser.add_argument('--latent_dim', type=int, default=2,
81
+ help='latent space dimension')
82
+ parser.add_argument('--n-clusters', type=int, default=2,
83
+ help='number of clusters in the latent space')
84
+
85
+ parser.add_argument('--n_1Dconv', type=int, default=4,
86
+ help='n_1dconv')
87
+ parser.add_argument('--kernel_size', default=[7, 5, 3, 3],
88
+ help='kernel_size')
89
+ parser.add_argument('--stride', type = int, default=1,
90
+ help='stride')
91
+ parser.add_argument('--num_blocks', type = int, default=4,
92
+ help='num_blocks')
93
+ parser.add_argument('--channels', type = int, default=[128, 64, 32, 16],
94
+ help='channels')
95
+
96
+ # Utility parameters
97
+ parser.add_argument('--n-jobs', type=int, default=1,
98
+ help='number of jobs to run in parallel')
99
+ parser.add_argument('--log-interval', type=int, default=20,
100
+ help=('how many batches to wait before logging the '
101
+ 'training status'))
102
+ parser.add_argument("--window_length", type = float, default= 0.4, help="window length")
103
+ parser.add_argument("--overlap", type = float, default= 0, help="overlap")
104
+
105
+
106
+ args = parser.parse_args()
107
+
108
+ if __name__ == "__main__":
109
+ # Example usage:
110
+ sample_dir = "datasets/spanish/segments/0096_[cut_193sec].wav"
111
+ dataset = CustomDataset(sample_dir=sample_dir,train= False)
112
+ # dataset = CustomDataset(sample_dir=sample_dir,train= False)
113
+
114
+
mix_sae.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from collections import OrderedDict
3
+ import torch
4
+ import argparse
5
+ import torch.nn.init as init
6
+ import numpy as np
7
+ from sklearn.cluster import KMeans, SpectralClustering
8
+ from load_dataset import AutoEncoderDataset
9
+ from torch.utils.data import DataLoader
10
+ from load_dataset import *
11
+ import torch.nn.functional as F
12
+ import matplotlib.pyplot as plt
13
+ #----------------------VERSION 3: SPARSE AUTOENDCODER KL PENALTY AND ENTROPY LOSS--------------------
14
+ import random
15
+
16
+ random.seed(10)
17
+
18
+ parser = argparse.ArgumentParser(description='Deep Clustering Network')
19
+ parser.add_argument('--input_dim', type=int, default=384,
20
+ help='input dimension')
21
+ # Model parameters
22
+ parser.add_argument('--lr', type=float, default=1e-3,
23
+ help='learning rate (default: 1e-4)')
24
+ parser.add_argument('--wd', type=float, default=1e-4,
25
+ help='weight decay (default: 5e-4)')
26
+ parser.add_argument('--batch-size', type=int, default=16,
27
+ help='input batch size for training')
28
+ parser.add_argument('--lamda', type=float, default=1,
29
+ help='coefficient of the reconstruction loss')
30
+ parser.add_argument('--beta', type=float, default=1,
31
+ help=('coefficient of the regularization term on '
32
+ 'clustering'))
33
+ parser.add_argument('--hidden-dims', default=[256, 128, 64, 32],
34
+ help='learning rate (default: 1e-4)')
35
+ parser.add_argument('--latent_dim', type=int, default=2,
36
+ help='latent space dimension')
37
+ parser.add_argument('--n-clusters', type=int, default=2,
38
+ help='number of clusters in the latent space')
39
+ parser.add_argument('--input-dim', type=int, default=384,
40
+ help='input dimension')
41
+ parser.add_argument('--n-classes', type=int, default=2,
42
+ help='output dimension')
43
+ parser.add_argument('--pretrain_epochs', type=int, default=80,
44
+ help='pretraining step epochs')
45
+ parser.add_argument('--pretrain_epochs_main', type=int, default=80,
46
+ help='pretraining step epochs')
47
+ parser.add_argument('--pretrain', type=bool, default=True,
48
+ help='whether use pre-training')
49
+ parser.add_argument('--main_train_epochs', type=int, default=80,
50
+ help='main_train epochs')
51
+ parser.add_argument('--rho', type=float, default=0.2,
52
+ help='whether use pre-training')
53
+ parser.add_argument('--sparsity_param', type=float, default=0.1,
54
+ help='sparsity constract param')
55
+ parser.add_argument('--cl_loss_param', type=float, default=0.05,
56
+ help='clasification loss param')
57
+ args = parser.parse_args()
58
+
59
+
60
+
61
+
62
+ class AutoEncoder(nn.Module):
63
+
64
+ def __init__(self, args):
65
+ super(AutoEncoder, self).__init__()
66
+ self.args = args
67
+ self.input_dim = args.input_dim
68
+ self.output_dim = self.input_dim
69
+ self.hidden_dims = args.hidden_dims
70
+ self.hidden_dims.append(args.latent_dim)
71
+ self.dims_list = (args.hidden_dims +
72
+ args.hidden_dims[:-1][::-1]) # mirrored structure
73
+ self.n_layers = len(self.dims_list)
74
+ self.latent_dim = args.latent_dim
75
+ self.n_clusters = args.n_clusters
76
+ self.RHO = args.rho
77
+
78
+ # Validation check
79
+ assert self.n_layers % 2 > 0
80
+ assert self.dims_list[self.n_layers // 2] == self.latent_dim
81
+
82
+ # Encoder Network
83
+ layers = OrderedDict()
84
+ for idx, hidden_dim in enumerate(self.hidden_dims):
85
+ if idx == 0:
86
+ layers.update(
87
+ {
88
+ 'linear0': nn.Linear(self.input_dim, hidden_dim),
89
+ # 'linear0': CustomDense(self.input_dim, hidden_dim),
90
+ # 'activation0': nn.LeakyReLU()
91
+ # 'activation0': nn.ReLU()
92
+ }
93
+ )
94
+ else:
95
+ layers.update(
96
+ {
97
+ 'linear{}'.format(idx): nn.Linear(
98
+ self.hidden_dims[idx-1], hidden_dim),
99
+ # 'linear{}'.format(idx): CustomDense(self.hidden_dims[idx-1], hidden_dim),
100
+ # 'activation{}'.format(idx): nn.LeakyReLU(),
101
+ # 'activation{}'.format(idx): nn.ELU(),
102
+ 'activation{}'.format(idx): nn.LeakyReLU(),
103
+ # 'dropout{}'.format(idx): nn.Dropout(0.5),
104
+ 'bn{}'.format(idx): nn.BatchNorm1d(
105
+ self.hidden_dims[idx]),
106
+
107
+
108
+ # 'bn{}'.format(idx): nn.BatchNorm1d(
109
+ # self.hidden_dims[idx])
110
+ }
111
+ )
112
+ self.encoder = nn.Sequential(layers)
113
+
114
+ # Decoder Network
115
+ layers = OrderedDict()
116
+ tmp_hidden_dims = self.hidden_dims[::-1]
117
+ for idx, hidden_dim in enumerate(tmp_hidden_dims):
118
+ if idx == len(tmp_hidden_dims) - 1:
119
+ layers.update(
120
+ {
121
+ 'linear{}'.format(idx): nn.Linear(
122
+ hidden_dim, self.output_dim),
123
+ # 'activation{}'.format(idx):nn.ReLU()
124
+ # 'activation{}'.format(idx): nn.LeakyReLU(),
125
+ # 'activation{}'.format(idx): nn.ELU(),
126
+ # 'linear{}'.format(idx): CustomDense(hidden_dim, self.output_dim),
127
+ }
128
+ )
129
+ else:
130
+ layers.update(
131
+ {
132
+ 'linear{}'.format(idx): nn.Linear(
133
+ hidden_dim, tmp_hidden_dims[idx+1]),
134
+ # 'linear{}'.format(idx): CustomDense(
135
+ # hidden_dim, tmp_hidden_dims[idx+1]),
136
+ # 'activation{}'.format(idx): nn.ELU(),
137
+ 'activation{}'.format(idx): nn.LeakyReLU(),
138
+ # 'dropout{}'.format(idx): nn.Dropout(0.5),
139
+ 'bn{}'.format(idx): nn.BatchNorm1d(
140
+ tmp_hidden_dims[idx+1]),
141
+ # 'activation{}'.format(idx): nn.ELU(),
142
+
143
+ # 'bn{}'.format(idx): nn.BatchNorm1d(
144
+ # tmp_hidden_dims[idx+1])
145
+ }
146
+ )
147
+ self.decoder = nn.Sequential(layers)
148
+ # Apply Xavier weight initialization to all linear layers
149
+ for m in self.modules():
150
+ if isinstance(m, nn.Linear):
151
+ init.xavier_normal_(m.weight)
152
+ init.constant_(m.bias, 0) # Initialize biases to 0
153
+ def __repr__(self):
154
+ repr_str = '[Structure]: {}-'.format(self.input_dim)
155
+ for idx, dim in enumerate(self.dims_list):
156
+ repr_str += '{}-'.format(dim)
157
+ repr_str += str(self.output_dim) + '\n'
158
+ repr_str += '[n_layers]: {}'.format(self.n_layers) + '\n'
159
+ repr_str += '[n_clusters]: {}'.format(self.n_clusters) + '\n'
160
+ repr_str += '[input_dims]: {}'.format(self.input_dim)
161
+ return repr_str
162
+
163
+ def __str__(self):
164
+ return self.__repr__()
165
+
166
+ def forward(self, X, latent=False):
167
+ output = self.encoder(X)
168
+ if latent:
169
+ return output
170
+ return self.decoder(output)
171
+
172
+
173
+ class VAE(nn.Module):
174
+ def __init__(self, args):
175
+ super(VAE, self).__init__()
176
+ self.args = args
177
+ self.input_dim = args.input_dim
178
+ self.output_dim = self.input_dim
179
+ self.hidden_dims = args.hidden_dims
180
+ self.latent_dim = args.latent_dim
181
+ self.n_clusters = args.n_clusters
182
+
183
+ # Encoder Network
184
+ layers = OrderedDict()
185
+ for idx, hidden_dim in enumerate(self.hidden_dims):
186
+ if idx == 0:
187
+ layers.update(
188
+ {
189
+ 'linear0': nn.Linear(self.input_dim, hidden_dim),
190
+
191
+ }
192
+ )
193
+ else:
194
+ layers.update(
195
+ {
196
+ 'linear{}'.format(idx): nn.Linear(
197
+ self.hidden_dims[idx-1], hidden_dim),
198
+
199
+ 'activation{}'.format(idx): nn.ReLU(),
200
+
201
+ 'bn{}'.format(idx): nn.BatchNorm1d(
202
+ self.hidden_dims[idx])
203
+ }
204
+ )
205
+ self.encoder = nn.Sequential(layers)
206
+
207
+ # Decoder Network
208
+ layers = OrderedDict()
209
+ tmp_hidden_dims = self.hidden_dims[::-1]
210
+ for idx, hidden_dim in enumerate(tmp_hidden_dims):
211
+ if idx == len(tmp_hidden_dims) - 1:
212
+ layers.update(
213
+ {
214
+ 'linear{}'.format(idx): nn.Linear(
215
+ hidden_dim, self.output_dim),
216
+ }
217
+ )
218
+ else:
219
+ layers.update(
220
+ {
221
+ 'linear{}'.format(idx): nn.Linear(
222
+ hidden_dim, tmp_hidden_dims[idx+1]),
223
+ 'activation{}'.format(idx): nn.ReLU(),
224
+ 'bn{}'.format(idx): nn.BatchNorm1d(
225
+ tmp_hidden_dims[idx+1])
226
+ }
227
+ )
228
+ self.decoder = nn.Sequential(layers)
229
+ self.fc_mu = nn.Linear(self.hidden_dims[-1], self.latent_dim)
230
+ self.fc_var = nn.Linear(self.hidden_dims[-1], self.latent_dim)
231
+
232
+ self.decode_input_linear = nn.Linear(self.latent_dim, self.hidden_dims[-1])
233
+ # Apply Xavier weight initialization to all linear layers
234
+ for m in self.modules():
235
+ if isinstance(m, nn.Linear):
236
+ init.xavier_normal_(m.weight)
237
+ init.constant_(m.bias, 0) # Initialize biases to 0
238
+
239
+ def encode(self, x):
240
+ x = self.encoder(x)
241
+ mu = self.fc_mu(x)
242
+ log_var = self.fc_var(x)
243
+
244
+ return [mu, log_var]
245
+
246
+ def decode(self, x):
247
+ x = self.decode_input_linear(x)
248
+ x = self.decoder(x)
249
+ return x
250
+
251
+ def reparameterize(self, mu, logvar):
252
+ """
253
+ Reparameterization trick to sample from N(mu, var) from
254
+ N(0,1).
255
+ """
256
+ std = torch.exp(0.5 * logvar)
257
+ eps = torch.randn_like(std)
258
+ return eps * std + mu
259
+
260
+
261
+ def loss_function(self, x_hat, x, mu, log_var, kld_weight = 1):
262
+ """
263
+ Computes the VAE loss function.
264
+ KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
265
+ """
266
+
267
+ rec_loss = torch.nn.functional.mse_loss(x_hat, x)
268
+
269
+ kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
270
+
271
+ loss = rec_loss + kld_weight * kld_loss
272
+
273
+ return [loss, rec_loss.detach(), -kld_loss.detach()]
274
+
275
+ def forward(self, x):
276
+ """
277
+ Forward VAE
278
+ Return: [output, input, mu, var]
279
+ """
280
+ # Encoder
281
+ mu, log_var = self.encode(x)
282
+ # Sample
283
+ z = self.reparameterize(mu, log_var)
284
+ # Decoder
285
+ output = self.decode(z)
286
+
287
+ return [output, x, mu, log_var]
288
+
289
+
290
+ class ClusterNet(nn.Module):
291
+
292
+ def __init__(self, input_dim, hidden_dims = [128], n_clusters=2):
293
+ """ClusterNet("""
294
+ super(ClusterNet, self).__init__()
295
+ layers = []
296
+ for i in range(len(hidden_dims)):
297
+ if i == 0:
298
+ layers.append(nn.Linear(input_dim, hidden_dims[i]))
299
+ layers.append(nn.LeakyReLU())
300
+ # layers.append(nn.Dropout(0.5))
301
+ # layers.append(nn.BatchNorm1d(hidden_dims[i])),
302
+ else:
303
+ layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])),
304
+ layers.append(nn.LeakyReLU())
305
+ # layers.append(nn.Dropout(0.5))
306
+ # layers.append(nn.BatchNorm1d(hidden_dims[i])),
307
+ # Last layer
308
+ layers.append(nn.Sequential(
309
+ nn.Flatten(),
310
+ nn.Linear(hidden_dims[-1], n_clusters),
311
+ nn.Softmax(dim = 1),
312
+ ))
313
+
314
+ self.layers = nn.Sequential(*layers)
315
+ # Apply Xavier weight initialization to all linear layers
316
+ for m in self.modules():
317
+ if isinstance(m, nn.Linear):
318
+ init.xavier_normal_(m.weight)
319
+ init.constant_(m.bias, 0) # Initialize biases to 0
320
+ def forward(self, x):
321
+ """Extract the feature vectors."""
322
+ features = x
323
+ for layer in self.layers:
324
+ features = layer(features)
325
+ return features
326
+
327
+
328
+ class MoESparseAutoencodersCL(nn.Module):
329
+ """
330
+ Mixture of Expert DNN-Autoencoder
331
+ """
332
+ def __init__(self, args):
333
+ super(MoESparseAutoencodersCL, self).__init__()
334
+ self.args = args
335
+ self.input_dim = args.input_dim
336
+ self.output_dim = self.input_dim
337
+ self.hidden_dims = args.hidden_dims
338
+ self.latent_dim = args.latent_dim
339
+ self.n_clusters = args.n_clusters
340
+ self.pretrain_epochs = args.pretrain_epochs
341
+ self.pretrain_epochs_main = args.pretrain_epochs_main
342
+ self.main_train_epochs = args.main_train_epochs
343
+ self.device = "cpu"
344
+ # Define main autoencoder at pretraining
345
+ self.main_autoencoder = AutoEncoder(args=args)
346
+ self.RHO = args.rho
347
+ self.BETA = args.sparsity_param
348
+ self.psedo_label = None
349
+ # Clustering algorithm for pre-training
350
+ self.cluster_algo = None
351
+ self.cl_loss_param = args.cl_loss_param
352
+
353
+ # Define autoencoder expert in mixture
354
+ self.moe = {}
355
+ for i in range(self.n_clusters):
356
+ self.moe[i] = AutoEncoder(args)
357
+ # Add cluster net (gating network) to moe
358
+ self.moe['cluster_net'] = ClusterNet(input_dim= self.input_dim, n_clusters=self.n_clusters)
359
+
360
+
361
+ def kl_divergence(self, rho, rho_hat):
362
+ rho_hat = torch.mean(F.sigmoid(rho_hat), 1) # sigmoid because we need the probability distributions
363
+ rho = torch.tensor([rho] * len(rho_hat)).to(self.device)
364
+ return torch.sum(rho * torch.log(rho/rho_hat) + (1 - rho) * torch.log((1 - rho)/(1 - rho_hat)))
365
+
366
+ # define the sparse loss function
367
+ def sparse_loss(self, rho, X, model):
368
+ values = X
369
+ loss = 0
370
+ model_children = list(model.children())
371
+ for i in range(len(model_children)):
372
+ values = model_children[i](values)
373
+ loss += self.kl_divergence(rho, values)
374
+ return loss / X.shape[0]
375
+
376
+
377
+
378
+ def batchwise_entropy_loss(self, cluster_outputs):
379
+ """
380
+ Calculate batch wise entropy loss
381
+ """
382
+ X = torch.mean(cluster_outputs, axis = 0)
383
+ return torch.special.entr(X).sum()
384
+
385
+
386
+
387
+ def loss_function(self, expert_outputs, cluster_net_outputs, X, psedo_label):
388
+ """
389
+ Compute loss function in a batch
390
+ Loss = L - Beta * Entropy(cluster_net_outputs)
391
+ L = -log [p_i * exp (-(xhat_i - x_i) ** 2)]
392
+ """
393
+
394
+ # Create one-hot psedo label
395
+ # print("Expert output" , expert_outputs)
396
+ encoded_arr = np.zeros((len(psedo_label), self.n_clusters), dtype=float)
397
+ for i in range(len(psedo_label)):
398
+ encoded_arr[i][psedo_label[i]] = 1
399
+ # print("Cluster network output", cluster_net_outputs)
400
+ # print("Encoded arr:", encoded_arr)
401
+ # Cross entropy loss
402
+ entropy_criterion = nn.CrossEntropyLoss()
403
+ entropy_loss = entropy_criterion(cluster_net_outputs, torch.tensor(psedo_label, dtype=torch.long))
404
+ # print("Entropy loss", entropy_loss)
405
+
406
+
407
+ # MOE reconstruction loss
408
+ loss = 0
409
+ for i in range(self.n_clusters):
410
+ mse = -((expert_outputs[i] - X)**2).mean(axis=1)
411
+ loss += cluster_net_outputs[:, i] * torch.exp(mse)
412
+
413
+ moe_loss = -torch.log(loss).sum()
414
+ # print('MOE loss', moe_loss)
415
+ return moe_loss - self.cl_loss_param * entropy_loss
416
+ # return moe_loss
417
+
418
+
419
+ def train_one_autoencoder(self, autoencoder, optimizer, criterion, data_loader, number_of_epochs, sparsity, rho, name='main', verbose=False):
420
+ """
421
+ Training one autoencoder
422
+ """
423
+ print('Training %s ...'%(name))
424
+ for epoch in range(number_of_epochs):
425
+
426
+ running_loss = 0.0
427
+ autoencoder.train()
428
+ for batch_index, (data) in enumerate(data_loader):
429
+ batch_size = data.size()[0]
430
+ # Duplicate if batch has one sample (handle one-sample err)
431
+ if batch_size == 1:
432
+ data = torch.cat([data, data], dim=0)
433
+ batch_size = 2
434
+
435
+ data = data.to(self.device).view(batch_size, -1)
436
+ # Get output decoder
437
+ rec_X = autoencoder(data)
438
+
439
+ if sparsity:
440
+ # Get latent
441
+ reg_loss = criterion(data, rec_X)
442
+ sparse_loss = self.sparse_loss(rho=rho, X = data, model=autoencoder)
443
+ loss = reg_loss + self.BETA * sparse_loss
444
+ # if batch_index & 100 == 0:
445
+ # print("Reg-loss: {} , Sparse-loss: {}".format(reg_loss, sparse_loss))
446
+ else:
447
+ loss = criterion(data, rec_X)
448
+ optimizer.zero_grad()
449
+ loss.backward()
450
+ optimizer.step()
451
+ running_loss += loss.data.numpy()
452
+ if batch_index % 200 ==0 and verbose:
453
+ print('epoch %d loss: %.5f batch: %d' % (epoch, running_loss/((batch_index + 1)), (batch_index + 1)*batch_size))
454
+ if batch_index != 0 and batch_index % 1000 == 0:
455
+ break
456
+ print('Done training %s'%(name))
457
+
458
+
459
+ def pretraining(self, dataloader):
460
+ """
461
+ Pretraining step
462
+ 1) Train a single main_autoencoder for the entire dataset
463
+ 2) Apply k-means for the embedding space after training to get label for cluster net
464
+ 3) Training i-th autoencoder using i-th assigned samples by K-means from the entire dataset
465
+ """
466
+ #---------Training main_autoencoder---------------
467
+ criterion = nn.MSELoss()
468
+ optimizer = torch.optim.Adam(self.main_autoencoder.parameters(), lr=args.lr, weight_decay = args.wd)
469
+
470
+ self.train_one_autoencoder(autoencoder=self.main_autoencoder, optimizer=optimizer,
471
+ criterion=criterion, data_loader= dataloader,
472
+ number_of_epochs=self.pretrain_epochs_main, name= "main_autoencoder",
473
+ verbose=True,
474
+ sparsity=False,
475
+ rho= self.RHO
476
+ )
477
+
478
+ # ----------K-means clustering --------------------------
479
+ print("------Clustering---------")
480
+
481
+ # Get latent X
482
+ batch_X = []
483
+ for batch_idx, (data) in enumerate(dataloader):
484
+ batch_size = data.size()[0]
485
+ # Duplicate if batch has one sample
486
+ if batch_size == 1:
487
+ data = torch.cat([data, data], dim=0)
488
+ batch_size = 2
489
+ data = data.to(self.device).view(batch_size, -1)
490
+ latent_X = self.main_autoencoder(data, latent=True)
491
+ print("BATCH LATENT X", latent_X)
492
+ batch_X.append(latent_X.detach().cpu().numpy())
493
+ full_latent_X = np.vstack(batch_X)
494
+
495
+ # Clustering
496
+ # self.cluster_algo = AgglomerativeClustering(n_clusters=self.n_clusters).fit(full_latent_X)
497
+ # print("Cluster algo", self.cluster_algo)
498
+ # self.cluster_algo.fit(full_latent_X)
499
+
500
+ # self.cluster_algo = KMeans(n_clusters=self.n_clusters, n_init= self.n_clusters, init="k-means++", random_state=42).fit(full_latent_X)
501
+ self.cluster_algo = SpectralClustering(n_clusters=self.n_clusters, random_state=42).fit(full_latent_X)
502
+ self.psedo_label = self.cluster_algo.labels_
503
+ print("Done clustering!")
504
+ print("Original label:", self.psedo_label)
505
+ # tsne = TSNE(n_components=2, random_state=42)
506
+ # X_tsne = tsne.fit_transform(full_latent_X)
507
+ # colors = ['black', 'red']
508
+ # for i in np.unique(self.cluster_algo.labels_):
509
+ # plt.scatter(X_tsne[self.cluster_algo.labels_ == i, 0], X_tsne[self.cluster_algo.labels_ == i, 1], color=colors[i], label=str(i))
510
+ # plt.xlabel('t-SNE feature 1')
511
+ # plt.ylabel('t-SNE feature 2')
512
+ # plt.legend()
513
+ # plt.show()
514
+
515
+ # ---------Training each autoencoder expert with predefined label from K-means---------------
516
+ for i in range(self.n_clusters):
517
+ # Get full dataset through batch loop
518
+ dataset = []
519
+ for batch_idx, (data) in enumerate(dataloader):
520
+ batch_size = data.size()[0]
521
+ # # Duplicate if batch has one sample
522
+ if batch_size == 1:
523
+ data = torch.cat([data, data], dim=0)
524
+ batch_size = 2
525
+ dataset.append(data.detach().cpu().numpy())
526
+ dataset = np.vstack(dataset)
527
+
528
+ # Extract data for specific expert i
529
+ data_expert_i = dataset[self.cluster_algo.labels_ == i]
530
+ data_expert_i = AutoEncoderDataset(data = data_expert_i)
531
+ dataset_expert_i = DataLoader(data_expert_i, batch_size = args.batch_size, shuffle = False)
532
+ optimizer = torch.optim.Adam(self.moe[i].parameters(), lr=args.lr, weight_decay = args.wd)
533
+ criterion = nn.MSELoss()
534
+ # Train expert_i
535
+ self.train_one_autoencoder(autoencoder=self.moe[i], optimizer=optimizer,
536
+ criterion=criterion, data_loader=dataset_expert_i,
537
+ number_of_epochs=self.pretrain_epochs, name="Expert {}".format(i),
538
+ verbose=True, sparsity=True, rho = self.RHO)
539
+
540
+ print("Done Pretraining step !")
541
+
542
+ return self.moe, full_latent_X
543
+
544
+
545
+ def get_expert_outputs(self, X, latent = False):
546
+ """
547
+ Get output of experts in a batch
548
+ Return: List of output of each expert
549
+ """
550
+ output = []
551
+ for i in range(self.n_clusters):
552
+ if latent:
553
+ output_expert_i = self.moe[i](X, latent = True)
554
+ else:
555
+ output_expert_i = self.moe[i](X)
556
+ output.append(output_expert_i)
557
+ return output
558
+
559
+
560
+ def main_training(self, dataloader, name = "MOE", verbose = True):
561
+ """
562
+ Main training to optimize loss function L = -log [p_i * exp (-(xhat_i - x_i) ** 2)]
563
+ """
564
+ print('Training %s ...'%(name))
565
+
566
+ # Add parameters
567
+ params = list(self.moe['cluster_net'].parameters())
568
+ for i in range(self.n_clusters):
569
+ params += list(self.moe[i].parameters())
570
+ self.moe[i].train()
571
+
572
+ optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay = args.wd)
573
+
574
+ for epoch in range(self.main_train_epochs):
575
+ running_loss = 0.0
576
+ self.moe['cluster_net'].train()
577
+
578
+ for batch_index, (data) in enumerate(dataloader):
579
+ batch_size = data.size()[0]
580
+ # Duplicate if batch has one sample (the last batch)
581
+ if batch_size == 1:
582
+ data = torch.cat([data, data], dim=0)
583
+ batch_size = 2
584
+ # Get psedo-label
585
+ psedo_label = self.psedo_label[batch_index: batch_index + batch_size]
586
+
587
+ # Get decoder output
588
+ expert_outputs = self.get_expert_outputs(data)
589
+ # Get latent output
590
+ # latent_outputs = self.get_expert_outputs(data, latent=True)
591
+
592
+ # # Concate k-latent outputs
593
+ # latent_tensor = latent_outputs[0]
594
+ # for i in range(1, len(latent_outputs)):
595
+ # latent_tensor = torch.hstack((latent_tensor, latent_outputs[i]))
596
+ # # if batch_index % 100 == 0:
597
+ # # # print("Latent tensor", latent_tensor)
598
+
599
+ clustering_net_outputs = self.moe['cluster_net'](data)
600
+
601
+ # print("Cluster net output", clustering_net_outputs)
602
+ loss = self.loss_function(expert_outputs=expert_outputs,
603
+ cluster_net_outputs=clustering_net_outputs, X = data,
604
+ psedo_label=psedo_label)
605
+
606
+ optimizer.zero_grad()
607
+ loss.backward()
608
+ optimizer.step()
609
+ running_loss += loss.data.numpy()
610
+ if batch_index % 100 ==0 and verbose:
611
+ print('epoch %d loss: %.5f batch: %d' % (epoch, running_loss/((batch_index + 1)), (batch_index + 1)*batch_size))
612
+ if batch_index != 0 and batch_index % 1000 == 0:
613
+ break
614
+
615
+ # Update psedolabel
616
+ if epoch != 0 and epoch % 10 == 0:
617
+ self.psedo_label = self.get_final_cluster(dataloader)
618
+ print("Updated psedo label!")
619
+ print("######################################")
620
+ print("New psedo label: ", self.psedo_label)
621
+ # self.cl_loss_param = self.cl_loss_param * 2
622
+
623
+ print("Done main training!")
624
+ return self.moe
625
+
626
+
627
+ def get_final_cluster(self, test_loader):
628
+ """
629
+ Assign final cluster for clustering based on cluster_net
630
+ """
631
+ # Convert to eval mode
632
+ for i in range(self.n_clusters):
633
+ self.moe[i].eval()
634
+ self.moe['cluster_net'].eval()
635
+ total_pred = []
636
+ for batch_idx, (data) in enumerate(test_loader):
637
+ batch_size = data.size()[0]
638
+ data = data.view(batch_size, -1).to(self.device)
639
+ # Get the hard assignment label
640
+ with torch.no_grad():
641
+ # # Get latent output
642
+ # latent_outputs = self.get_expert_outputs(data, latent=True)
643
+
644
+ # # Concate k-latent outputs
645
+ # latent_tensor = latent_outputs[0]
646
+ # for i in range(1, len(latent_outputs)):
647
+ # latent_tensor = torch.hstack((latent_tensor, latent_outputs[i]))
648
+
649
+ cluster_pred = self.moe['cluster_net'](data)
650
+ cluster_pred = cluster_pred.cpu().numpy()
651
+ batch_pred = np.argmax(cluster_pred, axis = 1)
652
+ total_pred.append(batch_pred)
653
+ total_pred = np.concatenate(total_pred, axis=0)
654
+ return total_pred
655
+
656
+
657
+ if __name__ == "__main__":
658
+ # sample_dir = "da_datasets/da_spanish/segments_0.2/1569_[cut_127sec].wav"
659
+ # dataset = CustomDataset(sample_dir=sample_dir,train= False)
660
+ # train_loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = False)
661
+ moe = MoESparseAutoencodersCL(args)
662
+ # mixture = moe.pretraining(train_loader)
663
+ # mixture = moe.main_training(train_loader)
664
+ # pred = moe.get_final_cluster(train_loader)
665
+ # print(pred)
666
+ total_params = sum(p.numel() for p in moe.parameters() if p.requires_grad)
667
+ print(total_params)
668
+
669
+
670
+
671
+
672
+
segment_process.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import numpy as np
4
+ import whisper
5
+ import torch
6
+ import argparse
7
+
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--data_dir", type=str, default= "./samples", help="data wav dir")
11
+ parser.add_argument("--segment_dir", type=str, default= "./segments_0.4", help="segment dir")
12
+ parser.add_argument("--model_type", type = str, default="tiny", help="model type")
13
+ parser.add_argument("--run_device", type = str, default="cpu", help="run device")
14
+ parser.add_argument("--window_length", type = float, default= 0.4, help="window length")
15
+ parser.add_argument("--overlap", type = float, default= 0, help="overlap")
16
+
17
+
18
+ # Define
19
+ opt = parser.parse_args()
20
+ DATA_DIR = opt.data_dir
21
+ SEGMENT_DIR = opt.segment_dir
22
+ model_type = opt.model_type
23
+ run_device = opt.run_device
24
+ window_length = opt.window_length
25
+ overlap = opt.overlap
26
+
27
+
28
+ if not os.path.exists(SEGMENT_DIR):
29
+ os.makedirs(SEGMENT_DIR)
30
+
31
+ # Load model
32
+ whisper_model = whisper.load_model(model_type, run_device)
33
+ embedding_dims = {"tiny": 384, 'small': 384, 'base': 512, 'medium':1024}
34
+
35
+
36
+
37
+
38
+ def extract_segment(input_file, output_file, start_time, end_time):
39
+ """
40
+ Extract one segment given start_time and end_time
41
+ input_file: input .wav file
42
+ output_file: extracted .wav segment
43
+ start_time, end_time: start-end time of the segment
44
+ """
45
+ # split_file_name = f'./{input_file}_segment_{start_time}_{end_time}.wav'
46
+ cmd= 'ffmpeg -i '+input_file+' -acodec copy -ss '+str(start_time)+' -to '+str(end_time)+' '+ output_file
47
+ os.system(cmd)
48
+
49
+
50
+
51
+ def split_audio_with_ffmpeg(input_file, output_dir, segment_length=window_length, overlap=overlap):
52
+ """
53
+ Extract all segments from original audio
54
+ """
55
+ input_filename = input_file.split("/")[-1]
56
+ if not os.path.exists(output_dir):
57
+ os.makedirs(output_dir)
58
+
59
+ # duration = float(subprocess.check_output(['ffprobe', '-i', input_file, '-show_entries', 'format=duration', '-v', 'quiet', '-of', 'csv=%s' % ("p=0")]))
60
+ duration = float(subprocess.check_output(['ffprobe', '-i', input_file, '-show_entries', 'format=duration', '-v', 'quiet', '-of', 'csv=%s' % ("p=0")]))
61
+
62
+ start_time = 0
63
+ last_flag = False
64
+ while start_time < duration and last_flag == False:
65
+ end_time = np.round(min(start_time + segment_length, duration), 2)
66
+ # Cover the last segment
67
+ if end_time + segment_length > duration:
68
+ end_time = duration
69
+ last_flag = True
70
+ output_file = os.path.join(output_dir, f"{input_filename}_segment_{start_time}_{end_time}.wav")
71
+ extract_segment(input_file, output_file, start_time, end_time)
72
+ start_time += segment_length - overlap
73
+
74
+
75
+
76
+ def extract_segment_embedding(segment_dir, save_segment_dir, window_length):
77
+ """
78
+ Extract embedding for each segment
79
+ """
80
+
81
+ audio = whisper.load_audio(segment_dir)
82
+ print("AUDIO SHAPE:", audio.shape)
83
+
84
+ # #Duplicate the array to get 30s chunk
85
+ # audio = np.tile(audio, int(30/window_length))
86
+
87
+ # print("AUDIO SHAPE:", audio.shape)
88
+ mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
89
+
90
+
91
+ # print("MEL SHAPE", mel.shape)
92
+ #--- this code to create the correct shape of mel spectrogram
93
+ while True:
94
+ nF, nT = np.shape(mel)
95
+ # print(nF, nT)
96
+ if nT > 3000:
97
+ mel = mel[:,0:3000]
98
+ break
99
+ else:
100
+ mel = torch.cat((mel, mel), -1)
101
+ mel = torch.unsqueeze(mel, 0)
102
+
103
+ wp_emb = whisper_model.embed_audio(mel)
104
+ print("Wb_emb shape:", wp_emb.shape)
105
+ # print("WB embedding:", wp_emb)
106
+
107
+ emb_1d = np.mean(wp_emb.cpu().detach().numpy(), axis=0)
108
+
109
+ emb_1d = np.mean(emb_1d, axis=0)
110
+
111
+ emb_1d = np.expand_dims(emb_1d, axis = 0)
112
+ print("Speaker embedding shape", emb_1d.shape)
113
+
114
+ np.save(save_segment_dir + '/{}.npy'.format(segment_dir.split("/")[-1]), emb_1d, allow_pickle=True)
115
+
116
+ return emb_1d
117
+
118
+
119
+
120
+ def delete_segment_after_done(segments_dir):
121
+ """
122
+ Delete segment after extracting embedding
123
+ """
124
+ for segment in os.listdir(segments_dir):
125
+ if segment.endswith('.wav'):
126
+ segment_path = segments_dir + "/" + segment
127
+ cmd = 'rm '+ segment_path
128
+ os.system(cmd)
129
+
130
+
131
+
132
+
133
+ if __name__ == "__main__":
134
+ data_list = sorted(os.listdir(DATA_DIR))
135
+
136
+ for sample in data_list:
137
+ sample_path = DATA_DIR + "/" + sample
138
+ segment_save_dir = SEGMENT_DIR + "/" + sample
139
+ if not os.path.exists(segment_save_dir):
140
+ os.makedirs(segment_save_dir)
141
+ # # Extract segments
142
+ split_audio_with_ffmpeg(input_file=sample_path, output_dir=segment_save_dir)
143
+
144
+ # Extract embedding
145
+ segment_list = sorted(os.listdir(segment_save_dir), key= lambda x: x.split("_")[-2])
146
+ for segment in segment_list:
147
+ segment_path = segment_save_dir + "/" + segment
148
+ # Extract embedding each segment
149
+ embed_1d = extract_segment_embedding(segment_dir=segment_path, save_segment_dir= segment_save_dir,window_length=window_length)
150
+
151
+ # Delele segment wav after embeddings are extracted
152
+ delete_segment_after_done(segments_dir=segment_save_dir)
153
+
154
+
155
+
156
+
train_mix_sae.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import datasets, transforms
3
+ from torch.utils.data import DataLoader
4
+ import numpy as np
5
+ import argparse
6
+ from load_dataset import CustomDataset
7
+ from create_DER import createDER
8
+ from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering, DBSCAN, AffinityPropagation
9
+ from pyannote.metrics.diarization import DiarizationErrorRate
10
+ from pyannote.core import Segment, Timeline, Annotation
11
+ from sklearn.manifold import TSNE
12
+ from mix_sae import MoESparseAutoencodersCL
13
+ import os
14
+ import pandas as pd
15
+
16
+
17
+ ######### NOTE: WRITE FILE FOR EACH METHOD ########################
18
+ parser = argparse.ArgumentParser(description='Deep Clustering Network')
19
+
20
+ # Dataset parameters
21
+ parser.add_argument('--dir', default='./a_dataset/model_size_english/',
22
+ help='dataset directory')
23
+ parser.add_argument('--input_dim', type=int, default=1280,
24
+ help='input dimension')
25
+ parser.add_argument('--n-classes', type=int, default=2,
26
+ help='output dimension')
27
+
28
+ # Training parameters
29
+ parser.add_argument('--lr', type=float, default=1e-3,
30
+ help='learning rate (default: 1e-4)')
31
+ parser.add_argument('--wd', type=float, default=1e-4,
32
+ help='weight decay (default: 5e-4)')
33
+ parser.add_argument('--batch-size', type=int, default=16,
34
+ help='input batch size for training')
35
+ parser.add_argument('--batch_size_moe', type=int, default=16,
36
+ help='input batch size for training')
37
+
38
+ parser.add_argument('--epoch', type=int, default=50,
39
+ help='number of epochs to train')
40
+ parser.add_argument('--pre-epoch', type=int, default=200,
41
+ help='number of pre-train epochs')
42
+ # parser.add_argument('--pretrain_epochs', type=int, default=80,
43
+ # help='pretraining step epochs')
44
+ # parser.add_argument('--pretrain', type=bool, default=True,
45
+ # help='whether use pre-training')
46
+ # parser.add_argument('--main_train_epochs', type=int, default=150,
47
+ # help='main_train epochs')
48
+ # Model parameters
49
+ parser.add_argument('--lamda', type=float, default=1,
50
+ help='coefficient of the reconstruction loss')
51
+ parser.add_argument('--beta', type=float, default=0.001,
52
+ help=('coefficient of the regularization term on '
53
+ 'clustering'))
54
+ parser.add_argument('--hidden-dims', default=[256, 64],
55
+ help='learning rate (default: 1e-4)')
56
+ parser.add_argument('--latent_dim', type=int, default=2,
57
+ help='latent space dimension')
58
+ parser.add_argument('--n-clusters', type=int, default=2,
59
+ help='number of clusters in the latent space')
60
+
61
+
62
+ # Utility parameters
63
+ parser.add_argument('--n-jobs', type=int, default=1,
64
+ help='number of jobs to run in parallel')
65
+ parser.add_argument('--log-interval', type=int, default=20,
66
+ help=('how many batches to wait before logging the '
67
+ 'training status'))
68
+ parser.add_argument("--window_length", type = float, default= 0.2, help="window length")
69
+ parser.add_argument("--overlap", type = float, default= 0, help="overlap")
70
+ parser.add_argument('--rho', type=float, default=0.2,
71
+ help='whether use pre-training')
72
+ parser.add_argument('--pretrain_epochs', type=int, default=10,
73
+ help='pretraining step epochs')
74
+ parser.add_argument('--pretrain_epochs_main', type=int, default= 20,
75
+ help='pretraining step epochs')
76
+ parser.add_argument('--pretrain', type=bool, default=True,
77
+ help='whether use pre-training')
78
+ parser.add_argument('--main_train_epochs', type=int, default =5,
79
+ help='main_train epochs')
80
+ parser.add_argument('--sparsity_param', type=float, default=0.01,
81
+ help='sparsity constract param')
82
+ parser.add_argument('--cl_loss_param', type=float, default= 1,
83
+ help='clasification loss param')
84
+
85
+ args = parser.parse_args()
86
+ label_dir = args.dir + "/label"
87
+ segment_dir = args.dir + "/large_segments_{}".format(args.window_length)
88
+ pyannote_label_dir = args.dir + "/label"
89
+ window_length = args.window_length
90
+ overlap = args.overlap
91
+
92
+ sample_list = sorted(os.listdir(segment_dir))
93
+ label_list = sorted(os.listdir(label_dir))
94
+ pyannote_label_list = sorted(os.listdir(pyannote_label_dir))
95
+
96
+
97
+ # # Create dataframe to store result
98
+ columns = ["Language", "Filename", "K-means_DER", "K-medoids_DER", "ONLY_PRE_DER" "MOE_CL_DER"]
99
+ # Create dataframe
100
+ df = pd.DataFrame(columns=columns)
101
+
102
+ for sample, label, py_label in zip(sample_list, label_list, pyannote_label_list):
103
+ print("Processing segments in folder {}".format(sample))
104
+ print("Label: ", label)
105
+
106
+ sample_path = segment_dir + "/" + sample
107
+ label_path = label_dir + '/' + label
108
+ pyannote_label_path = pyannote_label_dir + "/" + py_label
109
+ # ## -------------------------- BASELINE ML ----------------------------------
110
+ # agglo = AgglomerativeClustering(n_clusters=args.n_clusters)
111
+ # kmeans = KMeans(n_clusters=2)
112
+
113
+
114
+ # # Dataset for ML baseline
115
+ # ml_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
116
+ # data = ml_dataset.data
117
+ # # reduce_embeddings = dimension_reduce(embeddings=data, reduced_dims=2, reduce_method="pca")
118
+ # reduce_embeddings = data
119
+
120
+ # agglo_res = agglo.fit(reduce_embeddings)
121
+ # kmeans_res = kmeans.fit(reduce_embeddings)
122
+
123
+
124
+ # agglo_ref, agglo_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=agglo_res.labels_, window_length=window_length, overlap=overlap)
125
+ # kmean_ref, kmean_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=kmeans_res.labels_, window_length=window_length, overlap=overlap)
126
+
127
+ # # print("KMEANS HYP", kmean_hyp)
128
+ # # print("KMEDOIDS HYP", k_medoids_hyp)
129
+ # # print("AGGLO HYP", agglo_hyp)
130
+
131
+
132
+
133
+ # print("K-means label:", kmeans_res.labels_)
134
+ # print("Agglo label:", agglo_res.labels_)
135
+
136
+
137
+ der = DiarizationErrorRate(collar=0.25, skip_overlap=False)
138
+ # agglo_error = der(agglo_ref, agglo_hyp)
139
+ # kmeans_error = der(kmean_ref, kmean_hyp)
140
+
141
+
142
+ # print("Agglo DER: ", agglo_error)
143
+ # print("K-means DER: ", kmeans_error)
144
+
145
+
146
+ # Check tsne
147
+ # plt.style.use('grayscale')
148
+ # check_list = {"k-Means":kmeans_res,"k-Medoids": k_medoids_res, "Agglomerative":agglo_res}
149
+ # for algo in check_list:
150
+ # y_pred = check_list[algo].labels_
151
+ # tsne = TSNE(n_components=2, random_state=42)
152
+ # X_tsne = tsne.fit_transform(reduce_embeddings)
153
+ # colors = ['black', 'aqua']
154
+ # for i in np.unique(y_pred):
155
+ # plt.scatter(X_tsne[y_pred == i, 0], X_tsne[y_pred == i, 1], color=colors[i], label=str(i))
156
+ # plt.xlabel('t-SNE feature 1')
157
+ # plt.ylabel('t-SNE feature 2')
158
+ # plt.title('t-SNE visualization with cluster labels for {}'.format(algo))
159
+ # plt.legend()
160
+ # plt.show()
161
+
162
+ # # --------------------------- DEEP CLUSTERING ------------------------------------
163
+
164
+ train_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
165
+ # train_dataset = CustomDataset(sample_dir=sample_path)
166
+
167
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
168
+
169
+ test_dataset = CustomDataset(sample_dir=sample_path, embed_dim=args.input_dim)
170
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle= False)
171
+
172
+ # #Pretrain
173
+ # rec_loss_list = model.pretrain(train_loader, args.pre_epoch)
174
+
175
+ # for e in range(args.epoch):
176
+ # model.train()
177
+ # model.fit(e, train_loader)
178
+
179
+ # y_pred = []
180
+ # latent_X_list = []
181
+ # model.eval() # Set the model to evaluation mode
182
+
183
+
184
+ # for data in test_loader:
185
+ # batch_size = data.size()[0]
186
+ # data = data.view(batch_size, -1).to(model.device)
187
+ # latent_X = model.autoencoder(data, latent=True)
188
+ # print('Eval latent x', latent_X)
189
+ # latent_X = latent_X.detach().cpu().numpy()
190
+ # y_pred.append(model.kmeans.update_assign(latent_X).reshape(-1, 1))
191
+ # latent_X_list.append(latent_X)
192
+
193
+
194
+ # y_pred = np.concatenate(y_pred, axis=0)
195
+ # y_pred = list(np.squeeze(y_pred))
196
+ # print("Y_pred", y_pred)
197
+ # latent_X_list = np.concatenate(latent_X_list, axis = 0)
198
+ # algomerative = KMeans(n_clusters=args.n_clusters)
199
+ # cluster= algomerative.fit(latent_X_list)
200
+ # y_pred = cluster.labels_
201
+ # print("Y_pred", y_pred)
202
+ # print(latent_X_list)
203
+ # print(y_pred)
204
+ # # Perform t-SNE
205
+ # tsne = TSNE(n_components=2, random_state=42)
206
+ # X_tsne = tsne.fit_transform(latent_X_list)
207
+ # colors = ['black', 'aqua']
208
+ # for i in np.unique(y_pred):
209
+ # plt.scatter(X_tsne[y_pred == i, 0], X_tsne[y_pred == i, 1], color=colors[i], label=str(i))
210
+ # plt.xlabel('t-SNE feature 1')
211
+ # plt.ylabel('t-SNE feature 2')
212
+ # plt.title('t-SNE visualization with cluster labels for DCN-SD')
213
+ # plt.legend()
214
+ # plt.show()
215
+
216
+ # # -------------------SPECTRAL NET-----------------------------
217
+ # data = torch.tensor(train_dataset.data, dtype=torch.float)
218
+ # test_data = torch.tensor(test_dataset.data, dtype= torch.float)
219
+ # spectralnet = SpectralNet(n_clusters=2,
220
+ # should_use_ae = True,
221
+ # should_use_siamese= False,
222
+ # ae_hiddens = [256, 256, 512, 2],
223
+ # ae_epochs = 150,
224
+ # ae_batch_size = 128,
225
+ # ae_patience = 30,
226
+ # # siamese_hiddens = [384, 384, 128, 2],
227
+ # # siamese_epochs = 150,
228
+ # # siamese_batch_size = 128,
229
+ # # siamese_patience = 30,
230
+ # spectral_hiddens = [384, 384, 512, 2],
231
+ # spectral_epochs = 300,
232
+ # spectral_batch_size = 128,
233
+ # spectral_patience = 60)
234
+ # spectralnet.fit(data) # X is the dataset and it should be a torch.Tensor
235
+ # cluster_assignments = spectralnet.predict(data) # Get the final assignments to cluster
236
+ # print("Spectral pred", cluster_assignments)
237
+
238
+
239
+ # #------------------MOE ------------------------------------
240
+ # moe = MoEAutoencoders(args=args)
241
+ # mixture_moe = moe.pretraining(train_loader)
242
+ # mixture_moe = moe.main_training(train_loader)
243
+ # moe_pred = moe.get_final_cluster(train_loader)
244
+ # print("MOE pred", moe_pred)
245
+
246
+
247
+ # #------------------MOE SPARITY------------------------------------
248
+ # moe_spa = MoESparseAutoencoders(args=args)
249
+ # mixture_moe_spa = moe_spa.pretraining(train_loader)
250
+ # mixture_moe_spa = moe_spa.main_training(train_loader)
251
+ # moe_spa_pred = moe_spa.get_final_cluster(train_loader)
252
+ # print("MOE SPARITY pred", moe_spa_pred)
253
+ # plt.style.use("seaborn-v0_8-deep")
254
+ #------------------MOE SPARITY CL------------------------------------
255
+ moe_cl = MoESparseAutoencodersCL(args=args)
256
+ mixture_moe_cl, full_latent_X = moe_cl.pretraining(train_loader)
257
+ pre_label = moe_cl.psedo_label
258
+ # latent_X_list = np.concatenate(full_latent_X, axis = 0)
259
+ # Perform t-SNE
260
+ # tsne = TSNE(n_components=2, random_state=42)
261
+ # X_tsne = tsne.fit_transform(full_latent_X)
262
+ # colors = ['red', 'green']
263
+ # for i in np.unique(pre_label):
264
+ # plt.scatter(X_tsne[pre_label == i, 0], X_tsne[pre_label == i, 1], color=colors[i], label="Spk {}".format(str(i)))
265
+ # plt.legend()
266
+ # # plt.axis("off")
267
+ # plt.xticks([])
268
+ # plt.yticks([])
269
+
270
+ # plt.show()
271
+ mixture_moe_cl = moe_cl.main_training(train_loader)
272
+ moe_cl_pred = moe_cl.get_final_cluster(train_loader)
273
+ # for data in test_loader:
274
+ # batch_size = data.size()[0]
275
+ # data = data.view(batch_size, -1).to(model.device)
276
+ # latent_X = model.autoencoder(data, latent=True)
277
+ # print('Eval latent x', latent_X)
278
+ # latent_X = latent_X.detach().cpu().numpy()
279
+ # y_pred.append(model.kmeans.update_assign(latent_X).reshape(-1, 1))
280
+ # latent_X_list.append(latent_X)
281
+
282
+
283
+
284
+
285
+
286
+
287
+ # spt_ref, spt_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=cluster_assignments, window_length=window_length, overlap=overlap)
288
+ # spt_error = der(spt_ref, spt_hyp)
289
+
290
+ # print("Spectral Net DER: ", spt_error)
291
+
292
+ # moe_ref, moe_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_pred, window_length=window_length, overlap=overlap)
293
+ # moe_error = der(moe_ref, moe_hyp)
294
+ # print("MOE DER", moe_error)
295
+
296
+ # moe_spa_ref, moe_spa_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_spa_pred, window_length=window_length, overlap=overlap)
297
+ # moe_spa_error = der(moe_spa_ref, moe_spa_hyp)
298
+ # print("MOE SPA DER", moe_spa_error)
299
+
300
+ moe_cl_ref, moe_cl_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=pre_label, window_length=window_length, overlap=overlap)
301
+ moe_cl_pre_error = der(moe_cl_ref, moe_cl_hyp)
302
+
303
+ print("MOE CL PRE DER", moe_cl_pre_error)
304
+
305
+ moe_cl_ref, moe_cl_hyp = createDER(label_path=label_path, sample_dir=sample_path, prediction=moe_cl_pred, window_length=window_length, overlap=overlap)
306
+ moe_cl_error = der(moe_cl_ref, moe_cl_hyp)
307
+ print("MOE CL pred", moe_cl_pred)
308
+ print("MOE CL DER", moe_cl_error)
309
+
310
+
311
+ # # Calculate PYANNOTE DER
312
+ # ref, py_ref = create_pyannote_timeline(label_path=label_path, pyannote_label_path=pyannote_label_path)
313
+ # py_error = der(ref, py_ref)
314
+ # print("Pyannote DER: ", py_error)
315
+ # # Write result each algo
316
+ # with open(save_dir + "/{}.txt".format(file_name[:-4]), 'a') as file:
317
+ # file.write("Algo: {}\n".format(algo))
318
+ # file.write("Label \n")
319
+ # file.write(" ".join(label_series) + "\n")
320
+ # file.write("Pred \n")
321
+ # file.write(" ".join(pred_series) + "\n")
322
+ # file.write("DER: {}\n".format(error))
323
+ # file.write("\n")
324
+ # file.close()
325
+
326
+
327
+ # Update result to csv
328
+ language = args.dir.split("/")[-2]
329
+ file_name = sample
330
+ new_row = {"Language":language, "Filename": file_name,"MOE_CL":moe_cl_error}
331
+
332
+ df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
333
+ df.to_csv("1111_co_large_MOE_CL_{}_{}.csv".format(language, window_length), index=False)
334
+ # Write result each algo
335
+ with open("./1111_co_large_report_MOE_CL_{}_{}.txt".format(language, window_length), 'a') as file:
336
+ file.write("Filename: {}\n".format(file_name))
337
+ file.write("Pred \n")
338
+ for digit in moe_cl_pred:
339
+ file.write("{} ".format(digit))
340
+ file.write("\n")
341
+ file.write("DER: {}\n".format(moe_cl_error))
342
+ file.write("\n")
343
+ file.close()
whisper/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (9.65 kB). View file
 
whisper/__pycache__/audio.cpython-311.pyc ADDED
Binary file (7.55 kB). View file
 
whisper/__pycache__/decoding.cpython-311.pyc ADDED
Binary file (45.9 kB). View file
 
whisper/__pycache__/model.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
whisper/__pycache__/timing.cpython-311.pyc ADDED
Binary file (17.9 kB). View file
 
whisper/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (19.4 kB). View file
 
whisper/__pycache__/transcribe.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
whisper/__pycache__/utils.cpython-311.pyc ADDED
Binary file (12.9 kB). View file
 
whisper/__pycache__/version.cpython-311.pyc ADDED
Binary file (207 Bytes). View file