File size: 9,222 Bytes
c56c253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from speaker_encoder.hparams import *
from speaker_encoder import audio
from pathlib import Path
from typing import Union, List
from torch import nn
from time import perf_counter as timer
import numpy as np
import torch


class SpeakerEncoder(nn.Module):
    def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
        """
        :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). 
        If None, defaults to cuda if it is available on your machine, otherwise the model will 
        run on cpu. Outputs are always returned on the cpu, as numpy arrays.
        """
        super().__init__()
        
        # Define the network
        self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
        self.linear = nn.Linear(model_hidden_size, model_embedding_size)
        self.relu = nn.ReLU()
        
        # Get the target device
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        elif isinstance(device, str):
            device = torch.device(device)
        self.device = device
            
        # Load the pretrained model'speaker weights
        # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
        # if not weights_fpath.exists():
        #     raise Exception("Couldn't find the voice encoder pretrained model at %s." % 
        #                     weights_fpath)

        start = timer()
        checkpoint = torch.load(weights_fpath, map_location="cpu")

        self.load_state_dict(checkpoint["model_state"], strict=False)
        self.to(device)
        
        if verbose:
            print("Loaded the voice encoder model on %s in %.2f seconds." % 
                  (device.type, timer() - start))

    def forward(self, mels: torch.FloatTensor):
        """
        Computes the embeddings of a batch of utterance spectrograms.
        :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape 
        (batch_size, n_frames, n_channels) 
        :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). 
        Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
        """
        # Pass the input through the LSTM layers and retrieve the final hidden state of the last 
        # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
        _, (hidden, _) = self.lstm(mels)
        embeds_raw = self.relu(self.linear(hidden[-1]))
        return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
    
    @staticmethod
    def compute_partial_slices(n_samples: int, rate, min_coverage):
        """
        Computes where to split an utterance waveform and its corresponding mel spectrogram to 
        obtain partial utterances of <partials_n_frames> each. Both the waveform and the 
        mel spectrogram slices are returned, so as to make each partial utterance waveform 
        correspond to its spectrogram.
    
        The returned ranges may be indexing further than the length of the waveform. It is 
        recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
    
        :param n_samples: the number of samples in the waveform
        :param rate: how many partial utterances should occur per second. Partial utterances must 
        cover the span of the entire utterance, thus the rate should not be lower than the inverse 
        of the duration of a partial utterance. By default, partial utterances are 1.6s long and 
        the minimum rate is thus 0.625.
        :param min_coverage: when reaching the last partial utterance, it may or may not have 
        enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present, 
        then the last partial utterance will be considered by zero-padding the audio. Otherwise, 
        it will be discarded. If there aren't enough frames for one partial utterance, 
        this parameter is ignored so that the function always returns at least one slice.
        :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 
        respectively the waveform and the mel spectrogram with these slices to obtain the partial 
        utterances.
        """
        assert 0 < min_coverage <= 1
        
        # Compute how many frames separate two partial utterances
        samples_per_frame = int((sampling_rate * mel_window_step / 1000))
        n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
        frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
        assert 0 < frame_step, "The rate is too high"
        assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
            (sampling_rate / (samples_per_frame * partials_n_frames))
        
        # Compute the slices
        wav_slices, mel_slices = [], []
        steps = max(1, n_frames - partials_n_frames + frame_step + 1)
        for i in range(0, steps, frame_step):
            mel_range = np.array([i, i + partials_n_frames])
            wav_range = mel_range * samples_per_frame
            mel_slices.append(slice(*mel_range))
            wav_slices.append(slice(*wav_range))
        
        # Evaluate whether extra padding is warranted or not
        last_wav_range = wav_slices[-1]
        coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
        if coverage < min_coverage and len(mel_slices) > 1:
            mel_slices = mel_slices[:-1]
            wav_slices = wav_slices[:-1]
        
        return wav_slices, mel_slices
    
    def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
        """
        Computes an embedding for a single utterance. The utterance is divided in partial 
        utterances and an embedding is computed for each. The complete utterance embedding is the 
        L2-normed average embedding of the partial utterances.
        
        TODO: independent batched version of this function
    
        :param wav: a preprocessed utterance waveform as a numpy array of float32
        :param return_partials: if True, the partial embeddings will also be returned along with 
        the wav slices corresponding to each partial utterance.
        :param rate: how many partial utterances should occur per second. Partial utterances must 
        cover the span of the entire utterance, thus the rate should not be lower than the inverse 
        of the duration of a partial utterance. By default, partial utterances are 1.6s long and 
        the minimum rate is thus 0.625.
        :param min_coverage: when reaching the last partial utterance, it may or may not have 
        enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present, 
        then the last partial utterance will be considered by zero-padding the audio. Otherwise, 
        it will be discarded. If there aren't enough frames for one partial utterance, 
        this parameter is ignored so that the function always returns at least one slice.
        :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 
        <return_partials> is True, the partial utterances as a numpy array of float32 of shape 
        (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 
        returned.
        """
        # Compute where to split the utterance into partials and pad the waveform with zeros if 
        # the partial utterances cover a larger range. 
        wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
        max_wave_length = wav_slices[-1].stop
        if max_wave_length >= len(wav):
            wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
        
        # Split the utterance into partials and forward them through the model
        mel = audio.wav_to_mel_spectrogram(wav)
        mels = np.array([mel[s] for s in mel_slices])
        with torch.no_grad():
            mels = torch.from_numpy(mels).to(self.device)
            partial_embeds = self(mels).cpu().numpy()
        
        # Compute the utterance embedding from the partial embeddings
        raw_embed = np.mean(partial_embeds, axis=0)
        embed = raw_embed / np.linalg.norm(raw_embed, 2)
        
        if return_partials:
            return embed, partial_embeds, wav_slices
        return embed
    
    def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
        """
        Compute the embedding of a collection of wavs (presumably from the same speaker) by 
        averaging their embedding and L2-normalizing it.
        
        :param wavs: list of wavs a numpy arrays of float32.
        :param kwargs: extra arguments to embed_utterance()
        :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
        """
        raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
                             for wav in wavs], axis=0)
        return raw_embed / np.linalg.norm(raw_embed, 2)