import math import torch import typing as tp import torch.nn as nn import torch.nn.functional as F from transformers.utils import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from .helpers_svector import Fbank from .configuration_svector import SvectorConfig class InputNormalization(nn.Module): spk_dict_mean: tp.Dict[int, torch.Tensor] spk_dict_std: tp.Dict[int, torch.Tensor] spk_dict_count: tp.Dict[int, int] def __init__( self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3, ): super().__init__() self.mean_norm = mean_norm self.std_norm = std_norm self.norm_type = norm_type self.avg_factor = avg_factor self.requires_grad = requires_grad self.glob_mean = torch.tensor([0]) self.glob_std = torch.tensor([0]) self.spk_dict_mean = {} self.spk_dict_std = {} self.spk_dict_count = {} self.weight = 1.0 self.count = 0 self.eps = 1e-10 self.update_until_epoch = update_until_epoch def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): """Returns the tensor with the surrounding context. Arguments --------- x : tensor A batch of tensors. lengths : tensor A batch of tensors containing the relative length of each sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid computing stats on zero-padded steps. spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). It is used to perform per-speaker normalization when norm_type='speaker'. """ x = input_values N_batches = x.shape[0] current_means = [] current_stds = [] for snt_id in range(N_batches): # Avoiding padded time steps # lengths = torch.sum(attention_mask, dim=1) # relative_lengths = lengths / torch.max(lengths) # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() # computing statistics current_mean, current_std = self._compute_current_stats( x[snt_id, 0:actual_size, ...] ) current_means.append(current_mean) current_stds.append(current_std) if self.norm_type == "sentence": x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data if self.norm_type == "speaker": spk_id = int(spk_ids[snt_id][0]) if self.training: if spk_id not in self.spk_dict_mean: # Initialization of the dictionary self.spk_dict_mean[spk_id] = current_mean self.spk_dict_std[spk_id] = current_std self.spk_dict_count[spk_id] = 1 else: self.spk_dict_count[spk_id] = ( self.spk_dict_count[spk_id] + 1 ) if self.avg_factor is None: self.weight = 1 / self.spk_dict_count[spk_id] else: self.weight = self.avg_factor self.spk_dict_mean[spk_id] = ( (1 - self.weight) * self.spk_dict_mean[spk_id] + self.weight * current_mean ) self.spk_dict_std[spk_id] = ( (1 - self.weight) * self.spk_dict_std[spk_id] + self.weight * current_std ) self.spk_dict_mean[spk_id].detach() self.spk_dict_std[spk_id].detach() speaker_mean = self.spk_dict_mean[spk_id].data speaker_std = self.spk_dict_std[spk_id].data else: if spk_id in self.spk_dict_mean: speaker_mean = self.spk_dict_mean[spk_id].data speaker_std = self.spk_dict_std[spk_id].data else: speaker_mean = current_mean.data speaker_std = current_std.data x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std if self.norm_type == "batch" or self.norm_type == "global": current_mean = torch.mean(torch.stack(current_means), dim=0) current_std = torch.mean(torch.stack(current_stds), dim=0) if self.norm_type == "batch": x = (x - current_mean.data) / (current_std.data) if self.norm_type == "global": if self.training: if self.count == 0: self.glob_mean = current_mean self.glob_std = current_std elif epoch < self.update_until_epoch: if self.avg_factor is None: self.weight = 1 / (self.count + 1) else: self.weight = self.avg_factor self.glob_mean = ( 1 - self.weight ) * self.glob_mean + self.weight * current_mean self.glob_std = ( 1 - self.weight ) * self.glob_std + self.weight * current_std self.glob_mean.detach() self.glob_std.detach() self.count = self.count + 1 x = (x - self.glob_mean.data) / (self.glob_std.data) return x def _compute_current_stats(self, x): """Returns the tensor with the surrounding context. Arguments --------- x : tensor A batch of tensors. """ # Compute current mean if self.mean_norm: current_mean = torch.mean(x, dim=0).detach().data else: current_mean = torch.tensor([0.0], device=x.device) # Compute current std if self.std_norm: current_std = torch.std(x, dim=0).detach().data else: current_std = torch.tensor([1.0], device=x.device) # Improving numerical stability of std current_std = torch.max( current_std, self.eps * torch.ones_like(current_std) ) return current_mean, current_std def _statistics_dict(self): """Fills the dictionary containing the normalization statistics.""" state = {} state["count"] = self.count state["glob_mean"] = self.glob_mean state["glob_std"] = self.glob_std state["spk_dict_mean"] = self.spk_dict_mean state["spk_dict_std"] = self.spk_dict_std state["spk_dict_count"] = self.spk_dict_count return state def _load_statistics_dict(self, state): """Loads the dictionary containing the statistics. Arguments --------- state : dict A dictionary containing the normalization statistics. """ self.count = state["count"] if isinstance(state["glob_mean"], int): self.glob_mean = state["glob_mean"] self.glob_std = state["glob_std"] else: self.glob_mean = state["glob_mean"] # .to(self.device_inp) self.glob_std = state["glob_std"] # .to(self.device_inp) # Loading the spk_dict_mean in the right device self.spk_dict_mean = {} for spk in state["spk_dict_mean"]: self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( self.device_inp ) # Loading the spk_dict_std in the right device self.spk_dict_std = {} for spk in state["spk_dict_std"]: self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( self.device_inp ) self.spk_dict_count = state["spk_dict_count"] return state def to(self, device): """Puts the needed tensors in the right device.""" self = super(InputNormalization, self).to(device) self.glob_mean = self.glob_mean.to(device) self.glob_std = self.glob_std.to(device) for spk in self.spk_dict_mean: self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) return self class TdnnLayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, dilation=1, stride=1, padding=0, padding_mode="reflect", activation=torch.nn.LeakyReLU, ): super(TdnnLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.dilation = dilation self.stride = stride self.padding = padding self.padding_mode = padding_mode self.activation = activation self.conv = nn.Conv1d( self.in_channels, self.out_channels, self.kernel_size, dilation=self.dilation, padding=self.padding ) # Set Affine=false to be compatible with the original kaldi version # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False) self.norm = nn.BatchNorm1d(out_channels, affine=False) def forward(self, x): x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) out = self.conv(x) out = self.activation()(out) out = self.norm(out) return out def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): # Detecting input shape L_in = self.in_channels # Time padding padding = get_padding_elem(L_in, stride, kernel_size, dilation) # Applying padding x = F.pad(x, padding, mode=self.padding_mode) return x def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): """This function computes the number of elements to add for zero-padding. Arguments --------- L_in : int stride: int kernel_size : int dilation : int """ if stride > 1: padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] else: L_out = ( math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 ) padding = [ math.floor((L_in - L_out) / 2), math.floor((L_in - L_out) / 2), ] return padding class StatisticsPooling(nn.Module): def __init__(self, return_mean=True, return_std=True): super().__init__() # Small value for GaussNoise self.eps = 1e-5 self.return_mean = return_mean self.return_std = return_std if not (self.return_mean or self.return_std): raise ValueError( "both of statistics are equal to False \n" "consider enabling mean and/or std statistic pooling" ) def forward(self, input_values, lengths=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor It represents a tensor for a mini-batch. """ x = input_values if lengths is None: if self.return_mean: mean = x.mean(dim=1) if self.return_std: std = x.std(dim=1) else: mean = [] std = [] for snt_id in range(x.shape[0]): # Avoiding padded time steps # lengths = torch.sum(attention_mask, dim=1) # relative_lengths = lengths / torch.max(lengths) # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) # computing statistics if self.return_mean: mean.append( torch.mean(x[snt_id, 0:actual_size, ...], dim=0) ) if self.return_std: std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) if self.return_mean: mean = torch.stack(mean) if self.return_std: std = torch.stack(std) if self.return_mean: gnoise = self._get_gauss_noise(mean.size(), device=mean.device) gnoise = gnoise mean += gnoise if self.return_std: std = std + self.eps # Append mean and std of the batch if self.return_mean and self.return_std: pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(1) elif self.return_mean: pooled_stats = mean.unsqueeze(1) elif self.return_std: pooled_stats = std.unsqueeze(1) return pooled_stats def _get_gauss_noise(self, shape_of_tensor, device="cpu"): """Returns a tensor of epsilon Gaussian noise. Arguments --------- shape_of_tensor : tensor It represents the size of tensor for generating Gaussian noise. """ gnoise = torch.randn(shape_of_tensor, device=device) gnoise -= torch.min(gnoise) gnoise /= torch.max(gnoise) gnoise = self.eps * ((1 - 9) * gnoise + 9) return gnoise class SvectorEmbedder(nn.Module): def __init__( self, in_channels=40, num_heads=8, num_layers=5, activation=torch.nn.LeakyReLU, hidden_size=512, ) -> None: super(SvectorEmbedder, self).__init__() self.tdnn = TdnnLayer( in_channels=in_channels, out_channels=hidden_size, kernel_size=1, dilation=1, activation=activation, ) encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.pooler = StatisticsPooling() self.fc = nn.Linear(2 * hidden_size, hidden_size) def forward(self, input_values, lengths=None): """ x: [B, T, F] """ x = input_values x = self.tdnn(x.transpose(1, 2)) last_hidden_state = self.transformer_encoder(x.transpose(1, 2)) pooler_output = self.pooler(last_hidden_state, lengths) pooler_output = self.fc(pooler_output.squeeze(1)) return ModelOutput( last_hidden_state=last_hidden_state, pooler_output=pooler_output ) class CosineSimilarityHead(torch.nn.Module): """ This class implements the cosine similarity on the top of features. """ def __init__( self, in_channels, lin_blocks=0, hidden_size=192, num_classes=1211, ): super().__init__() self.blocks = nn.ModuleList() for block_index in range(lin_blocks): self.blocks.extend( [ nn.BatchNorm1d(num_features=in_channels), nn.Linear(in_features=in_channels, out_features=hidden_size), ] ) in_channels = hidden_size # Final Layer self.weight = nn.Parameter( torch.FloatTensor(num_classes, in_channels) ) nn.init.xavier_uniform_(self.weight) def forward(self, x): """Returns the output probabilities over speakers. Arguments --------- x : torch.Tensor Torch tensor. """ for layer in self.blocks: x = layer(x) # Need to be normalized x = F.linear(F.normalize(x), F.normalize(self.weight)) return x class SvectorPreTrainedModel(PreTrainedModel): config_class = SvectorConfig base_model_prefix = "svector" main_input_name = "input_values" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight.data) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() class SvectorModel(SvectorPreTrainedModel): def __init__(self, config): super().__init__(config) self.compute_features = Fbank( n_mels=config.n_mels, sample_rate=config.sample_rate, win_length=config.win_length, hop_length=config.hop_length, ) self.mean_var_norm = InputNormalization( mean_norm=config.mean_norm, std_norm=config.std_norm, norm_type=config.norm_type ) self.embedding_model = SvectorEmbedder( in_channels=config.n_mels, activation=nn.LeakyReLU, num_heads=config.num_heads, num_layers=config.num_layers, hidden_size=config.hidden_size, ) def forward(self, input_values, lengths=None): x = input_values x = self.compute_features(x) x = self.mean_var_norm(x, lengths) output = self.embedding_model(x, lengths) return output