DINO-HuVITS / src /wespeaker_campplus /pooling_layers.py
SazerLife's picture
feat: added model
36a67ca
raw
history blame contribute delete
No virus
9.82 kB
# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pooling functions to aggregate frame-level deep features
into segment-level speaker embeddings
High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
even though we remove the mean statistic, on Voxceleb.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, in_dim=0, **kwargs):
super(TAP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean
def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim
class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, in_dim=0, **kwargs):
super(TSDP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std
def get_out_dim(self):
self.out_dim = self.in_dim
return self.out_dim
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, in_dim=0, **kwargs):
super(TSTP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
def get_out_dim(self):
self.out_dim = self.in_dim * 2
return self.out_dim
class ASTP(nn.Module):
"""Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs):
super(ASTP, self).__init__()
self.in_dim = in_dim
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, bottleneck_dim, kernel_size=1
) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, bottleneck_dim, kernel_size=1
) # equals W and b in the paper
self.linear2 = nn.Conv1d(
bottleneck_dim, in_dim, kernel_size=1
) # equals V and k in the paper
def forward(self, x):
"""
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(x.shape) == 4:
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
assert len(x.shape) == 3
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-7
).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-7))
return torch.cat([mean, std], dim=1)
def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim
class MHASTP(torch.nn.Module):
"""Multi head attentive statistics pooling
Reference:
Self Multi-Head Attention for Speaker Recognition
https://arxiv.org/pdf/1906.09890.pdf
"""
def __init__(
self, in_dim, layer_num=2, head_num=2, d_s=1, bottleneck_dim=64, **kwargs
):
super(MHASTP, self).__init__()
assert (
in_dim % head_num
) == 0 # make sure that head num can be divided by input_dim
self.in_dim = in_dim
self.head_num = head_num
d_model = int(in_dim / head_num)
channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
if d_s > 1:
d_s = d_model
else:
d_s = 1
self.d_s = d_s
channel_dims[0], channel_dims[-1] = d_model, d_s
heads_att_trans = []
for i in range(self.head_num):
att_trans = nn.Sequential()
for i in range(layer_num - 1):
att_trans.add_module(
"att_" + str(i),
nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1),
)
att_trans.add_module("tanh" + str(i), nn.Tanh())
att_trans.add_module(
"att_" + str(layer_num - 1),
nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], 1, 1),
)
heads_att_trans.append(att_trans)
self.heads_att_trans = nn.ModuleList(heads_att_trans)
def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(
input.shape[0], input.shape[1] * input.shape[2], input.shape[3]
)
assert len(input.shape) == 3
bs, f_dim, t_dim = input.shape
chunks = torch.chunk(input, self.head_num, 1)
# split
chunks_out = []
# for i in range(self.head_num):
# att_score = self.heads_att_trans[i](chunks[i])
for i, layer in enumerate(self.heads_att_trans):
att_score = layer(chunks[i])
alpha = F.softmax(att_score, dim=-1)
mean = torch.sum(alpha * chunks[i], dim=2)
var = torch.sum(alpha * chunks[i] ** 2, dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-7))
chunks_out.append(torch.cat((mean, std), dim=1))
out = torch.cat(chunks_out, dim=1)
return out
def get_out_dim(self):
self.out_dim = 2 * self.in_dim
return self.out_dim
class MQMHASTP(torch.nn.Module):
"""An attentive pooling
Reference:
multi query multi head attentive statistics pooling
https://arxiv.org/pdf/2110.05042.pdf
Args:
in_dim: the feature dimension of input
layer_num: the number of layer in the pooling layer
query_num: the number of querys
head_num: the number of heads
bottleneck_dim: the bottleneck dimension
SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
https://arxiv.org/pdf/1906.09890.pdf
AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
https://arxiv.org/pdf/1803.10963.pdf
VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
"""
def __init__(
self,
in_dim,
layer_num=2,
query_num=2,
head_num=8,
d_s=2,
bottleneck_dim=64,
**kwargs
):
super(MQMHASTP, self).__init__()
self.n_query = nn.ModuleList(
[
MHASTP(
in_dim,
layer_num=layer_num,
head_num=head_num,
d_s=d_s,
bottleneck_dim=bottleneck_dim,
)
for i in range(query_num)
]
)
self.query_num = query_num
self.in_dim = in_dim
def forward(self, input):
"""
input: a 3-dimensional tensor in xvector architecture
or a 4-dimensional tensor in resnet architecture
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(input.shape) == 4: # B x F x T
input = input.reshape(
input.shape[0], input.shape[1] * input.shape[2], input.shape[3]
)
assert len(input.shape) == 3
res = []
for i, layer in enumerate(self.n_query):
res.append(layer(input))
out = torch.cat(res, dim=-1)
return out
def get_out_dim(self):
self.out_dim = self.in_dim * 2 * self.query_num
return self.out_dim