speaker-embedding / xvector_sincnet.py
nguyenvulebinh's picture
upload infer utils
fd06d88
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from asteroid_filterbanks import Encoder, ParamSincFB
def merge_dict(defaults: dict, custom: dict = None):
params = dict(defaults)
if custom is not None:
params.update(custom)
return params
class StatsPool(nn.Module):
"""Statistics pooling
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
"""
def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass
Parameters
----------
sequences : (batch, channel, frames) torch.Tensor
Sequences.
weights : (batch, frames) torch.Tensor, optional
When provided, compute weighted mean and standard deviation.
Returns
-------
output : (batch, 2 * channel) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""
if weights is None:
mean = sequences.mean(dim=2)
std = sequences.std(dim=2, unbiased=True)
else:
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)
num_frames = sequences.shape[2]
num_weights = weights.shape[2]
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(
weights, size=num_frames, mode="linear", align_corners=False
)
v1 = weights.sum(dim=2)
mean = torch.sum(sequences * weights, dim=2) / v1
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1)
std = torch.sqrt(var)
return torch.cat([mean, std], dim=1)
class SincNet(nn.Module):
def __init__(self, sample_rate: int = 16000, stride: int = 1):
super().__init__()
if sample_rate != 16000:
raise NotImplementedError("PyanNet only supports 16kHz audio for now.")
# TODO: add support for other sample rate. it should be enough to multiply
# kernel_size by (sample_rate / 16000). but this needs to be double-checked.
self.stride = stride
self.wav_norm1d = nn.InstanceNorm1d(1, affine=True)
self.conv1d = nn.ModuleList()
self.pool1d = nn.ModuleList()
self.norm1d = nn.ModuleList()
self.conv1d.append(
Encoder(
ParamSincFB(
80,
251,
stride=self.stride,
sample_rate=sample_rate,
min_low_hz=50,
min_band_hz=50,
)
)
)
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
self.norm1d.append(nn.InstanceNorm1d(80, affine=True))
self.conv1d.append(nn.Conv1d(80, 60, 5, stride=1))
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
self.conv1d.append(nn.Conv1d(60, 60, 5, stride=1))
self.pool1d.append(nn.MaxPool1d(3, stride=3, padding=0, dilation=1))
self.norm1d.append(nn.InstanceNorm1d(60, affine=True))
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Pass forward
Parameters
----------
waveforms : (batch, channel, sample)
"""
outputs = self.wav_norm1d(waveforms)
for c, (conv1d, pool1d, norm1d) in enumerate(
zip(self.conv1d, self.pool1d, self.norm1d)
):
outputs = conv1d(outputs)
# https://github.com/mravanelli/SincNet/issues/4
if c == 0:
outputs = torch.abs(outputs)
outputs = F.leaky_relu(norm1d(pool1d(outputs)))
return outputs
class XVectorSincNet(nn.Module):
SINCNET_DEFAULTS = {"stride": 10}
def __init__(
self,
sample_rate: int = 16000,
# num_channels: int = 1,
sincnet: dict = dict(
stride=10,
sample_rate=16000
),
dimension: int = 512,
# task: Optional[Task] = None,
):
super(XVectorSincNet, self).__init__()
sincnet = merge_dict(self.SINCNET_DEFAULTS, sincnet)
sincnet["sample_rate"] = sample_rate
# self.save_hyperparameters("sincnet", "dimension")
self.sincnet = SincNet(**sincnet)
in_channel = 60
self.tdnns = nn.ModuleList()
out_channels = [512, 512, 512, 512, 1500]
kernel_sizes = [5, 3, 3, 1, 1]
dilations = [1, 2, 3, 1, 1]
for out_channel, kernel_size, dilation in zip(
out_channels, kernel_sizes, dilations
):
self.tdnns.extend(
[
nn.Conv1d(
in_channels=in_channel,
out_channels=out_channel,
kernel_size=kernel_size,
dilation=dilation,
),
nn.LeakyReLU(),
nn.BatchNorm1d(out_channel),
]
)
in_channel = out_channel
self.stats_pool = StatsPool()
self.embedding = nn.Linear(in_channel * 2, dimension)
def forward(
self, waveforms: torch.Tensor, weights: torch.Tensor = None
) -> torch.Tensor:
"""
Parameters
----------
waveforms : torch.Tensor
Batch of waveforms with shape (batch, channel, sample)
weights : torch.Tensor, optional
Batch of weights with shape (batch, frame).
"""
outputs = self.sincnet(waveforms).squeeze(dim=1)
for tdnn in self.tdnns:
outputs = tdnn(outputs)
outputs = self.stats_pool(outputs, weights=weights)
return self.embedding(outputs)
""" Load model
def cal_xvector_sincnet_embedding(xvector_model, ref_wav, max_length=5, sr=16000):
wavs = []
for i in range(0, len(ref_wav), max_length*sr):
wav = ref_wav[i:i + max_length*sr]
wav = np.concatenate([wav, np.zeros(max(0, max_length * sr - len(wav)))])
wavs.append(wav)
wavs = torch.from_numpy(np.stack(wavs))
if use_gpu:
wavs = wavs.cuda()
embed = xvector_model(wavs.unsqueeze(1).float())
return torch.mean(embed, dim=0).detach().cpu()
xvector_model = XVectorSincNet()
model_file = "model-bin/speaker_embedding/xvector_sincnet.pt"
meta = torch.load(model_file, map_location='cpu')['state_dict']
print('load_xvector_sincnet_model', xvector_model.load_state_dict(meta, strict=False))
xvector_model = xvector_model.eval()
for param in xvector_model.parameters():
param.requires_grad = False
"""