tdnn-aam / modeling_xvector.py
yangwang825's picture
Create modeling_xvector.py
7f6d802
import math
import torch
import typing as tp
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from .helpers_xvector import Fbank
from .configuration_xvector import XvectorConfig
class InputNormalization(nn.Module):
spk_dict_mean: tp.Dict[int, torch.Tensor]
spk_dict_std: tp.Dict[int, torch.Tensor]
spk_dict_count: tp.Dict[int, int]
def __init__(
self,
mean_norm=True,
std_norm=True,
norm_type="global",
avg_factor=None,
requires_grad=False,
update_until_epoch=3,
):
super().__init__()
self.mean_norm = mean_norm
self.std_norm = std_norm
self.norm_type = norm_type
self.avg_factor = avg_factor
self.requires_grad = requires_grad
self.glob_mean = torch.tensor([0])
self.glob_std = torch.tensor([0])
self.spk_dict_mean = {}
self.spk_dict_std = {}
self.spk_dict_count = {}
self.weight = 1.0
self.count = 0
self.eps = 1e-10
self.update_until_epoch = update_until_epoch
def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
"""Returns the tensor with the surrounding context.
Arguments
---------
x : tensor
A batch of tensors.
lengths : tensor
A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
norm_type='speaker'.
"""
x = input_values
N_batches = x.shape[0]
current_means = []
current_stds = []
for snt_id in range(N_batches):
# Avoiding padded time steps
# lengths = torch.sum(attention_mask, dim=1)
# relative_lengths = lengths / torch.max(lengths)
# actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
# computing statistics
current_mean, current_std = self._compute_current_stats(
x[snt_id, 0:actual_size, ...]
)
current_means.append(current_mean)
current_stds.append(current_std)
if self.norm_type == "sentence":
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
if self.norm_type == "speaker":
spk_id = int(spk_ids[snt_id][0])
if self.training:
if spk_id not in self.spk_dict_mean:
# Initialization of the dictionary
self.spk_dict_mean[spk_id] = current_mean
self.spk_dict_std[spk_id] = current_std
self.spk_dict_count[spk_id] = 1
else:
self.spk_dict_count[spk_id] = (
self.spk_dict_count[spk_id] + 1
)
if self.avg_factor is None:
self.weight = 1 / self.spk_dict_count[spk_id]
else:
self.weight = self.avg_factor
self.spk_dict_mean[spk_id] = (
(1 - self.weight) * self.spk_dict_mean[spk_id]
+ self.weight * current_mean
)
self.spk_dict_std[spk_id] = (
(1 - self.weight) * self.spk_dict_std[spk_id]
+ self.weight * current_std
)
self.spk_dict_mean[spk_id].detach()
self.spk_dict_std[spk_id].detach()
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
if spk_id in self.spk_dict_mean:
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
speaker_mean = current_mean.data
speaker_std = current_std.data
x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
if self.norm_type == "batch" or self.norm_type == "global":
current_mean = torch.mean(torch.stack(current_means), dim=0)
current_std = torch.mean(torch.stack(current_stds), dim=0)
if self.norm_type == "batch":
x = (x - current_mean.data) / (current_std.data)
if self.norm_type == "global":
if self.training:
if self.count == 0:
self.glob_mean = current_mean
self.glob_std = current_std
elif epoch < self.update_until_epoch:
if self.avg_factor is None:
self.weight = 1 / (self.count + 1)
else:
self.weight = self.avg_factor
self.glob_mean = (
1 - self.weight
) * self.glob_mean + self.weight * current_mean
self.glob_std = (
1 - self.weight
) * self.glob_std + self.weight * current_std
self.glob_mean.detach()
self.glob_std.detach()
self.count = self.count + 1
x = (x - self.glob_mean.data) / (self.glob_std.data)
return x
def _compute_current_stats(self, x):
"""Returns the tensor with the surrounding context.
Arguments
---------
x : tensor
A batch of tensors.
"""
# Compute current mean
if self.mean_norm:
current_mean = torch.mean(x, dim=0).detach().data
else:
current_mean = torch.tensor([0.0], device=x.device)
# Compute current std
if self.std_norm:
current_std = torch.std(x, dim=0).detach().data
else:
current_std = torch.tensor([1.0], device=x.device)
# Improving numerical stability of std
current_std = torch.max(
current_std, self.eps * torch.ones_like(current_std)
)
return current_mean, current_std
def _statistics_dict(self):
"""Fills the dictionary containing the normalization statistics."""
state = {}
state["count"] = self.count
state["glob_mean"] = self.glob_mean
state["glob_std"] = self.glob_std
state["spk_dict_mean"] = self.spk_dict_mean
state["spk_dict_std"] = self.spk_dict_std
state["spk_dict_count"] = self.spk_dict_count
return state
def _load_statistics_dict(self, state):
"""Loads the dictionary containing the statistics.
Arguments
---------
state : dict
A dictionary containing the normalization statistics.
"""
self.count = state["count"]
if isinstance(state["glob_mean"], int):
self.glob_mean = state["glob_mean"]
self.glob_std = state["glob_std"]
else:
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
self.glob_std = state["glob_std"] # .to(self.device_inp)
# Loading the spk_dict_mean in the right device
self.spk_dict_mean = {}
for spk in state["spk_dict_mean"]:
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
self.device_inp
)
# Loading the spk_dict_std in the right device
self.spk_dict_std = {}
for spk in state["spk_dict_std"]:
self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
self.device_inp
)
self.spk_dict_count = state["spk_dict_count"]
return state
def to(self, device):
"""Puts the needed tensors in the right device."""
self = super(InputNormalization, self).to(device)
self.glob_mean = self.glob_mean.to(device)
self.glob_std = self.glob_std.to(device)
for spk in self.spk_dict_mean:
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
return self
class TdnnLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
padding=0,
padding_mode="reflect",
activation=torch.nn.LeakyReLU,
):
super(TdnnLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
self.activation = activation
self.conv = nn.Conv1d(
self.in_channels,
self.out_channels,
self.kernel_size,
dilation=self.dilation,
padding=self.padding
)
# Set Affine=false to be compatible with the original kaldi version
# self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
self.norm = nn.BatchNorm1d(out_channels, affine=False)
def forward(self, x):
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
out = self.conv(x)
out = self.activation()(out)
out = self.norm(out)
return out
def _manage_padding(
self, x, kernel_size: int, dilation: int, stride: int,
):
# Detecting input shape
L_in = self.in_channels
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
"""
if stride > 1:
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
else:
L_out = (
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
)
padding = [
math.floor((L_in - L_out) / 2),
math.floor((L_in - L_out) / 2),
]
return padding
class StatisticsPooling(nn.Module):
def __init__(self, return_mean=True, return_std=True):
super().__init__()
# Small value for GaussNoise
self.eps = 1e-5
self.return_mean = return_mean
self.return_std = return_std
if not (self.return_mean or self.return_std):
raise ValueError(
"both of statistics are equal to False \n"
"consider enabling mean and/or std statistic pooling"
)
def forward(self, input_values, lengths=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
It represents a tensor for a mini-batch.
"""
x = input_values
if lengths is None:
if self.return_mean:
mean = x.mean(dim=1)
if self.return_std:
std = x.std(dim=1)
else:
mean = []
std = []
for snt_id in range(x.shape[0]):
# Avoiding padded time steps
# lengths = torch.sum(attention_mask, dim=1)
# relative_lengths = lengths / torch.max(lengths)
# actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
# computing statistics
if self.return_mean:
mean.append(
torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
)
if self.return_std:
std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
if self.return_mean:
mean = torch.stack(mean)
if self.return_std:
std = torch.stack(std)
if self.return_mean:
gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
gnoise = gnoise
mean += gnoise
if self.return_std:
std = std + self.eps
# Append mean and std of the batch
if self.return_mean and self.return_std:
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(1)
elif self.return_mean:
pooled_stats = mean.unsqueeze(1)
elif self.return_std:
pooled_stats = std.unsqueeze(1)
return pooled_stats
def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
"""Returns a tensor of epsilon Gaussian noise.
Arguments
---------
shape_of_tensor : tensor
It represents the size of tensor for generating Gaussian noise.
"""
gnoise = torch.randn(shape_of_tensor, device=device)
gnoise -= torch.min(gnoise)
gnoise /= torch.max(gnoise)
gnoise = self.eps * ((1 - 9) * gnoise + 9)
return gnoise
class XvectorEmbedder(nn.Module):
def __init__(
self,
in_channels=40,
activation=torch.nn.LeakyReLU,
tdnn_blocks=5,
tdnn_channels=[512, 512, 512, 512, 1500],
tdnn_kernel_sizes=[5, 3, 3, 1, 1],
tdnn_dilations=[1, 2, 3, 1, 1],
hidden_size=512,
) -> None:
super(XvectorEmbedder, self).__init__()
self.activation = activation
self.blocks = nn.ModuleList()
for block_index in range(tdnn_blocks):
out_channels = tdnn_channels[block_index]
tdnn = TdnnLayer(
in_channels,
out_channels,
kernel_size=tdnn_kernel_sizes[block_index],
dilation=tdnn_dilations[block_index],
activation=activation,
)
self.blocks.append(tdnn)
in_channels = tdnn_channels[block_index]
self.pooler = StatisticsPooling()
self.fc = nn.Linear(2 * out_channels, hidden_size)
def forward(self, input_values, lengths=None):
x = input_values
x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T)
for block in self.blocks:
x = block(x)
last_hidden_state = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
pooler_output = self.pooler(last_hidden_state, lengths)
pooler_output = self.fc(pooler_output.squeeze(1))
return ModelOutput(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output
)
class CosineSimilarityHead(torch.nn.Module):
"""
This class implements the cosine similarity on the top of features.
"""
def __init__(
self,
in_channels,
lin_blocks=0,
hidden_size=192,
num_classes=1211,
):
super().__init__()
self.blocks = nn.ModuleList()
for block_index in range(lin_blocks):
self.blocks.extend(
[
nn.BatchNorm1d(num_features=in_channels),
nn.Linear(in_features=in_channels, out_features=hidden_size),
]
)
in_channels = hidden_size
# Final Layer
self.weight = nn.Parameter(
torch.FloatTensor(num_classes, in_channels)
)
nn.init.xavier_uniform_(self.weight)
def forward(self, x):
"""Returns the output probabilities over speakers.
Arguments
---------
x : torch.Tensor
Torch tensor.
"""
for layer in self.blocks:
x = layer(x)
# Need to be normalized
x = F.linear(F.normalize(x), F.normalize(self.weight))
return x
class XvectorPreTrainedModel(PreTrainedModel):
config_class = XvectorConfig
base_model_prefix = "xvector"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
class XvectorModel(XvectorPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.compute_features = Fbank(
n_mels=config.n_mels,
sample_rate=config.sample_rate,
win_length=config.win_length,
hop_length=config.hop_length,
)
self.mean_var_norm = InputNormalization(
mean_norm=config.mean_norm,
std_norm=config.std_norm,
norm_type=config.norm_type
)
self.embedding_model = XvectorEmbedder(
in_channels=config.n_mels,
activation=nn.LeakyReLU,
tdnn_blocks=config.tdnn_blocks,
tdnn_channels=config.tdnn_channels,
tdnn_kernel_sizes=config.tdnn_kernel_sizes,
tdnn_dilations=config.tdnn_dilations,
hidden_size=config.hidden_size,
)
def forward(self, input_values, lengths=None):
x = input_values
# if attention_mask is None:
# attention_mask = torch.ones_like(input_values, device=x.device)
x = self.compute_features(x)
x = self.mean_var_norm(x, lengths)
output = self.embedding_model(x, lengths)
return output