sudip1310 commited on
Commit
998b155
1 Parent(s): 78655fa

Upload 5 files

Browse files
tacotron_pytorch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ from __future__ import with_statement, print_function, absolute_import
3
+
4
+ from .tacotron import Tacotron
5
+ from .version import __version__
tacotron_pytorch/attention.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class BahdanauAttention(nn.Module):
8
+ def __init__(self, dim):
9
+ super(BahdanauAttention, self).__init__()
10
+ self.query_layer = nn.Linear(dim, dim, bias=False)
11
+ self.tanh = nn.Tanh()
12
+ self.v = nn.Linear(dim, 1, bias=False)
13
+
14
+ def forward(self, query, processed_memory):
15
+ """
16
+ Args:
17
+ query: (batch, 1, dim) or (batch, dim)
18
+ processed_memory: (batch, max_time, dim)
19
+ """
20
+ if query.dim() == 2:
21
+ # insert time-axis for broadcasting
22
+ query = query.unsqueeze(1)
23
+ # (batch, 1, dim)
24
+ processed_query = self.query_layer(query)
25
+
26
+ # (batch, max_time, 1)
27
+ alignment = self.v(self.tanh(processed_query + processed_memory))
28
+
29
+ # (batch, max_time)
30
+ return alignment.squeeze(-1)
31
+
32
+
33
+ def get_mask_from_lengths(memory, memory_lengths):
34
+ """Get mask tensor from list of length
35
+
36
+ Args:
37
+ memory: (batch, max_time, dim)
38
+ memory_lengths: array like
39
+ """
40
+ mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
41
+ for idx, l in enumerate(memory_lengths):
42
+ mask[idx][:l] = 1
43
+ return ~mask
44
+
45
+
46
+ class AttentionWrapper(nn.Module):
47
+ def __init__(self, rnn_cell, attention_mechanism,
48
+ score_mask_value=-float("inf")):
49
+ super(AttentionWrapper, self).__init__()
50
+ self.rnn_cell = rnn_cell
51
+ self.attention_mechanism = attention_mechanism
52
+ self.score_mask_value = score_mask_value
53
+
54
+ def forward(self, query, attention, cell_state, memory,
55
+ processed_memory=None, mask=None, memory_lengths=None):
56
+ if processed_memory is None:
57
+ processed_memory = memory
58
+ if memory_lengths is not None and mask is None:
59
+ mask = get_mask_from_lengths(memory, memory_lengths)
60
+
61
+ # Concat input query and previous attention context
62
+ cell_input = torch.cat((query, attention), -1)
63
+
64
+ # Feed it to RNN
65
+ cell_output = self.rnn_cell(cell_input, cell_state)
66
+
67
+ # Alignment
68
+ # (batch, max_time)
69
+ alignment = self.attention_mechanism(cell_output, processed_memory)
70
+
71
+ if mask is not None:
72
+ mask = mask.view(query.size(0), -1)
73
+ alignment.data.masked_fill_(mask, self.score_mask_value)
74
+
75
+ # Normalize attention weight
76
+ alignment = F.softmax(alignment)
77
+
78
+ # Attention context vector
79
+ # (batch, 1, dim)
80
+ attention = torch.bmm(alignment.unsqueeze(1), memory)
81
+
82
+ # (batch, dim)
83
+ attention = attention.squeeze(1)
84
+
85
+ return cell_output, attention, alignment
tacotron_pytorch/tacotron.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ from __future__ import with_statement, print_function, absolute_import
3
+
4
+ import torch
5
+ from torch.autograd import Variable
6
+ from torch import nn
7
+
8
+ from .attention import BahdanauAttention, AttentionWrapper
9
+ from .attention import get_mask_from_lengths
10
+
11
+
12
+ class Prenet(nn.Module):
13
+ def __init__(self, in_dim, sizes=[256, 128]):
14
+ super(Prenet, self).__init__()
15
+ in_sizes = [in_dim] + sizes[:-1]
16
+ self.layers = nn.ModuleList(
17
+ [nn.Linear(in_size, out_size)
18
+ for (in_size, out_size) in zip(in_sizes, sizes)])
19
+ self.relu = nn.ReLU()
20
+ self.dropout = nn.Dropout(0.5)
21
+
22
+ def forward(self, inputs):
23
+ for linear in self.layers:
24
+ inputs = self.dropout(self.relu(linear(inputs)))
25
+ return inputs
26
+
27
+
28
+ class BatchNormConv1d(nn.Module):
29
+ def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
30
+ activation=None):
31
+ super(BatchNormConv1d, self).__init__()
32
+ self.conv1d = nn.Conv1d(in_dim, out_dim,
33
+ kernel_size=kernel_size,
34
+ stride=stride, padding=padding, bias=False)
35
+ self.bn = nn.BatchNorm1d(out_dim)
36
+ self.activation = activation
37
+
38
+ def forward(self, x):
39
+ x = self.conv1d(x)
40
+ if self.activation is not None:
41
+ x = self.activation(x)
42
+ return self.bn(x)
43
+
44
+
45
+ class Highway(nn.Module):
46
+ def __init__(self, in_size, out_size):
47
+ super(Highway, self).__init__()
48
+ self.H = nn.Linear(in_size, out_size)
49
+ self.H.bias.data.zero_()
50
+ self.T = nn.Linear(in_size, out_size)
51
+ self.T.bias.data.fill_(-1)
52
+ self.relu = nn.ReLU()
53
+ self.sigmoid = nn.Sigmoid()
54
+
55
+ def forward(self, inputs):
56
+ H = self.relu(self.H(inputs))
57
+ T = self.sigmoid(self.T(inputs))
58
+ return H * T + inputs * (1.0 - T)
59
+
60
+
61
+ class CBHG(nn.Module):
62
+ """CBHG module: a recurrent neural network composed of:
63
+ - 1-d convolution banks
64
+ - Highway networks + residual connections
65
+ - Bidirectional gated recurrent units
66
+ """
67
+
68
+ def __init__(self, in_dim, K=16, projections=[128, 128]):
69
+ super(CBHG, self).__init__()
70
+ self.in_dim = in_dim
71
+ self.relu = nn.ReLU()
72
+ self.conv1d_banks = nn.ModuleList(
73
+ [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
74
+ padding=k // 2, activation=self.relu)
75
+ for k in range(1, K + 1)])
76
+ self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
77
+
78
+ in_sizes = [K * in_dim] + projections[:-1]
79
+ activations = [self.relu] * (len(projections) - 1) + [None]
80
+ self.conv1d_projections = nn.ModuleList(
81
+ [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
82
+ padding=1, activation=ac)
83
+ for (in_size, out_size, ac) in zip(
84
+ in_sizes, projections, activations)])
85
+
86
+ self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
87
+ self.highways = nn.ModuleList(
88
+ [Highway(in_dim, in_dim) for _ in range(4)])
89
+
90
+ self.gru = nn.GRU(
91
+ in_dim, in_dim, 1, batch_first=True, bidirectional=True)
92
+
93
+ def forward(self, inputs, input_lengths=None):
94
+ # (B, T_in, in_dim)
95
+ x = inputs
96
+
97
+ # Needed to perform conv1d on time-axis
98
+ # (B, in_dim, T_in)
99
+ if x.size(-1) == self.in_dim:
100
+ x = x.transpose(1, 2)
101
+
102
+ T = x.size(-1)
103
+
104
+ # (B, in_dim*K, T_in)
105
+ # Concat conv1d bank outputs
106
+ x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
107
+ assert x.size(1) == self.in_dim * len(self.conv1d_banks)
108
+ x = self.max_pool1d(x)[:, :, :T]
109
+
110
+ for conv1d in self.conv1d_projections:
111
+ x = conv1d(x)
112
+
113
+ # (B, T_in, in_dim)
114
+ # Back to the original shape
115
+ x = x.transpose(1, 2)
116
+
117
+ if x.size(-1) != self.in_dim:
118
+ x = self.pre_highway(x)
119
+
120
+ # Residual connection
121
+ x += inputs
122
+ for highway in self.highways:
123
+ x = highway(x)
124
+
125
+ if input_lengths is not None:
126
+ x = nn.utils.rnn.pack_padded_sequence(
127
+ x, input_lengths, batch_first=True)
128
+
129
+ # (B, T_in, in_dim*2)
130
+ outputs, _ = self.gru(x)
131
+
132
+ if input_lengths is not None:
133
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(
134
+ outputs, batch_first=True)
135
+
136
+ return outputs
137
+
138
+
139
+ class Encoder(nn.Module):
140
+ def __init__(self, in_dim):
141
+ super(Encoder, self).__init__()
142
+ self.prenet = Prenet(in_dim, sizes=[256, 128])
143
+ self.cbhg = CBHG(128, K=16, projections=[128, 128])
144
+
145
+ def forward(self, inputs, input_lengths=None):
146
+ inputs = self.prenet(inputs)
147
+ return self.cbhg(inputs, input_lengths)
148
+
149
+
150
+ class Decoder(nn.Module):
151
+ def __init__(self, in_dim, r):
152
+ super(Decoder, self).__init__()
153
+ self.in_dim = in_dim
154
+ self.r = r
155
+ self.prenet = Prenet(in_dim * r, sizes=[256, 128])
156
+ # (prenet_out + attention context) -> output
157
+ self.attention_rnn = AttentionWrapper(
158
+ nn.GRUCell(256 + 128, 256),
159
+ BahdanauAttention(256)
160
+ )
161
+ self.memory_layer = nn.Linear(256, 256, bias=False)
162
+ self.project_to_decoder_in = nn.Linear(512, 256)
163
+
164
+ self.decoder_rnns = nn.ModuleList(
165
+ [nn.GRUCell(256, 256) for _ in range(2)])
166
+
167
+ self.proj_to_mel = nn.Linear(256, in_dim * r)
168
+ self.max_decoder_steps = 200
169
+
170
+ def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
171
+ """
172
+ Decoder forward step.
173
+
174
+ If decoder inputs are not given (e.g., at testing time), as noted in
175
+ Tacotron paper, greedy decoding is adapted.
176
+
177
+ Args:
178
+ encoder_outputs: Encoder outputs. (B, T_encoder, dim)
179
+ inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time),
180
+ decoder outputs are used as decoder inputs.
181
+ memory_lengths: Encoder output (memory) lengths. If not None, used for
182
+ attention masking.
183
+ """
184
+ B = encoder_outputs.size(0)
185
+
186
+ processed_memory = self.memory_layer(encoder_outputs)
187
+ if memory_lengths is not None:
188
+ mask = get_mask_from_lengths(processed_memory, memory_lengths)
189
+ else:
190
+ mask = None
191
+
192
+ # Run greedy decoding if inputs is None
193
+ greedy = inputs is None
194
+
195
+ if inputs is not None:
196
+ # Grouping multiple frames if necessary
197
+ if inputs.size(-1) == self.in_dim:
198
+ inputs = inputs.view(B, inputs.size(1) // self.r, -1)
199
+ assert inputs.size(-1) == self.in_dim * self.r
200
+ T_decoder = inputs.size(1)
201
+
202
+ # go frames
203
+ initial_input = Variable(
204
+ encoder_outputs.data.new(B, self.in_dim * self.r).zero_())
205
+
206
+ # Init decoder states
207
+ attention_rnn_hidden = Variable(
208
+ encoder_outputs.data.new(B, 256).zero_())
209
+ decoder_rnn_hiddens = [Variable(
210
+ encoder_outputs.data.new(B, 256).zero_())
211
+ for _ in range(len(self.decoder_rnns))]
212
+ current_attention = Variable(
213
+ encoder_outputs.data.new(B, 256).zero_())
214
+
215
+ # Time first (T_decoder, B, in_dim)
216
+ if inputs is not None:
217
+ inputs = inputs.transpose(0, 1)
218
+
219
+ outputs = []
220
+ alignments = []
221
+
222
+ t = 0
223
+ current_input = initial_input
224
+ while True:
225
+ if t > 0:
226
+ current_input = outputs[-1] if greedy else inputs[t - 1]
227
+ # Prenet
228
+ current_input = self.prenet(current_input)
229
+
230
+ # Attention RNN
231
+ attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
232
+ current_input, current_attention, attention_rnn_hidden,
233
+ encoder_outputs, processed_memory=processed_memory, mask=mask)
234
+
235
+ # Concat RNN output and attention context vector
236
+ decoder_input = self.project_to_decoder_in(
237
+ torch.cat((attention_rnn_hidden, current_attention), -1))
238
+
239
+ # Pass through the decoder RNNs
240
+ for idx in range(len(self.decoder_rnns)):
241
+ decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
242
+ decoder_input, decoder_rnn_hiddens[idx])
243
+ # Residual connectinon
244
+ decoder_input = decoder_rnn_hiddens[idx] + decoder_input
245
+
246
+ output = decoder_input
247
+ output = self.proj_to_mel(output)
248
+
249
+ outputs += [output]
250
+ alignments += [alignment]
251
+
252
+ t += 1
253
+
254
+ if greedy:
255
+ if t > 1 and is_end_of_frames(output):
256
+ break
257
+ elif t > self.max_decoder_steps:
258
+ print("Warning! doesn't seems to be converged")
259
+ break
260
+ else:
261
+ if t >= T_decoder:
262
+ break
263
+
264
+ assert greedy or len(outputs) == T_decoder
265
+
266
+ # Back to batch first
267
+ alignments = torch.stack(alignments).transpose(0, 1)
268
+ outputs = torch.stack(outputs).transpose(0, 1).contiguous()
269
+
270
+ return outputs, alignments
271
+
272
+
273
+ def is_end_of_frames(output, eps=0.2):
274
+ return (output.data <= eps).all()
275
+
276
+
277
+ class Tacotron(nn.Module):
278
+ def __init__(self, n_vocab, embedding_dim=256, mel_dim=80, linear_dim=1025,
279
+ r=5, padding_idx=None, use_memory_mask=False):
280
+ super(Tacotron, self).__init__()
281
+ self.mel_dim = mel_dim
282
+ self.linear_dim = linear_dim
283
+ self.use_memory_mask = use_memory_mask
284
+ self.embedding = nn.Embedding(n_vocab, embedding_dim,
285
+ padding_idx=padding_idx)
286
+ # Trying smaller std
287
+ self.embedding.weight.data.normal_(0, 0.3)
288
+ self.encoder = Encoder(embedding_dim)
289
+ self.decoder = Decoder(mel_dim, r)
290
+
291
+ self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
292
+ self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
293
+
294
+ def forward(self, inputs, targets=None, input_lengths=None):
295
+ B = inputs.size(0)
296
+
297
+ inputs = self.embedding(inputs)
298
+ # (B, T', in_dim)
299
+ encoder_outputs = self.encoder(inputs, input_lengths)
300
+
301
+ if self.use_memory_mask:
302
+ memory_lengths = input_lengths
303
+ else:
304
+ memory_lengths = None
305
+ # (B, T', mel_dim*r)
306
+ mel_outputs, alignments = self.decoder(
307
+ encoder_outputs, targets, memory_lengths=memory_lengths)
308
+
309
+ # Post net processing below
310
+
311
+ # Reshape
312
+ # (B, T, mel_dim)
313
+ mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
314
+
315
+ linear_outputs = self.postnet(mel_outputs)
316
+ linear_outputs = self.last_linear(linear_outputs)
317
+
318
+ return mel_outputs, linear_outputs, alignments
tests/test_attention.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch import nn
4
+
5
+ from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper
6
+ from tacotron_pytorch.attention import get_mask_from_lengths
7
+
8
+
9
+ def test_attention_wrapper():
10
+ B = 2
11
+
12
+ encoder_outputs = Variable(torch.rand(B, 100, 256))
13
+ memory_lengths = [100, 50]
14
+
15
+ mask = get_mask_from_lengths(encoder_outputs, memory_lengths)
16
+ print("Mask size:", mask.size())
17
+
18
+ memory_layer = nn.Linear(256, 256)
19
+ query = Variable(torch.rand(B, 128))
20
+
21
+ attention_mechanism = BahdanauAttention(256)
22
+
23
+ # Attention context + input
24
+ rnn = nn.GRUCell(256 + 128, 256)
25
+
26
+ attention_rnn = AttentionWrapper(rnn, attention_mechanism)
27
+ initial_attention = Variable(torch.zeros(B, 256))
28
+ cell_state = Variable(torch.zeros(B, 256))
29
+
30
+ processed_memory = memory_layer(encoder_outputs)
31
+
32
+ cell_output, attention, alignment = attention_rnn(
33
+ query, initial_attention, cell_state, encoder_outputs,
34
+ processed_memory=processed_memory,
35
+ mask=None, memory_lengths=memory_lengths)
36
+
37
+ print("Cell output size:", cell_output.size())
38
+ print("Attention output size:", attention.size())
39
+ print("Alignment size:", alignment.size())
40
+
41
+ assert (alignment.sum(-1) == 1).data.all()
42
+
43
+
44
+ test_attention_wrapper()
tests/test_tacotron.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import sys
3
+ from os.path import dirname, join
4
+ tacotron_lib_dir = join(dirname(__file__), "..", "lib", "tacotron")
5
+ sys.path.append(tacotron_lib_dir)
6
+ from text import text_to_sequence, symbols
7
+ import torch
8
+ from torch.autograd import Variable
9
+ from tacotron_pytorch import Tacotron
10
+ import numpy as np
11
+
12
+
13
+ def _pad(seq, max_len):
14
+ return np.pad(seq, (0, max_len - len(seq)),
15
+ mode='constant', constant_values=0)
16
+
17
+
18
+ def test_taco():
19
+ B, T_out, D_out = 2, 400, 80
20
+ r = 5
21
+ T_encoder = T_out // r
22
+
23
+ texts = ["Thank you very much.", "Hello"]
24
+ seqs = [np.array(text_to_sequence(
25
+ t, ["english_cleaners"]), dtype=np.int) for t in texts]
26
+ input_lengths = np.array([len(s) for s in seqs])
27
+ max_len = np.max(input_lengths)
28
+ seqs = np.array([_pad(s, max_len) for s in seqs])
29
+
30
+ x = torch.LongTensor(seqs)
31
+ y = torch.rand(B, T_out, D_out)
32
+ x = Variable(x)
33
+ y = Variable(y)
34
+
35
+ model = Tacotron(n_vocab=len(symbols), r=r)
36
+
37
+ print("Encoder input shape: ", x.size())
38
+ print("Decoder input shape: ", y.size())
39
+ a, b, c = model(x, y, input_lengths=input_lengths)
40
+ print("Mel shape:", a.size())
41
+ print("Linear shape:", b.size())
42
+ print("Attention shape:", c.size())
43
+
44
+ assert c.size() == (B, T_encoder, max_len)
45
+
46
+ # Test greddy decoding
47
+ a, b, c = model(x, input_lengths=input_lengths)