File size: 1,179 Bytes
60ea83f 622b6ed 60ea83f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import sys
import os
import torch
sys.path.append(f"{os.getcwd()}/eres2net")
sv_path = "pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
class SV:
def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state)
embedding_model.eval()
self.embedding_model = embedding_model
if is_half == False:
self.embedding_model = self.embedding_model.to(device)
else:
self.embedding_model = self.embedding_model.half().to(device)
self.is_half = is_half
def compute_embedding3(self, wav):
with torch.no_grad():
if self.is_half == True:
wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb
|