VoxSIM / score.py
junseok
new commit
08cc398
import os
import numpy
import librosa
import torch
import torch.nn.functional as F
from ssl_ecapa_model import SSL_ECAPA_TDNN
from huggingface_hub import hf_hub_download
def loadWav(filename, max_frames: int = 400, num_eval: int = 10):
# Maximum audio length
max_audio = max_frames * 160 + 240
# Read wav file and convert to torch tensor
audio, sr = librosa.load(filename, sr=16000)
audio_org = audio.copy()
audiosize = audio.shape[0]
if audiosize <= max_audio:
shortage = max_audio - audiosize + 1
audio = numpy.pad(audio, (0, shortage), 'wrap')
audiosize = audio.shape[0]
startframe = numpy.linspace(0,audiosize-max_audio, num=num_eval)
feats = []
if max_frames == 0:
feats.append(audio)
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
return torch.FloatTensor(feat)
else:
for asf in startframe:
feats.append(audio[int(asf):int(asf)+max_audio])
feat = numpy.stack(feats,axis=0).astype(numpy.float32)
return torch.FloatTensor(feat), torch.FloatTensor(numpy.stack([audio_org],axis=0).astype(numpy.float32))
def loadModel(ckpt_path):
model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large')
if not os.path.isfile(ckpt_path):
print("Downloading model from Hugging Face Hub...")
ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./")
model.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
return model
class Score:
"""Predicting score for each audio clip."""
def __init__(
self,
ckpt_path: str = "voxsim_wavlm_ecapa.model",
device: str = "gpu"):
"""
Args:
ckpt_path: path to pretrained checkpoint of voxsim evaluator.
input_sample_rate: sampling rate of input audio tensor. The input audio tensor
is automatically downsampled to 16kHz.
"""
print(f"Using device: {device}")
self.device = device
self.model = loadModel(ckpt_path).to(self.device)
self.model.eval()
def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor:
inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device)
inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device)
ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device)
ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device)
with torch.no_grad():
input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach()
input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach()
ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach()
ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach()
emb_size = input_emb_1.shape[-1]
input_emb_1 = input_emb_1.reshape(-1, emb_size)
input_emb_2 = input_emb_2.reshape(-1, emb_size)
ref_emb_1 = ref_emb_1.reshape(-1, emb_size)
ref_emb_2 = ref_emb_2.reshape(-1, emb_size)
score_1 = torch.mean(torch.matmul(input_emb_1, ref_emb_1.T))
score_2 = torch.mean(torch.matmul(input_emb_2, ref_emb_2.T))
score = (score_1 + score_2) / 2
score = score.detach().cpu().item()
return score