File size: 3,071 Bytes
24a35c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import numpy as np

import torchaudio
import torch
from torch import nn

from speechbrain.lobes.models.Xvector import Xvector
from speechbrain.lobes.features import Fbank
from speechbrain.processing.features import InputNormalization


class Extractor(nn.Module):
    model_dict = [
        "mean_var_norm",
        "compute_features",
        "embedding_model",
        "mean_var_norm_emb",
    ]
    def __init__(self, model_path, n_mels=24, device="cpu"):
        super().__init__()
        self.device = device
        self.compute_features = Fbank(n_mels=n_mels)
        self.mean_var_norm = InputNormalization(norm_type="sentence", std_norm=False)
        self.embedding_model = Xvector(
            in_channels = n_mels,
            activation = torch.nn.LeakyReLU,
            tdnn_blocks = 5,
            tdnn_channels = [512, 512, 512, 512, 1500],
            tdnn_kernel_sizes = [5, 3, 3, 1, 1],
            tdnn_dilations = [1, 2, 3, 1, 1],
            lin_neurons = 512,
        )
        self.mean_var_norm_emb = InputNormalization(norm_type="global", std_norm=False)
        for mod_name in self.model_dict:
            filename = os.path.join(model_path, f"{mod_name}.ckpt")
            module = getattr(self, mod_name)
            if os.path.exists(filename):
                if hasattr(module, "_load"):
                    print(f"Load: {filename}")
                    module._load(filename)
                else:
                    print(f"Load State Dict: {filename}")
                    module.load_state_dict(torch.load(filename))
            module.to(self.device)

    @torch.no_grad()
    def forward(self, wavs, wav_lens = None, normalize=False):
        # Manage single waveforms in input
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)

        # Assign full length if wav_lens is not assigned
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)

        # Storing waveform in the specified device
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()

        # Computing features and embeddings
        feats = self.compute_features(wavs)
        feats = self.mean_var_norm(feats, wav_lens)
        embeddings = self.embedding_model(feats, wav_lens)
        if normalize:
            embeddings = self.mean_var_norm_emb(
                embeddings, torch.ones(embeddings.shape[0], device=self.device)
            )
        return embeddings


MODEL_PATH = "pretrained_models/spkrec-xvect-voxceleb"
signal, fs = torchaudio.load('audio.wav')

device = "cuda"
extractor = Extractor(MODEL_PATH, device=device)

for k, p in extractor.named_parameters():
    p.requires_grad = False

extractor.eval()
embeddings_x = extractor(signal).cpu().squeeze()

# Tracing
traced_model = torch.jit.trace(extractor, signal)
torch.jit.save(traced_model, f"model_{device}.pt")
embeddings_t = traced_model(signal).squeeze()
print(embeddings_t)

model = torch.jit.load(f"model_{device}.pt")
emb_m = model(signal).squeeze()
print(emb_m)