|
""" |
|
Taken from ESPNet |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from .Attention import RelPositionMultiHeadedAttention |
|
from .Convolution import ConvolutionModule |
|
from .EncoderLayer import EncoderLayer |
|
from .LayerNorm import LayerNorm |
|
from .MultiLayeredConv1d import MultiLayeredConv1d |
|
from .MultiSequential import repeat |
|
from .PositionalEncoding import RelPositionalEncoding |
|
from .Swish import Swish |
|
|
|
|
|
class Conformer(torch.nn.Module): |
|
""" |
|
Conformer encoder module. |
|
|
|
Args: |
|
idim (int): Input dimension. |
|
attention_dim (int): Dimension of attention. |
|
attention_heads (int): The number of heads of multi head attention. |
|
linear_units (int): The number of units of position-wise feed forward. |
|
num_blocks (int): The number of decoder blocks. |
|
dropout_rate (float): Dropout rate. |
|
positional_dropout_rate (float): Dropout rate after adding positional encoding. |
|
attention_dropout_rate (float): Dropout rate in attention. |
|
input_layer (Union[str, torch.nn.Module]): Input layer type. |
|
normalize_before (bool): Whether to use layer_norm before the first block. |
|
concat_after (bool): Whether to concat attention layer's input and output. |
|
if True, additional linear will be applied. |
|
i.e. x -> x + linear(concat(x, att(x))) |
|
if False, no additional linear will be applied. i.e. x -> x + att(x) |
|
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". |
|
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. |
|
macaron_style (bool): Whether to use macaron style for positionwise layer. |
|
pos_enc_layer_type (str): Conformer positional encoding layer type. |
|
selfattention_layer_type (str): Conformer attention layer type. |
|
activation_type (str): Conformer activation function type. |
|
use_cnn_module (bool): Whether to use convolution module. |
|
cnn_module_kernel (int): Kernerl size of convolution module. |
|
padding_idx (int): Padding idx for input_layer=embed. |
|
|
|
""" |
|
|
|
def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, |
|
attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, |
|
macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True, |
|
spk_emb_bottleneck_size=128, lang_embs=None): |
|
super(Conformer, self).__init__() |
|
|
|
activation = Swish() |
|
self.conv_subsampling_factor = 1 |
|
|
|
if isinstance(input_layer, torch.nn.Module): |
|
self.embed = input_layer |
|
self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate) |
|
elif input_layer is None: |
|
self.embed = None |
|
self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate)) |
|
else: |
|
raise ValueError("unknown input_layer: " + input_layer) |
|
|
|
self.normalize_before = normalize_before |
|
|
|
self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out |
|
if utt_embed is not None: |
|
self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim) |
|
|
|
self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size), |
|
torch.nn.Softsign()) |
|
if lang_embs is not None: |
|
self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim) |
|
|
|
|
|
encoder_selfattn_layer = RelPositionMultiHeadedAttention |
|
encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) |
|
|
|
|
|
positionwise_layer = MultiLayeredConv1d |
|
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) |
|
|
|
|
|
convolution_layer = ConvolutionModule |
|
convolution_layer_args = (attention_dim, cnn_module_kernel, activation) |
|
|
|
self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), |
|
positionwise_layer(*positionwise_layer_args), |
|
positionwise_layer(*positionwise_layer_args) if macaron_style else None, |
|
convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, |
|
normalize_before, concat_after)) |
|
if self.normalize_before: |
|
self.after_norm = LayerNorm(attention_dim) |
|
|
|
def forward(self, xs, masks, utterance_embedding=None, lang_ids=None): |
|
""" |
|
Encode input sequence. |
|
|
|
Args: |
|
utterance_embedding: embedding containing lots of conditioning signals |
|
step: indicator for when to start updating the embedding function |
|
xs (torch.Tensor): Input tensor (#batch, time, idim). |
|
masks (torch.Tensor): Mask tensor (#batch, time). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, time, attention_dim). |
|
torch.Tensor: Mask tensor (#batch, time). |
|
|
|
""" |
|
|
|
if self.embed is not None: |
|
xs = self.embed(xs) |
|
|
|
if lang_ids is not None: |
|
lang_embs = self.language_embedding(lang_ids) |
|
xs = xs + lang_embs |
|
|
|
if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out: |
|
xs = self._integrate_with_utt_embed(xs, utterance_embedding) |
|
|
|
xs = self.pos_enc(xs) |
|
|
|
xs, masks = self.encoders(xs, masks) |
|
if isinstance(xs, tuple): |
|
xs = xs[0] |
|
|
|
if self.normalize_before: |
|
xs = self.after_norm(xs) |
|
|
|
if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out: |
|
xs = self._integrate_with_utt_embed(xs, utterance_embedding) |
|
|
|
return xs, masks |
|
|
|
def _integrate_with_utt_embed(self, hs, utt_embeddings): |
|
|
|
speaker_embeddings_projected = self.embedding_projection(utt_embeddings) |
|
|
|
speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1) |
|
hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1)) |
|
return hs |
|
|