import os import numpy as np import torchaudio import torch from torch import nn from speechbrain.lobes.models.Xvector import Xvector from speechbrain.lobes.features import Fbank from speechbrain.processing.features import InputNormalization class Extractor(nn.Module): model_dict = [ "mean_var_norm", "compute_features", "embedding_model", "mean_var_norm_emb", ] def __init__(self, model_path, n_mels=24, device="cpu"): super().__init__() self.device = device self.compute_features = Fbank(n_mels=n_mels) self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False) self.embedding_model = Xvector( in_channels = n_mels, activation = torch.nn.LeakyReLU, tdnn_blocks = 5, tdnn_channels = [512, 512, 512, 512, 1500], tdnn_kernel_sizes = [5, 3, 3, 1, 1], tdnn_dilations = [1, 2, 3, 1, 1], lin_neurons = 512, ) self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False) for mod_name in self.model_dict: filename = os.path.join(model_path, f"{mod_name}.ckpt") module = getattr(self, mod_name) if os.path.exists(filename): if hasattr(module, "_load"): print(f"Load: {filename}") module._load(filename) else: print(f"Load State Dict: {filename}") module.load_state_dict(torch.load(filename)) module.to(self.device) @torch.no_grad() def forward(self, wavs, wav_lens = None, normalize=False): # Manage single waveforms in input if len(wavs.shape) == 1: wavs = wavs.unsqueeze(0) # Assign full length if wav_lens is not assigned if wav_lens is None: wav_lens = torch.ones(wavs.shape[0], device=self.device) # Storing waveform in the specified device wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) wavs = wavs.float() # Computing features and embeddings feats = self.compute_features(wavs) feats = self.mean_var_norm(feats, wav_lens) embeddings = self.embedding_model(feats, wav_lens) if normalize: embeddings = self.mean_var_norm_emb( embeddings, torch.ones(embeddings.shape[0], device=self.device) ) return embeddings MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb" signal, fs = torchaudio.load('audio.wav') device = "cuda" extractor = Extractor(MODEL_PATH, device=device) for k, p in extractor.named_parameters(): p.requires_grad = False extractor.eval() embeddings_x = extractor(signal).cpu().squeeze() # Tracing traced_model = torch.jit.trace(extractor, signal) torch.jit.save(traced_model, f"model_{device}.pt") embeddings_t = traced_model(signal).squeeze() print(embeddings_t) model = torch.jit.load(f"model_{device}.pt") emb_m = model(signal).squeeze() print(emb_m)