#
Spaces:
Running

Running

from speaker_encoder.params_data import * | |

from speaker_encoder.model import SpeakerEncoder | |

from speaker_encoder.audio import preprocess_wav # We want to expose this function from here | |

from matplotlib import cm | |

from speaker_encoder import audio | |

from pathlib import Path | |

import matplotlib.pyplot as plt | |

import numpy as np | |

import torch | |

_model = None # type: SpeakerEncoder | |

_device = None # type: torch.device | |

def load_model(weights_fpath: Path, device=None): | |

""" | |

Loads the model in memory. If this function is not explicitely called, it will be run on the | |

first call to embed_frames() with the default weights file. | |

:param weights_fpath: the path to saved model weights. | |

:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The | |

model will be loaded and will run on this device. Outputs will however always be on the cpu. | |

If None, will default to your GPU if it"s available, otherwise your CPU. | |

""" | |

# TODO: I think the slow loading of the encoder might have something to do with the device it | |

# was saved on. Worth investigating. | |

global _model, _device | |

if device is None: | |

_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |

elif isinstance(device, str): | |

_device = torch.device(device) | |

_model = SpeakerEncoder(_device, torch.device("cpu")) | |

checkpoint = torch.load(weights_fpath) | |

_model.load_state_dict(checkpoint["model_state"]) | |

_model.eval() | |

print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) | |

def is_loaded(): | |

return _model is not None | |

def embed_frames_batch(frames_batch): | |

""" | |

Computes embeddings for a batch of mel spectrogram. | |

:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape | |

(batch_size, n_frames, n_channels) | |

:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) | |

""" | |

if _model is None: | |

raise Exception("Model was not loaded. Call load_model() before inference.") | |

frames = torch.from_numpy(frames_batch).to(_device) | |

embed = _model.forward(frames).detach().cpu().numpy() | |

return embed | |

def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, | |

min_pad_coverage=0.75, overlap=0.5): | |

""" | |

Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain | |

partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel | |

spectrogram slices are returned, so as to make each partial utterance waveform correspond to | |

its spectrogram. This function assumes that the mel spectrogram parameters used are those | |

defined in params_data.py. | |

The returned ranges may be indexing further than the length of the waveform. It is | |

recommended that you pad the waveform with zeros up to wave_slices[-1].stop. | |

:param n_samples: the number of samples in the waveform | |

:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial | |

utterance | |

:param min_pad_coverage: when reaching the last partial utterance, it may or may not have | |

enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present, | |

then the last partial utterance will be considered, as if we padded the audio. Otherwise, | |

it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial | |

utterance, this parameter is ignored so that the function always returns at least 1 slice. | |

:param overlap: by how much the partial utterance should overlap. If set to 0, the partial | |

utterances are entirely disjoint. | |

:return: the waveform slices and mel spectrogram slices as lists of array slices. Index | |

respectively the waveform and the mel spectrogram with these slices to obtain the partial | |

utterances. | |

""" | |

assert 0 <= overlap < 1 | |

assert 0 < min_pad_coverage <= 1 | |

samples_per_frame = int((sampling_rate * mel_window_step / 1000)) | |

n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) | |

frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) | |

# Compute the slices | |

wav_slices, mel_slices = [], [] | |

steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) | |

for i in range(0, steps, frame_step): | |

mel_range = np.array([i, i + partial_utterance_n_frames]) | |

wav_range = mel_range * samples_per_frame | |

mel_slices.append(slice(*mel_range)) | |

wav_slices.append(slice(*wav_range)) | |

# Evaluate whether extra padding is warranted or not | |

last_wav_range = wav_slices[-1] | |

coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) | |

if coverage < min_pad_coverage and len(mel_slices) > 1: | |

mel_slices = mel_slices[:-1] | |

wav_slices = wav_slices[:-1] | |

return wav_slices, mel_slices | |

def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): | |

""" | |

Computes an embedding for a single utterance. | |

# TODO: handle multiple wavs to benefit from batching on GPU | |

:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 | |

:param using_partials: if True, then the utterance is split in partial utterances of | |

<partial_utterance_n_frames> frames and the utterance embedding is computed from their | |

normalized average. If False, the utterance is instead computed from feeding the entire | |

spectogram to the network. | |

:param return_partials: if True, the partial embeddings will also be returned along with the | |

wav slices that correspond to the partial embeddings. | |

:param kwargs: additional arguments to compute_partial_splits() | |

:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If | |

<return_partials> is True, the partial utterances as a numpy array of float32 of shape | |

(n_partials, model_embedding_size) and the wav partials as a list of slices will also be | |

returned. If <using_partials> is simultaneously set to False, both these values will be None | |

instead. | |

""" | |

# Process the entire utterance if not using partials | |

if not using_partials: | |

frames = audio.wav_to_mel_spectrogram(wav) | |

embed = embed_frames_batch(frames[None, ...])[0] | |

if return_partials: | |

return embed, None, None | |

return embed | |

# Compute where to split the utterance into partials and pad if necessary | |

wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) | |

max_wave_length = wave_slices[-1].stop | |

if max_wave_length >= len(wav): | |

wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") | |

# Split the utterance into partials | |

frames = audio.wav_to_mel_spectrogram(wav) | |

frames_batch = np.array([frames[s] for s in mel_slices]) | |

partial_embeds = embed_frames_batch(frames_batch) | |

# Compute the utterance embedding from the partial embeddings | |

raw_embed = np.mean(partial_embeds, axis=0) | |

embed = raw_embed / np.linalg.norm(raw_embed, 2) | |

if return_partials: | |

return embed, partial_embeds, wave_slices | |

return embed | |

def embed_speaker(wavs, **kwargs): | |

raise NotImplemented() | |

def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): | |

if ax is None: | |

ax = plt.gca() | |

if shape is None: | |

height = int(np.sqrt(len(embed))) | |

shape = (height, -1) | |

embed = embed.reshape(shape) | |

cmap = cm.get_cmap() | |

mappable = ax.imshow(embed, cmap=cmap) | |

cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) | |

cbar.set_clim(*color_range) | |

ax.set_xticks([]), ax.set_yticks([]) | |

ax.set_title(title) | |