File size: 9,378 Bytes
9e275b8
 
 
 
 
 
70399da
 
 
 
 
 
 
 
 
 
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
 
 
9e275b8
70399da
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70399da
 
 
 
9e275b8
 
 
70399da
 
 
 
9e275b8
 
 
 
 
 
70399da
 
 
 
 
 
9e275b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Taken from ESPNet, but heavily modified
"""

import torch

from Modules.GeneralLayers.Attention import RelPositionMultiHeadedAttention
from Modules.GeneralLayers.ConditionalLayerNorm import AdaIN1d
from Modules.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm
from Modules.GeneralLayers.Convolution import ConvolutionModule
from Modules.GeneralLayers.EncoderLayer import EncoderLayer
from Modules.GeneralLayers.LayerNorm import LayerNorm
from Modules.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d
from Modules.GeneralLayers.MultiSequential import repeat
from Modules.GeneralLayers.PositionalEncoding import RelPositionalEncoding
from Modules.GeneralLayers.Swish import Swish
from Utility.utils import integrate_with_utt_embed


class Conformer(torch.nn.Module):
    """
    Conformer encoder module.

    Args:
        idim (int): Input dimension.
        attention_dim (int): Dimension of attention.
        attention_heads (int): The number of heads of multi head attention.
        linear_units (int): The number of units of position-wise feed forward.
        num_blocks (int): The number of decoder blocks.
        dropout_rate (float): Dropout rate.
        positional_dropout_rate (float): Dropout rate after adding positional encoding.
        attention_dropout_rate (float): Dropout rate in attention.
        input_layer (Union[str, torch.nn.Module]): Input layer type.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
        macaron_style (bool): Whether to use macaron style for positionwise layer.
        use_cnn_module (bool): Whether to use convolution module.
        cnn_module_kernel (int): Kernel size of convolution module.

    """

    def __init__(self, conformer_type, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
                 attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
                 macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, lang_emb_size=16, use_output_norm=True, embedding_integration="AdaIN"):
        super(Conformer, self).__init__()

        activation = Swish()
        self.conv_subsampling_factor = 1
        self.use_output_norm = use_output_norm

        if isinstance(input_layer, torch.nn.Module):
            self.embed = input_layer
            self.art_embed_norm = LayerNorm(attention_dim)
            self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
        elif input_layer is None:
            self.embed = None
            self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
        else:
            raise ValueError("unknown input_layer: " + input_layer)

        if self.use_output_norm:
            self.output_norm = LayerNorm(attention_dim)
        self.utt_embed = utt_embed
        self.conformer_type = conformer_type
        self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"]
        if utt_embed is not None:
            if conformer_type == "encoder":  # the encoder gets an additional conditioning signal added to its output
                if embedding_integration == "AdaIN":
                    self.encoder_embedding_projection = AdaIN1d(style_dim=utt_embed, num_features=attention_dim)
                elif embedding_integration == "ConditionalLayerNorm":
                    self.encoder_embedding_projection = ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim)
                else:
                    self.encoder_embedding_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim)
            else:
                if embedding_integration == "AdaIN":
                    self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: AdaIN1d(style_dim=utt_embed, num_features=attention_dim))
                elif embedding_integration == "ConditionalLayerNorm":
                    self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim))
                else:
                    self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim))
        if lang_embs is not None:
            self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size)
            if lang_emb_size == attention_dim:
                self.language_embedding_projection = lambda x: x
            else:
                self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim)
            self.language_emb_norm = LayerNorm(attention_dim)

        # self-attention module definition
        encoder_selfattn_layer = RelPositionMultiHeadedAttention
        encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)

        # feed-forward module definition
        positionwise_layer = MultiLayeredConv1d
        positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)

        # convolution module definition
        convolution_layer = ConvolutionModule
        convolution_layer_args = (attention_dim, cnn_module_kernel, activation)

        self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
                                                                     positionwise_layer(*positionwise_layer_args),
                                                                     positionwise_layer(*positionwise_layer_args) if macaron_style else None,
                                                                     convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
                                                                     normalize_before, concat_after))

    def forward(self,
                xs,
                masks,
                utterance_embedding=None,
                lang_ids=None):
        """
        Encode input sequence.
        Args:
            utterance_embedding: embedding containing lots of conditioning signals
            lang_ids: ids of the languages per sample in the batch
            xs (torch.Tensor): Input tensor (#batch, time, idim).
            masks (torch.Tensor): Mask tensor (#batch, time).
        Returns:
            torch.Tensor: Output tensor (#batch, time, attention_dim).
            torch.Tensor: Mask tensor (#batch, time).
        """

        if self.embed is not None:
            xs = self.embed(xs)
            xs = self.art_embed_norm(xs)

        if lang_ids is not None:
            lang_embs = self.language_embedding(lang_ids)
            projected_lang_embs = self.language_embedding_projection(lang_embs).unsqueeze(-1).transpose(1, 2)
            projected_lang_embs = self.language_emb_norm(projected_lang_embs)
            xs = xs + projected_lang_embs  # offset phoneme representation by language specific offset

        xs = self.pos_enc(xs)

        for encoder_index, encoder in enumerate(self.encoders):
            if self.utt_embed:
                if isinstance(xs, tuple):
                    x, pos_emb = xs[0], xs[1]
                    if self.conformer_type != "encoder":
                        x = integrate_with_utt_embed(hs=x,
                                                     utt_embeddings=utterance_embedding,
                                                     projection=self.decoder_embedding_projections[encoder_index],
                                                     embedding_training=self.use_conditional_layernorm_embedding_integration)
                    xs = (x, pos_emb)
                else:
                    if self.conformer_type != "encoder":
                        xs = integrate_with_utt_embed(hs=xs,
                                                      utt_embeddings=utterance_embedding,
                                                      projection=self.decoder_embedding_projections[encoder_index],
                                                      embedding_training=self.use_conditional_layernorm_embedding_integration)
            xs, masks = encoder(xs, masks)

        if isinstance(xs, tuple):
            xs = xs[0]

        if self.utt_embed and self.conformer_type == "encoder":
            xs = integrate_with_utt_embed(hs=xs,
                                          utt_embeddings=utterance_embedding,
                                          projection=self.encoder_embedding_projection,
                                          embedding_training=self.use_conditional_layernorm_embedding_integration)
        elif self.use_output_norm:
            xs = self.output_norm(xs)

        return xs, masks