JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
4.1 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchaudio
EMBEDDER_PARAMS = {
'num_mels': 40,
'n_fft': 512,
'emb_dim': 256,
'lstm_hidden': 768,
'lstm_layers': 3,
'window': 80,
'stride': 40,
}
def set_requires_grad(nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary
computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
class LinearNorm(nn.Module):
def __init__(self, hp):
super(LinearNorm, self).__init__()
self.linear_layer = nn.Linear(hp["lstm_hidden"], hp["emb_dim"])
def forward(self, x):
return self.linear_layer(x)
class SpeechEmbedder(nn.Module):
def __init__(self, hp):
super(SpeechEmbedder, self).__init__()
self.lstm = nn.LSTM(hp["num_mels"],
hp["lstm_hidden"],
num_layers=hp["lstm_layers"],
batch_first=True)
self.proj = LinearNorm(hp)
self.hp = hp
def forward(self, mel):
# (num_mels, T) -> (num_mels, T', window)
mels = mel.unfold(1, self.hp["window"], self.hp["stride"])
mels = mels.permute(1, 2, 0) # (T', window, num_mels)
x, _ = self.lstm(mels) # (T', window, lstm_hidden)
x = x[:, -1, :] # (T', lstm_hidden), use last frame only
x = self.proj(x) # (T', emb_dim)
x = x / torch.norm(x, p=2, dim=1, keepdim=True) # (T', emb_dim)
x = x.mean(dim=0)
if x.norm(p=2) != 0:
x = x / x.norm(p=2)
return x
class SpkrEmbedder(nn.Module):
RATE = 16000
def __init__(
self,
embedder_path,
embedder_params=EMBEDDER_PARAMS,
rate=16000,
hop_length=160,
win_length=400,
pad=False,
):
super(SpkrEmbedder, self).__init__()
embedder_pt = torch.load(embedder_path, map_location="cpu")
self.embedder = SpeechEmbedder(embedder_params)
self.embedder.load_state_dict(embedder_pt)
self.embedder.eval()
set_requires_grad(self.embedder, requires_grad=False)
self.embedder_params = embedder_params
self.register_buffer('mel_basis', torch.from_numpy(
librosa.filters.mel(
sr=self.RATE,
n_fft=self.embedder_params["n_fft"],
n_mels=self.embedder_params["num_mels"])
)
)
self.resample = None
if rate != self.RATE:
self.resample = torchaudio.transforms.Resample(rate, self.RATE)
self.hop_length = hop_length
self.win_length = win_length
self.pad = pad
def get_mel(self, y):
if self.pad and y.shape[-1] < 14000:
y = F.pad(y, (0, 14000 - y.shape[-1]))
window = torch.hann_window(self.win_length).to(y)
y = torch.stft(y, n_fft=self.embedder_params["n_fft"],
hop_length=self.hop_length,
win_length=self.win_length,
window=window)
magnitudes = torch.norm(y, dim=-1, p=2) ** 2
mel = torch.log10(self.mel_basis @ magnitudes + 1e-6)
return mel
def forward(self, inputs):
dvecs = []
for wav in inputs:
mel = self.get_mel(wav)
if mel.dim() == 3:
mel = mel.squeeze(0)
dvecs += [self.embedder(mel)]
dvecs = torch.stack(dvecs)
dvec = torch.mean(dvecs, dim=0)
dvec = dvec / torch.norm(dvec)
return dvec