xiaoxuezi commited on
Commit
ce7b81a
1 Parent(s): aea15e1
SpeakerNet.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy, sys, random
6
+ from DatasetLoader import test_dataset_loader
7
+ import importlib
8
+ import time, itertools
9
+ from utils.log import init_log
10
+ from tqdm import tqdm
11
+ import wandb
12
+ from tuneThreshold import *
13
+
14
+
15
+ class SpeakerNet(nn.Module):
16
+
17
+ def __init__(self, model, trainfunc, nPerSpeaker):
18
+ super(SpeakerNet, self).__init__()
19
+
20
+ self.model = model
21
+ self.loss = trainfunc
22
+ self.nPerSpeaker = nPerSpeaker
23
+
24
+ def forward(self, data, label=None):
25
+
26
+ data = data.reshape(-1, data.size()[-1])
27
+ outp = self.model(data)
28
+
29
+ if label == None:
30
+ return outp
31
+
32
+ else:
33
+
34
+ emb = outp.reshape(-1, self.nPerSpeaker, outp.size()[-1]).squeeze(1)
35
+ nloss, prec1 = self.loss(emb, label)
36
+
37
+
38
+ return nloss, prec1
39
+
40
+
41
+ class Trainer(object):
42
+
43
+ def __init__(self, cfg, model, optimizer, scheduler, device):
44
+ self.cfg = cfg
45
+ self.model = model
46
+ self.optimizer = optimizer
47
+ self.scheduler = scheduler
48
+ self.device = device
49
+ logging = init_log(cfg.save_dir)
50
+ self._print = logging.info
51
+ self.best = 0
52
+ self.test_eer = 0
53
+ self.test_mindcf = 0
54
+ self.best_model = []
55
+
56
+ def train(self, epoch, dataloader):
57
+ self.model.train()
58
+ pbar = tqdm(dataloader)
59
+ loss = 0
60
+ top1 = 0
61
+ index = 0
62
+ counter = 0
63
+
64
+ for data in pbar:
65
+ x, label = data[0].to(self.device), data[1].long().to(self.device)
66
+ nloss, prec1 = self.model(x, label)
67
+
68
+ self.optimizer.zero_grad()
69
+ nloss.backward()
70
+ self.optimizer.step()
71
+ # self.scheduler.step()
72
+
73
+ loss += nloss.detach().cpu().item()
74
+ top1 += prec1.detach().cpu().item()
75
+ index += x.size(0)
76
+ counter += 1
77
+
78
+ if self.cfg.wandb:
79
+ wandb.log({
80
+ "epoch": epoch,
81
+ "train_acc": top1 / counter,
82
+ "train_loss": loss / counter,
83
+ })
84
+ pbar.set_description("Train Epoch:%3d ,Tloss:%.3f, Tacc:%.3f" % (epoch, loss/counter, top1/counter))
85
+
86
+ # self.scheduler.step()
87
+ self._print('epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
88
+ epoch, loss/counter, top1/counter, index))
89
+
90
+ def test(self, epoch, test_list, test_path, nDataLoaderThread, eval_frames, num_eval=10):
91
+
92
+ self.model.eval()
93
+ feats = {}
94
+
95
+ # read all lines
96
+ with open(test_list) as f:
97
+ lines = f.readlines()
98
+ files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
99
+ setfiles = list(set(files))
100
+ setfiles.sort()
101
+
102
+ # Define test data loader
103
+ test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval)
104
+
105
+ test_loader = torch.utils.data.DataLoader(
106
+ test_dataset,
107
+ batch_size=1,
108
+ shuffle=False,
109
+ num_workers=nDataLoaderThread,
110
+ drop_last=False,
111
+ sampler=None
112
+ )
113
+
114
+ # Extract features for every wav
115
+ for idx, data in enumerate(tqdm(test_loader)):
116
+
117
+ inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav')
118
+ with torch.no_grad():
119
+ ref_feat = self.model(inp1).detach().cpu()
120
+
121
+ feats[data[1][0]] = ref_feat
122
+
123
+ all_scores = []
124
+ all_labels = []
125
+ all_trials = []
126
+
127
+ # Read files and compute all scores
128
+ for idx, line in enumerate(tqdm(lines)):
129
+
130
+ data = line.split()
131
+
132
+ # Append random label if missing
133
+ if len(data) == 2:
134
+ data = [random.randint(0, 1)] + data
135
+ ref_feat = feats[data[1]].to(self.device)
136
+ com_feat = feats[data[2]].to(self.device)
137
+
138
+ if self.model.loss.test_normalize:
139
+ ref_feat = F.normalize(ref_feat, p=2, dim=1)
140
+ com_feat = F.normalize(com_feat, p=2, dim=1)
141
+
142
+ # dist = F.pairwise_distance(ref_feat.unsqueeze(-1),
143
+ # com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
144
+ #
145
+ # score = -1 * numpy.mean(dist)
146
+ dist = F.cosine_similarity(ref_feat.unsqueeze(-1),
147
+ com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
148
+ score = numpy.mean(dist)
149
+
150
+ all_scores.append(score)
151
+ all_labels.append(int(data[0]))
152
+ all_trials.append(data[1] + " " + data[2])
153
+
154
+ result = tuneThresholdfromScore(all_scores, all_labels, [1, 0.1])
155
+ fnrs, fprs, thresholds = ComputeErrorRates(all_scores, all_labels)
156
+ mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, self.cfg.dcf_p_target, self.cfg.dcf_c_miss, self.cfg.dcf_c_fa)
157
+ self.test_eer = result[1]
158
+ self.test_mindcf = mindcf
159
+ self.threshold = threshold
160
+ if self.cfg.wandb:
161
+ wandb.log({
162
+ "test_eer": self.test_eer,
163
+ "test_MinDCF": self.test_mindcf,
164
+ })
165
+ self._print('epoch:{} - test EER: {:.3f} and test MinDCF: {:.3f} total sample: {} threshold: {:.3f}'.format(
166
+ epoch, self.test_eer, self.test_mindcf, len(lines), self.threshold))
167
+
168
+ return self.test_eer
169
+
170
+ def save_model(self, epoch):
171
+ if self.test_eer < self.best or self.best == 0:
172
+ self.best = self.test_eer
173
+ if self.cfg.wandb:
174
+ wandb.run.summary["best_accuracy"] = self.best
175
+ model_state_dict = self.model.state_dict()
176
+ optimizer_state_dict = self.optimizer.state_dict()
177
+ scheduler_state_dict = self.scheduler.state_dict()
178
+ file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf)
179
+ if not os.path.exists(self.cfg.save_dir):
180
+ os.mkdir(self.cfg.save_dir)
181
+ torch.save({
182
+ 'epoch': epoch,
183
+ 'test_eer': self.test_eer,
184
+ 'test_mindcf': self.test_mindcf,
185
+ 'model_state_dict': model_state_dict,
186
+ 'optimizer_state_dict': optimizer_state_dict,
187
+ 'scheduler_state_dict': scheduler_state_dict},
188
+ os.path.join(self.cfg.save_dir, file_save_path))
189
+ self.best_model.append(file_save_path)
190
+ if len(self.best_model) > 3:
191
+ del_file = os.path.join(self.cfg.save_dir, self.best_model.pop(0))
192
+ if os.path.exists(del_file):
193
+ os.remove(del_file)
194
+ else:
195
+ print("no exists {}".format(del_file))
196
+ # 每20个epoch保存一下
197
+ if epoch % 20 == 0:
198
+ model_state_dict = self.model.state_dict()
199
+ optimizer_state_dict = self.optimizer.state_dict()
200
+ scheduler_state_dict = self.scheduler.state_dict()
201
+ file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf)
202
+ if not os.path.exists(self.cfg.save_dir):
203
+ os.mkdir(self.cfg.save_dir)
204
+ if not os.path.exists(os.path.join(self.cfg.save_dir, file_save_path)):
205
+ torch.save({
206
+ 'epoch': epoch,
207
+ 'test_eee': self.test_eer,
208
+ 'test_mindcf': self.test_mindcf,
209
+ 'model_state_dict': model_state_dict,
210
+ 'optimizer_state_dict': optimizer_state_dict,
211
+ 'scheduler_state_dict': scheduler_state_dict},
212
+ os.path.join(self.cfg.save_dir, file_save_path))
213
+
214
+ def scoretxt(self, score_file, test_list, test_path, eval_frames, num_eval=10):
215
+
216
+ self.model.eval()
217
+ feats = {}
218
+
219
+ # read all lines
220
+ with open(test_list) as f:
221
+ lines = f.readlines()
222
+ files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
223
+ setfiles = list(set(files))
224
+ setfiles.sort()
225
+
226
+ # Define test data loader
227
+ test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval)
228
+
229
+ test_loader = torch.utils.data.DataLoader(
230
+ test_dataset,
231
+ batch_size=1,
232
+ shuffle=False,
233
+ drop_last=False,
234
+ sampler=None
235
+ )
236
+
237
+ # Extract features for every wav
238
+ for idx, data in enumerate(tqdm(test_loader)):
239
+
240
+ inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav')
241
+ with torch.no_grad():
242
+ ref_feat = self.model(inp1).detach().cpu()
243
+
244
+ feats[data[1][0]] = ref_feat
245
+
246
+
247
+ f = open(score_file, "w")
248
+ # Read files and compute all scores
249
+ for idx, line in enumerate(tqdm(lines)):
250
+
251
+ data = line.split()
252
+
253
+ # Append random label if missing
254
+ ref_feat = feats[data[-2]].to(self.device)
255
+ com_feat = feats[data[-1]].to(self.device)
256
+
257
+ if self.model.loss.test_normalize:
258
+ ref_feat = F.normalize(ref_feat, p=2, dim=1)
259
+ com_feat = F.normalize(com_feat, p=2, dim=1)
260
+
261
+ # dist = F.pairwise_distance(ref_feat.unsqueeze(-1),
262
+ # com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
263
+ #
264
+ # score = -1 * numpy.mean(dist)
265
+ dist = F.cosine_similarity(ref_feat.unsqueeze(-1),
266
+ com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
267
+ score = numpy.mean(dist)
268
+
269
+ score_line = str(score) + " " + data[-2] + " " + data[-1]
270
+ f.write(score_line+'\n')
271
+ f.close()
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import net
4
+ import argparse
5
+ from config import set_cfg, cfg
6
+ from SpeakerNet import *
7
+ import lossfunction
8
+ from DatasetLoader import loadWAV
9
+
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--config_name", type=str, default="ECAPA_TDNN_data_cfg", help="the configs name that will as a base configs")
12
+ parser.add_argument("--resume", default="train_models/epoch_37_ECAPA_TDNN2.48", type=str, help="resume path")
13
+ args = parser.parse_args()
14
+ global cfg
15
+ assert args.config_name is not None
16
+ if args.config_name:
17
+ set_cfg(args.config_name)
18
+ cfg.replace(vars(args))
19
+ del args
20
+
21
+ device = torch.device("cpu")
22
+ model = getattr(net, cfg.model)().to(device)
23
+ loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device)
24
+ model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker)
25
+
26
+ ckpt = torch.load("train_models/epoch_37_ECAPA_TDNN2.48", map_location="cpu")
27
+ model.load_state_dict(ckpt['model_state_dict'], strict=False)
28
+ print("checkpoint加载完毕!")
29
+
30
+ model.eval()
31
+
32
+ def SpeakerVerification(path1,path2):
33
+ inp1 = loadWAV(path1, max_frames=300, evalmode=True)
34
+ inp2 = loadWAV(path2, max_frames=300, evalmode=True)
35
+ # print(inp1.shape)
36
+ inp1 = torch.FloatTensor(inp1)
37
+ inp2 = torch.FloatTensor(inp2)
38
+ # print(inp1.shape)
39
+ with torch.no_grad():
40
+ emb1 = model(inp1).detach().cpu()
41
+ emb2 = model(inp2).detach().cpu()
42
+ emb1 = F.normalize(emb1, p=2, dim=1)
43
+ emb2 = F.normalize(emb2, p=2, dim=1)
44
+ dist = F.cosine_similarity(emb1.unsqueeze(-1), emb2.unsqueeze(-1).transpose(0, 2)).numpy()
45
+ score = numpy.mean(dist)
46
+ print(score)
47
+ # threshold = 0.414
48
+ if score >= 0.414:
49
+ output = "同一个人"
50
+ else:
51
+ output = "不同的人"
52
+
53
+ return output
54
+
55
+ inputs = [
56
+ gr.inputs.Audio(source="upload", type="filepath", label="Speaker #1", optional=False),
57
+ gr.inputs.Audio(source="upload", type="filepath", label="Speaker #2", optional=False)
58
+ ]
59
+
60
+
61
+ examples = [["example/speaker1-1.wav", "example/speaker1-2.wav"],
62
+ ["example/speaker1-1.wav", "example/speaker2-1.wav"],
63
+ ["example/speaker2-1.wav", "example/speaker2-2.wav"],
64
+ ["example/speaker1-2.wav", "example/speaker2-2.wav"],
65
+ ["example/speaker3-1.wav", "example/speaker3-2.wav"],
66
+ ["example/speaker3-1.wav", "example/speaker4-1.wav"],
67
+ ["example/speaker4-1.wav", "example/speaker4-2.wav"],
68
+ ["example/speaker3-2.wav", "example/speaker4-2.wav"],
69
+ ["example/speaker4-1.wav", "example/speaker5-2.wav"],
70
+ ]
71
+
72
+ iface = gr.Interface(fn=SpeakerVerification, inputs=inputs, outputs="text", examples=examples)
73
+ iface.launch(share=True)
74
+
75
+ if __name__ == '__main__':
76
+ # print(SpeakerVerification("example/speaker1-1.wav", "example/speaker1-2.wav"))
77
+ pass
config.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Config(object):
2
+ def __init__(self, config_dict: dict):
3
+ for key, val in config_dict.items():
4
+ if val is not None:
5
+ self.__setattr__(key, val)
6
+
7
+ def copy(self, new_config_dict={}):
8
+ ret = Config(vars(self))
9
+ for key, val in new_config_dict.items():
10
+ if val is not None:
11
+ ret.__setattr__(key, val)
12
+ return ret
13
+
14
+ def replace(self, new_config_dict):
15
+ if isinstance(new_config_dict, Config):
16
+ new_config_dict = vars(new_config_dict)
17
+
18
+ for key, val in new_config_dict.items():
19
+ if val is not None:
20
+ self.__setattr__(key, val)
21
+
22
+ def print(self):
23
+ for k, v in vars(self).items():
24
+ print(k, '=', v)
25
+
26
+ # def parser_val(self, val):
27
+ # if isinstance(val, dict):
28
+ # return Config(val)
29
+ # elif isinstance(val, list):
30
+ # for i in range(len(val)):
31
+ # if val is not None:
32
+ # val[i] = self.parser_val(val[i])
33
+ # return val
34
+ # else:
35
+ # return val
36
+
37
+ def __str__(self):
38
+ return str(vars(self))
39
+
40
+
41
+ base_config = Config({
42
+ "project": "speaker_verification",
43
+ "name": "VGGVox",
44
+ "save_dir": "train_models/",
45
+ "resume": "",
46
+
47
+ # Training and test data
48
+ "dataset": Config({
49
+ "name": "voxceleb2_wav",
50
+ "train_list": "data/train_list.txt",
51
+ "test_list": "data/veri_list.txt",
52
+ "train_path": "data/voxceleb2",
53
+ "test_path": "data/voxceleb1",
54
+ "musan_path": "data/musan_split", # 噪声文件
55
+ "rir_path": "data/RIRS_NOISES/simulated_rirs", # 混响文件
56
+ }),
57
+
58
+
59
+ # Data loader
60
+ "max_frames": 300, # 训练时帧长
61
+ "eval_frames": 300,
62
+ "batch_size": 64,
63
+ "max_seg_per_spk": 500, # 每个说话人最大的语音段数
64
+ "nDataLoaderThread": 16, # 多线程加载
65
+ "augment": True, # 是否数据增强
66
+ "seed": 10,
67
+ "segment": 1,
68
+
69
+ # Training details
70
+ "test_interval": 1, # 测试间隔
71
+ "max_epoch": 500,
72
+
73
+ # Model definition
74
+ "n_mels": 40,
75
+ "log_input": False,
76
+ "model": "Vgg",
77
+ "encoder_type": "SAP",
78
+ "nOut": 512,
79
+
80
+ # Loss functions
81
+ "loss": "SoftmaxProto", # lossfunction function
82
+ "hard_prob": 0.5,
83
+ "hard_rank": 10,
84
+ "margin": 0.2,
85
+ "scale": 30,
86
+ "nPerSpeaker": 2, # 同一段语音取多少组
87
+ "nClasses": 5994,
88
+
89
+ # Optimizer
90
+ "optimizer": "adam",
91
+ "scheduler": "steplr",
92
+ "lr": 0.001,
93
+ "lr_decay": 0.95,
94
+ "weight_decay": 0,
95
+
96
+ # Evaluation parameters
97
+ "dcf_p_target": 0.05,
98
+ "dcf_c_miss": 1,
99
+ "dcf_c_fa": 1,
100
+
101
+ # eval
102
+ "eval": False,
103
+ })
104
+
105
+ cfg = base_config
106
+
107
+ vgg_cfg = Config({
108
+ "name": "vgg_spectrogram",
109
+ "model": "vgg",
110
+ "batch_size": 64,
111
+ "nPerSpeaker": 2,
112
+ })
113
+
114
+ Unet_cfg = Config({
115
+ "name": "Unet",
116
+ "model": "UNetVgg",
117
+ "batch_size": 48,
118
+ "nPerSpeaker": 2,
119
+ "loss": "Unetloss"
120
+ })
121
+
122
+ UnetMask_cfg = Config({
123
+ "name": "UnetMask",
124
+ "model": "UNetVggMask",
125
+ "batch_size": 16,
126
+ "nPerSpeaker": 2,
127
+ "segment": 3,
128
+ "loss": "UnetMaskloss"
129
+ })
130
+
131
+ ECAPA_TDNN_cfg = Config({
132
+ "name": "ECAPA_TDNNm",
133
+ "model": "ECAPA_TDNN",
134
+ "loss": "AamSoftmaxProto",
135
+ "batch_size": 180,
136
+ "nPerSpeaker": 2,
137
+ "nOut": 192,
138
+ })
139
+
140
+ ECAPA_TDNNm_cfg = Config({
141
+ "name": "ECAPA_TDNNm",
142
+ "model": "ECAPA_TDNN",
143
+ "batch_size": 180,
144
+ "nPerSpeaker": 2,
145
+ "nOut": 192,
146
+ })
147
+
148
+ ECAPA_TDNN1024_cfg = Config({
149
+ "name": "ECAPA_TDNN1024",
150
+ "model": "ECAPA_TDNN",
151
+ "batch_size": 80,
152
+ "nPerSpeaker": 2,
153
+ "channels": 1024,
154
+ "nOut": 192,
155
+ })
156
+
157
+ ECAPA_TDNN_ks5_cfg = Config({
158
+ "name": "ECAPA_TDNN_ks5",
159
+ "model": "ECAPA_TDNN_ks5",
160
+ "batch_size": 180,
161
+ "nPerSpeaker": 2,
162
+ "nOut": 192,
163
+ })
164
+
165
+ ECAPA_TDNN_L2_cfg = Config({
166
+ "name": "ECAPA_TDNN_L2_pre",
167
+ "model": "ECAPA_TDNN_L2",
168
+ "batch_size": 180,
169
+ "nPerSpeaker": 2,
170
+ "nOut": 192,
171
+ "resume": "train_models/speaker_verification_ECAPA_TDNN/20210725/epoch:47,EER:2.5981,MinDCF:0.1912"
172
+ })
173
+
174
+ ECAPA_TDNN_br_cfg = Config({
175
+ "name": "ECAPA_TDNN_br",
176
+ "model": "ECAPA_TDNN_br",
177
+ "batch_size": 180,
178
+ "nPerSpeaker": 2,
179
+ "nOut": 192,
180
+ })
181
+
182
+ ECAPATDNN_cfg = Config({
183
+ "name": "ECAPATDNN",
184
+ "model": "ECAPATDNN",
185
+ "batch_size": 110,
186
+ "nPerSpeaker": 2,
187
+ "nOut": 192,
188
+ "input_size": 80,
189
+ })
190
+
191
+ HRNet_cfg = Config({
192
+ "name": "hrnet",
193
+ "model": "hrnet",
194
+ "max_frames": 224,
195
+ "eval_frames": 224,
196
+ "batch_size": 48,
197
+ "nPerSpeaker": 2,
198
+ "nOut": 1024,
199
+ "input_size": 224*224,
200
+
201
+ "model_cfg": Config({
202
+ "hrnet_name": "w48",
203
+ "STAGE1": {
204
+ "NUM_MODULES": 1,
205
+ "NUM_BRANCHES": 1,
206
+ "BLOCK": "BOTTLENECK",
207
+ "NUM_BLOCKS": [4],
208
+ "NUM_CHANNELS": [64],
209
+ "FUSE_METHOD": "SUM"
210
+ },
211
+ "STAGE2": {
212
+ "NUM_MODULES": 1,
213
+ "NUM_BRANCHES": 2,
214
+ "BLOCK": "BASIC",
215
+ "NUM_BLOCKS": [4, 4],
216
+ "NUM_CHANNELS": [18, 36],
217
+ "FUSE_METHOD": "SUM"
218
+ },
219
+ "STAGE3": {
220
+ "NUM_MODULES": 4,
221
+ "NUM_BRANCHES": 3,
222
+ "BLOCK": "BASIC",
223
+ "NUM_BLOCKS": [4, 4, 4],
224
+ "NUM_CHANNELS": [18, 36, 72],
225
+ "FUSE_METHOD": "SUM"
226
+ },
227
+ "STAGE4": {
228
+ "NUM_MODULES": 3,
229
+ "NUM_BRANCHES": 4,
230
+ "BLOCK": "BASIC",
231
+ "NUM_BLOCKS": [4, 4, 4, 4],
232
+ "NUM_CHANNELS": [18, 36, 72, 144],
233
+ "FUSE_METHOD": "SUM"
234
+ },
235
+ }),
236
+
237
+ })
238
+
239
+ VGG_TDNN_cfg = Config({
240
+ "name": "Vggtdnn1",
241
+ "model": "Vggtdnn",
242
+ "batch_size": 256,
243
+ "nOut": 512,
244
+ "nDataLoaderThread": 16,
245
+ })
246
+
247
+ ResNetSE34V2_cfg = Config({
248
+ "name": "ResNetSE34V2",
249
+ "model": "ResNetSE34V2",
250
+ "batch_size": 128,
251
+ "nOut": 512,
252
+ "nDataLoaderThread": 16,
253
+ })
254
+
255
+ HRTDNN_cfg = Config({
256
+ "name": "hrtdnn",
257
+ "model": "hrtdnn",
258
+ "max_frames": 300,
259
+ "eval_frames": 300,
260
+ "batch_size": 96,
261
+ "nPerSpeaker": 2,
262
+ "nOut": 256,
263
+
264
+ "model_cfg": Config({
265
+ "hrnet_name": "hrtdnn",
266
+ "STAGE1": {
267
+ "NUM_BRANCHES": 1,
268
+ "BLOCK": 'TDNNBlock',
269
+ "NUM_CHANNELS": [128],
270
+ "FUSE_METHOD": "SUM"
271
+ },
272
+ "STAGE2": {
273
+ "NUM_BRANCHES": 2,
274
+ "BLOCK": 'TDNNBlock',
275
+ "NUM_CHANNELS": [128, 512],
276
+ "FUSE_METHOD": "SUM"
277
+ },
278
+ "STAGE3": {
279
+ "NUM_BRANCHES": 3,
280
+ "BLOCK": 'TDNNBlock',
281
+ "NUM_CHANNELS": [128, 512, 1024],
282
+ "FUSE_METHOD": "SUM"
283
+ },
284
+
285
+ }),
286
+
287
+ })
288
+
289
+ ResTDNN_cfg = Config({
290
+ "name": "ResTDNN",
291
+ "model": "ResTDNN",
292
+ "batch_size": 110,
293
+ "nOut": 256,
294
+ "nDataLoaderThread": 16,
295
+ })
296
+
297
+ TDNN_VGG_cfg = Config({
298
+ "name": "TDNN_VGG",
299
+ "model": "TDNN_VGG",
300
+ "batch_size": 64,
301
+ "nOut": 256,
302
+ "nDataLoaderThread": 16,
303
+ })
304
+
305
+ ResNet_TDNN_cfg = Config({
306
+ "name": "ResNet_TDNN",
307
+ "model": "ResNet_TDNN",
308
+ "batch_size": 96,
309
+ "nOut": 192,
310
+ "nDataLoaderThread": 16,
311
+ })
312
+
313
+ ResNet_TDNNa_cfg = Config({
314
+ "name": "ResNet_TDNNa",
315
+ "model": "ResNet_TDNN",
316
+ "batch_size": 96,
317
+ "nOut": 192,
318
+ "nDataLoaderThread": 16,
319
+ })
320
+
321
+ ResNet_TDNNaam_cfg = Config({
322
+ "name": "ResNet_TDNNaam",
323
+ "model": "ResNet_TDNN",
324
+ "loss": "AamSoftmaxProto",
325
+ "margin": 0.2,
326
+ "scale": 30,
327
+ "batch_size": 96,
328
+ "nOut": 192,
329
+ "nDataLoaderThread": 16,
330
+ "augment": True,
331
+ })
332
+
333
+ TDNN_ResNet_cfg = Config({
334
+ "name": "TDNN_ResNet",
335
+ "model": "TDNN_ResNet",
336
+ "batch_size": 48,
337
+ "nOut": 256,
338
+ "nDataLoaderThread": 16,
339
+ })
340
+
341
+ hr_tdnn_cfg = Config({
342
+ "name": "hr_tdnn",
343
+ "model": "hr_tdnn",
344
+ "batch_size": 46,
345
+ "nOut": 192,
346
+ "nDataLoaderThread": 16,
347
+ })
348
+
349
+
350
+ ECAPA_TDNNma_cfg = Config({
351
+ "name": "ECAPA_TDNNma",
352
+ "model": "ECAPA_TDNN",
353
+ "batch_size": 180,
354
+ "nPerSpeaker": 2,
355
+ "nOut": 192,
356
+ "augment": True,
357
+ })
358
+
359
+ ECAPA_TDNNaam_cfg = Config({
360
+ "name": "ECAPA_TDNNaam",
361
+ "model": "ECAPA_TDNN",
362
+ "loss": "AamSoftmax",
363
+ "batch_size": 360,
364
+ "nPerSpeaker": 1,
365
+ "nOut": 192,
366
+ "augment": True,
367
+ })
368
+
369
+ ECAPA_TDNNaam1_cfg = Config({
370
+ "name": "ECAPA_TDNNaam1",
371
+ "model": "ECAPA_TDNN",
372
+ "loss": "AdditiveAngularMargin",
373
+ "batch_size": 360,
374
+ "nPerSpeaker": 1,
375
+ "nOut": 192,
376
+ "augment": True,
377
+ })
378
+
379
+ ECAPA_TDNNaam2_cfg = Config({
380
+ "name": "ECAPA_TDNNaam2",
381
+ "model": "ECAPA_TDNN",
382
+ "loss": "AamSoftmax",
383
+ "margin": 0.2,
384
+ "scale": 30,
385
+ "batch_size": 360,
386
+ "nPerSpeaker": 1,
387
+ "nOut": 192,
388
+ "augment": True,
389
+
390
+ })
391
+
392
+ ECAPA_TDNNaam3_cfg = Config({
393
+ "name": "ECAPA_TDNNaam3",
394
+ "model": "ECAPA_TDNN",
395
+ "loss": "AamSoftmax",
396
+ "margin": 0.1,
397
+ "scale": 30,
398
+ "batch_size": 360,
399
+ "nPerSpeaker": 1,
400
+ "nOut": 192,
401
+ "augment": True,
402
+
403
+ })
404
+
405
+ ECAPA_TDNN_aamproto_cfg = Config({
406
+ "name": "ECAPA_TDNN_aamproto",
407
+ "model": "ECAPA_TDNN",
408
+ "loss": "AamSoftmaxProto",
409
+ "batch_size": 180,
410
+ "nPerSpeaker": 2,
411
+ "nOut": 192,
412
+ "augment": True,
413
+ })
414
+
415
+ ECAPA_TDNN_aamproto1_cfg = Config({
416
+ "name": "ECAPA_TDNN_aamproto1",
417
+ "model": "ECAPA_TDNN",
418
+ "loss": "AamSoftmaxProto",
419
+ "margin": 0.2,
420
+ "scale": 30,
421
+ "batch_size": 180,
422
+ "nPerSpeaker": 2,
423
+ "nOut": 192,
424
+ "augment": True,
425
+ })
426
+
427
+ ECAPA_TDNN0_cfg = Config({
428
+ "name": "ECAPA_TDNN-1lr0.001",
429
+ "model": "ECAPA_TDNN",
430
+ "loss": "AamSoftmax",
431
+ "batch_size": 360,
432
+ "nOut": 192,
433
+ "nPerSpeaker": 1,
434
+ "resume": "train_models/speaker_verification_ECAPA_TDNN0/20210928/epoch:25,EER:2.4125,MinDCF:0.1537",
435
+ })
436
+
437
+ SwinTransformer_cfg = Config({
438
+ "name": "SwinTransformer",
439
+ "model": "SwinTransformer",
440
+ "loss": "SoftmaxProto",
441
+ "max_frames": 224,
442
+ "eval_frames": 224,
443
+ "n_mels": 224,
444
+ "batch_size": 90,
445
+ "nPerSpeaker": 2,
446
+ "nOut": 192,
447
+ "augment": True,
448
+ "lr": 5e-5,
449
+ })
450
+
451
+ ECAPA_TDNN_aampre_cfg = Config({
452
+ "name": "ECAPA_TDNN_aampre",
453
+ "model": "ECAPA_TDNN",
454
+ "loss": "AamSoftmaxProto",
455
+ "batch_size": 180,
456
+ "nOut": 192,
457
+ "nPerSpeaker": 2,
458
+ "resume": "train_models/speaker_verification_ECAPA_TDNNma/20210908/epoch:67,EER:2.3224,MinDCF:0.1658",
459
+ })
460
+
461
+ # 更换dataloader
462
+ ECAPA_TDNN_data_cfg = Config({
463
+ "name": "ECAPA_TDNN_data",
464
+ "model": "ECAPA_TDNN",
465
+ "loss": "AamSoftmax",
466
+ "batch_size": 360,
467
+ "nPerSpeaker": 1,
468
+ "nOut": 192,
469
+ "augment": True,
470
+
471
+ })
472
+
473
+ # 标准的ECAPA_TDNN 学习率CyclicLR
474
+ ECAPA_TDNNaam_cyclr_cfg = Config({
475
+ "name": "ECAPA_TDNNaam_cyclr",
476
+ "model": "ECAPA_TDNN",
477
+ "loss": "AamSoftmax",
478
+ "margin": 0.2,
479
+ "scale": 30,
480
+ "batch_size": 360,
481
+ "nPerSpeaker": 1,
482
+ "nOut": 192,
483
+ "augment": True,
484
+
485
+ })
486
+
487
+ # 跟换数据加载的ResNet_TDNN只用softmax
488
+ ResNet_TDNNaam_data_cfg = Config({
489
+ "name": "ResNet_TDNNaam_data",
490
+ "model": "ResNet_TDNN",
491
+ "loss": "AamSoftmax",
492
+ "margin": 0.2,
493
+ "scale": 30,
494
+ "batch_size": 192,
495
+ "nOut": 192,
496
+ "nDataLoaderThread": 16,
497
+ "nPerSpeaker": 1,
498
+ "augment": True,
499
+ })
500
+
501
+ # 更换dataloader, 和cyclical lr
502
+ ECAPA_TDNN_dataClr_cfg = Config({
503
+ "name": "ECAPA_TDNN_dataClr",
504
+ "model": "ECAPA_TDNN",
505
+ "loss": "AamSoftmax",
506
+ "batch_size": 360,
507
+ "nPerSpeaker": 1,
508
+ "nOut": 192,
509
+ "augment": True,
510
+ })
511
+
512
+ def set_cfg(config_name: str):
513
+ """ Sets the active configs. Works even if cfg is already imported! """
514
+ global cfg
515
+ # Note this is not just an eval because I'm lazy, but also because it can
516
+ # be used like ssd300_config.copy({'max_size': 400}) for extreme fine-tuning
517
+ cfg.replace(eval(config_name))
dataloader.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from DatasetLoader import AugmentWAV, loadWAV
4
+ import os
5
+ import numpy as np
6
+ import random
7
+
8
+
9
+ class TrainDataset(Dataset):
10
+ def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,):
11
+ self.train_list = train_list
12
+ self.max_frames = max_frames
13
+ self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames)
14
+ self.augment = augment
15
+ self.musan_path = musan_path
16
+ self.rir_path = rir_path
17
+
18
+ with open(train_list) as dataset_file:
19
+ lines = dataset_file.readlines()
20
+
21
+ dictkeys = list(set([x.split()[0] for x in lines]))
22
+ dictkeys.sort()
23
+ dictkeys = {key: ii for ii, key in enumerate(dictkeys)}
24
+
25
+ np.random.seed(100)
26
+ np.random.shuffle(lines)
27
+
28
+ self.data_list = []
29
+ self.data_label = []
30
+
31
+ for lidx, line in enumerate(lines):
32
+ data = line.strip().split()
33
+ speaker_label = dictkeys[data[0]]
34
+ filename = os.path.join(train_path, data[1])
35
+
36
+ self.data_list.append(filename)
37
+ self.data_label.append(speaker_label)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
42
+ if self.augment:
43
+ augtype = random.randint(0, 4) # 包括0,4
44
+ if augtype == 1:
45
+ audio = self.augment_wav.reverberate(audio)
46
+ elif augtype == 2:
47
+ audio = self.augment_wav.additive_noise('music', audio)
48
+ elif augtype == 3:
49
+ audio = self.augment_wav.additive_noise('speech', audio)
50
+ elif augtype == 4:
51
+ audio = self.augment_wav.additive_noise('noise', audio)
52
+
53
+ return torch.FloatTensor(audio), self.data_label[index]
54
+
55
+ def __len__(self):
56
+ return len(self.data_list)
57
+
58
+
59
+ if __name__ == "__main__":
60
+ train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True,
61
+ musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs",
62
+ max_frames=300, train_path="data/voxceleb2")
63
+ train_loader = torch.utils.data.DataLoader(
64
+ train_dataset,
65
+ batch_size=32,
66
+ pin_memory=False,
67
+ drop_last=True,
68
+ )
69
+ x, y = iter(train_loader).next()
70
+ print("x:", x.shape, "y:", y.shape)
71
+
72
+
73
+
74
+
75
+
example/.DS_Store ADDED
Binary file (6.15 kB). View file
 
