|
import soundfile as sf |
|
import torch |
|
import fire |
|
import torch.nn.functional as F |
|
from torchaudio.transforms import Resample |
|
from models.ecapa_tdnn import ECAPA_TDNN_SMALL |
|
|
|
MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"] |
|
|
|
|
|
def init_model(model_name, checkpoint=None): |
|
if model_name == 'unispeech_sat': |
|
config_path = 'config/unispeech_sat.th' |
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) |
|
elif model_name == 'wavlm_base_plus': |
|
config_path = None |
|
model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path) |
|
elif model_name == 'wavlm_large': |
|
config_path = None |
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path) |
|
elif model_name == 'hubert_large': |
|
config_path = None |
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) |
|
elif model_name == 'wav2vec2_xlsr': |
|
config_path = None |
|
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) |
|
else: |
|
model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') |
|
|
|
if checkpoint is not None: |
|
state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) |
|
model.load_state_dict(state_dict['model'], strict=False) |
|
return model |
|
|
|
|
|
def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None): |
|
|
|
assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) |
|
model = init_model(model_name, checkpoint) |
|
|
|
wav1, sr1 = sf.read(wav1) |
|
wav2, sr2 = sf.read(wav2) |
|
|
|
wav1 = torch.from_numpy(wav1).unsqueeze(0).float() |
|
wav2 = torch.from_numpy(wav2).unsqueeze(0).float() |
|
resample1 = Resample(orig_freq=sr1, new_freq=16000) |
|
resample2 = Resample(orig_freq=sr2, new_freq=16000) |
|
wav1 = resample1(wav1) |
|
wav2 = resample2(wav2) |
|
|
|
if use_gpu: |
|
model = model.cuda() |
|
wav1 = wav1.cuda() |
|
wav2 = wav2.cuda() |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
emb1 = model(wav1) |
|
emb2 = model(wav2) |
|
|
|
sim = F.cosine_similarity(emb1, emb2) |
|
|
|
return sim[0].item() |
|
|
|
def verification_batch(model_name, batch_wav1, batch_wav2, use_gpu=True, checkpoint=None): |
|
assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) |
|
model = init_model(model_name, checkpoint) |
|
|
|
|
|
|
|
|
|
sr1 = sf.read(str(batch_wav1[0]))[1] |
|
sr2 = sf.read(str(batch_wav2[0]))[1] |
|
|
|
|
|
|
|
batch_wav1 = [torch.from_numpy(sf.read(wav)[0][:50000]).unsqueeze(0).float() for wav in batch_wav1] |
|
batch_wav2 = [torch.from_numpy(sf.read(wav)[0][:50000]).unsqueeze(0).float() for wav in batch_wav2] |
|
|
|
resample1 = Resample(orig_freq=sr1, new_freq=16000) |
|
resample2 = Resample(orig_freq=sr2, new_freq=16000) |
|
|
|
|
|
|
|
batch_wav1 = torch.cat([resample1(wav) for wav in batch_wav1], 0) |
|
batch_wav2 = torch.cat([resample2(wav) for wav in batch_wav2], 0) |
|
|
|
|
|
|
|
|
|
if use_gpu: |
|
model = model.cuda() |
|
batch_wav1 = batch_wav1.cuda() |
|
batch_wav2 = batch_wav2.cuda() |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
emb1 = model(batch_wav1) |
|
emb2 = model(batch_wav2) |
|
|
|
sim = F.cosine_similarity(emb1, emb2 ,dim=-1) |
|
|
|
return sim.cpu().numpy() |
|
if __name__ == "__main__": |
|
fire.Fire(verification) |
|
|
|
|