Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- tacotron_pytorch/__init__.py +5 -0
- tacotron_pytorch/attention.py +85 -0
- tacotron_pytorch/tacotron.py +318 -0
- tests/test_attention.py +44 -0
- tests/test_tacotron.py +47 -0
tacotron_pytorch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
from __future__ import with_statement, print_function, absolute_import
|
3 |
+
|
4 |
+
from .tacotron import Tacotron
|
5 |
+
from .version import __version__
|
tacotron_pytorch/attention.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class BahdanauAttention(nn.Module):
|
8 |
+
def __init__(self, dim):
|
9 |
+
super(BahdanauAttention, self).__init__()
|
10 |
+
self.query_layer = nn.Linear(dim, dim, bias=False)
|
11 |
+
self.tanh = nn.Tanh()
|
12 |
+
self.v = nn.Linear(dim, 1, bias=False)
|
13 |
+
|
14 |
+
def forward(self, query, processed_memory):
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
query: (batch, 1, dim) or (batch, dim)
|
18 |
+
processed_memory: (batch, max_time, dim)
|
19 |
+
"""
|
20 |
+
if query.dim() == 2:
|
21 |
+
# insert time-axis for broadcasting
|
22 |
+
query = query.unsqueeze(1)
|
23 |
+
# (batch, 1, dim)
|
24 |
+
processed_query = self.query_layer(query)
|
25 |
+
|
26 |
+
# (batch, max_time, 1)
|
27 |
+
alignment = self.v(self.tanh(processed_query + processed_memory))
|
28 |
+
|
29 |
+
# (batch, max_time)
|
30 |
+
return alignment.squeeze(-1)
|
31 |
+
|
32 |
+
|
33 |
+
def get_mask_from_lengths(memory, memory_lengths):
|
34 |
+
"""Get mask tensor from list of length
|
35 |
+
|
36 |
+
Args:
|
37 |
+
memory: (batch, max_time, dim)
|
38 |
+
memory_lengths: array like
|
39 |
+
"""
|
40 |
+
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
|
41 |
+
for idx, l in enumerate(memory_lengths):
|
42 |
+
mask[idx][:l] = 1
|
43 |
+
return ~mask
|
44 |
+
|
45 |
+
|
46 |
+
class AttentionWrapper(nn.Module):
|
47 |
+
def __init__(self, rnn_cell, attention_mechanism,
|
48 |
+
score_mask_value=-float("inf")):
|
49 |
+
super(AttentionWrapper, self).__init__()
|
50 |
+
self.rnn_cell = rnn_cell
|
51 |
+
self.attention_mechanism = attention_mechanism
|
52 |
+
self.score_mask_value = score_mask_value
|
53 |
+
|
54 |
+
def forward(self, query, attention, cell_state, memory,
|
55 |
+
processed_memory=None, mask=None, memory_lengths=None):
|
56 |
+
if processed_memory is None:
|
57 |
+
processed_memory = memory
|
58 |
+
if memory_lengths is not None and mask is None:
|
59 |
+
mask = get_mask_from_lengths(memory, memory_lengths)
|
60 |
+
|
61 |
+
# Concat input query and previous attention context
|
62 |
+
cell_input = torch.cat((query, attention), -1)
|
63 |
+
|
64 |
+
# Feed it to RNN
|
65 |
+
cell_output = self.rnn_cell(cell_input, cell_state)
|
66 |
+
|
67 |
+
# Alignment
|
68 |
+
# (batch, max_time)
|
69 |
+
alignment = self.attention_mechanism(cell_output, processed_memory)
|
70 |
+
|
71 |
+
if mask is not None:
|
72 |
+
mask = mask.view(query.size(0), -1)
|
73 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
74 |
+
|
75 |
+
# Normalize attention weight
|
76 |
+
alignment = F.softmax(alignment)
|
77 |
+
|
78 |
+
# Attention context vector
|
79 |
+
# (batch, 1, dim)
|
80 |
+
attention = torch.bmm(alignment.unsqueeze(1), memory)
|
81 |
+
|
82 |
+
# (batch, dim)
|
83 |
+
attention = attention.squeeze(1)
|
84 |
+
|
85 |
+
return cell_output, attention, alignment
|
tacotron_pytorch/tacotron.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
from __future__ import with_statement, print_function, absolute_import
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .attention import BahdanauAttention, AttentionWrapper
|
9 |
+
from .attention import get_mask_from_lengths
|
10 |
+
|
11 |
+
|
12 |
+
class Prenet(nn.Module):
|
13 |
+
def __init__(self, in_dim, sizes=[256, 128]):
|
14 |
+
super(Prenet, self).__init__()
|
15 |
+
in_sizes = [in_dim] + sizes[:-1]
|
16 |
+
self.layers = nn.ModuleList(
|
17 |
+
[nn.Linear(in_size, out_size)
|
18 |
+
for (in_size, out_size) in zip(in_sizes, sizes)])
|
19 |
+
self.relu = nn.ReLU()
|
20 |
+
self.dropout = nn.Dropout(0.5)
|
21 |
+
|
22 |
+
def forward(self, inputs):
|
23 |
+
for linear in self.layers:
|
24 |
+
inputs = self.dropout(self.relu(linear(inputs)))
|
25 |
+
return inputs
|
26 |
+
|
27 |
+
|
28 |
+
class BatchNormConv1d(nn.Module):
|
29 |
+
def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
|
30 |
+
activation=None):
|
31 |
+
super(BatchNormConv1d, self).__init__()
|
32 |
+
self.conv1d = nn.Conv1d(in_dim, out_dim,
|
33 |
+
kernel_size=kernel_size,
|
34 |
+
stride=stride, padding=padding, bias=False)
|
35 |
+
self.bn = nn.BatchNorm1d(out_dim)
|
36 |
+
self.activation = activation
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.conv1d(x)
|
40 |
+
if self.activation is not None:
|
41 |
+
x = self.activation(x)
|
42 |
+
return self.bn(x)
|
43 |
+
|
44 |
+
|
45 |
+
class Highway(nn.Module):
|
46 |
+
def __init__(self, in_size, out_size):
|
47 |
+
super(Highway, self).__init__()
|
48 |
+
self.H = nn.Linear(in_size, out_size)
|
49 |
+
self.H.bias.data.zero_()
|
50 |
+
self.T = nn.Linear(in_size, out_size)
|
51 |
+
self.T.bias.data.fill_(-1)
|
52 |
+
self.relu = nn.ReLU()
|
53 |
+
self.sigmoid = nn.Sigmoid()
|
54 |
+
|
55 |
+
def forward(self, inputs):
|
56 |
+
H = self.relu(self.H(inputs))
|
57 |
+
T = self.sigmoid(self.T(inputs))
|
58 |
+
return H * T + inputs * (1.0 - T)
|
59 |
+
|
60 |
+
|
61 |
+
class CBHG(nn.Module):
|
62 |
+
"""CBHG module: a recurrent neural network composed of:
|
63 |
+
- 1-d convolution banks
|
64 |
+
- Highway networks + residual connections
|
65 |
+
- Bidirectional gated recurrent units
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, in_dim, K=16, projections=[128, 128]):
|
69 |
+
super(CBHG, self).__init__()
|
70 |
+
self.in_dim = in_dim
|
71 |
+
self.relu = nn.ReLU()
|
72 |
+
self.conv1d_banks = nn.ModuleList(
|
73 |
+
[BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
|
74 |
+
padding=k // 2, activation=self.relu)
|
75 |
+
for k in range(1, K + 1)])
|
76 |
+
self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
77 |
+
|
78 |
+
in_sizes = [K * in_dim] + projections[:-1]
|
79 |
+
activations = [self.relu] * (len(projections) - 1) + [None]
|
80 |
+
self.conv1d_projections = nn.ModuleList(
|
81 |
+
[BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
|
82 |
+
padding=1, activation=ac)
|
83 |
+
for (in_size, out_size, ac) in zip(
|
84 |
+
in_sizes, projections, activations)])
|
85 |
+
|
86 |
+
self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
|
87 |
+
self.highways = nn.ModuleList(
|
88 |
+
[Highway(in_dim, in_dim) for _ in range(4)])
|
89 |
+
|
90 |
+
self.gru = nn.GRU(
|
91 |
+
in_dim, in_dim, 1, batch_first=True, bidirectional=True)
|
92 |
+
|
93 |
+
def forward(self, inputs, input_lengths=None):
|
94 |
+
# (B, T_in, in_dim)
|
95 |
+
x = inputs
|
96 |
+
|
97 |
+
# Needed to perform conv1d on time-axis
|
98 |
+
# (B, in_dim, T_in)
|
99 |
+
if x.size(-1) == self.in_dim:
|
100 |
+
x = x.transpose(1, 2)
|
101 |
+
|
102 |
+
T = x.size(-1)
|
103 |
+
|
104 |
+
# (B, in_dim*K, T_in)
|
105 |
+
# Concat conv1d bank outputs
|
106 |
+
x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
|
107 |
+
assert x.size(1) == self.in_dim * len(self.conv1d_banks)
|
108 |
+
x = self.max_pool1d(x)[:, :, :T]
|
109 |
+
|
110 |
+
for conv1d in self.conv1d_projections:
|
111 |
+
x = conv1d(x)
|
112 |
+
|
113 |
+
# (B, T_in, in_dim)
|
114 |
+
# Back to the original shape
|
115 |
+
x = x.transpose(1, 2)
|
116 |
+
|
117 |
+
if x.size(-1) != self.in_dim:
|
118 |
+
x = self.pre_highway(x)
|
119 |
+
|
120 |
+
# Residual connection
|
121 |
+
x += inputs
|
122 |
+
for highway in self.highways:
|
123 |
+
x = highway(x)
|
124 |
+
|
125 |
+
if input_lengths is not None:
|
126 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
127 |
+
x, input_lengths, batch_first=True)
|
128 |
+
|
129 |
+
# (B, T_in, in_dim*2)
|
130 |
+
outputs, _ = self.gru(x)
|
131 |
+
|
132 |
+
if input_lengths is not None:
|
133 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
134 |
+
outputs, batch_first=True)
|
135 |
+
|
136 |
+
return outputs
|
137 |
+
|
138 |
+
|
139 |
+
class Encoder(nn.Module):
|
140 |
+
def __init__(self, in_dim):
|
141 |
+
super(Encoder, self).__init__()
|
142 |
+
self.prenet = Prenet(in_dim, sizes=[256, 128])
|
143 |
+
self.cbhg = CBHG(128, K=16, projections=[128, 128])
|
144 |
+
|
145 |
+
def forward(self, inputs, input_lengths=None):
|
146 |
+
inputs = self.prenet(inputs)
|
147 |
+
return self.cbhg(inputs, input_lengths)
|
148 |
+
|
149 |
+
|
150 |
+
class Decoder(nn.Module):
|
151 |
+
def __init__(self, in_dim, r):
|
152 |
+
super(Decoder, self).__init__()
|
153 |
+
self.in_dim = in_dim
|
154 |
+
self.r = r
|
155 |
+
self.prenet = Prenet(in_dim * r, sizes=[256, 128])
|
156 |
+
# (prenet_out + attention context) -> output
|
157 |
+
self.attention_rnn = AttentionWrapper(
|
158 |
+
nn.GRUCell(256 + 128, 256),
|
159 |
+
BahdanauAttention(256)
|
160 |
+
)
|
161 |
+
self.memory_layer = nn.Linear(256, 256, bias=False)
|
162 |
+
self.project_to_decoder_in = nn.Linear(512, 256)
|
163 |
+
|
164 |
+
self.decoder_rnns = nn.ModuleList(
|
165 |
+
[nn.GRUCell(256, 256) for _ in range(2)])
|
166 |
+
|
167 |
+
self.proj_to_mel = nn.Linear(256, in_dim * r)
|
168 |
+
self.max_decoder_steps = 200
|
169 |
+
|
170 |
+
def forward(self, encoder_outputs, inputs=None, memory_lengths=None):
|
171 |
+
"""
|
172 |
+
Decoder forward step.
|
173 |
+
|
174 |
+
If decoder inputs are not given (e.g., at testing time), as noted in
|
175 |
+
Tacotron paper, greedy decoding is adapted.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
encoder_outputs: Encoder outputs. (B, T_encoder, dim)
|
179 |
+
inputs: Decoder inputs. i.e., mel-spectrogram. If None (at eval-time),
|
180 |
+
decoder outputs are used as decoder inputs.
|
181 |
+
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
182 |
+
attention masking.
|
183 |
+
"""
|
184 |
+
B = encoder_outputs.size(0)
|
185 |
+
|
186 |
+
processed_memory = self.memory_layer(encoder_outputs)
|
187 |
+
if memory_lengths is not None:
|
188 |
+
mask = get_mask_from_lengths(processed_memory, memory_lengths)
|
189 |
+
else:
|
190 |
+
mask = None
|
191 |
+
|
192 |
+
# Run greedy decoding if inputs is None
|
193 |
+
greedy = inputs is None
|
194 |
+
|
195 |
+
if inputs is not None:
|
196 |
+
# Grouping multiple frames if necessary
|
197 |
+
if inputs.size(-1) == self.in_dim:
|
198 |
+
inputs = inputs.view(B, inputs.size(1) // self.r, -1)
|
199 |
+
assert inputs.size(-1) == self.in_dim * self.r
|
200 |
+
T_decoder = inputs.size(1)
|
201 |
+
|
202 |
+
# go frames
|
203 |
+
initial_input = Variable(
|
204 |
+
encoder_outputs.data.new(B, self.in_dim * self.r).zero_())
|
205 |
+
|
206 |
+
# Init decoder states
|
207 |
+
attention_rnn_hidden = Variable(
|
208 |
+
encoder_outputs.data.new(B, 256).zero_())
|
209 |
+
decoder_rnn_hiddens = [Variable(
|
210 |
+
encoder_outputs.data.new(B, 256).zero_())
|
211 |
+
for _ in range(len(self.decoder_rnns))]
|
212 |
+
current_attention = Variable(
|
213 |
+
encoder_outputs.data.new(B, 256).zero_())
|
214 |
+
|
215 |
+
# Time first (T_decoder, B, in_dim)
|
216 |
+
if inputs is not None:
|
217 |
+
inputs = inputs.transpose(0, 1)
|
218 |
+
|
219 |
+
outputs = []
|
220 |
+
alignments = []
|
221 |
+
|
222 |
+
t = 0
|
223 |
+
current_input = initial_input
|
224 |
+
while True:
|
225 |
+
if t > 0:
|
226 |
+
current_input = outputs[-1] if greedy else inputs[t - 1]
|
227 |
+
# Prenet
|
228 |
+
current_input = self.prenet(current_input)
|
229 |
+
|
230 |
+
# Attention RNN
|
231 |
+
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
|
232 |
+
current_input, current_attention, attention_rnn_hidden,
|
233 |
+
encoder_outputs, processed_memory=processed_memory, mask=mask)
|
234 |
+
|
235 |
+
# Concat RNN output and attention context vector
|
236 |
+
decoder_input = self.project_to_decoder_in(
|
237 |
+
torch.cat((attention_rnn_hidden, current_attention), -1))
|
238 |
+
|
239 |
+
# Pass through the decoder RNNs
|
240 |
+
for idx in range(len(self.decoder_rnns)):
|
241 |
+
decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
242 |
+
decoder_input, decoder_rnn_hiddens[idx])
|
243 |
+
# Residual connectinon
|
244 |
+
decoder_input = decoder_rnn_hiddens[idx] + decoder_input
|
245 |
+
|
246 |
+
output = decoder_input
|
247 |
+
output = self.proj_to_mel(output)
|
248 |
+
|
249 |
+
outputs += [output]
|
250 |
+
alignments += [alignment]
|
251 |
+
|
252 |
+
t += 1
|
253 |
+
|
254 |
+
if greedy:
|
255 |
+
if t > 1 and is_end_of_frames(output):
|
256 |
+
break
|
257 |
+
elif t > self.max_decoder_steps:
|
258 |
+
print("Warning! doesn't seems to be converged")
|
259 |
+
break
|
260 |
+
else:
|
261 |
+
if t >= T_decoder:
|
262 |
+
break
|
263 |
+
|
264 |
+
assert greedy or len(outputs) == T_decoder
|
265 |
+
|
266 |
+
# Back to batch first
|
267 |
+
alignments = torch.stack(alignments).transpose(0, 1)
|
268 |
+
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
269 |
+
|
270 |
+
return outputs, alignments
|
271 |
+
|
272 |
+
|
273 |
+
def is_end_of_frames(output, eps=0.2):
|
274 |
+
return (output.data <= eps).all()
|
275 |
+
|
276 |
+
|
277 |
+
class Tacotron(nn.Module):
|
278 |
+
def __init__(self, n_vocab, embedding_dim=256, mel_dim=80, linear_dim=1025,
|
279 |
+
r=5, padding_idx=None, use_memory_mask=False):
|
280 |
+
super(Tacotron, self).__init__()
|
281 |
+
self.mel_dim = mel_dim
|
282 |
+
self.linear_dim = linear_dim
|
283 |
+
self.use_memory_mask = use_memory_mask
|
284 |
+
self.embedding = nn.Embedding(n_vocab, embedding_dim,
|
285 |
+
padding_idx=padding_idx)
|
286 |
+
# Trying smaller std
|
287 |
+
self.embedding.weight.data.normal_(0, 0.3)
|
288 |
+
self.encoder = Encoder(embedding_dim)
|
289 |
+
self.decoder = Decoder(mel_dim, r)
|
290 |
+
|
291 |
+
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
292 |
+
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
293 |
+
|
294 |
+
def forward(self, inputs, targets=None, input_lengths=None):
|
295 |
+
B = inputs.size(0)
|
296 |
+
|
297 |
+
inputs = self.embedding(inputs)
|
298 |
+
# (B, T', in_dim)
|
299 |
+
encoder_outputs = self.encoder(inputs, input_lengths)
|
300 |
+
|
301 |
+
if self.use_memory_mask:
|
302 |
+
memory_lengths = input_lengths
|
303 |
+
else:
|
304 |
+
memory_lengths = None
|
305 |
+
# (B, T', mel_dim*r)
|
306 |
+
mel_outputs, alignments = self.decoder(
|
307 |
+
encoder_outputs, targets, memory_lengths=memory_lengths)
|
308 |
+
|
309 |
+
# Post net processing below
|
310 |
+
|
311 |
+
# Reshape
|
312 |
+
# (B, T, mel_dim)
|
313 |
+
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
314 |
+
|
315 |
+
linear_outputs = self.postnet(mel_outputs)
|
316 |
+
linear_outputs = self.last_linear(linear_outputs)
|
317 |
+
|
318 |
+
return mel_outputs, linear_outputs, alignments
|
tests/test_attention.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from tacotron_pytorch.attention import BahdanauAttention, AttentionWrapper
|
6 |
+
from tacotron_pytorch.attention import get_mask_from_lengths
|
7 |
+
|
8 |
+
|
9 |
+
def test_attention_wrapper():
|
10 |
+
B = 2
|
11 |
+
|
12 |
+
encoder_outputs = Variable(torch.rand(B, 100, 256))
|
13 |
+
memory_lengths = [100, 50]
|
14 |
+
|
15 |
+
mask = get_mask_from_lengths(encoder_outputs, memory_lengths)
|
16 |
+
print("Mask size:", mask.size())
|
17 |
+
|
18 |
+
memory_layer = nn.Linear(256, 256)
|
19 |
+
query = Variable(torch.rand(B, 128))
|
20 |
+
|
21 |
+
attention_mechanism = BahdanauAttention(256)
|
22 |
+
|
23 |
+
# Attention context + input
|
24 |
+
rnn = nn.GRUCell(256 + 128, 256)
|
25 |
+
|
26 |
+
attention_rnn = AttentionWrapper(rnn, attention_mechanism)
|
27 |
+
initial_attention = Variable(torch.zeros(B, 256))
|
28 |
+
cell_state = Variable(torch.zeros(B, 256))
|
29 |
+
|
30 |
+
processed_memory = memory_layer(encoder_outputs)
|
31 |
+
|
32 |
+
cell_output, attention, alignment = attention_rnn(
|
33 |
+
query, initial_attention, cell_state, encoder_outputs,
|
34 |
+
processed_memory=processed_memory,
|
35 |
+
mask=None, memory_lengths=memory_lengths)
|
36 |
+
|
37 |
+
print("Cell output size:", cell_output.size())
|
38 |
+
print("Attention output size:", attention.size())
|
39 |
+
print("Alignment size:", alignment.size())
|
40 |
+
|
41 |
+
assert (alignment.sum(-1) == 1).data.all()
|
42 |
+
|
43 |
+
|
44 |
+
test_attention_wrapper()
|
tests/test_tacotron.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
import sys
|
3 |
+
from os.path import dirname, join
|
4 |
+
tacotron_lib_dir = join(dirname(__file__), "..", "lib", "tacotron")
|
5 |
+
sys.path.append(tacotron_lib_dir)
|
6 |
+
from text import text_to_sequence, symbols
|
7 |
+
import torch
|
8 |
+
from torch.autograd import Variable
|
9 |
+
from tacotron_pytorch import Tacotron
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
def _pad(seq, max_len):
|
14 |
+
return np.pad(seq, (0, max_len - len(seq)),
|
15 |
+
mode='constant', constant_values=0)
|
16 |
+
|
17 |
+
|
18 |
+
def test_taco():
|
19 |
+
B, T_out, D_out = 2, 400, 80
|
20 |
+
r = 5
|
21 |
+
T_encoder = T_out // r
|
22 |
+
|
23 |
+
texts = ["Thank you very much.", "Hello"]
|
24 |
+
seqs = [np.array(text_to_sequence(
|
25 |
+
t, ["english_cleaners"]), dtype=np.int) for t in texts]
|
26 |
+
input_lengths = np.array([len(s) for s in seqs])
|
27 |
+
max_len = np.max(input_lengths)
|
28 |
+
seqs = np.array([_pad(s, max_len) for s in seqs])
|
29 |
+
|
30 |
+
x = torch.LongTensor(seqs)
|
31 |
+
y = torch.rand(B, T_out, D_out)
|
32 |
+
x = Variable(x)
|
33 |
+
y = Variable(y)
|
34 |
+
|
35 |
+
model = Tacotron(n_vocab=len(symbols), r=r)
|
36 |
+
|
37 |
+
print("Encoder input shape: ", x.size())
|
38 |
+
print("Decoder input shape: ", y.size())
|
39 |
+
a, b, c = model(x, y, input_lengths=input_lengths)
|
40 |
+
print("Mel shape:", a.size())
|
41 |
+
print("Linear shape:", b.size())
|
42 |
+
print("Attention shape:", c.size())
|
43 |
+
|
44 |
+
assert c.size() == (B, T_encoder, max_len)
|
45 |
+
|
46 |
+
# Test greddy decoding
|
47 |
+
a, b, c = model(x, input_lengths=input_lengths)
|