example/speaker1-1.wav ADDED
Binary file (277 kB). View file
 
example/speaker1-2.wav ADDED
Binary file (247 kB). View file
 
example/speaker2-1.wav ADDED
Binary file (202 kB). View file
 
example/speaker2-2.wav ADDED
Binary file (169 kB). View file
 
example/speaker3-1.wav ADDED
Binary file (102 kB). View file
 
example/speaker3-2.wav ADDED
Binary file (112 kB). View file
 
example/speaker4-1.wav ADDED
Binary file (132 kB). View file
 
example/speaker4-2.wav ADDED
Binary file (415 kB). View file
 
example/speaker5-1.wav ADDED
Binary file (113 kB). View file
 
example/speaker5-2.wav ADDED
Binary file (120 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ wandb
train.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import optim
2
+ import argparse
3
+ from datetime import datetime
4
+ import wandb
5
+ import torch.backends.cudnn as cudnn
6
+ from torch import optim
7
+ from torch.utils.data import DataLoader
8
+ from torchinfo import summary
9
+ from timm.scheduler.cosine_lr import CosineLRScheduler
10
+
11
+ import lossfunction
12
+ import net
13
+ from DatasetLoader import *
14
+ from dataloader import TrainDataset
15
+ from SpeakerNet import *
16
+ from config import set_cfg, cfg
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+
22
+ parser.add_argument("--config_name", type=str, default="", help="the configs name that will as a base configs")
23
+ parser.add_argument("--project", default=None, type=str, help="project name")
24
+ parser.add_argument("--name", default=None, type=str, help="run name")
25
+ parser.add_argument("--save_dir", default=None, type=str, help="save path")
26
+ parser.add_argument("--resume", default=None, type=str, help="resume path")
27
+ parser.add_argument("--dataset", default=None, type=str, help="dataset path")
28
+
29
+ parser.add_argument("--epoch", default=None, type=int, help="max epoch")
30
+ parser.add_argument("--test_freq", default=None, type=int, help="frequency test epoch")
31
+ parser.add_argument("--batch_size", default=None, type=int, help="batch size")
32
+ parser.add_argument("--lr", default=None, type=float, help="learning rate")
33
+ parser.add_argument("--seed", default=None, type=int)
34
+
35
+ parser.add_argument("--wandb", action='store_true', default=False, help='use wandb to log ')
36
+ parser.add_argument("--note", type=str, default="", help='wandb note')
37
+
38
+ parser.add_argument('--eval', dest='eval', action='store_true', default=False, help='Eval only')
39
+ parser.add_argument('--score', dest='score', action='store_true', default=False, help='Eval only')
40
+
41
+ args = parser.parse_args()
42
+ return args
43
+
44
+
45
+ def main():
46
+ global cfg
47
+ args = get_args()
48
+ assert args.config_name is not None
49
+ if args.config_name:
50
+ set_cfg(args.config_name)
51
+ cfg.replace(vars(args))
52
+ del args
53
+
54
+ cfg.save_dir = os.path.join(cfg.save_dir, cfg.project + '_' + cfg.name, datetime.now().strftime('%Y%m%d'))
55
+ if not os.path.exists(cfg.save_dir):
56
+ os.makedirs(cfg.save_dir)
57
+
58
+ if cfg.wandb:
59
+ wandb.login(host="http://49.233.11.7:8080", key="local-7dc64cc63778f0723dc202d2624a97cef7043120")
60
+ wandb.init(project=cfg.project, name=cfg.name, config=cfg, save_code=True, notes=cfg.note)
61
+
62
+ # cudnn related setting
63
+ cudnn.benchmark = True
64
+ torch.backends.cudnn.deterministic = False
65
+ torch.backends.cudnn.enabled = True
66
+
67
+ start_epoch = 1
68
+
69
+ # ---------------模型---------------
70
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71
+ # device = torch.device("cpu")
72
+ # model = getattr(net, cfg.model)(cfg.nOut, cfg.encoder_type, cfg.log_input).to(device)
73
+ # ------ECAPA_TDNN.yaml------ResNet_TDNN----
74
+ model = getattr(net, cfg.model)().to(device)
75
+
76
+ # loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses, cfg.margin, cfg.scale).to(device)
77
+ # ----aamsoftmax----
78
+ loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device)
79
+
80
+ # model = SpeakerUnet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker, segment=cfg.segment)
81
+ model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker)
82
+ # swin
83
+ optimizer = optim.AdamW(model.parameters(), eps=1e-8, betas=(0.9, 0.999),
84
+ lr=cfg.lr, weight_decay=0.05)
85
+ # optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=0.000002)
86
+ # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70], gamma=0.1, last_epoch=-1)
87
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5,
88
+ threshold=0.001, threshold_mode='rel',
89
+ cooldown=0, min_lr=1e-5, eps=1e-08, verbose=True)
90
+ # scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr, max_lr=0.003, mode='triangular2',
91
+ # step_size_up=12000, cycle_momentum=False)
92
+
93
+
94
+ if cfg.resume:
95
+ # ckpt = torch.load(cfg.resume, map_location="cpu")
96
+ ckpt = torch.load(cfg.resume)
97
+ model.load_state_dict(ckpt['model_state_dict'], strict=False)
98
+ # optimizer.load_state_dict(ckpt['optimizer_state_dict'])
99
+ # scheduler.load_state_dict(ckpt['scheduler_state_dict'])
100
+ # start_epoch = ckpt['epoch'] + 1
101
+ print("checkpoint加载完毕!")
102
+ # print(model)
103
+
104
+ # test, eval, train
105
+ trainer = Trainer(cfg, model, optimizer, scheduler, device)
106
+
107
+ # ---------------score--------------
108
+ if cfg.score:
109
+ score_dir = os.path.join('score', cfg.name+"_"+datetime.now().strftime('%Y%m%d'))
110
+ if not os.path.exists(score_dir):
111
+ os.makedirs(score_dir)
112
+ score_file = os.path.join(score_dir, 'scores.txt')
113
+ trainer.scoretxt(score_file, 'data/voxsrc2021_blind.txt', 'data/voxsrc2021', cfg.eval_frames)
114
+ # trainer.scoretxt(score_file, cfg.dataset.test_list, cfg.dataset.test_path, cfg.eval_frames)
115
+ # ---------------eval--------------
116
+ elif cfg.eval:
117
+ trainer.test(0, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames)
118
+ else:
119
+ # ---------------训练--------------
120
+ train_dataset = train_dataset_loader(train_list=cfg.dataset.train_list,
121
+ augment=cfg.augment, musan_path=cfg.dataset.musan_path,
122
+ rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
123
+ segment=cfg.segment, train_path=cfg.dataset.train_path)
124
+
125
+ train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=cfg.nPerSpeaker,
126
+ max_seg_per_spk=cfg.max_seg_per_spk, batch_size=cfg.batch_size,
127
+ seed=cfg.seed)
128
+
129
+ # train_dataset = TrainDataset(train_list=cfg.dataset.train_list,
130
+ # augment=cfg.augment, musan_path=cfg.dataset.musan_path,
131
+ # rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
132
+ # train_path=cfg.dataset.train_path)
133
+
134
+ train_loader = torch.utils.data.DataLoader(
135
+ train_dataset,
136
+ batch_size=cfg.batch_size,
137
+ num_workers=cfg.nDataLoaderThread,
138
+ sampler=train_sampler,
139
+ pin_memory=False,
140
+ drop_last=True,
141
+ )
142
+
143
+ x, y = iter(train_loader).next()
144
+ print('x.shape:', x.shape, 'y.shape:', y.shape)
145
+ print('x.dtype:', x.dtype, 'y.dtype:', y.dtype)
146
+
147
+ summary(model, input_size=(tuple(x.shape)))
148
+
149
+ it = 0
150
+ min_eer = float("inf")
151
+ for epoch in range(start_epoch, cfg.max_epoch):
152
+ trainer.train(epoch, train_loader)
153
+ if epoch % cfg.test_interval == 0:
154
+ eer = trainer.test(epoch, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread,
155
+ cfg.eval_frames)
156
+ scheduler.step(eer)
157
+ # # -----Clr------
158
+ # if eer < min_eer:
159
+ # min_eer = eer
160
+ # it = 0
161
+ #
162
+ # else:
163
+ # it += 1
164
+ #
165
+ # if it >= 8:
166
+ # lr = cfg.lr * 0.1
167
+ # trainer.scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=cfg.lr,
168
+ # mode='triangular2',
169
+ # step_size_up=6000, cycle_momentum=False)
170
+ # it = 0
171
+ # # -----Clr------
172
+ trainer.save_model(epoch)
173
+
174
+ print("finishing")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
tuneThreshold.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn import metrics
2
+ import numpy
3
+ from operator import itemgetter
4
+
5
+
6
+ def tuneThresholdfromScore(scores, labels, target_fa, target_fr=None):
7
+ fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1)
8
+ fnr = 1 - tpr
9
+
10
+ tunedThreshold = []
11
+ if target_fr:
12
+ for tfr in target_fr:
13
+ idx = numpy.nanargmin(numpy.absolute((tfr - fnr)))
14
+ tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]])
15
+
16
+ for tfa in target_fa:
17
+ idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] nanargmin 返回轴上最小的值忽略Nans
18
+ tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]])
19
+
20
+ idxE = numpy.nanargmin(numpy.absolute((fnr - fpr)))
21
+ eer = max(fpr[idxE], fnr[idxE]) * 100
22
+
23
+ return tunedThreshold, eer, fpr, fnr
24
+
25
+ # Creates a list of false-negative rates, a list of false-positive rates
26
+ # and a list of decision thresholds that give those error-rates.
27
+ def ComputeErrorRates(scores, labels):
28
+ sorted_indexes, thresholds = zip(*sorted([(index, threshold) for index, threshold in enumerate(scores)],
29
+ key=itemgetter(1)))
30
+ labels = [labels[i] for i in sorted_indexes]
31
+ fnrs = [] # 负样本接受
32
+ fprs = [] # 正样本接受
33
+
34
+ for i in range(0, len(labels)):
35
+ if i == 0:
36
+ fnrs.append(labels[i])
37
+ fprs.append(1 - labels[i])
38
+ else:
39
+ fnrs.append(fnrs[i-1] + labels[i])
40
+ fprs.append(fprs[i-1] + 1 - labels[i])
41
+
42
+ fnrs_norm = sum(labels) # 真正样本个数
43
+ fprs_norm = len(labels) - fnrs_norm # 负样本个数
44
+ fnrs = [x / float(fnrs_norm) for x in fnrs] # 错误的拒绝 正样本分错的比例
45
+ fprs = [1 - x / float(fprs_norm) for x in fprs] # 错误接受 负样本分错的比例
46
+
47
+ return fnrs, fprs, thresholds
48
+
49
+ # Computes the minimum of the detection cost function. The comments refer to
50
+ # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
51
+ def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
52
+ min_c_det = float("inf")
53
+ min_c_det_threshold = thresholds[0]
54
+ for i in range(0, len(fnrs)):
55
+ c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
56
+ if c_det < min_c_det:
57
+ min_c_det = c_det
58
+ min_c_det_threshold = thresholds[i]
59
+
60
+ c_def = min(c_miss * p_target, c_fa * (1 - p_target))
61
+ min_dcf = min_c_det / c_def
62
+ return min_dcf, min_c_det_threshold