TheComputerMan commited on
Commit
0c5616f
1 Parent(s): 3ae2f30

Upload Conformer.py

Browse files
Files changed (1) hide show
  1. Conformer.py +144 -0
Conformer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from ESPNet
3
+ """
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from Layers.Attention import RelPositionMultiHeadedAttention
9
+ from Layers.Convolution import ConvolutionModule
10
+ from Layers.EncoderLayer import EncoderLayer
11
+ from Layers.LayerNorm import LayerNorm
12
+ from Layers.MultiLayeredConv1d import MultiLayeredConv1d
13
+ from Layers.MultiSequential import repeat
14
+ from Layers.PositionalEncoding import RelPositionalEncoding
15
+ from Layers.Swish import Swish
16
+
17
+
18
+ class Conformer(torch.nn.Module):
19
+ """
20
+ Conformer encoder module.
21
+
22
+ Args:
23
+ idim (int): Input dimension.
24
+ attention_dim (int): Dimension of attention.
25
+ attention_heads (int): The number of heads of multi head attention.
26
+ linear_units (int): The number of units of position-wise feed forward.
27
+ num_blocks (int): The number of decoder blocks.
28
+ dropout_rate (float): Dropout rate.
29
+ positional_dropout_rate (float): Dropout rate after adding positional encoding.
30
+ attention_dropout_rate (float): Dropout rate in attention.
31
+ input_layer (Union[str, torch.nn.Module]): Input layer type.
32
+ normalize_before (bool): Whether to use layer_norm before the first block.
33
+ concat_after (bool): Whether to concat attention layer's input and output.
34
+ if True, additional linear will be applied.
35
+ i.e. x -> x + linear(concat(x, att(x)))
36
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
37
+ positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
38
+ positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
39
+ macaron_style (bool): Whether to use macaron style for positionwise layer.
40
+ pos_enc_layer_type (str): Conformer positional encoding layer type.
41
+ selfattention_layer_type (str): Conformer attention layer type.
42
+ activation_type (str): Conformer activation function type.
43
+ use_cnn_module (bool): Whether to use convolution module.
44
+ cnn_module_kernel (int): Kernerl size of convolution module.
45
+ padding_idx (int): Padding idx for input_layer=embed.
46
+
47
+ """
48
+
49
+ def __init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1,
50
+ attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1,
51
+ macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, connect_utt_emb_at_encoder_out=True,
52
+ spk_emb_bottleneck_size=128, lang_embs=None):
53
+ super(Conformer, self).__init__()
54
+
55
+ activation = Swish()
56
+ self.conv_subsampling_factor = 1
57
+
58
+ if isinstance(input_layer, torch.nn.Module):
59
+ self.embed = input_layer
60
+ self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate)
61
+ elif input_layer is None:
62
+ self.embed = None
63
+ self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate))
64
+ else:
65
+ raise ValueError("unknown input_layer: " + input_layer)
66
+
67
+ self.normalize_before = normalize_before
68
+
69
+ self.connect_utt_emb_at_encoder_out = connect_utt_emb_at_encoder_out
70
+ if utt_embed is not None:
71
+ self.hs_emb_projection = torch.nn.Linear(attention_dim + spk_emb_bottleneck_size, attention_dim)
72
+ # embedding projection derived from https://arxiv.org/pdf/1705.08947.pdf
73
+ self.embedding_projection = torch.nn.Sequential(torch.nn.Linear(utt_embed, spk_emb_bottleneck_size),
74
+ torch.nn.Softsign())
75
+ if lang_embs is not None:
76
+ self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=attention_dim)
77
+
78
+ # self-attention module definition
79
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
80
+ encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu)
81
+
82
+ # feed-forward module definition
83
+ positionwise_layer = MultiLayeredConv1d
84
+ positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,)
85
+
86
+ # convolution module definition
87
+ convolution_layer = ConvolutionModule
88
+ convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
89
+
90
+ self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args),
91
+ positionwise_layer(*positionwise_layer_args),
92
+ positionwise_layer(*positionwise_layer_args) if macaron_style else None,
93
+ convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate,
94
+ normalize_before, concat_after))
95
+ if self.normalize_before:
96
+ self.after_norm = LayerNorm(attention_dim)
97
+
98
+ def forward(self, xs, masks, utterance_embedding=None, lang_ids=None):
99
+ """
100
+ Encode input sequence.
101
+
102
+ Args:
103
+ utterance_embedding: embedding containing lots of conditioning signals
104
+ step: indicator for when to start updating the embedding function
105
+ xs (torch.Tensor): Input tensor (#batch, time, idim).
106
+ masks (torch.Tensor): Mask tensor (#batch, time).
107
+
108
+ Returns:
109
+ torch.Tensor: Output tensor (#batch, time, attention_dim).
110
+ torch.Tensor: Mask tensor (#batch, time).
111
+
112
+ """
113
+
114
+ if self.embed is not None:
115
+ xs = self.embed(xs)
116
+
117
+ if lang_ids is not None:
118
+ lang_embs = self.language_embedding(lang_ids)
119
+ xs = xs + lang_embs # offset the phoneme distribution of a language
120
+
121
+ if utterance_embedding is not None and not self.connect_utt_emb_at_encoder_out:
122
+ xs = self._integrate_with_utt_embed(xs, utterance_embedding)
123
+
124
+ xs = self.pos_enc(xs)
125
+
126
+ xs, masks = self.encoders(xs, masks)
127
+ if isinstance(xs, tuple):
128
+ xs = xs[0]
129
+
130
+ if self.normalize_before:
131
+ xs = self.after_norm(xs)
132
+
133
+ if utterance_embedding is not None and self.connect_utt_emb_at_encoder_out:
134
+ xs = self._integrate_with_utt_embed(xs, utterance_embedding)
135
+
136
+ return xs, masks
137
+
138
+ def _integrate_with_utt_embed(self, hs, utt_embeddings):
139
+ # project embedding into smaller space
140
+ speaker_embeddings_projected = self.embedding_projection(utt_embeddings)
141
+ # concat hidden states with spk embeds and then apply projection
142
+ speaker_embeddings_expanded = F.normalize(speaker_embeddings_projected).unsqueeze(1).expand(-1, hs.size(1), -1)
143
+ hs = self.hs_emb_projection(torch.cat([hs, speaker_embeddings_expanded], dim=-1))
144
+ return hs