Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy, sys, random | |
from DatasetLoader import test_dataset_loader | |
import importlib | |
import time, itertools | |
from utils.log import init_log | |
from tqdm import tqdm | |
import wandb | |
from tuneThreshold import * | |
class SpeakerNet(nn.Module): | |
def __init__(self, model, trainfunc, nPerSpeaker): | |
super(SpeakerNet, self).__init__() | |
self.model = model | |
self.loss = trainfunc | |
self.nPerSpeaker = nPerSpeaker | |
def forward(self, data, label=None): | |
data = data.reshape(-1, data.size()[-1]) | |
outp = self.model(data) | |
if label == None: | |
return outp | |
else: | |
emb = outp.reshape(-1, self.nPerSpeaker, outp.size()[-1]).squeeze(1) | |
nloss, prec1 = self.loss(emb, label) | |
return nloss, prec1 | |
class Trainer(object): | |
def __init__(self, cfg, model, optimizer, scheduler, device): | |
self.cfg = cfg | |
self.model = model | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.device = device | |
logging = init_log(cfg.save_dir) | |
self._print = logging.info | |
self.best = 0 | |
self.test_eer = 0 | |
self.test_mindcf = 0 | |
self.best_model = [] | |
def train(self, epoch, dataloader): | |
self.model.train() | |
pbar = tqdm(dataloader) | |
loss = 0 | |
top1 = 0 | |
index = 0 | |
counter = 0 | |
for data in pbar: | |
x, label = data[0].to(self.device), data[1].long().to(self.device) | |
nloss, prec1 = self.model(x, label) | |
self.optimizer.zero_grad() | |
nloss.backward() | |
self.optimizer.step() | |
# self.scheduler.step() | |
loss += nloss.detach().cpu().item() | |
top1 += prec1.detach().cpu().item() | |
index += x.size(0) | |
counter += 1 | |
if self.cfg.wandb: | |
wandb.log({ | |
"epoch": epoch, | |
"train_acc": top1 / counter, | |
"train_loss": loss / counter, | |
}) | |
pbar.set_description("Train Epoch:%3d ,Tloss:%.3f, Tacc:%.3f" % (epoch, loss/counter, top1/counter)) | |
# self.scheduler.step() | |
self._print('epoch:{} - train loss: {:.3f} and train acc: {:.3f} total sample: {}'.format( | |
epoch, loss/counter, top1/counter, index)) | |
def test(self, epoch, test_list, test_path, nDataLoaderThread, eval_frames, num_eval=10): | |
self.model.eval() | |
feats = {} | |
# read all lines | |
with open(test_list) as f: | |
lines = f.readlines() | |
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines])) | |
setfiles = list(set(files)) | |
setfiles.sort() | |
# Define test data loader | |
test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval) | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=1, | |
shuffle=False, | |
num_workers=nDataLoaderThread, | |
drop_last=False, | |
sampler=None | |
) | |
# Extract features for every wav | |
for idx, data in enumerate(tqdm(test_loader)): | |
inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav') | |
with torch.no_grad(): | |
ref_feat = self.model(inp1).detach().cpu() | |
feats[data[1][0]] = ref_feat | |
all_scores = [] | |
all_labels = [] | |
all_trials = [] | |
# Read files and compute all scores | |
for idx, line in enumerate(tqdm(lines)): | |
data = line.split() | |
# Append random label if missing | |
if len(data) == 2: | |
data = [random.randint(0, 1)] + data | |
ref_feat = feats[data[1]].to(self.device) | |
com_feat = feats[data[2]].to(self.device) | |
if self.model.loss.test_normalize: | |
ref_feat = F.normalize(ref_feat, p=2, dim=1) | |
com_feat = F.normalize(com_feat, p=2, dim=1) | |
# dist = F.pairwise_distance(ref_feat.unsqueeze(-1), | |
# com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy() | |
# | |
# score = -1 * numpy.mean(dist) | |
dist = F.cosine_similarity(ref_feat.unsqueeze(-1), | |
com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy() | |
score = numpy.mean(dist) | |
all_scores.append(score) | |
all_labels.append(int(data[0])) | |
all_trials.append(data[1] + " " + data[2]) | |
result = tuneThresholdfromScore(all_scores, all_labels, [1, 0.1]) | |
fnrs, fprs, thresholds = ComputeErrorRates(all_scores, all_labels) | |
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, self.cfg.dcf_p_target, self.cfg.dcf_c_miss, self.cfg.dcf_c_fa) | |
self.test_eer = result[1] | |
self.test_mindcf = mindcf | |
self.threshold = threshold | |
if self.cfg.wandb: | |
wandb.log({ | |
"test_eer": self.test_eer, | |
"test_MinDCF": self.test_mindcf, | |
}) | |
self._print('epoch:{} - test EER: {:.3f} and test MinDCF: {:.3f} total sample: {} threshold: {:.3f}'.format( | |
epoch, self.test_eer, self.test_mindcf, len(lines), self.threshold)) | |
return self.test_eer | |
def save_model(self, epoch): | |
if self.test_eer < self.best or self.best == 0: | |
self.best = self.test_eer | |
if self.cfg.wandb: | |
wandb.run.summary["best_accuracy"] = self.best | |
model_state_dict = self.model.state_dict() | |
optimizer_state_dict = self.optimizer.state_dict() | |
scheduler_state_dict = self.scheduler.state_dict() | |
file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf) | |
if not os.path.exists(self.cfg.save_dir): | |
os.mkdir(self.cfg.save_dir) | |
torch.save({ | |
'epoch': epoch, | |
'test_eer': self.test_eer, | |
'test_mindcf': self.test_mindcf, | |
'model_state_dict': model_state_dict, | |
'optimizer_state_dict': optimizer_state_dict, | |
'scheduler_state_dict': scheduler_state_dict}, | |
os.path.join(self.cfg.save_dir, file_save_path)) | |
self.best_model.append(file_save_path) | |
if len(self.best_model) > 3: | |
del_file = os.path.join(self.cfg.save_dir, self.best_model.pop(0)) | |
if os.path.exists(del_file): | |
os.remove(del_file) | |
else: | |
print("no exists {}".format(del_file)) | |
# 每20个epoch保存一下 | |
if epoch % 20 == 0: | |
model_state_dict = self.model.state_dict() | |
optimizer_state_dict = self.optimizer.state_dict() | |
scheduler_state_dict = self.scheduler.state_dict() | |
file_save_path = 'epoch:%d,EER:%.4f,MinDCF:%.4f' % (epoch, self.test_eer, self.test_mindcf) | |
if not os.path.exists(self.cfg.save_dir): | |
os.mkdir(self.cfg.save_dir) | |
if not os.path.exists(os.path.join(self.cfg.save_dir, file_save_path)): | |
torch.save({ | |
'epoch': epoch, | |
'test_eee': self.test_eer, | |
'test_mindcf': self.test_mindcf, | |
'model_state_dict': model_state_dict, | |
'optimizer_state_dict': optimizer_state_dict, | |
'scheduler_state_dict': scheduler_state_dict}, | |
os.path.join(self.cfg.save_dir, file_save_path)) | |
def scoretxt(self, score_file, test_list, test_path, eval_frames, num_eval=10): | |
self.model.eval() | |
feats = {} | |
# read all lines | |
with open(test_list) as f: | |
lines = f.readlines() | |
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines])) | |
setfiles = list(set(files)) | |
setfiles.sort() | |
# Define test data loader | |
test_dataset = test_dataset_loader(setfiles, test_path, eval_frames=eval_frames, num_eval=num_eval) | |
test_loader = torch.utils.data.DataLoader( | |
test_dataset, | |
batch_size=1, | |
shuffle=False, | |
drop_last=False, | |
sampler=None | |
) | |
# Extract features for every wav | |
for idx, data in enumerate(tqdm(test_loader)): | |
inp1 = data[0][0].to(self.device) # (data[0]:(1,10,1024),data[1]:'id10270/GWXujl-xAVM/00017.wav') | |
with torch.no_grad(): | |
ref_feat = self.model(inp1).detach().cpu() | |
feats[data[1][0]] = ref_feat | |
f = open(score_file, "w") | |
# Read files and compute all scores | |
for idx, line in enumerate(tqdm(lines)): | |
data = line.split() | |
# Append random label if missing | |
ref_feat = feats[data[-2]].to(self.device) | |
com_feat = feats[data[-1]].to(self.device) | |
if self.model.loss.test_normalize: | |
ref_feat = F.normalize(ref_feat, p=2, dim=1) | |
com_feat = F.normalize(com_feat, p=2, dim=1) | |
# dist = F.pairwise_distance(ref_feat.unsqueeze(-1), | |
# com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy() | |
# | |
# score = -1 * numpy.mean(dist) | |
dist = F.cosine_similarity(ref_feat.unsqueeze(-1), | |
com_feat.unsqueeze(-1).transpose(0, 2)).detach().cpu().numpy() | |
score = numpy.mean(dist) | |
score_line = str(score) + " " + data[-2] + " " + data[-1] | |
f.write(score_line+'\n') | |
f.close() | |