Spaces:
Runtime error
Runtime error
app.py
Browse files- SpeakerNet.py +283 -0
- app.py +77 -0
- config.py +517 -0
- dataloader.py +75 -0
- example/.DS_Store +0 -0
- example/speaker1-1.wav +0 -0
- example/speaker1-2.wav +0 -0
- example/speaker2-1.wav +0 -0
- example/speaker2-2.wav +0 -0
- example/speaker3-1.wav +0 -0
- example/speaker3-2.wav +0 -0
- example/speaker4-1.wav +0 -0
- example/speaker4-2.wav +0 -0
- example/speaker5-1.wav +0 -0
- example/speaker5-2.wav +0 -0
- requirements.txt +1 -0
- train.py +186 -0
- tuneThreshold.py +62 -0
SpeakerNet.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy, sys, random
|
6 |
+
from DatasetLoader import test_dataset_loader
|
7 |
+
import importlib
|
8 |
+
import time, itertools
|
9 |
+
from utils.log import init_log
|
10 |
+
from tqdm import tqdm
|
11 |
+
import wandb
|
12 |
+
from tuneThreshold import *
|
13 |
+
|
14 |
+
|
15 |
+
class SpeakerNet(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, model, trainfunc, nPerSpeaker):
|
18 |
+
super(SpeakerNet, self).__init__()
|
19 |
+
|
20 |
+
self.model = model
|
21 |
+
self.loss = trainfunc
|
22 |
+
self.nPerSpeaker = nPerSpeaker
|
23 |
+
|
24 |
+
def forward(self, data, label=None):
|
25 |
+
|
26 |
+
data = data.reshape(-1, data.size()[-1])
|
27 |
+
outp = self.model(data)
|
28 |
+
|
29 |
+
if label == None:
|
30 |
+
return outp
|
31 |
+
|
32 |
+
else:
|
33 |
+
|
34 |
+
emb = outp.reshape(-1, self.nPerSpeaker, outp.size()[-1]).squeeze(1)
|
35 |
+
nloss, prec1 = self.loss(emb, label)
|
36 |
+
|
37 |
+
|
38 |
+
return nloss, prec1
|
39 |
+
|
40 |
+
|
41 |
+
class Trainer(object):
|
42 |
+
|
43 |
+
def __init__(self, cfg, model, optimizer, scheduler, device):
|
44 |
+
self.cfg = cfg
|
45 |
+
self.model = model
|
46 |
+
self.optimizer = optimizer
|
47 |
+
self.scheduler = scheduler
|
48 |
+
self.device = device
|
49 |
+
logging = init_log(cfg.save_dir)
|
50 |
+
self._print = logging.info
|
51 |
+
self.best = 0
|
52 |
+
self.test_eer = 0
|
53 |
+
self.test_mindcf = 0
|
54 |
+
self.best_model = []
|
55 |
+
|
56 |
+
def train(self, epoch, dataloader):
|
57 |
+
self.model.train()
|
58 |
+
pbar = tqdm(dataloader)
|
59 |
+
loss = 0
|
60 |
+
top1 = 0
|
61 |
+
index = 0
|
62 |
+
counter = 0
|
63 |
+
|
64 |
+
for data in pbar:
|
65 |
+
x, label = data[0].to(self.device), data[1].long().to(self.device)
|
66 |
+
nloss, prec1 = self.model(x, label)
|
67 |
+
|
68 |
+
self.optimizer.zero_grad()
|
69 |
+
nloss.backward()
|
70 |
+
self.optimizer.step()
|
71 |
+
# self.scheduler.step()
|
72 |
+
|
73 |
+
loss += nloss.detach().cpu().item()
|
74 |
+
top1 += prec1.detach().cpu().item()
|
75 |
+
index += x.size(0)
|
76 |
+
counter += 1
|
77 |
+
|
78 |
+
if self.cfg.wandb:
|
79 |
+
wandb.log({
|
80 |
+
"epoch": epoch,
|
81 |
+
"train_acc": top1 / counter,
|
82 |
+
"train_loss": loss / counter,
|
83 |
+
})
|
84 |
+
pbar.set_description("Train Epoch:%3d ,Tloss:%.3f, Tacc:%.3f" % (epoch, loss/counter, top1/counter))
|
85 |
+
|
86 |
+
# self.scheduler.step()
|
87 |
+
self._print('epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format(
|
88 |
+
epoch, loss/counter, top1/counter, index))
|
89 |
+
|
90 |
+
def test(self, epoch, test_list, test_path, nDataLoaderThread, eval_frames, num_eval=10):
|
91 |
+
|
92 |
+
self.model.eval()
|
93 |
+
feats = {}
|
94 |
+
|
95 |
+
# read all lines
|
96 |
+
with open(test_list) as f:
|
97 |
+
lines = f.readlines()
|
98 |
+
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
|
99 |
+
setfiles = list(set(files))
|
100 |
+
setfiles.sort()
|
101 |
+
|
102 |
+
# Define test data loader
|
103 |
+
test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval)
|
104 |
+
|
105 |
+
test_loader = torch.utils.data.DataLoader(
|
106 |
+
test_dataset,
|
107 |
+
batch_size=1,
|
108 |
+
shuffle=False,
|
109 |
+
num_workers=nDataLoaderThread,
|
110 |
+
drop_last=False,
|
111 |
+
sampler=None
|
112 |
+
)
|
113 |
+
|
114 |
+
# Extract features for every wav
|
115 |
+
for idx, data in enumerate(tqdm(test_loader)):
|
116 |
+
|
117 |
+
inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav')
|
118 |
+
with torch.no_grad():
|
119 |
+
ref_feat = self.model(inp1).detach().cpu()
|
120 |
+
|
121 |
+
feats[data[1][0]] = ref_feat
|
122 |
+
|
123 |
+
all_scores = []
|
124 |
+
all_labels = []
|
125 |
+
all_trials = []
|
126 |
+
|
127 |
+
# Read files and compute all scores
|
128 |
+
for idx, line in enumerate(tqdm(lines)):
|
129 |
+
|
130 |
+
data = line.split()
|
131 |
+
|
132 |
+
# Append random label if missing
|
133 |
+
if len(data) == 2:
|
134 |
+
data = [random.randint(0, 1)] + data
|
135 |
+
ref_feat = feats[data[1]].to(self.device)
|
136 |
+
com_feat = feats[data[2]].to(self.device)
|
137 |
+
|
138 |
+
if self.model.loss.test_normalize:
|
139 |
+
ref_feat = F.normalize(ref_feat, p=2, dim=1)
|
140 |
+
com_feat = F.normalize(com_feat, p=2, dim=1)
|
141 |
+
|
142 |
+
# dist = F.pairwise_distance(ref_feat.unsqueeze(-1),
|
143 |
+
# com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
|
144 |
+
#
|
145 |
+
# score = -1 * numpy.mean(dist)
|
146 |
+
dist = F.cosine_similarity(ref_feat.unsqueeze(-1),
|
147 |
+
com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
|
148 |
+
score = numpy.mean(dist)
|
149 |
+
|
150 |
+
all_scores.append(score)
|
151 |
+
all_labels.append(int(data[0]))
|
152 |
+
all_trials.append(data[1] + " " + data[2])
|
153 |
+
|
154 |
+
result = tuneThresholdfromScore(all_scores, all_labels, [1, 0.1])
|
155 |
+
fnrs, fprs, thresholds = ComputeErrorRates(all_scores, all_labels)
|
156 |
+
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, self.cfg.dcf_p_target, self.cfg.dcf_c_miss, self.cfg.dcf_c_fa)
|
157 |
+
self.test_eer = result[1]
|
158 |
+
self.test_mindcf = mindcf
|
159 |
+
self.threshold = threshold
|
160 |
+
if self.cfg.wandb:
|
161 |
+
wandb.log({
|
162 |
+
"test_eer": self.test_eer,
|
163 |
+
"test_MinDCF": self.test_mindcf,
|
164 |
+
})
|
165 |
+
self._print('epoch:{} - test EER: {:.3f} and test MinDCF: {:.3f} total sample: {} threshold: {:.3f}'.format(
|
166 |
+
epoch, self.test_eer, self.test_mindcf, len(lines), self.threshold))
|
167 |
+
|
168 |
+
return self.test_eer
|
169 |
+
|
170 |
+
def save_model(self, epoch):
|
171 |
+
if self.test_eer < self.best or self.best == 0:
|
172 |
+
self.best = self.test_eer
|
173 |
+
if self.cfg.wandb:
|
174 |
+
wandb.run.summary["best_accuracy"] = self.best
|
175 |
+
model_state_dict = self.model.state_dict()
|
176 |
+
optimizer_state_dict = self.optimizer.state_dict()
|
177 |
+
scheduler_state_dict = self.scheduler.state_dict()
|
178 |
+
file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf)
|
179 |
+
if not os.path.exists(self.cfg.save_dir):
|
180 |
+
os.mkdir(self.cfg.save_dir)
|
181 |
+
torch.save({
|
182 |
+
'epoch': epoch,
|
183 |
+
'test_eer': self.test_eer,
|
184 |
+
'test_mindcf': self.test_mindcf,
|
185 |
+
'model_state_dict': model_state_dict,
|
186 |
+
'optimizer_state_dict': optimizer_state_dict,
|
187 |
+
'scheduler_state_dict': scheduler_state_dict},
|
188 |
+
os.path.join(self.cfg.save_dir, file_save_path))
|
189 |
+
self.best_model.append(file_save_path)
|
190 |
+
if len(self.best_model) > 3:
|
191 |
+
del_file = os.path.join(self.cfg.save_dir, self.best_model.pop(0))
|
192 |
+
if os.path.exists(del_file):
|
193 |
+
os.remove(del_file)
|
194 |
+
else:
|
195 |
+
print("no exists {}".format(del_file))
|
196 |
+
# 每20个epoch保存一下
|
197 |
+
if epoch % 20 == 0:
|
198 |
+
model_state_dict = self.model.state_dict()
|
199 |
+
optimizer_state_dict = self.optimizer.state_dict()
|
200 |
+
scheduler_state_dict = self.scheduler.state_dict()
|
201 |
+
file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf)
|
202 |
+
if not os.path.exists(self.cfg.save_dir):
|
203 |
+
os.mkdir(self.cfg.save_dir)
|
204 |
+
if not os.path.exists(os.path.join(self.cfg.save_dir, file_save_path)):
|
205 |
+
torch.save({
|
206 |
+
'epoch': epoch,
|
207 |
+
'test_eee': self.test_eer,
|
208 |
+
'test_mindcf': self.test_mindcf,
|
209 |
+
'model_state_dict': model_state_dict,
|
210 |
+
'optimizer_state_dict': optimizer_state_dict,
|
211 |
+
'scheduler_state_dict': scheduler_state_dict},
|
212 |
+
os.path.join(self.cfg.save_dir, file_save_path))
|
213 |
+
|
214 |
+
def scoretxt(self, score_file, test_list, test_path, eval_frames, num_eval=10):
|
215 |
+
|
216 |
+
self.model.eval()
|
217 |
+
feats = {}
|
218 |
+
|
219 |
+
# read all lines
|
220 |
+
with open(test_list) as f:
|
221 |
+
lines = f.readlines()
|
222 |
+
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
|
223 |
+
setfiles = list(set(files))
|
224 |
+
setfiles.sort()
|
225 |
+
|
226 |
+
# Define test data loader
|
227 |
+
test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval)
|
228 |
+
|
229 |
+
test_loader = torch.utils.data.DataLoader(
|
230 |
+
test_dataset,
|
231 |
+
batch_size=1,
|
232 |
+
shuffle=False,
|
233 |
+
drop_last=False,
|
234 |
+
sampler=None
|
235 |
+
)
|
236 |
+
|
237 |
+
# Extract features for every wav
|
238 |
+
for idx, data in enumerate(tqdm(test_loader)):
|
239 |
+
|
240 |
+
inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav')
|
241 |
+
with torch.no_grad():
|
242 |
+
ref_feat = self.model(inp1).detach().cpu()
|
243 |
+
|
244 |
+
feats[data[1][0]] = ref_feat
|
245 |
+
|
246 |
+
|
247 |
+
f = open(score_file, "w")
|
248 |
+
# Read files and compute all scores
|
249 |
+
for idx, line in enumerate(tqdm(lines)):
|
250 |
+
|
251 |
+
data = line.split()
|
252 |
+
|
253 |
+
# Append random label if missing
|
254 |
+
ref_feat = feats[data[-2]].to(self.device)
|
255 |
+
com_feat = feats[data[-1]].to(self.device)
|
256 |
+
|
257 |
+
if self.model.loss.test_normalize:
|
258 |
+
ref_feat = F.normalize(ref_feat, p=2, dim=1)
|
259 |
+
com_feat = F.normalize(com_feat, p=2, dim=1)
|
260 |
+
|
261 |
+
# dist = F.pairwise_distance(ref_feat.unsqueeze(-1),
|
262 |
+
# com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
|
263 |
+
#
|
264 |
+
# score = -1 * numpy.mean(dist)
|
265 |
+
dist = F.cosine_similarity(ref_feat.unsqueeze(-1),
|
266 |
+
com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy()
|
267 |
+
score = numpy.mean(dist)
|
268 |
+
|
269 |
+
score_line = str(score) + " " + data[-2] + " " + data[-1]
|
270 |
+
f.write(score_line+'\n')
|
271 |
+
f.close()
|
272 |
+
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
|
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import net
|
4 |
+
import argparse
|
5 |
+
from config import set_cfg, cfg
|
6 |
+
from SpeakerNet import *
|
7 |
+
import lossfunction
|
8 |
+
from DatasetLoader import loadWAV
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--config_name", type=str, default="ECAPA_TDNN_data_cfg", help="the configs name that will as a base configs")
|
12 |
+
parser.add_argument("--resume", default="train_models/epoch_37_ECAPA_TDNN2.48", type=str, help="resume path")
|
13 |
+
args = parser.parse_args()
|
14 |
+
global cfg
|
15 |
+
assert args.config_name is not None
|
16 |
+
if args.config_name:
|
17 |
+
set_cfg(args.config_name)
|
18 |
+
cfg.replace(vars(args))
|
19 |
+
del args
|
20 |
+
|
21 |
+
device = torch.device("cpu")
|
22 |
+
model = getattr(net, cfg.model)().to(device)
|
23 |
+
loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device)
|
24 |
+
model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker)
|
25 |
+
|
26 |
+
ckpt = torch.load("train_models/epoch_37_ECAPA_TDNN2.48", map_location="cpu")
|
27 |
+
model.load_state_dict(ckpt['model_state_dict'], strict=False)
|
28 |
+
print("checkpoint加载完毕!")
|
29 |
+
|
30 |
+
model.eval()
|
31 |
+
|
32 |
+
def SpeakerVerification(path1,path2):
|
33 |
+
inp1 = loadWAV(path1, max_frames=300, evalmode=True)
|
34 |
+
inp2 = loadWAV(path2, max_frames=300, evalmode=True)
|
35 |
+
# print(inp1.shape)
|
36 |
+
inp1 = torch.FloatTensor(inp1)
|
37 |
+
inp2 = torch.FloatTensor(inp2)
|
38 |
+
# print(inp1.shape)
|
39 |
+
with torch.no_grad():
|
40 |
+
emb1 = model(inp1).detach().cpu()
|
41 |
+
emb2 = model(inp2).detach().cpu()
|
42 |
+
emb1 = F.normalize(emb1, p=2, dim=1)
|
43 |
+
emb2 = F.normalize(emb2, p=2, dim=1)
|
44 |
+
dist = F.cosine_similarity(emb1.unsqueeze(-1), emb2.unsqueeze(-1).transpose(0, 2)).numpy()
|
45 |
+
score = numpy.mean(dist)
|
46 |
+
print(score)
|
47 |
+
# threshold = 0.414
|
48 |
+
if score >= 0.414:
|
49 |
+
output = "同一个人"
|
50 |
+
else:
|
51 |
+
output = "不同的人"
|
52 |
+
|
53 |
+
return output
|
54 |
+
|
55 |
+
inputs = [
|
56 |
+
gr.inputs.Audio(source="upload", type="filepath", label="Speaker #1", optional=False),
|
57 |
+
gr.inputs.Audio(source="upload", type="filepath", label="Speaker #2", optional=False)
|
58 |
+
]
|
59 |
+
|
60 |
+
|
61 |
+
examples = [["example/speaker1-1.wav", "example/speaker1-2.wav"],
|
62 |
+
["example/speaker1-1.wav", "example/speaker2-1.wav"],
|
63 |
+
["example/speaker2-1.wav", "example/speaker2-2.wav"],
|
64 |
+
["example/speaker1-2.wav", "example/speaker2-2.wav"],
|
65 |
+
["example/speaker3-1.wav", "example/speaker3-2.wav"],
|
66 |
+
["example/speaker3-1.wav", "example/speaker4-1.wav"],
|
67 |
+
["example/speaker4-1.wav", "example/speaker4-2.wav"],
|
68 |
+
["example/speaker3-2.wav", "example/speaker4-2.wav"],
|
69 |
+
["example/speaker4-1.wav", "example/speaker5-2.wav"],
|
70 |
+
]
|
71 |
+
|
72 |
+
iface = gr.Interface(fn=SpeakerVerification, inputs=inputs, outputs="text", examples=examples)
|
73 |
+
iface.launch(share=True)
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
# print(SpeakerVerification("example/speaker1-1.wav", "example/speaker1-2.wav"))
|
77 |
+
pass
|
config.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Config(object):
|
2 |
+
def __init__(self, config_dict: dict):
|
3 |
+
for key, val in config_dict.items():
|
4 |
+
if val is not None:
|
5 |
+
self.__setattr__(key, val)
|
6 |
+
|
7 |
+
def copy(self, new_config_dict={}):
|
8 |
+
ret = Config(vars(self))
|
9 |
+
for key, val in new_config_dict.items():
|
10 |
+
if val is not None:
|
11 |
+
ret.__setattr__(key, val)
|
12 |
+
return ret
|
13 |
+
|
14 |
+
def replace(self, new_config_dict):
|
15 |
+
if isinstance(new_config_dict, Config):
|
16 |
+
new_config_dict = vars(new_config_dict)
|
17 |
+
|
18 |
+
for key, val in new_config_dict.items():
|
19 |
+
if val is not None:
|
20 |
+
self.__setattr__(key, val)
|
21 |
+
|
22 |
+
def print(self):
|
23 |
+
for k, v in vars(self).items():
|
24 |
+
print(k, '=', v)
|
25 |
+
|
26 |
+
# def parser_val(self, val):
|
27 |
+
# if isinstance(val, dict):
|
28 |
+
# return Config(val)
|
29 |
+
# elif isinstance(val, list):
|
30 |
+
# for i in range(len(val)):
|
31 |
+
# if val is not None:
|
32 |
+
# val[i] = self.parser_val(val[i])
|
33 |
+
# return val
|
34 |
+
# else:
|
35 |
+
# return val
|
36 |
+
|
37 |
+
def __str__(self):
|
38 |
+
return str(vars(self))
|
39 |
+
|
40 |
+
|
41 |
+
base_config = Config({
|
42 |
+
"project": "speaker_verification",
|
43 |
+
"name": "VGGVox",
|
44 |
+
"save_dir": "train_models/",
|
45 |
+
"resume": "",
|
46 |
+
|
47 |
+
# Training and test data
|
48 |
+
"dataset": Config({
|
49 |
+
"name": "voxceleb2_wav",
|
50 |
+
"train_list": "data/train_list.txt",
|
51 |
+
"test_list": "data/veri_list.txt",
|
52 |
+
"train_path": "data/voxceleb2",
|
53 |
+
"test_path": "data/voxceleb1",
|
54 |
+
"musan_path": "data/musan_split", # 噪声文件
|
55 |
+
"rir_path": "data/RIRS_NOISES/simulated_rirs", # 混响文件
|
56 |
+
}),
|
57 |
+
|
58 |
+
|
59 |
+
# Data loader
|
60 |
+
"max_frames": 300, # 训练时帧长
|
61 |
+
"eval_frames": 300,
|
62 |
+
"batch_size": 64,
|
63 |
+
"max_seg_per_spk": 500, # 每个说话人最大的语音段数
|
64 |
+
"nDataLoaderThread": 16, # 多线程加载
|
65 |
+
"augment": True, # 是否数据增强
|
66 |
+
"seed": 10,
|
67 |
+
"segment": 1,
|
68 |
+
|
69 |
+
# Training details
|
70 |
+
"test_interval": 1, # 测试间隔
|
71 |
+
"max_epoch": 500,
|
72 |
+
|
73 |
+
# Model definition
|
74 |
+
"n_mels": 40,
|
75 |
+
"log_input": False,
|
76 |
+
"model": "Vgg",
|
77 |
+
"encoder_type": "SAP",
|
78 |
+
"nOut": 512,
|
79 |
+
|
80 |
+
# Loss functions
|
81 |
+
"loss": "SoftmaxProto", # lossfunction function
|
82 |
+
"hard_prob": 0.5,
|
83 |
+
"hard_rank": 10,
|
84 |
+
"margin": 0.2,
|
85 |
+
"scale": 30,
|
86 |
+
"nPerSpeaker": 2, # 同一段语音取多少组
|
87 |
+
"nClasses": 5994,
|
88 |
+
|
89 |
+
# Optimizer
|
90 |
+
"optimizer": "adam",
|
91 |
+
"scheduler": "steplr",
|
92 |
+
"lr": 0.001,
|
93 |
+
"lr_decay": 0.95,
|
94 |
+
"weight_decay": 0,
|
95 |
+
|
96 |
+
# Evaluation parameters
|
97 |
+
"dcf_p_target": 0.05,
|
98 |
+
"dcf_c_miss": 1,
|
99 |
+
"dcf_c_fa": 1,
|
100 |
+
|
101 |
+
# eval
|
102 |
+
"eval": False,
|
103 |
+
})
|
104 |
+
|
105 |
+
cfg = base_config
|
106 |
+
|
107 |
+
vgg_cfg = Config({
|
108 |
+
"name": "vgg_spectrogram",
|
109 |
+
"model": "vgg",
|
110 |
+
"batch_size": 64,
|
111 |
+
"nPerSpeaker": 2,
|
112 |
+
})
|
113 |
+
|
114 |
+
Unet_cfg = Config({
|
115 |
+
"name": "Unet",
|
116 |
+
"model": "UNetVgg",
|
117 |
+
"batch_size": 48,
|
118 |
+
"nPerSpeaker": 2,
|
119 |
+
"loss": "Unetloss"
|
120 |
+
})
|
121 |
+
|
122 |
+
UnetMask_cfg = Config({
|
123 |
+
"name": "UnetMask",
|
124 |
+
"model": "UNetVggMask",
|
125 |
+
"batch_size": 16,
|
126 |
+
"nPerSpeaker": 2,
|
127 |
+
"segment": 3,
|
128 |
+
"loss": "UnetMaskloss"
|
129 |
+
})
|
130 |
+
|
131 |
+
ECAPA_TDNN_cfg = Config({
|
132 |
+
"name": "ECAPA_TDNNm",
|
133 |
+
"model": "ECAPA_TDNN",
|
134 |
+
"loss": "AamSoftmaxProto",
|
135 |
+
"batch_size": 180,
|
136 |
+
"nPerSpeaker": 2,
|
137 |
+
"nOut": 192,
|
138 |
+
})
|
139 |
+
|
140 |
+
ECAPA_TDNNm_cfg = Config({
|
141 |
+
"name": "ECAPA_TDNNm",
|
142 |
+
"model": "ECAPA_TDNN",
|
143 |
+
"batch_size": 180,
|
144 |
+
"nPerSpeaker": 2,
|
145 |
+
"nOut": 192,
|
146 |
+
})
|
147 |
+
|
148 |
+
ECAPA_TDNN1024_cfg = Config({
|
149 |
+
"name": "ECAPA_TDNN1024",
|
150 |
+
"model": "ECAPA_TDNN",
|
151 |
+
"batch_size": 80,
|
152 |
+
"nPerSpeaker": 2,
|
153 |
+
"channels": 1024,
|
154 |
+
"nOut": 192,
|
155 |
+
})
|
156 |
+
|
157 |
+
ECAPA_TDNN_ks5_cfg = Config({
|
158 |
+
"name": "ECAPA_TDNN_ks5",
|
159 |
+
"model": "ECAPA_TDNN_ks5",
|
160 |
+
"batch_size": 180,
|
161 |
+
"nPerSpeaker": 2,
|
162 |
+
"nOut": 192,
|
163 |
+
})
|
164 |
+
|
165 |
+
ECAPA_TDNN_L2_cfg = Config({
|
166 |
+
"name": "ECAPA_TDNN_L2_pre",
|
167 |
+
"model": "ECAPA_TDNN_L2",
|
168 |
+
"batch_size": 180,
|
169 |
+
"nPerSpeaker": 2,
|
170 |
+
"nOut": 192,
|
171 |
+
"resume": "train_models/speaker_verification_ECAPA_TDNN/20210725/epoch:47,EER:2.5981,MinDCF:0.1912"
|
172 |
+
})
|
173 |
+
|
174 |
+
ECAPA_TDNN_br_cfg = Config({
|
175 |
+
"name": "ECAPA_TDNN_br",
|
176 |
+
"model": "ECAPA_TDNN_br",
|
177 |
+
"batch_size": 180,
|
178 |
+
"nPerSpeaker": 2,
|
179 |
+
"nOut": 192,
|
180 |
+
})
|
181 |
+
|
182 |
+
ECAPATDNN_cfg = Config({
|
183 |
+
"name": "ECAPATDNN",
|
184 |
+
"model": "ECAPATDNN",
|
185 |
+
"batch_size": 110,
|
186 |
+
"nPerSpeaker": 2,
|
187 |
+
"nOut": 192,
|
188 |
+
"input_size": 80,
|
189 |
+
})
|
190 |
+
|
191 |
+
HRNet_cfg = Config({
|
192 |
+
"name": "hrnet",
|
193 |
+
"model": "hrnet",
|
194 |
+
"max_frames": 224,
|
195 |
+
"eval_frames": 224,
|
196 |
+
"batch_size": 48,
|
197 |
+
"nPerSpeaker": 2,
|
198 |
+
"nOut": 1024,
|
199 |
+
"input_size": 224*224,
|
200 |
+
|
201 |
+
"model_cfg": Config({
|
202 |
+
"hrnet_name": "w48",
|
203 |
+
"STAGE1": {
|
204 |
+
"NUM_MODULES": 1,
|
205 |
+
"NUM_BRANCHES": 1,
|
206 |
+
"BLOCK": "BOTTLENECK",
|
207 |
+
"NUM_BLOCKS": [4],
|
208 |
+
"NUM_CHANNELS": [64],
|
209 |
+
"FUSE_METHOD": "SUM"
|
210 |
+
},
|
211 |
+
"STAGE2": {
|
212 |
+
"NUM_MODULES": 1,
|
213 |
+
"NUM_BRANCHES": 2,
|
214 |
+
"BLOCK": "BASIC",
|
215 |
+
"NUM_BLOCKS": [4, 4],
|
216 |
+
"NUM_CHANNELS": [18, 36],
|
217 |
+
"FUSE_METHOD": "SUM"
|
218 |
+
},
|
219 |
+
"STAGE3": {
|
220 |
+
"NUM_MODULES": 4,
|
221 |
+
"NUM_BRANCHES": 3,
|
222 |
+
"BLOCK": "BASIC",
|
223 |
+
"NUM_BLOCKS": [4, 4, 4],
|
224 |
+
"NUM_CHANNELS": [18, 36, 72],
|
225 |
+
"FUSE_METHOD": "SUM"
|
226 |
+
},
|
227 |
+
"STAGE4": {
|
228 |
+
"NUM_MODULES": 3,
|
229 |
+
"NUM_BRANCHES": 4,
|
230 |
+
"BLOCK": "BASIC",
|
231 |
+
"NUM_BLOCKS": [4, 4, 4, 4],
|
232 |
+
"NUM_CHANNELS": [18, 36, 72, 144],
|
233 |
+
"FUSE_METHOD": "SUM"
|
234 |
+
},
|
235 |
+
}),
|
236 |
+
|
237 |
+
})
|
238 |
+
|
239 |
+
VGG_TDNN_cfg = Config({
|
240 |
+
"name": "Vggtdnn1",
|
241 |
+
"model": "Vggtdnn",
|
242 |
+
"batch_size": 256,
|
243 |
+
"nOut": 512,
|
244 |
+
"nDataLoaderThread": 16,
|
245 |
+
})
|
246 |
+
|
247 |
+
ResNetSE34V2_cfg = Config({
|
248 |
+
"name": "ResNetSE34V2",
|
249 |
+
"model": "ResNetSE34V2",
|
250 |
+
"batch_size": 128,
|
251 |
+
"nOut": 512,
|
252 |
+
"nDataLoaderThread": 16,
|
253 |
+
})
|
254 |
+
|
255 |
+
HRTDNN_cfg = Config({
|
256 |
+
"name": "hrtdnn",
|
257 |
+
"model": "hrtdnn",
|
258 |
+
"max_frames": 300,
|
259 |
+
"eval_frames": 300,
|
260 |
+
"batch_size": 96,
|
261 |
+
"nPerSpeaker": 2,
|
262 |
+
"nOut": 256,
|
263 |
+
|
264 |
+
"model_cfg": Config({
|
265 |
+
"hrnet_name": "hrtdnn",
|
266 |
+
"STAGE1": {
|
267 |
+
"NUM_BRANCHES": 1,
|
268 |
+
"BLOCK": 'TDNNBlock',
|
269 |
+
"NUM_CHANNELS": [128],
|
270 |
+
"FUSE_METHOD": "SUM"
|
271 |
+
},
|
272 |
+
"STAGE2": {
|
273 |
+
"NUM_BRANCHES": 2,
|
274 |
+
"BLOCK": 'TDNNBlock',
|
275 |
+
"NUM_CHANNELS": [128, 512],
|
276 |
+
"FUSE_METHOD": "SUM"
|
277 |
+
},
|
278 |
+
"STAGE3": {
|
279 |
+
"NUM_BRANCHES": 3,
|
280 |
+
"BLOCK": 'TDNNBlock',
|
281 |
+
"NUM_CHANNELS": [128, 512, 1024],
|
282 |
+
"FUSE_METHOD": "SUM"
|
283 |
+
},
|
284 |
+
|
285 |
+
}),
|
286 |
+
|
287 |
+
})
|
288 |
+
|
289 |
+
ResTDNN_cfg = Config({
|
290 |
+
"name": "ResTDNN",
|
291 |
+
"model": "ResTDNN",
|
292 |
+
"batch_size": 110,
|
293 |
+
"nOut": 256,
|
294 |
+
"nDataLoaderThread": 16,
|
295 |
+
})
|
296 |
+
|
297 |
+
TDNN_VGG_cfg = Config({
|
298 |
+
"name": "TDNN_VGG",
|
299 |
+
"model": "TDNN_VGG",
|
300 |
+
"batch_size": 64,
|
301 |
+
"nOut": 256,
|
302 |
+
"nDataLoaderThread": 16,
|
303 |
+
})
|
304 |
+
|
305 |
+
ResNet_TDNN_cfg = Config({
|
306 |
+
"name": "ResNet_TDNN",
|
307 |
+
"model": "ResNet_TDNN",
|
308 |
+
"batch_size": 96,
|
309 |
+
"nOut": 192,
|
310 |
+
"nDataLoaderThread": 16,
|
311 |
+
})
|
312 |
+
|
313 |
+
ResNet_TDNNa_cfg = Config({
|
314 |
+
"name": "ResNet_TDNNa",
|
315 |
+
"model": "ResNet_TDNN",
|
316 |
+
"batch_size": 96,
|
317 |
+
"nOut": 192,
|
318 |
+
"nDataLoaderThread": 16,
|
319 |
+
})
|
320 |
+
|
321 |
+
ResNet_TDNNaam_cfg = Config({
|
322 |
+
"name": "ResNet_TDNNaam",
|
323 |
+
"model": "ResNet_TDNN",
|
324 |
+
"loss": "AamSoftmaxProto",
|
325 |
+
"margin": 0.2,
|
326 |
+
"scale": 30,
|
327 |
+
"batch_size": 96,
|
328 |
+
"nOut": 192,
|
329 |
+
"nDataLoaderThread": 16,
|
330 |
+
"augment": True,
|
331 |
+
})
|
332 |
+
|
333 |
+
TDNN_ResNet_cfg = Config({
|
334 |
+
"name": "TDNN_ResNet",
|
335 |
+
"model": "TDNN_ResNet",
|
336 |
+
"batch_size": 48,
|
337 |
+
"nOut": 256,
|
338 |
+
"nDataLoaderThread": 16,
|
339 |
+
})
|
340 |
+
|
341 |
+
hr_tdnn_cfg = Config({
|
342 |
+
"name": "hr_tdnn",
|
343 |
+
"model": "hr_tdnn",
|
344 |
+
"batch_size": 46,
|
345 |
+
"nOut": 192,
|
346 |
+
"nDataLoaderThread": 16,
|
347 |
+
})
|
348 |
+
|
349 |
+
|
350 |
+
ECAPA_TDNNma_cfg = Config({
|
351 |
+
"name": "ECAPA_TDNNma",
|
352 |
+
"model": "ECAPA_TDNN",
|
353 |
+
"batch_size": 180,
|
354 |
+
"nPerSpeaker": 2,
|
355 |
+
"nOut": 192,
|
356 |
+
"augment": True,
|
357 |
+
})
|
358 |
+
|
359 |
+
ECAPA_TDNNaam_cfg = Config({
|
360 |
+
"name": "ECAPA_TDNNaam",
|
361 |
+
"model": "ECAPA_TDNN",
|
362 |
+
"loss": "AamSoftmax",
|
363 |
+
"batch_size": 360,
|
364 |
+
"nPerSpeaker": 1,
|
365 |
+
"nOut": 192,
|
366 |
+
"augment": True,
|
367 |
+
})
|
368 |
+
|
369 |
+
ECAPA_TDNNaam1_cfg = Config({
|
370 |
+
"name": "ECAPA_TDNNaam1",
|
371 |
+
"model": "ECAPA_TDNN",
|
372 |
+
"loss": "AdditiveAngularMargin",
|
373 |
+
"batch_size": 360,
|
374 |
+
"nPerSpeaker": 1,
|
375 |
+
"nOut": 192,
|
376 |
+
"augment": True,
|
377 |
+
})
|
378 |
+
|
379 |
+
ECAPA_TDNNaam2_cfg = Config({
|
380 |
+
"name": "ECAPA_TDNNaam2",
|
381 |
+
"model": "ECAPA_TDNN",
|
382 |
+
"loss": "AamSoftmax",
|
383 |
+
"margin": 0.2,
|
384 |
+
"scale": 30,
|
385 |
+
"batch_size": 360,
|
386 |
+
"nPerSpeaker": 1,
|
387 |
+
"nOut": 192,
|
388 |
+
"augment": True,
|
389 |
+
|
390 |
+
})
|
391 |
+
|
392 |
+
ECAPA_TDNNaam3_cfg = Config({
|
393 |
+
"name": "ECAPA_TDNNaam3",
|
394 |
+
"model": "ECAPA_TDNN",
|
395 |
+
"loss": "AamSoftmax",
|
396 |
+
"margin": 0.1,
|
397 |
+
"scale": 30,
|
398 |
+
"batch_size": 360,
|
399 |
+
"nPerSpeaker": 1,
|
400 |
+
"nOut": 192,
|
401 |
+
"augment": True,
|
402 |
+
|
403 |
+
})
|
404 |
+
|
405 |
+
ECAPA_TDNN_aamproto_cfg = Config({
|
406 |
+
"name": "ECAPA_TDNN_aamproto",
|
407 |
+
"model": "ECAPA_TDNN",
|
408 |
+
"loss": "AamSoftmaxProto",
|
409 |
+
"batch_size": 180,
|
410 |
+
"nPerSpeaker": 2,
|
411 |
+
"nOut": 192,
|
412 |
+
"augment": True,
|
413 |
+
})
|
414 |
+
|
415 |
+
ECAPA_TDNN_aamproto1_cfg = Config({
|
416 |
+
"name": "ECAPA_TDNN_aamproto1",
|
417 |
+
"model": "ECAPA_TDNN",
|
418 |
+
"loss": "AamSoftmaxProto",
|
419 |
+
"margin": 0.2,
|
420 |
+
"scale": 30,
|
421 |
+
"batch_size": 180,
|
422 |
+
"nPerSpeaker": 2,
|
423 |
+
"nOut": 192,
|
424 |
+
"augment": True,
|
425 |
+
})
|
426 |
+
|
427 |
+
ECAPA_TDNN0_cfg = Config({
|
428 |
+
"name": "ECAPA_TDNN-1lr0.001",
|
429 |
+
"model": "ECAPA_TDNN",
|
430 |
+
"loss": "AamSoftmax",
|
431 |
+
"batch_size": 360,
|
432 |
+
"nOut": 192,
|
433 |
+
"nPerSpeaker": 1,
|
434 |
+
"resume": "train_models/speaker_verification_ECAPA_TDNN0/20210928/epoch:25,EER:2.4125,MinDCF:0.1537",
|
435 |
+
})
|
436 |
+
|
437 |
+
SwinTransformer_cfg = Config({
|
438 |
+
"name": "SwinTransformer",
|
439 |
+
"model": "SwinTransformer",
|
440 |
+
"loss": "SoftmaxProto",
|
441 |
+
"max_frames": 224,
|
442 |
+
"eval_frames": 224,
|
443 |
+
"n_mels": 224,
|
444 |
+
"batch_size": 90,
|
445 |
+
"nPerSpeaker": 2,
|
446 |
+
"nOut": 192,
|
447 |
+
"augment": True,
|
448 |
+
"lr": 5e-5,
|
449 |
+
})
|
450 |
+
|
451 |
+
ECAPA_TDNN_aampre_cfg = Config({
|
452 |
+
"name": "ECAPA_TDNN_aampre",
|
453 |
+
"model": "ECAPA_TDNN",
|
454 |
+
"loss": "AamSoftmaxProto",
|
455 |
+
"batch_size": 180,
|
456 |
+
"nOut": 192,
|
457 |
+
"nPerSpeaker": 2,
|
458 |
+
"resume": "train_models/speaker_verification_ECAPA_TDNNma/20210908/epoch:67,EER:2.3224,MinDCF:0.1658",
|
459 |
+
})
|
460 |
+
|
461 |
+
# 更换dataloader
|
462 |
+
ECAPA_TDNN_data_cfg = Config({
|
463 |
+
"name": "ECAPA_TDNN_data",
|
464 |
+
"model": "ECAPA_TDNN",
|
465 |
+
"loss": "AamSoftmax",
|
466 |
+
"batch_size": 360,
|
467 |
+
"nPerSpeaker": 1,
|
468 |
+
"nOut": 192,
|
469 |
+
"augment": True,
|
470 |
+
|
471 |
+
})
|
472 |
+
|
473 |
+
# 标准的ECAPA_TDNN 学习率CyclicLR
|
474 |
+
ECAPA_TDNNaam_cyclr_cfg = Config({
|
475 |
+
"name": "ECAPA_TDNNaam_cyclr",
|
476 |
+
"model": "ECAPA_TDNN",
|
477 |
+
"loss": "AamSoftmax",
|
478 |
+
"margin": 0.2,
|
479 |
+
"scale": 30,
|
480 |
+
"batch_size": 360,
|
481 |
+
"nPerSpeaker": 1,
|
482 |
+
"nOut": 192,
|
483 |
+
"augment": True,
|
484 |
+
|
485 |
+
})
|
486 |
+
|
487 |
+
# 跟换数据加载的ResNet_TDNN只用softmax
|
488 |
+
ResNet_TDNNaam_data_cfg = Config({
|
489 |
+
"name": "ResNet_TDNNaam_data",
|
490 |
+
"model": "ResNet_TDNN",
|
491 |
+
"loss": "AamSoftmax",
|
492 |
+
"margin": 0.2,
|
493 |
+
"scale": 30,
|
494 |
+
"batch_size": 192,
|
495 |
+
"nOut": 192,
|
496 |
+
"nDataLoaderThread": 16,
|
497 |
+
"nPerSpeaker": 1,
|
498 |
+
"augment": True,
|
499 |
+
})
|
500 |
+
|
501 |
+
# 更换dataloader, 和cyclical lr
|
502 |
+
ECAPA_TDNN_dataClr_cfg = Config({
|
503 |
+
"name": "ECAPA_TDNN_dataClr",
|
504 |
+
"model": "ECAPA_TDNN",
|
505 |
+
"loss": "AamSoftmax",
|
506 |
+
"batch_size": 360,
|
507 |
+
"nPerSpeaker": 1,
|
508 |
+
"nOut": 192,
|
509 |
+
"augment": True,
|
510 |
+
})
|
511 |
+
|
512 |
+
def set_cfg(config_name: str):
|
513 |
+
""" Sets the active configs. Works even if cfg is already imported! """
|
514 |
+
global cfg
|
515 |
+
# Note this is not just an eval because I'm lazy, but also because it can
|
516 |
+
# be used like ssd300_config.copy({'max_size': 400}) for extreme fine-tuning
|
517 |
+
cfg.replace(eval(config_name))
|
dataloader.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from DatasetLoader import AugmentWAV, loadWAV
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
|
8 |
+
|
9 |
+
class TrainDataset(Dataset):
|
10 |
+
def __init__(self, train_list, train_path, augment, musan_path, rir_path, max_frames,):
|
11 |
+
self.train_list = train_list
|
12 |
+
self.max_frames = max_frames
|
13 |
+
self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames=max_frames)
|
14 |
+
self.augment = augment
|
15 |
+
self.musan_path = musan_path
|
16 |
+
self.rir_path = rir_path
|
17 |
+
|
18 |
+
with open(train_list) as dataset_file:
|
19 |
+
lines = dataset_file.readlines()
|
20 |
+
|
21 |
+
dictkeys = list(set([x.split()[0] for x in lines]))
|
22 |
+
dictkeys.sort()
|
23 |
+
dictkeys = {key: ii for ii, key in enumerate(dictkeys)}
|
24 |
+
|
25 |
+
np.random.seed(100)
|
26 |
+
np.random.shuffle(lines)
|
27 |
+
|
28 |
+
self.data_list = []
|
29 |
+
self.data_label = []
|
30 |
+
|
31 |
+
for lidx, line in enumerate(lines):
|
32 |
+
data = line.strip().split()
|
33 |
+
speaker_label = dictkeys[data[0]]
|
34 |
+
filename = os.path.join(train_path, data[1])
|
35 |
+
|
36 |
+
self.data_list.append(filename)
|
37 |
+
self.data_label.append(speaker_label)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False)
|
42 |
+
if self.augment:
|
43 |
+
augtype = random.randint(0, 4) # 包括0,4
|
44 |
+
if augtype == 1:
|
45 |
+
audio = self.augment_wav.reverberate(audio)
|
46 |
+
elif augtype == 2:
|
47 |
+
audio = self.augment_wav.additive_noise('music', audio)
|
48 |
+
elif augtype == 3:
|
49 |
+
audio = self.augment_wav.additive_noise('speech', audio)
|
50 |
+
elif augtype == 4:
|
51 |
+
audio = self.augment_wav.additive_noise('noise', audio)
|
52 |
+
|
53 |
+
return torch.FloatTensor(audio), self.data_label[index]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.data_list)
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
train_dataset = TrainDataset(train_list="data/train_list.txt", augment=True,
|
61 |
+
musan_path="data/musan_split", rir_path="data/RIRS_NOISES/simulated_rirs",
|
62 |
+
max_frames=300, train_path="data/voxceleb2")
|
63 |
+
train_loader = torch.utils.data.DataLoader(
|
64 |
+
train_dataset,
|
65 |
+
batch_size=32,
|
66 |
+
pin_memory=False,
|
67 |
+
drop_last=True,
|
68 |
+
)
|
69 |
+
x, y = iter(train_loader).next()
|
70 |
+
print("x:", x.shape, "y:", y.shape)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
example/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
example/speaker1-1.wav
ADDED
Binary file (277 kB). View file
|
|
example/speaker1-2.wav
ADDED
Binary file (247 kB). View file
|
|
example/speaker2-1.wav
ADDED
Binary file (202 kB). View file
|
|
example/speaker2-2.wav
ADDED
Binary file (169 kB). View file
|
|
example/speaker3-1.wav
ADDED
Binary file (102 kB). View file
|
|
example/speaker3-2.wav
ADDED
Binary file (112 kB). View file
|
|
example/speaker4-1.wav
ADDED
Binary file (132 kB). View file
|
|
example/speaker4-2.wav
ADDED
Binary file (415 kB). View file
|
|
example/speaker5-1.wav
ADDED
Binary file (113 kB). View file
|
|
example/speaker5-2.wav
ADDED
Binary file (120 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
wandb
|
train.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import optim
|
2 |
+
import argparse
|
3 |
+
from datetime import datetime
|
4 |
+
import wandb
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
from torch import optim
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchinfo import summary
|
9 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
10 |
+
|
11 |
+
import lossfunction
|
12 |
+
import net
|
13 |
+
from DatasetLoader import *
|
14 |
+
from dataloader import TrainDataset
|
15 |
+
from SpeakerNet import *
|
16 |
+
from config import set_cfg, cfg
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
|
22 |
+
parser.add_argument("--config_name", type=str, default="", help="the configs name that will as a base configs")
|
23 |
+
parser.add_argument("--project", default=None, type=str, help="project name")
|
24 |
+
parser.add_argument("--name", default=None, type=str, help="run name")
|
25 |
+
parser.add_argument("--save_dir", default=None, type=str, help="save path")
|
26 |
+
parser.add_argument("--resume", default=None, type=str, help="resume path")
|
27 |
+
parser.add_argument("--dataset", default=None, type=str, help="dataset path")
|
28 |
+
|
29 |
+
parser.add_argument("--epoch", default=None, type=int, help="max epoch")
|
30 |
+
parser.add_argument("--test_freq", default=None, type=int, help="frequency test epoch")
|
31 |
+
parser.add_argument("--batch_size", default=None, type=int, help="batch size")
|
32 |
+
parser.add_argument("--lr", default=None, type=float, help="learning rate")
|
33 |
+
parser.add_argument("--seed", default=None, type=int)
|
34 |
+
|
35 |
+
parser.add_argument("--wandb", action='store_true', default=False, help='use wandb to log ')
|
36 |
+
parser.add_argument("--note", type=str, default="", help='wandb note')
|
37 |
+
|
38 |
+
parser.add_argument('--eval', dest='eval', action='store_true', default=False, help='Eval only')
|
39 |
+
parser.add_argument('--score', dest='score', action='store_true', default=False, help='Eval only')
|
40 |
+
|
41 |
+
args = parser.parse_args()
|
42 |
+
return args
|
43 |
+
|
44 |
+
|
45 |
+
def main():
|
46 |
+
global cfg
|
47 |
+
args = get_args()
|
48 |
+
assert args.config_name is not None
|
49 |
+
if args.config_name:
|
50 |
+
set_cfg(args.config_name)
|
51 |
+
cfg.replace(vars(args))
|
52 |
+
del args
|
53 |
+
|
54 |
+
cfg.save_dir = os.path.join(cfg.save_dir, cfg.project + '_' + cfg.name, datetime.now().strftime('%Y%m%d'))
|
55 |
+
if not os.path.exists(cfg.save_dir):
|
56 |
+
os.makedirs(cfg.save_dir)
|
57 |
+
|
58 |
+
if cfg.wandb:
|
59 |
+
wandb.login(host="http://49.233.11.7:8080", key="local-7dc64cc63778f0723dc202d2624a97cef7043120")
|
60 |
+
wandb.init(project=cfg.project, name=cfg.name, config=cfg, save_code=True, notes=cfg.note)
|
61 |
+
|
62 |
+
# cudnn related setting
|
63 |
+
cudnn.benchmark = True
|
64 |
+
torch.backends.cudnn.deterministic = False
|
65 |
+
torch.backends.cudnn.enabled = True
|
66 |
+
|
67 |
+
start_epoch = 1
|
68 |
+
|
69 |
+
# ---------------模型---------------
|
70 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
71 |
+
# device = torch.device("cpu")
|
72 |
+
# model = getattr(net, cfg.model)(cfg.nOut, cfg.encoder_type, cfg.log_input).to(device)
|
73 |
+
# ------ECAPA_TDNN.yaml------ResNet_TDNN----
|
74 |
+
model = getattr(net, cfg.model)().to(device)
|
75 |
+
|
76 |
+
# loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses, cfg.margin, cfg.scale).to(device)
|
77 |
+
# ----aamsoftmax----
|
78 |
+
loss = getattr(lossfunction, cfg.loss)(cfg.nOut, cfg.nClasses).to(device)
|
79 |
+
|
80 |
+
# model = SpeakerUnet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker, segment=cfg.segment)
|
81 |
+
model = SpeakerNet(model=model, trainfunc=loss, nPerSpeaker=cfg.nPerSpeaker)
|
82 |
+
# swin
|
83 |
+
optimizer = optim.AdamW(model.parameters(), eps=1e-8, betas=(0.9, 0.999),
|
84 |
+
lr=cfg.lr, weight_decay=0.05)
|
85 |
+
# optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=0.000002)
|
86 |
+
# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 50, 70], gamma=0.1, last_epoch=-1)
|
87 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5,
|
88 |
+
threshold=0.001, threshold_mode='rel',
|
89 |
+
cooldown=0, min_lr=1e-5, eps=1e-08, verbose=True)
|
90 |
+
# scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr, max_lr=0.003, mode='triangular2',
|
91 |
+
# step_size_up=12000, cycle_momentum=False)
|
92 |
+
|
93 |
+
|
94 |
+
if cfg.resume:
|
95 |
+
# ckpt = torch.load(cfg.resume, map_location="cpu")
|
96 |
+
ckpt = torch.load(cfg.resume)
|
97 |
+
model.load_state_dict(ckpt['model_state_dict'], strict=False)
|
98 |
+
# optimizer.load_state_dict(ckpt['optimizer_state_dict'])
|
99 |
+
# scheduler.load_state_dict(ckpt['scheduler_state_dict'])
|
100 |
+
# start_epoch = ckpt['epoch'] + 1
|
101 |
+
print("checkpoint加载完毕!")
|
102 |
+
# print(model)
|
103 |
+
|
104 |
+
# test, eval, train
|
105 |
+
trainer = Trainer(cfg, model, optimizer, scheduler, device)
|
106 |
+
|
107 |
+
# ---------------score--------------
|
108 |
+
if cfg.score:
|
109 |
+
score_dir = os.path.join('score', cfg.name+"_"+datetime.now().strftime('%Y%m%d'))
|
110 |
+
if not os.path.exists(score_dir):
|
111 |
+
os.makedirs(score_dir)
|
112 |
+
score_file = os.path.join(score_dir, 'scores.txt')
|
113 |
+
trainer.scoretxt(score_file, 'data/voxsrc2021_blind.txt', 'data/voxsrc2021', cfg.eval_frames)
|
114 |
+
# trainer.scoretxt(score_file, cfg.dataset.test_list, cfg.dataset.test_path, cfg.eval_frames)
|
115 |
+
# ---------------eval--------------
|
116 |
+
elif cfg.eval:
|
117 |
+
trainer.test(0, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread, cfg.eval_frames)
|
118 |
+
else:
|
119 |
+
# ---------------训练--------------
|
120 |
+
train_dataset = train_dataset_loader(train_list=cfg.dataset.train_list,
|
121 |
+
augment=cfg.augment, musan_path=cfg.dataset.musan_path,
|
122 |
+
rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
|
123 |
+
segment=cfg.segment, train_path=cfg.dataset.train_path)
|
124 |
+
|
125 |
+
train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=cfg.nPerSpeaker,
|
126 |
+
max_seg_per_spk=cfg.max_seg_per_spk, batch_size=cfg.batch_size,
|
127 |
+
seed=cfg.seed)
|
128 |
+
|
129 |
+
# train_dataset = TrainDataset(train_list=cfg.dataset.train_list,
|
130 |
+
# augment=cfg.augment, musan_path=cfg.dataset.musan_path,
|
131 |
+
# rir_path=cfg.dataset.rir_path, max_frames=cfg.max_frames,
|
132 |
+
# train_path=cfg.dataset.train_path)
|
133 |
+
|
134 |
+
train_loader = torch.utils.data.DataLoader(
|
135 |
+
train_dataset,
|
136 |
+
batch_size=cfg.batch_size,
|
137 |
+
num_workers=cfg.nDataLoaderThread,
|
138 |
+
sampler=train_sampler,
|
139 |
+
pin_memory=False,
|
140 |
+
drop_last=True,
|
141 |
+
)
|
142 |
+
|
143 |
+
x, y = iter(train_loader).next()
|
144 |
+
print('x.shape:', x.shape, 'y.shape:', y.shape)
|
145 |
+
print('x.dtype:', x.dtype, 'y.dtype:', y.dtype)
|
146 |
+
|
147 |
+
summary(model, input_size=(tuple(x.shape)))
|
148 |
+
|
149 |
+
it = 0
|
150 |
+
min_eer = float("inf")
|
151 |
+
for epoch in range(start_epoch, cfg.max_epoch):
|
152 |
+
trainer.train(epoch, train_loader)
|
153 |
+
if epoch % cfg.test_interval == 0:
|
154 |
+
eer = trainer.test(epoch, cfg.dataset.test_list, cfg.dataset.test_path, cfg.nDataLoaderThread,
|
155 |
+
cfg.eval_frames)
|
156 |
+
scheduler.step(eer)
|
157 |
+
# # -----Clr------
|
158 |
+
# if eer < min_eer:
|
159 |
+
# min_eer = eer
|
160 |
+
# it = 0
|
161 |
+
#
|
162 |
+
# else:
|
163 |
+
# it += 1
|
164 |
+
#
|
165 |
+
# if it >= 8:
|
166 |
+
# lr = cfg.lr * 0.1
|
167 |
+
# trainer.scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=lr, max_lr=cfg.lr,
|
168 |
+
# mode='triangular2',
|
169 |
+
# step_size_up=6000, cycle_momentum=False)
|
170 |
+
# it = 0
|
171 |
+
# # -----Clr------
|
172 |
+
trainer.save_model(epoch)
|
173 |
+
|
174 |
+
print("finishing")
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
main()
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
tuneThreshold.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn import metrics
|
2 |
+
import numpy
|
3 |
+
from operator import itemgetter
|
4 |
+
|
5 |
+
|
6 |
+
def tuneThresholdfromScore(scores, labels, target_fa, target_fr=None):
|
7 |
+
fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1)
|
8 |
+
fnr = 1 - tpr
|
9 |
+
|
10 |
+
tunedThreshold = []
|
11 |
+
if target_fr:
|
12 |
+
for tfr in target_fr:
|
13 |
+
idx = numpy.nanargmin(numpy.absolute((tfr - fnr)))
|
14 |
+
tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]])
|
15 |
+
|
16 |
+
for tfa in target_fa:
|
17 |
+
idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] nanargmin 返回轴上最小的值忽略Nans
|
18 |
+
tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]])
|
19 |
+
|
20 |
+
idxE = numpy.nanargmin(numpy.absolute((fnr - fpr)))
|
21 |
+
eer = max(fpr[idxE], fnr[idxE]) * 100
|
22 |
+
|
23 |
+
return tunedThreshold, eer, fpr, fnr
|
24 |
+
|
25 |
+
# Creates a list of false-negative rates, a list of false-positive rates
|
26 |
+
# and a list of decision thresholds that give those error-rates.
|
27 |
+
def ComputeErrorRates(scores, labels):
|
28 |
+
sorted_indexes, thresholds = zip(*sorted([(index, threshold) for index, threshold in enumerate(scores)],
|
29 |
+
key=itemgetter(1)))
|
30 |
+
labels = [labels[i] for i in sorted_indexes]
|
31 |
+
fnrs = [] # 负样本接受
|
32 |
+
fprs = [] # 正样本接受
|
33 |
+
|
34 |
+
for i in range(0, len(labels)):
|
35 |
+
if i == 0:
|
36 |
+
fnrs.append(labels[i])
|
37 |
+
fprs.append(1 - labels[i])
|
38 |
+
else:
|
39 |
+
fnrs.append(fnrs[i-1] + labels[i])
|
40 |
+
fprs.append(fprs[i-1] + 1 - labels[i])
|
41 |
+
|
42 |
+
fnrs_norm = sum(labels) # 真正样本个数
|
43 |
+
fprs_norm = len(labels) - fnrs_norm # 负样本个数
|
44 |
+
fnrs = [x / float(fnrs_norm) for x in fnrs] # 错误的拒绝 正样本分错的比例
|
45 |
+
fprs = [1 - x / float(fprs_norm) for x in fprs] # 错误接受 负样本分错的比例
|
46 |
+
|
47 |
+
return fnrs, fprs, thresholds
|
48 |
+
|
49 |
+
# Computes the minimum of the detection cost function. The comments refer to
|
50 |
+
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
|
51 |
+
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
|
52 |
+
min_c_det = float("inf")
|
53 |
+
min_c_det_threshold = thresholds[0]
|
54 |
+
for i in range(0, len(fnrs)):
|
55 |
+
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
|
56 |
+
if c_det < min_c_det:
|
57 |
+
min_c_det = c_det
|
58 |
+
min_c_det_threshold = thresholds[i]
|
59 |
+
|
60 |
+
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
|
61 |
+
min_dcf = min_c_det / c_def
|
62 |
+
return min_dcf, min_c_det_threshold
|