ecapa-aam / modeling_ecapa.py
yangwang825's picture
Create modeling_ecapa.py
5920d4c
import math
import torch
import typing as tp
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from .helpers_ecapa import Fbank
from .configuration_ecapa import EcapaConfig
class InputNormalization(nn.Module):
spk_dict_mean: tp.Dict[int, torch.Tensor]
spk_dict_std: tp.Dict[int, torch.Tensor]
spk_dict_count: tp.Dict[int, int]
def __init__(
self,
mean_norm=True,
std_norm=True,
norm_type="global",
avg_factor=None,
requires_grad=False,
update_until_epoch=3,
):
super().__init__()
self.mean_norm = mean_norm
self.std_norm = std_norm
self.norm_type = norm_type
self.avg_factor = avg_factor
self.requires_grad = requires_grad
self.glob_mean = torch.tensor([0])
self.glob_std = torch.tensor([0])
self.spk_dict_mean = {}
self.spk_dict_std = {}
self.spk_dict_count = {}
self.weight = 1.0
self.count = 0
self.eps = 1e-10
self.update_until_epoch = update_until_epoch
def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0):
"""Returns the tensor with the surrounding context.
Arguments
---------
x : tensor
A batch of tensors.
lengths : tensor
A batch of tensors containing the relative length of each
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid
computing stats on zero-padded steps.
spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]).
It is used to perform per-speaker normalization when
norm_type='speaker'.
"""
x = input_values
N_batches = x.shape[0]
current_means = []
current_stds = []
for snt_id in range(N_batches):
# Avoiding padded time steps
# lengths = torch.sum(attention_mask, dim=1)
# relative_lengths = lengths / torch.max(lengths)
# actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int()
# computing statistics
current_mean, current_std = self._compute_current_stats(
x[snt_id, 0:actual_size, ...]
)
current_means.append(current_mean)
current_stds.append(current_std)
if self.norm_type == "sentence":
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data
if self.norm_type == "speaker":
spk_id = int(spk_ids[snt_id][0])
if self.training:
if spk_id not in self.spk_dict_mean:
# Initialization of the dictionary
self.spk_dict_mean[spk_id] = current_mean
self.spk_dict_std[spk_id] = current_std
self.spk_dict_count[spk_id] = 1
else:
self.spk_dict_count[spk_id] = (
self.spk_dict_count[spk_id] + 1
)
if self.avg_factor is None:
self.weight = 1 / self.spk_dict_count[spk_id]
else:
self.weight = self.avg_factor
self.spk_dict_mean[spk_id] = (
(1 - self.weight) * self.spk_dict_mean[spk_id]
+ self.weight * current_mean
)
self.spk_dict_std[spk_id] = (
(1 - self.weight) * self.spk_dict_std[spk_id]
+ self.weight * current_std
)
self.spk_dict_mean[spk_id].detach()
self.spk_dict_std[spk_id].detach()
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
if spk_id in self.spk_dict_mean:
speaker_mean = self.spk_dict_mean[spk_id].data
speaker_std = self.spk_dict_std[spk_id].data
else:
speaker_mean = current_mean.data
speaker_std = current_std.data
x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std
if self.norm_type == "batch" or self.norm_type == "global":
current_mean = torch.mean(torch.stack(current_means), dim=0)
current_std = torch.mean(torch.stack(current_stds), dim=0)
if self.norm_type == "batch":
x = (x - current_mean.data) / (current_std.data)
if self.norm_type == "global":
if self.training:
if self.count == 0:
self.glob_mean = current_mean
self.glob_std = current_std
elif epoch < self.update_until_epoch:
if self.avg_factor is None:
self.weight = 1 / (self.count + 1)
else:
self.weight = self.avg_factor
self.glob_mean = (
1 - self.weight
) * self.glob_mean + self.weight * current_mean
self.glob_std = (
1 - self.weight
) * self.glob_std + self.weight * current_std
self.glob_mean.detach()
self.glob_std.detach()
self.count = self.count + 1
x = (x - self.glob_mean.data) / (self.glob_std.data)
return x
def _compute_current_stats(self, x):
"""Returns the tensor with the surrounding context.
Arguments
---------
x : tensor
A batch of tensors.
"""
# Compute current mean
if self.mean_norm:
current_mean = torch.mean(x, dim=0).detach().data
else:
current_mean = torch.tensor([0.0], device=x.device)
# Compute current std
if self.std_norm:
current_std = torch.std(x, dim=0).detach().data
else:
current_std = torch.tensor([1.0], device=x.device)
# Improving numerical stability of std
current_std = torch.max(
current_std, self.eps * torch.ones_like(current_std)
)
return current_mean, current_std
def _statistics_dict(self):
"""Fills the dictionary containing the normalization statistics."""
state = {}
state["count"] = self.count
state["glob_mean"] = self.glob_mean
state["glob_std"] = self.glob_std
state["spk_dict_mean"] = self.spk_dict_mean
state["spk_dict_std"] = self.spk_dict_std
state["spk_dict_count"] = self.spk_dict_count
return state
def _load_statistics_dict(self, state):
"""Loads the dictionary containing the statistics.
Arguments
---------
state : dict
A dictionary containing the normalization statistics.
"""
self.count = state["count"]
if isinstance(state["glob_mean"], int):
self.glob_mean = state["glob_mean"]
self.glob_std = state["glob_std"]
else:
self.glob_mean = state["glob_mean"] # .to(self.device_inp)
self.glob_std = state["glob_std"] # .to(self.device_inp)
# Loading the spk_dict_mean in the right device
self.spk_dict_mean = {}
for spk in state["spk_dict_mean"]:
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to(
self.device_inp
)
# Loading the spk_dict_std in the right device
self.spk_dict_std = {}
for spk in state["spk_dict_std"]:
self.spk_dict_std[spk] = state["spk_dict_std"][spk].to(
self.device_inp
)
self.spk_dict_count = state["spk_dict_count"]
return state
def to(self, device):
"""Puts the needed tensors in the right device."""
self = super(InputNormalization, self).to(device)
self.glob_mean = self.glob_mean.to(device)
self.glob_std = self.glob_std.to(device)
for spk in self.spk_dict_mean:
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device)
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device)
return self
class TdnnLayer(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dilation=1,
stride=1,
groups=1,
padding=0,
padding_mode="reflect",
activation=torch.nn.LeakyReLU,
):
super(TdnnLayer, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.stride = stride
self.groups = groups
self.padding = padding
self.padding_mode = padding_mode
self.activation = activation()
self.conv = nn.Conv1d(
self.in_channels,
self.out_channels,
self.kernel_size,
dilation=self.dilation,
padding=self.padding,
groups=self.groups
)
# Set Affine=false to be compatible with the original kaldi version
# self.ln = nn.LayerNorm(out_channels, elementwise_affine=False)
self.norm = nn.BatchNorm1d(out_channels, affine=False)
def forward(self, x):
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride)
out = self.conv(x)
out = self.activation(out)
out = self.norm(out)
return out
def _manage_padding(
self, x, kernel_size: int, dilation: int, stride: int,
):
# Detecting input shape
L_in = self.in_channels
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
"""
if stride > 1:
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
else:
L_out = (
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
)
padding = [
math.floor((L_in - L_out) / 2),
math.floor((L_in - L_out) / 2),
]
return padding
class Res2NetBlock(torch.nn.Module):
"""An implementation of Res2NetBlock w/ dilation.
Arguments
---------
in_channels : int
The number of channels expected in the input.
out_channels : int
The number of output channels.
scale : int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the Res2Net block.
dilation : int
The dilation of the Res2Net block.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
):
super(Res2NetBlock, self).__init__()
assert in_channels % scale == 0
assert out_channels % scale == 0
in_channel = in_channels // scale
hidden_channel = out_channels // scale
self.blocks = nn.ModuleList(
[
TdnnLayer(
in_channel,
hidden_channel,
kernel_size=kernel_size,
dilation=dilation,
)
for _ in range(scale - 1)
]
)
self.scale = scale
def forward(self, x):
"""Processes the input tensor x and returns an output tensor."""
y = []
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
if i == 0:
y_i = x_i
elif i == 1:
y_i = self.blocks[i - 1](x_i)
else:
y_i = self.blocks[i - 1](x_i + y_i)
y.append(y_i)
y = torch.cat(y, dim=1)
return y
class SEBlock(nn.Module):
"""An implementation of squeeze-and-excitation block.
Arguments
---------
in_channels : int
The number of input channels.
se_channels : int
The number of output channels after squeeze.
out_channels : int
The number of output channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> se_layer = SEBlock(64, 16, 64)
>>> lengths = torch.rand((8,))
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 120, 64])
"""
def __init__(self, in_channels, se_channels, out_channels):
super(SEBlock, self).__init__()
self.conv1 = nn.Conv1d(
in_channels=in_channels, out_channels=se_channels, kernel_size=1
)
self.relu = torch.nn.ReLU(inplace=True)
self.conv2 = nn.Conv1d(
in_channels=se_channels, out_channels=out_channels, kernel_size=1
)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x, lengths=None):
"""Processes the input tensor x and returns an output tensor."""
L = x.shape[-1]
if lengths is not None:
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
total = mask.sum(dim=2, keepdim=True)
s = (x * mask).sum(dim=2, keepdim=True) / total
else:
s = x.mean(dim=2, keepdim=True)
s = self.relu(self.conv1(s))
s = self.sigmoid(self.conv2(s))
return s * x
def length_to_mask(length, max_len=None, dtype=None, device=None):
"""Creates a binary mask for each sequence.
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
Arguments
---------
length : torch.LongTensor
Containing the length of each sequence in the batch. Must be 1D.
max_len : int
Max length for the mask, also the size of the second dimension.
dtype : torch.dtype, default: None
The dtype of the generated mask.
device: torch.device, default: None
The device to put the mask variable.
Returns
-------
mask : tensor
The binary mask.
Example
-------
>>> length=torch.Tensor([1,2,3])
>>> mask=length_to_mask(length)
>>> mask
tensor([[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]])
"""
assert len(length.shape) == 1
if max_len is None:
max_len = length.max().long().item() # using arange to generate mask
mask = torch.arange(
max_len, device=length.device, dtype=length.dtype
).expand(len(length), max_len) < length.unsqueeze(1)
if dtype is None:
dtype = length.dtype
if device is None:
device = length.device
mask = torch.as_tensor(mask, dtype=dtype, device=device)
return mask
class AttentiveStatisticsPooling(nn.Module):
"""This class implements an attentive statistic pooling layer for each channel.
It returns the concatenated mean and std of the input tensor.
Arguments
---------
channels: int
The number of input channels.
attention_channels: int
The number of attention channels.
Example
-------
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
>>> asp_layer = AttentiveStatisticsPooling(64)
>>> lengths = torch.rand((8,))
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
>>> out_tensor.shape
torch.Size([8, 1, 128])
"""
def __init__(self, channels, attention_channels=128, global_context=True):
super().__init__()
self.eps = 1e-12
self.global_context = global_context
if global_context:
self.tdnn = TdnnLayer(channels * 3, attention_channels, 1, 1)
else:
self.tdnn = TdnnLayer(channels, attention_channels, 1, 1)
self.tanh = nn.Tanh()
self.conv = nn.Conv1d(
in_channels=attention_channels, out_channels=channels, kernel_size=1
)
def forward(self, x, lengths=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
Tensor of shape [N, C, L].
"""
L = x.shape[-1]
def _compute_statistics(x, m, dim=2, eps=self.eps):
mean = (m * x).sum(dim)
std = torch.sqrt(
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
)
return mean, std
if lengths is None:
lengths = torch.ones(x.shape[0], device=x.device)
# Make binary mask of shape [N, 1, L]
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
mask = mask.unsqueeze(1)
# Expand the temporal context of the pooling layer by allowing the
# self-attention to look at global properties of the utterance.
if self.global_context:
# torch.std is unstable for backward computation
# https://github.com/pytorch/pytorch/issues/4320
total = mask.sum(dim=2, keepdim=True).float()
mean, std = _compute_statistics(x, mask / total)
mean = mean.unsqueeze(2).repeat(1, 1, L)
std = std.unsqueeze(2).repeat(1, 1, L)
attn = torch.cat([x, mean, std], dim=1)
else:
attn = x
# Apply layers
attn = self.conv(self.tanh(self.tdnn(attn)))
# Filter out zero-paddings
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(attn, dim=2)
mean, std = _compute_statistics(x, attn)
# Append mean and std of the batch
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(2)
return pooled_stats
class SERes2NetBlock(nn.Module):
"""An implementation of building block in ECAPA-TDNN, i.e.,
TDNN-Res2Net-TDNN-SEBlock.
Arguments
----------
out_channels: int
The number of output channels.
res2net_scale: int
The scale of the Res2Net block.
kernel_size: int
The kernel size of the TDNN blocks.
dilation: int
The dilation of the Res2Net block.
activation : torch class
A class for constructing the activation layers.
groups: int
Number of blocked connections from input channels to output channels.
Example
-------
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
>>> out = conv(x).transpose(1, 2)
>>> out.shape
torch.Size([8, 120, 64])
"""
def __init__(
self,
in_channels,
out_channels,
res2net_scale=8,
se_channels=128,
kernel_size=1,
dilation=1,
activation=torch.nn.ReLU,
groups=1,
):
super().__init__()
self.out_channels = out_channels
self.tdnn1 = TdnnLayer(
in_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.res2net_block = Res2NetBlock(
out_channels, out_channels, res2net_scale, kernel_size, dilation
)
self.tdnn2 = TdnnLayer(
out_channels,
out_channels,
kernel_size=1,
dilation=1,
activation=activation,
groups=groups,
)
self.se_block = SEBlock(out_channels, se_channels, out_channels)
self.shortcut = None
if in_channels != out_channels:
self.shortcut = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
)
def forward(self, x, lengths=None):
"""Processes the input tensor x and returns an output tensor."""
residual = x
if self.shortcut:
residual = self.shortcut(x)
x = self.tdnn1(x)
x = self.res2net_block(x)
x = self.tdnn2(x)
x = self.se_block(x, lengths)
return x + residual
class EcapaEmbedder(nn.Module):
def __init__(
self,
in_channels=80,
hidden_size=192,
activation=torch.nn.ReLU,
channels=[512, 512, 512, 512, 1536],
kernel_sizes=[5, 3, 3, 3, 1],
dilations=[1, 2, 3, 4, 1],
attention_channels=128,
res2net_scale=8,
se_channels=128,
global_context=True,
groups=[1, 1, 1, 1, 1],
) -> None:
super(EcapaEmbedder, self).__init__()
self.channels = channels
self.blocks = nn.ModuleList()
# The initial TDNN layer
self.blocks.append(
TdnnLayer(
in_channels,
channels[0],
kernel_sizes[0],
dilations[0],
activation=activation,
groups=groups[0],
)
)
# SE-Res2Net layers
for i in range(1, len(channels) - 1):
self.blocks.append(
SERes2NetBlock(
channels[i - 1],
channels[i],
res2net_scale=res2net_scale,
se_channels=se_channels,
kernel_size=kernel_sizes[i],
dilation=dilations[i],
activation=activation,
groups=groups[i],
)
)
# Multi-layer feature aggregation
self.mfa = TdnnLayer(
channels[-2] * (len(channels) - 2),
channels[-1],
kernel_sizes[-1],
dilations[-1],
activation=activation,
groups=groups[-1],
)
# Attentive Statistical Pooling
self.asp = AttentiveStatisticsPooling(
channels[-1],
attention_channels=attention_channels,
global_context=global_context,
)
self.asp_bn = nn.BatchNorm1d(channels[-1] * 2)
# Final linear transformation
self.fc = nn.Conv1d(
in_channels=channels[-1] * 2,
out_channels=hidden_size,
kernel_size=1,
)
def forward(self, input_values, lengths=None):
# Minimize transpose for efficiency
x = input_values.transpose(1, 2)
# lengths = torch.sum(attention_mask, dim=1)
# lengths = lengths / torch.max(lengths)
xl = []
for layer in self.blocks:
try:
x = layer(x, lengths)
except TypeError:
x = layer(x)
xl.append(x)
# Multi-layer feature aggregation
x = torch.cat(xl[1:], dim=1)
x = self.mfa(x)
# Attentive Statistical Pooling
x = self.asp(x, lengths)
x = self.asp_bn(x)
# Final linear transformation
x = self.fc(x)
pooler_output = x.transpose(1, 2)
pooler_output = pooler_output.squeeze(1)
return ModelOutput(
# last_hidden_state=last_hidden_state,
pooler_output=pooler_output
)
class CosineSimilarityHead(torch.nn.Module):
"""
This class implements the cosine similarity on the top of features.
"""
def __init__(
self,
in_channels,
lin_blocks=0,
hidden_size=192,
num_classes=1211,
):
super().__init__()
self.blocks = nn.ModuleList()
for block_index in range(lin_blocks):
self.blocks.extend(
[
nn.BatchNorm1d(num_features=in_channels),
nn.Linear(in_features=in_channels, out_features=hidden_size),
]
)
in_channels = hidden_size
# Final Layer
self.weight = nn.Parameter(
torch.FloatTensor(num_classes, in_channels)
)
nn.init.xavier_uniform_(self.weight)
def forward(self, x):
"""Returns the output probabilities over speakers.
Arguments
---------
x : torch.Tensor
Torch tensor.
"""
for layer in self.blocks:
x = layer(x)
# Need to be normalized
x = F.linear(F.normalize(x), F.normalize(self.weight))
return x
class EcapaPreTrainedModel(PreTrainedModel):
config_class = EcapaConfig
base_model_prefix = "ecapa"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight.data)
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
class EcapaModel(EcapaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.compute_features = Fbank(
n_mels=config.n_mels,
sample_rate=config.sample_rate,
win_length=config.win_length,
hop_length=config.hop_length,
)
self.mean_var_norm = InputNormalization(
mean_norm=config.mean_norm,
std_norm=config.std_norm,
norm_type=config.norm_type
)
self.embedding_model = EcapaEmbedder(
in_channels=config.n_mels,
channels=config.channels,
kernel_sizes=config.kernel_sizes,
dilations=config.dilations,
attention_channels=config.attention_channels,
res2net_scale=config.res2net_scale,
se_channels=config.se_channels,
global_context=config.global_context,
groups=config.groups,
hidden_size=config.hidden_size
)
def forward(self, input_values, lengths=None):
x = input_values
# if attention_mask is None:
# attention_mask = torch.ones_like(input_values, device=x.device)
x = self.compute_features(x)
x = self.mean_var_norm(x, lengths)
output = self.embedding_model(x, lengths)
return ModelOutput(
pooler_output=output.pooler_output,
)