import math import torch import typing as tp import torch.nn as nn import torch.nn.functional as F from transformers.utils import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from .helpers_ecapa import Fbank from .configuration_ecapa import EcapaConfig class InputNormalization(nn.Module): spk_dict_mean: tp.Dict[int, torch.Tensor] spk_dict_std: tp.Dict[int, torch.Tensor] spk_dict_count: tp.Dict[int, int] def __init__( self, mean_norm=True, std_norm=True, norm_type="global", avg_factor=None, requires_grad=False, update_until_epoch=3, ): super().__init__() self.mean_norm = mean_norm self.std_norm = std_norm self.norm_type = norm_type self.avg_factor = avg_factor self.requires_grad = requires_grad self.glob_mean = torch.tensor([0]) self.glob_std = torch.tensor([0]) self.spk_dict_mean = {} self.spk_dict_std = {} self.spk_dict_count = {} self.weight = 1.0 self.count = 0 self.eps = 1e-10 self.update_until_epoch = update_until_epoch def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): """Returns the tensor with the surrounding context. Arguments --------- x : tensor A batch of tensors. lengths : tensor A batch of tensors containing the relative length of each sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid computing stats on zero-padded steps. spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). It is used to perform per-speaker normalization when norm_type='speaker'. """ x = input_values N_batches = x.shape[0] current_means = [] current_stds = [] for snt_id in range(N_batches): # Avoiding padded time steps # lengths = torch.sum(attention_mask, dim=1) # relative_lengths = lengths / torch.max(lengths) # actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() # computing statistics current_mean, current_std = self._compute_current_stats( x[snt_id, 0:actual_size, ...] ) current_means.append(current_mean) current_stds.append(current_std) if self.norm_type == "sentence": x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data if self.norm_type == "speaker": spk_id = int(spk_ids[snt_id][0]) if self.training: if spk_id not in self.spk_dict_mean: # Initialization of the dictionary self.spk_dict_mean[spk_id] = current_mean self.spk_dict_std[spk_id] = current_std self.spk_dict_count[spk_id] = 1 else: self.spk_dict_count[spk_id] = ( self.spk_dict_count[spk_id] + 1 ) if self.avg_factor is None: self.weight = 1 / self.spk_dict_count[spk_id] else: self.weight = self.avg_factor self.spk_dict_mean[spk_id] = ( (1 - self.weight) * self.spk_dict_mean[spk_id] + self.weight * current_mean ) self.spk_dict_std[spk_id] = ( (1 - self.weight) * self.spk_dict_std[spk_id] + self.weight * current_std ) self.spk_dict_mean[spk_id].detach() self.spk_dict_std[spk_id].detach() speaker_mean = self.spk_dict_mean[spk_id].data speaker_std = self.spk_dict_std[spk_id].data else: if spk_id in self.spk_dict_mean: speaker_mean = self.spk_dict_mean[spk_id].data speaker_std = self.spk_dict_std[spk_id].data else: speaker_mean = current_mean.data speaker_std = current_std.data x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std if self.norm_type == "batch" or self.norm_type == "global": current_mean = torch.mean(torch.stack(current_means), dim=0) current_std = torch.mean(torch.stack(current_stds), dim=0) if self.norm_type == "batch": x = (x - current_mean.data) / (current_std.data) if self.norm_type == "global": if self.training: if self.count == 0: self.glob_mean = current_mean self.glob_std = current_std elif epoch < self.update_until_epoch: if self.avg_factor is None: self.weight = 1 / (self.count + 1) else: self.weight = self.avg_factor self.glob_mean = ( 1 - self.weight ) * self.glob_mean + self.weight * current_mean self.glob_std = ( 1 - self.weight ) * self.glob_std + self.weight * current_std self.glob_mean.detach() self.glob_std.detach() self.count = self.count + 1 x = (x - self.glob_mean.data) / (self.glob_std.data) return x def _compute_current_stats(self, x): """Returns the tensor with the surrounding context. Arguments --------- x : tensor A batch of tensors. """ # Compute current mean if self.mean_norm: current_mean = torch.mean(x, dim=0).detach().data else: current_mean = torch.tensor([0.0], device=x.device) # Compute current std if self.std_norm: current_std = torch.std(x, dim=0).detach().data else: current_std = torch.tensor([1.0], device=x.device) # Improving numerical stability of std current_std = torch.max( current_std, self.eps * torch.ones_like(current_std) ) return current_mean, current_std def _statistics_dict(self): """Fills the dictionary containing the normalization statistics.""" state = {} state["count"] = self.count state["glob_mean"] = self.glob_mean state["glob_std"] = self.glob_std state["spk_dict_mean"] = self.spk_dict_mean state["spk_dict_std"] = self.spk_dict_std state["spk_dict_count"] = self.spk_dict_count return state def _load_statistics_dict(self, state): """Loads the dictionary containing the statistics. Arguments --------- state : dict A dictionary containing the normalization statistics. """ self.count = state["count"] if isinstance(state["glob_mean"], int): self.glob_mean = state["glob_mean"] self.glob_std = state["glob_std"] else: self.glob_mean = state["glob_mean"] # .to(self.device_inp) self.glob_std = state["glob_std"] # .to(self.device_inp) # Loading the spk_dict_mean in the right device self.spk_dict_mean = {} for spk in state["spk_dict_mean"]: self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( self.device_inp ) # Loading the spk_dict_std in the right device self.spk_dict_std = {} for spk in state["spk_dict_std"]: self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( self.device_inp ) self.spk_dict_count = state["spk_dict_count"] return state def to(self, device): """Puts the needed tensors in the right device.""" self = super(InputNormalization, self).to(device) self.glob_mean = self.glob_mean.to(device) self.glob_std = self.glob_std.to(device) for spk in self.spk_dict_mean: self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) return self class TdnnLayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1, padding=0, padding_mode="reflect", activation=torch.nn.LeakyReLU, ): super(TdnnLayer, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.dilation = dilation self.stride = stride self.groups = groups self.padding = padding self.padding_mode = padding_mode self.activation = activation() self.conv = nn.Conv1d( self.in_channels, self.out_channels, self.kernel_size, dilation=self.dilation, padding=self.padding, groups=self.groups ) # Set Affine=false to be compatible with the original kaldi version # self.ln = nn.LayerNorm(out_channels, elementwise_affine=False) self.norm = nn.BatchNorm1d(out_channels, affine=False) def forward(self, x): x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) out = self.conv(x) out = self.activation(out) out = self.norm(out) return out def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): # Detecting input shape L_in = self.in_channels # Time padding padding = get_padding_elem(L_in, stride, kernel_size, dilation) # Applying padding x = F.pad(x, padding, mode=self.padding_mode) return x def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): """This function computes the number of elements to add for zero-padding. Arguments --------- L_in : int stride: int kernel_size : int dilation : int """ if stride > 1: padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] else: L_out = ( math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 ) padding = [ math.floor((L_in - L_out) / 2), math.floor((L_in - L_out) / 2), ] return padding class Res2NetBlock(torch.nn.Module): """An implementation of Res2NetBlock w/ dilation. Arguments --------- in_channels : int The number of channels expected in the input. out_channels : int The number of output channels. scale : int The scale of the Res2Net block. kernel_size: int The kernel size of the Res2Net block. dilation : int The dilation of the Res2Net block. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) >>> out_tensor = layer(inp_tensor).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 120, 64]) """ def __init__( self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 ): super(Res2NetBlock, self).__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 in_channel = in_channels // scale hidden_channel = out_channels // scale self.blocks = nn.ModuleList( [ TdnnLayer( in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, ) for _ in range(scale - 1) ] ) self.scale = scale def forward(self, x): """Processes the input tensor x and returns an output tensor.""" y = [] for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): if i == 0: y_i = x_i elif i == 1: y_i = self.blocks[i - 1](x_i) else: y_i = self.blocks[i - 1](x_i + y_i) y.append(y_i) y = torch.cat(y, dim=1) return y class SEBlock(nn.Module): """An implementation of squeeze-and-excitation block. Arguments --------- in_channels : int The number of input channels. se_channels : int The number of output channels after squeeze. out_channels : int The number of output channels. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> se_layer = SEBlock(64, 16, 64) >>> lengths = torch.rand((8,)) >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 120, 64]) """ def __init__(self, in_channels, se_channels, out_channels): super(SEBlock, self).__init__() self.conv1 = nn.Conv1d( in_channels=in_channels, out_channels=se_channels, kernel_size=1 ) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = nn.Conv1d( in_channels=se_channels, out_channels=out_channels, kernel_size=1 ) self.sigmoid = torch.nn.Sigmoid() def forward(self, x, lengths=None): """Processes the input tensor x and returns an output tensor.""" L = x.shape[-1] if lengths is not None: mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) total = mask.sum(dim=2, keepdim=True) s = (x * mask).sum(dim=2, keepdim=True) / total else: s = x.mean(dim=2, keepdim=True) s = self.relu(self.conv1(s)) s = self.sigmoid(self.conv2(s)) return s * x def length_to_mask(length, max_len=None, dtype=None, device=None): """Creates a binary mask for each sequence. Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 Arguments --------- length : torch.LongTensor Containing the length of each sequence in the batch. Must be 1D. max_len : int Max length for the mask, also the size of the second dimension. dtype : torch.dtype, default: None The dtype of the generated mask. device: torch.device, default: None The device to put the mask variable. Returns ------- mask : tensor The binary mask. Example ------- >>> length=torch.Tensor([1,2,3]) >>> mask=length_to_mask(length) >>> mask tensor([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]) """ assert len(length.shape) == 1 if max_len is None: max_len = length.max().long().item() # using arange to generate mask mask = torch.arange( max_len, device=length.device, dtype=length.dtype ).expand(len(length), max_len) < length.unsqueeze(1) if dtype is None: dtype = length.dtype if device is None: device = length.device mask = torch.as_tensor(mask, dtype=dtype, device=device) return mask class AttentiveStatisticsPooling(nn.Module): """This class implements an attentive statistic pooling layer for each channel. It returns the concatenated mean and std of the input tensor. Arguments --------- channels: int The number of input channels. attention_channels: int The number of attention channels. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> asp_layer = AttentiveStatisticsPooling(64) >>> lengths = torch.rand((8,)) >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 1, 128]) """ def __init__(self, channels, attention_channels=128, global_context=True): super().__init__() self.eps = 1e-12 self.global_context = global_context if global_context: self.tdnn = TdnnLayer(channels * 3, attention_channels, 1, 1) else: self.tdnn = TdnnLayer(channels, attention_channels, 1, 1) self.tanh = nn.Tanh() self.conv = nn.Conv1d( in_channels=attention_channels, out_channels=channels, kernel_size=1 ) def forward(self, x, lengths=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor Tensor of shape [N, C, L]. """ L = x.shape[-1] def _compute_statistics(x, m, dim=2, eps=self.eps): mean = (m * x).sum(dim) std = torch.sqrt( (m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) ) return mean, std if lengths is None: lengths = torch.ones(x.shape[0], device=x.device) # Make binary mask of shape [N, 1, L] mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) # Expand the temporal context of the pooling layer by allowing the # self-attention to look at global properties of the utterance. if self.global_context: # torch.std is unstable for backward computation # https://github.com/pytorch/pytorch/issues/4320 total = mask.sum(dim=2, keepdim=True).float() mean, std = _compute_statistics(x, mask / total) mean = mean.unsqueeze(2).repeat(1, 1, L) std = std.unsqueeze(2).repeat(1, 1, L) attn = torch.cat([x, mean, std], dim=1) else: attn = x # Apply layers attn = self.conv(self.tanh(self.tdnn(attn))) # Filter out zero-paddings attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=2) mean, std = _compute_statistics(x, attn) # Append mean and std of the batch pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(2) return pooled_stats class SERes2NetBlock(nn.Module): """An implementation of building block in ECAPA-TDNN, i.e., TDNN-Res2Net-TDNN-SEBlock. Arguments ---------- out_channels: int The number of output channels. res2net_scale: int The scale of the Res2Net block. kernel_size: int The kernel size of the TDNN blocks. dilation: int The dilation of the Res2Net block. activation : torch class A class for constructing the activation layers. groups: int Number of blocked connections from input channels to output channels. Example ------- >>> x = torch.rand(8, 120, 64).transpose(1, 2) >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) >>> out = conv(x).transpose(1, 2) >>> out.shape torch.Size([8, 120, 64]) """ def __init__( self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, groups=1, ): super().__init__() self.out_channels = out_channels self.tdnn1 = TdnnLayer( in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.res2net_block = Res2NetBlock( out_channels, out_channels, res2net_scale, kernel_size, dilation ) self.tdnn2 = TdnnLayer( out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, groups=groups, ) self.se_block = SEBlock(out_channels, se_channels, out_channels) self.shortcut = None if in_channels != out_channels: self.shortcut = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ) def forward(self, x, lengths=None): """Processes the input tensor x and returns an output tensor.""" residual = x if self.shortcut: residual = self.shortcut(x) x = self.tdnn1(x) x = self.res2net_block(x) x = self.tdnn2(x) x = self.se_block(x, lengths) return x + residual class EcapaEmbedder(nn.Module): def __init__( self, in_channels=80, hidden_size=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, groups=[1, 1, 1, 1, 1], ) -> None: super(EcapaEmbedder, self).__init__() self.channels = channels self.blocks = nn.ModuleList() # The initial TDNN layer self.blocks.append( TdnnLayer( in_channels, channels[0], kernel_sizes[0], dilations[0], activation=activation, groups=groups[0], ) ) # SE-Res2Net layers for i in range(1, len(channels) - 1): self.blocks.append( SERes2NetBlock( channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, groups=groups[i], ) ) # Multi-layer feature aggregation self.mfa = TdnnLayer( channels[-2] * (len(channels) - 2), channels[-1], kernel_sizes[-1], dilations[-1], activation=activation, groups=groups[-1], ) # Attentive Statistical Pooling self.asp = AttentiveStatisticsPooling( channels[-1], attention_channels=attention_channels, global_context=global_context, ) self.asp_bn = nn.BatchNorm1d(channels[-1] * 2) # Final linear transformation self.fc = nn.Conv1d( in_channels=channels[-1] * 2, out_channels=hidden_size, kernel_size=1, ) def forward(self, input_values, lengths=None): # Minimize transpose for efficiency x = input_values.transpose(1, 2) # lengths = torch.sum(attention_mask, dim=1) # lengths = lengths / torch.max(lengths) xl = [] for layer in self.blocks: try: x = layer(x, lengths) except TypeError: x = layer(x) xl.append(x) # Multi-layer feature aggregation x = torch.cat(xl[1:], dim=1) x = self.mfa(x) # Attentive Statistical Pooling x = self.asp(x, lengths) x = self.asp_bn(x) # Final linear transformation x = self.fc(x) pooler_output = x.transpose(1, 2) pooler_output = pooler_output.squeeze(1) return ModelOutput( # last_hidden_state=last_hidden_state, pooler_output=pooler_output ) class CosineSimilarityHead(torch.nn.Module): """ This class implements the cosine similarity on the top of features. """ def __init__( self, in_channels, lin_blocks=0, hidden_size=192, num_classes=1211, ): super().__init__() self.blocks = nn.ModuleList() for block_index in range(lin_blocks): self.blocks.extend( [ nn.BatchNorm1d(num_features=in_channels), nn.Linear(in_features=in_channels, out_features=hidden_size), ] ) in_channels = hidden_size # Final Layer self.weight = nn.Parameter( torch.FloatTensor(num_classes, in_channels) ) nn.init.xavier_uniform_(self.weight) def forward(self, x): """Returns the output probabilities over speakers. Arguments --------- x : torch.Tensor Torch tensor. """ for layer in self.blocks: x = layer(x) # Need to be normalized x = F.linear(F.normalize(x), F.normalize(self.weight)) return x class EcapaPreTrainedModel(PreTrainedModel): config_class = EcapaConfig base_model_prefix = "ecapa" main_input_name = "input_values" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight.data) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() class EcapaModel(EcapaPreTrainedModel): def __init__(self, config): super().__init__(config) self.compute_features = Fbank( n_mels=config.n_mels, sample_rate=config.sample_rate, win_length=config.win_length, hop_length=config.hop_length, ) self.mean_var_norm = InputNormalization( mean_norm=config.mean_norm, std_norm=config.std_norm, norm_type=config.norm_type ) self.embedding_model = EcapaEmbedder( in_channels=config.n_mels, channels=config.channels, kernel_sizes=config.kernel_sizes, dilations=config.dilations, attention_channels=config.attention_channels, res2net_scale=config.res2net_scale, se_channels=config.se_channels, global_context=config.global_context, groups=config.groups, hidden_size=config.hidden_size ) def forward(self, input_values, lengths=None): x = input_values # if attention_mask is None: # attention_mask = torch.ones_like(input_values, device=x.device) x = self.compute_features(x) x = self.mean_var_norm(x, lengths) output = self.embedding_model(x, lengths) return ModelOutput( pooler_output=output.pooler_output, )