|
import math |
|
import torch |
|
import typing as tp |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.utils import ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from .helpers_ecapa import Fbank |
|
from .configuration_ecapa import EcapaConfig |
|
|
|
|
|
class InputNormalization(nn.Module): |
|
|
|
spk_dict_mean: tp.Dict[int, torch.Tensor] |
|
spk_dict_std: tp.Dict[int, torch.Tensor] |
|
spk_dict_count: tp.Dict[int, int] |
|
|
|
def __init__( |
|
self, |
|
mean_norm=True, |
|
std_norm=True, |
|
norm_type="global", |
|
avg_factor=None, |
|
requires_grad=False, |
|
update_until_epoch=3, |
|
): |
|
super().__init__() |
|
self.mean_norm = mean_norm |
|
self.std_norm = std_norm |
|
self.norm_type = norm_type |
|
self.avg_factor = avg_factor |
|
self.requires_grad = requires_grad |
|
self.glob_mean = torch.tensor([0]) |
|
self.glob_std = torch.tensor([0]) |
|
self.spk_dict_mean = {} |
|
self.spk_dict_std = {} |
|
self.spk_dict_count = {} |
|
self.weight = 1.0 |
|
self.count = 0 |
|
self.eps = 1e-10 |
|
self.update_until_epoch = update_until_epoch |
|
|
|
def forward(self, input_values, lengths=None, spk_ids=torch.tensor([]), epoch=0): |
|
"""Returns the tensor with the surrounding context. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
lengths : tensor |
|
A batch of tensors containing the relative length of each |
|
sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid |
|
computing stats on zero-padded steps. |
|
spk_ids : tensor containing the ids of each speaker (e.g, [0 10 6]). |
|
It is used to perform per-speaker normalization when |
|
norm_type='speaker'. |
|
""" |
|
x = input_values |
|
N_batches = x.shape[0] |
|
|
|
current_means = [] |
|
current_stds = [] |
|
|
|
for snt_id in range(N_batches): |
|
|
|
|
|
|
|
|
|
actual_size = torch.round(lengths[snt_id] * x.shape[1]).int() |
|
|
|
|
|
current_mean, current_std = self._compute_current_stats( |
|
x[snt_id, 0:actual_size, ...] |
|
) |
|
|
|
current_means.append(current_mean) |
|
current_stds.append(current_std) |
|
|
|
if self.norm_type == "sentence": |
|
x[snt_id] = (x[snt_id] - current_mean.data) / current_std.data |
|
|
|
if self.norm_type == "speaker": |
|
spk_id = int(spk_ids[snt_id][0]) |
|
|
|
if self.training: |
|
if spk_id not in self.spk_dict_mean: |
|
|
|
self.spk_dict_mean[spk_id] = current_mean |
|
self.spk_dict_std[spk_id] = current_std |
|
self.spk_dict_count[spk_id] = 1 |
|
|
|
else: |
|
self.spk_dict_count[spk_id] = ( |
|
self.spk_dict_count[spk_id] + 1 |
|
) |
|
|
|
if self.avg_factor is None: |
|
self.weight = 1 / self.spk_dict_count[spk_id] |
|
else: |
|
self.weight = self.avg_factor |
|
|
|
self.spk_dict_mean[spk_id] = ( |
|
(1 - self.weight) * self.spk_dict_mean[spk_id] |
|
+ self.weight * current_mean |
|
) |
|
self.spk_dict_std[spk_id] = ( |
|
(1 - self.weight) * self.spk_dict_std[spk_id] |
|
+ self.weight * current_std |
|
) |
|
|
|
self.spk_dict_mean[spk_id].detach() |
|
self.spk_dict_std[spk_id].detach() |
|
|
|
speaker_mean = self.spk_dict_mean[spk_id].data |
|
speaker_std = self.spk_dict_std[spk_id].data |
|
else: |
|
if spk_id in self.spk_dict_mean: |
|
speaker_mean = self.spk_dict_mean[spk_id].data |
|
speaker_std = self.spk_dict_std[spk_id].data |
|
else: |
|
speaker_mean = current_mean.data |
|
speaker_std = current_std.data |
|
|
|
x[snt_id] = (x[snt_id] - speaker_mean) / speaker_std |
|
|
|
if self.norm_type == "batch" or self.norm_type == "global": |
|
current_mean = torch.mean(torch.stack(current_means), dim=0) |
|
current_std = torch.mean(torch.stack(current_stds), dim=0) |
|
|
|
if self.norm_type == "batch": |
|
x = (x - current_mean.data) / (current_std.data) |
|
|
|
if self.norm_type == "global": |
|
if self.training: |
|
if self.count == 0: |
|
self.glob_mean = current_mean |
|
self.glob_std = current_std |
|
|
|
elif epoch < self.update_until_epoch: |
|
if self.avg_factor is None: |
|
self.weight = 1 / (self.count + 1) |
|
else: |
|
self.weight = self.avg_factor |
|
|
|
self.glob_mean = ( |
|
1 - self.weight |
|
) * self.glob_mean + self.weight * current_mean |
|
|
|
self.glob_std = ( |
|
1 - self.weight |
|
) * self.glob_std + self.weight * current_std |
|
|
|
self.glob_mean.detach() |
|
self.glob_std.detach() |
|
|
|
self.count = self.count + 1 |
|
|
|
x = (x - self.glob_mean.data) / (self.glob_std.data) |
|
|
|
return x |
|
|
|
def _compute_current_stats(self, x): |
|
"""Returns the tensor with the surrounding context. |
|
Arguments |
|
--------- |
|
x : tensor |
|
A batch of tensors. |
|
""" |
|
|
|
if self.mean_norm: |
|
current_mean = torch.mean(x, dim=0).detach().data |
|
else: |
|
current_mean = torch.tensor([0.0], device=x.device) |
|
|
|
|
|
if self.std_norm: |
|
current_std = torch.std(x, dim=0).detach().data |
|
else: |
|
current_std = torch.tensor([1.0], device=x.device) |
|
|
|
|
|
current_std = torch.max( |
|
current_std, self.eps * torch.ones_like(current_std) |
|
) |
|
|
|
return current_mean, current_std |
|
|
|
def _statistics_dict(self): |
|
"""Fills the dictionary containing the normalization statistics.""" |
|
state = {} |
|
state["count"] = self.count |
|
state["glob_mean"] = self.glob_mean |
|
state["glob_std"] = self.glob_std |
|
state["spk_dict_mean"] = self.spk_dict_mean |
|
state["spk_dict_std"] = self.spk_dict_std |
|
state["spk_dict_count"] = self.spk_dict_count |
|
|
|
return state |
|
|
|
def _load_statistics_dict(self, state): |
|
"""Loads the dictionary containing the statistics. |
|
Arguments |
|
--------- |
|
state : dict |
|
A dictionary containing the normalization statistics. |
|
""" |
|
self.count = state["count"] |
|
if isinstance(state["glob_mean"], int): |
|
self.glob_mean = state["glob_mean"] |
|
self.glob_std = state["glob_std"] |
|
else: |
|
self.glob_mean = state["glob_mean"] |
|
self.glob_std = state["glob_std"] |
|
|
|
|
|
self.spk_dict_mean = {} |
|
for spk in state["spk_dict_mean"]: |
|
self.spk_dict_mean[spk] = state["spk_dict_mean"][spk].to( |
|
self.device_inp |
|
) |
|
|
|
|
|
self.spk_dict_std = {} |
|
for spk in state["spk_dict_std"]: |
|
self.spk_dict_std[spk] = state["spk_dict_std"][spk].to( |
|
self.device_inp |
|
) |
|
|
|
self.spk_dict_count = state["spk_dict_count"] |
|
|
|
return state |
|
|
|
def to(self, device): |
|
"""Puts the needed tensors in the right device.""" |
|
self = super(InputNormalization, self).to(device) |
|
self.glob_mean = self.glob_mean.to(device) |
|
self.glob_std = self.glob_std.to(device) |
|
for spk in self.spk_dict_mean: |
|
self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) |
|
self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) |
|
return self |
|
|
|
|
|
class TdnnLayer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
dilation=1, |
|
stride=1, |
|
groups=1, |
|
padding=0, |
|
padding_mode="reflect", |
|
activation=torch.nn.LeakyReLU, |
|
): |
|
super(TdnnLayer, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
self.stride = stride |
|
self.groups = groups |
|
self.padding = padding |
|
self.padding_mode = padding_mode |
|
self.activation = activation() |
|
|
|
self.conv = nn.Conv1d( |
|
self.in_channels, |
|
self.out_channels, |
|
self.kernel_size, |
|
dilation=self.dilation, |
|
padding=self.padding, |
|
groups=self.groups |
|
) |
|
|
|
|
|
|
|
self.norm = nn.BatchNorm1d(out_channels, affine=False) |
|
|
|
def forward(self, x): |
|
x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) |
|
out = self.conv(x) |
|
out = self.activation(out) |
|
out = self.norm(out) |
|
return out |
|
|
|
def _manage_padding( |
|
self, x, kernel_size: int, dilation: int, stride: int, |
|
): |
|
|
|
L_in = self.in_channels |
|
|
|
|
|
padding = get_padding_elem(L_in, stride, kernel_size, dilation) |
|
|
|
|
|
x = F.pad(x, padding, mode=self.padding_mode) |
|
|
|
return x |
|
|
|
|
|
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): |
|
"""This function computes the number of elements to add for zero-padding. |
|
Arguments |
|
--------- |
|
L_in : int |
|
stride: int |
|
kernel_size : int |
|
dilation : int |
|
""" |
|
if stride > 1: |
|
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)] |
|
|
|
else: |
|
L_out = ( |
|
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1 |
|
) |
|
padding = [ |
|
math.floor((L_in - L_out) / 2), |
|
math.floor((L_in - L_out) / 2), |
|
] |
|
return padding |
|
|
|
|
|
class Res2NetBlock(torch.nn.Module): |
|
"""An implementation of Res2NetBlock w/ dilation. |
|
Arguments |
|
--------- |
|
in_channels : int |
|
The number of channels expected in the input. |
|
out_channels : int |
|
The number of output channels. |
|
scale : int |
|
The scale of the Res2Net block. |
|
kernel_size: int |
|
The kernel size of the Res2Net block. |
|
dilation : int |
|
The dilation of the Res2Net block. |
|
Example |
|
------- |
|
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
|
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) |
|
>>> out_tensor = layer(inp_tensor).transpose(1, 2) |
|
>>> out_tensor.shape |
|
torch.Size([8, 120, 64]) |
|
""" |
|
|
|
def __init__( |
|
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1 |
|
): |
|
super(Res2NetBlock, self).__init__() |
|
assert in_channels % scale == 0 |
|
assert out_channels % scale == 0 |
|
|
|
in_channel = in_channels // scale |
|
hidden_channel = out_channels // scale |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
TdnnLayer( |
|
in_channel, |
|
hidden_channel, |
|
kernel_size=kernel_size, |
|
dilation=dilation, |
|
) |
|
for _ in range(scale - 1) |
|
] |
|
) |
|
self.scale = scale |
|
|
|
def forward(self, x): |
|
"""Processes the input tensor x and returns an output tensor.""" |
|
y = [] |
|
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): |
|
if i == 0: |
|
y_i = x_i |
|
elif i == 1: |
|
y_i = self.blocks[i - 1](x_i) |
|
else: |
|
y_i = self.blocks[i - 1](x_i + y_i) |
|
y.append(y_i) |
|
y = torch.cat(y, dim=1) |
|
return y |
|
|
|
|
|
class SEBlock(nn.Module): |
|
"""An implementation of squeeze-and-excitation block. |
|
Arguments |
|
--------- |
|
in_channels : int |
|
The number of input channels. |
|
se_channels : int |
|
The number of output channels after squeeze. |
|
out_channels : int |
|
The number of output channels. |
|
Example |
|
------- |
|
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
|
>>> se_layer = SEBlock(64, 16, 64) |
|
>>> lengths = torch.rand((8,)) |
|
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) |
|
>>> out_tensor.shape |
|
torch.Size([8, 120, 64]) |
|
""" |
|
|
|
def __init__(self, in_channels, se_channels, out_channels): |
|
super(SEBlock, self).__init__() |
|
|
|
self.conv1 = nn.Conv1d( |
|
in_channels=in_channels, out_channels=se_channels, kernel_size=1 |
|
) |
|
self.relu = torch.nn.ReLU(inplace=True) |
|
self.conv2 = nn.Conv1d( |
|
in_channels=se_channels, out_channels=out_channels, kernel_size=1 |
|
) |
|
self.sigmoid = torch.nn.Sigmoid() |
|
|
|
def forward(self, x, lengths=None): |
|
"""Processes the input tensor x and returns an output tensor.""" |
|
L = x.shape[-1] |
|
if lengths is not None: |
|
mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
|
mask = mask.unsqueeze(1) |
|
total = mask.sum(dim=2, keepdim=True) |
|
s = (x * mask).sum(dim=2, keepdim=True) / total |
|
else: |
|
s = x.mean(dim=2, keepdim=True) |
|
|
|
s = self.relu(self.conv1(s)) |
|
s = self.sigmoid(self.conv2(s)) |
|
|
|
return s * x |
|
|
|
|
|
def length_to_mask(length, max_len=None, dtype=None, device=None): |
|
"""Creates a binary mask for each sequence. |
|
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 |
|
Arguments |
|
--------- |
|
length : torch.LongTensor |
|
Containing the length of each sequence in the batch. Must be 1D. |
|
max_len : int |
|
Max length for the mask, also the size of the second dimension. |
|
dtype : torch.dtype, default: None |
|
The dtype of the generated mask. |
|
device: torch.device, default: None |
|
The device to put the mask variable. |
|
Returns |
|
------- |
|
mask : tensor |
|
The binary mask. |
|
Example |
|
------- |
|
>>> length=torch.Tensor([1,2,3]) |
|
>>> mask=length_to_mask(length) |
|
>>> mask |
|
tensor([[1., 0., 0.], |
|
[1., 1., 0.], |
|
[1., 1., 1.]]) |
|
""" |
|
assert len(length.shape) == 1 |
|
|
|
if max_len is None: |
|
max_len = length.max().long().item() |
|
mask = torch.arange( |
|
max_len, device=length.device, dtype=length.dtype |
|
).expand(len(length), max_len) < length.unsqueeze(1) |
|
|
|
if dtype is None: |
|
dtype = length.dtype |
|
|
|
if device is None: |
|
device = length.device |
|
|
|
mask = torch.as_tensor(mask, dtype=dtype, device=device) |
|
return mask |
|
|
|
|
|
class AttentiveStatisticsPooling(nn.Module): |
|
"""This class implements an attentive statistic pooling layer for each channel. |
|
It returns the concatenated mean and std of the input tensor. |
|
Arguments |
|
--------- |
|
channels: int |
|
The number of input channels. |
|
attention_channels: int |
|
The number of attention channels. |
|
Example |
|
------- |
|
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) |
|
>>> asp_layer = AttentiveStatisticsPooling(64) |
|
>>> lengths = torch.rand((8,)) |
|
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) |
|
>>> out_tensor.shape |
|
torch.Size([8, 1, 128]) |
|
""" |
|
|
|
def __init__(self, channels, attention_channels=128, global_context=True): |
|
super().__init__() |
|
|
|
self.eps = 1e-12 |
|
self.global_context = global_context |
|
if global_context: |
|
self.tdnn = TdnnLayer(channels * 3, attention_channels, 1, 1) |
|
else: |
|
self.tdnn = TdnnLayer(channels, attention_channels, 1, 1) |
|
self.tanh = nn.Tanh() |
|
self.conv = nn.Conv1d( |
|
in_channels=attention_channels, out_channels=channels, kernel_size=1 |
|
) |
|
|
|
def forward(self, x, lengths=None): |
|
"""Calculates mean and std for a batch (input tensor). |
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Tensor of shape [N, C, L]. |
|
""" |
|
L = x.shape[-1] |
|
|
|
def _compute_statistics(x, m, dim=2, eps=self.eps): |
|
mean = (m * x).sum(dim) |
|
std = torch.sqrt( |
|
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps) |
|
) |
|
return mean, std |
|
|
|
if lengths is None: |
|
lengths = torch.ones(x.shape[0], device=x.device) |
|
|
|
|
|
mask = length_to_mask(lengths * L, max_len=L, device=x.device) |
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
|
|
if self.global_context: |
|
|
|
|
|
total = mask.sum(dim=2, keepdim=True).float() |
|
mean, std = _compute_statistics(x, mask / total) |
|
mean = mean.unsqueeze(2).repeat(1, 1, L) |
|
std = std.unsqueeze(2).repeat(1, 1, L) |
|
attn = torch.cat([x, mean, std], dim=1) |
|
else: |
|
attn = x |
|
|
|
|
|
attn = self.conv(self.tanh(self.tdnn(attn))) |
|
|
|
|
|
attn = attn.masked_fill(mask == 0, float("-inf")) |
|
|
|
attn = F.softmax(attn, dim=2) |
|
mean, std = _compute_statistics(x, attn) |
|
|
|
pooled_stats = torch.cat((mean, std), dim=1) |
|
pooled_stats = pooled_stats.unsqueeze(2) |
|
|
|
return pooled_stats |
|
|
|
|
|
|
|
class SERes2NetBlock(nn.Module): |
|
"""An implementation of building block in ECAPA-TDNN, i.e., |
|
TDNN-Res2Net-TDNN-SEBlock. |
|
Arguments |
|
---------- |
|
out_channels: int |
|
The number of output channels. |
|
res2net_scale: int |
|
The scale of the Res2Net block. |
|
kernel_size: int |
|
The kernel size of the TDNN blocks. |
|
dilation: int |
|
The dilation of the Res2Net block. |
|
activation : torch class |
|
A class for constructing the activation layers. |
|
groups: int |
|
Number of blocked connections from input channels to output channels. |
|
Example |
|
------- |
|
>>> x = torch.rand(8, 120, 64).transpose(1, 2) |
|
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4) |
|
>>> out = conv(x).transpose(1, 2) |
|
>>> out.shape |
|
torch.Size([8, 120, 64]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
res2net_scale=8, |
|
se_channels=128, |
|
kernel_size=1, |
|
dilation=1, |
|
activation=torch.nn.ReLU, |
|
groups=1, |
|
): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.tdnn1 = TdnnLayer( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
dilation=1, |
|
activation=activation, |
|
groups=groups, |
|
) |
|
self.res2net_block = Res2NetBlock( |
|
out_channels, out_channels, res2net_scale, kernel_size, dilation |
|
) |
|
self.tdnn2 = TdnnLayer( |
|
out_channels, |
|
out_channels, |
|
kernel_size=1, |
|
dilation=1, |
|
activation=activation, |
|
groups=groups, |
|
) |
|
self.se_block = SEBlock(out_channels, se_channels, out_channels) |
|
|
|
self.shortcut = None |
|
if in_channels != out_channels: |
|
self.shortcut = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
) |
|
|
|
def forward(self, x, lengths=None): |
|
"""Processes the input tensor x and returns an output tensor.""" |
|
residual = x |
|
if self.shortcut: |
|
residual = self.shortcut(x) |
|
|
|
x = self.tdnn1(x) |
|
x = self.res2net_block(x) |
|
x = self.tdnn2(x) |
|
x = self.se_block(x, lengths) |
|
|
|
return x + residual |
|
|
|
|
|
class EcapaEmbedder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
in_channels=80, |
|
hidden_size=192, |
|
activation=torch.nn.ReLU, |
|
channels=[512, 512, 512, 512, 1536], |
|
kernel_sizes=[5, 3, 3, 3, 1], |
|
dilations=[1, 2, 3, 4, 1], |
|
attention_channels=128, |
|
res2net_scale=8, |
|
se_channels=128, |
|
global_context=True, |
|
groups=[1, 1, 1, 1, 1], |
|
) -> None: |
|
super(EcapaEmbedder, self).__init__() |
|
self.channels = channels |
|
self.blocks = nn.ModuleList() |
|
|
|
|
|
self.blocks.append( |
|
TdnnLayer( |
|
in_channels, |
|
channels[0], |
|
kernel_sizes[0], |
|
dilations[0], |
|
activation=activation, |
|
groups=groups[0], |
|
) |
|
) |
|
|
|
|
|
for i in range(1, len(channels) - 1): |
|
self.blocks.append( |
|
SERes2NetBlock( |
|
channels[i - 1], |
|
channels[i], |
|
res2net_scale=res2net_scale, |
|
se_channels=se_channels, |
|
kernel_size=kernel_sizes[i], |
|
dilation=dilations[i], |
|
activation=activation, |
|
groups=groups[i], |
|
) |
|
) |
|
|
|
|
|
self.mfa = TdnnLayer( |
|
channels[-2] * (len(channels) - 2), |
|
channels[-1], |
|
kernel_sizes[-1], |
|
dilations[-1], |
|
activation=activation, |
|
groups=groups[-1], |
|
) |
|
|
|
|
|
self.asp = AttentiveStatisticsPooling( |
|
channels[-1], |
|
attention_channels=attention_channels, |
|
global_context=global_context, |
|
) |
|
self.asp_bn = nn.BatchNorm1d(channels[-1] * 2) |
|
|
|
|
|
self.fc = nn.Conv1d( |
|
in_channels=channels[-1] * 2, |
|
out_channels=hidden_size, |
|
kernel_size=1, |
|
) |
|
|
|
def forward(self, input_values, lengths=None): |
|
|
|
x = input_values.transpose(1, 2) |
|
|
|
|
|
|
|
xl = [] |
|
for layer in self.blocks: |
|
try: |
|
x = layer(x, lengths) |
|
except TypeError: |
|
x = layer(x) |
|
xl.append(x) |
|
|
|
|
|
x = torch.cat(xl[1:], dim=1) |
|
x = self.mfa(x) |
|
|
|
|
|
x = self.asp(x, lengths) |
|
x = self.asp_bn(x) |
|
|
|
|
|
x = self.fc(x) |
|
|
|
pooler_output = x.transpose(1, 2) |
|
pooler_output = pooler_output.squeeze(1) |
|
return ModelOutput( |
|
|
|
pooler_output=pooler_output |
|
) |
|
|
|
|
|
class CosineSimilarityHead(torch.nn.Module): |
|
""" |
|
This class implements the cosine similarity on the top of features. |
|
""" |
|
def __init__( |
|
self, |
|
in_channels, |
|
lin_blocks=0, |
|
hidden_size=192, |
|
num_classes=1211, |
|
): |
|
super().__init__() |
|
self.blocks = nn.ModuleList() |
|
|
|
for block_index in range(lin_blocks): |
|
self.blocks.extend( |
|
[ |
|
nn.BatchNorm1d(num_features=in_channels), |
|
nn.Linear(in_features=in_channels, out_features=hidden_size), |
|
] |
|
) |
|
in_channels = hidden_size |
|
|
|
|
|
self.weight = nn.Parameter( |
|
torch.FloatTensor(num_classes, in_channels) |
|
) |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
def forward(self, x): |
|
"""Returns the output probabilities over speakers. |
|
Arguments |
|
--------- |
|
x : torch.Tensor |
|
Torch tensor. |
|
""" |
|
for layer in self.blocks: |
|
x = layer(x) |
|
|
|
|
|
x = F.linear(F.normalize(x), F.normalize(self.weight)) |
|
return x |
|
|
|
|
|
class EcapaPreTrainedModel(PreTrainedModel): |
|
|
|
config_class = EcapaConfig |
|
base_model_prefix = "ecapa" |
|
main_input_name = "input_values" |
|
supports_gradient_checkpointing = True |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class EcapaModel(EcapaPreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.compute_features = Fbank( |
|
n_mels=config.n_mels, |
|
sample_rate=config.sample_rate, |
|
win_length=config.win_length, |
|
hop_length=config.hop_length, |
|
) |
|
self.mean_var_norm = InputNormalization( |
|
mean_norm=config.mean_norm, |
|
std_norm=config.std_norm, |
|
norm_type=config.norm_type |
|
) |
|
self.embedding_model = EcapaEmbedder( |
|
in_channels=config.n_mels, |
|
channels=config.channels, |
|
kernel_sizes=config.kernel_sizes, |
|
dilations=config.dilations, |
|
attention_channels=config.attention_channels, |
|
res2net_scale=config.res2net_scale, |
|
se_channels=config.se_channels, |
|
global_context=config.global_context, |
|
groups=config.groups, |
|
hidden_size=config.hidden_size |
|
) |
|
|
|
def forward(self, input_values, lengths=None): |
|
x = input_values |
|
|
|
|
|
x = self.compute_features(x) |
|
x = self.mean_var_norm(x, lengths) |
|
output = self.embedding_model(x, lengths) |
|
return ModelOutput( |
|
pooler_output=output.pooler_output, |
|
) |