Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import librosa | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data | |
import torchaudio | |
EMBEDDER_PARAMS = { | |
'num_mels': 40, | |
'n_fft': 512, | |
'emb_dim': 256, | |
'lstm_hidden': 768, | |
'lstm_layers': 3, | |
'window': 80, | |
'stride': 40, | |
} | |
def set_requires_grad(nets, requires_grad=False): | |
"""Set requies_grad=Fasle for all the networks to avoid unnecessary | |
computations | |
Parameters: | |
nets (network list) -- a list of networks | |
requires_grad (bool) -- whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |
class LinearNorm(nn.Module): | |
def __init__(self, hp): | |
super(LinearNorm, self).__init__() | |
self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"]) | |
def forward(self, x): | |
return self.linear_layer(x) | |
class SpeechEmbedder(nn.Module): | |
def __init__(self, hp): | |
super(SpeechEmbedder, self).__init__() | |
self.lstm = nn.LSTM(hp["num_mels"], | |
hp["lstm_hidden"], | |
num_layers=hp["lstm_layers"], | |
batch_first=True) | |
self.proj = LinearNorm(hp) | |
self.hp = hp | |
def forward(self, mel): | |
# (num_mels, T) -> (num_mels, T', window) | |
mels = mel.unfold(1, self.hp["window"], self.hp["stride"]) | |
mels = mels.permute(1, 2, 0) # (T', window, num_mels) | |
x, _ = self.lstm(mels) # (T', window, lstm_hidden) | |
x = x[:, -1, :] # (T', lstm_hidden), use last frame only | |
x = self.proj(x) # (T', emb_dim) | |
x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim) | |
x = x.mean(dim=0) | |
if x.norm(p=2) != 0: | |
x = x / x.norm(p=2) | |
return x | |
class SpkrEmbedder(nn.Module): | |
RATE = 16000 | |
def __init__( | |
self, | |
embedder_path, | |
embedder_params=EMBEDDER_PARAMS, | |
rate=16000, | |
hop_length=160, | |
win_length=400, | |
pad=False, | |
): | |
super(SpkrEmbedder, self).__init__() | |
embedder_pt = torch.load(embedder_path, map_location="cpu") | |
self.embedder = SpeechEmbedder(embedder_params) | |
self.embedder.load_state_dict(embedder_pt) | |
self.embedder.eval() | |
set_requires_grad(self.embedder, requires_grad=False) | |
self.embedder_params = embedder_params | |
self.register_buffer('mel_basis', torch.from_numpy( | |
librosa.filters.mel( | |
sr=self.RATE, | |
n_fft=self.embedder_params["n_fft"], | |
n_mels=self.embedder_params["num_mels"]) | |
) | |
) | |
self.resample = None | |
if rate != self.RATE: | |
self.resample = torchaudio.transforms.Resample(rate, self.RATE) | |
self.hop_length = hop_length | |
self.win_length = win_length | |
self.pad = pad | |
def get_mel(self, y): | |
if self.pad and y.shape[-1] < 14000: | |
y = F.pad(y, (0, 14000 - y.shape[-1])) | |
window = torch.hann_window(self.win_length).to(y) | |
y = torch.stft(y, n_fft=self.embedder_params["n_fft"], | |
hop_length=self.hop_length, | |
win_length=self.win_length, | |
window=window) | |
magnitudes = torch.norm(y, dim=-1, p=2) ** 2 | |
mel = torch.log10(self.mel_basis @ magnitudes + 1e-6) | |
return mel | |
def forward(self, inputs): | |
dvecs = [] | |
for wav in inputs: | |
mel = self.get_mel(wav) | |
if mel.dim() == 3: | |
mel = mel.squeeze(0) | |
dvecs += [self.embedder(mel)] | |
dvecs = torch.stack(dvecs) | |
dvec = torch.mean(dvecs, dim=0) | |
dvec = dvec / torch.norm(dvec) | |
return dvec | |