xvect-voxceleb_traced / tracing_code.py
Fhrozen's picture
add models
24a35c2
import os
import numpy as np
import torchaudio
import torch
from torch import nn
from speechbrain.lobes.models.Xvector import Xvector
from speechbrain.lobes.features import Fbank
from speechbrain.processing.features import InputNormalization
class Extractor(nn.Module):
model_dict = [
"mean_var_norm",
"compute_features",
"embedding_model",
"mean_var_norm_emb",
]
def __init__(self, model_path, n_mels=24, device="cpu"):
super().__init__()
self.device = device
self.compute_features = Fbank(n_mels=n_mels)
self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False)
self.embedding_model = Xvector(
in_channels = n_mels,
activation = torch.nn.LeakyReLU,
tdnn_blocks = 5,
tdnn_channels = [512, 512, 512, 512, 1500],
tdnn_kernel_sizes = [5, 3, 3, 1, 1],
tdnn_dilations = [1, 2, 3, 1, 1],
lin_neurons = 512,
)
self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False)
for mod_name in self.model_dict:
filename = os.path.join(model_path, f"{mod_name}.ckpt")
module = getattr(self, mod_name)
if os.path.exists(filename):
if hasattr(module, "_load"):
print(f"Load: {filename}")
module._load(filename)
else:
print(f"Load State Dict: {filename}")
module.load_state_dict(torch.load(filename))
module.to(self.device)
@torch.no_grad()
def forward(self, wavs, wav_lens = None, normalize=False):
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
feats = self.compute_features(wavs)
feats = self.mean_var_norm(feats, wav_lens)
embeddings = self.embedding_model(feats, wav_lens)
if normalize:
embeddings = self.mean_var_norm_emb(
embeddings, torch.ones(embeddings.shape[0], device=self.device)
)
return embeddings
MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb"
signal, fs = torchaudio.load('audio.wav')
device = "cuda"
extractor = Extractor(MODEL_PATH, device=device)
for k, p in extractor.named_parameters():
p.requires_grad = False
extractor.eval()
embeddings_x = extractor(signal).cpu().squeeze()
# Tracing
traced_model = torch.jit.trace(extractor, signal)
torch.jit.save(traced_model, f"model_{device}.pt")
embeddings_t = traced_model(signal).squeeze()
print(embeddings_t)
model = torch.jit.load(f"model_{device}.pt")
emb_m = model(signal).squeeze()
print(emb_m)