NTT123 commited on
Commit
df1ad02
·
1 Parent(s): 73eaac3

a slow but working model

Browse files
.gitattributes CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ pretrained_model_ljs_500k.ckpt filter=lfs diff=lfs merge=lfs -text
29
+ wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt filter=lfs diff=lfs merge=lfs -text
alphabet.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _
2
+
3
+ !
4
+ "
5
+ '
6
+ (
7
+ )
8
+ ,
9
+ -
10
+ .
11
+ :
12
+ ;
13
+ ?
14
+ [
15
+ ]
16
+ a
17
+ b
18
+ c
19
+ d
20
+ e
21
+ f
22
+ g
23
+ h
24
+ i
25
+ j
26
+ k
27
+ l
28
+ m
29
+ n
30
+ o
31
+ p
32
+ q
33
+ r
34
+ s
35
+ t
36
+ u
37
+ v
38
+ w
39
+ x
40
+ y
41
+ z
app.py CHANGED
@@ -1,7 +1,33 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
 
4
 
5
+ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
6
+ "./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
7
+ )
8
+
9
+
10
+ wavegru_config, wavegru_net = load_wavegru_net(
11
+ "./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
12
+ )
13
+
14
+
15
+ def speak(text):
16
+ mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
17
+ y = mel_to_wav(wavegru_net, mel, wavegru_config)
18
+ return 24_000, y
19
+
20
+
21
+ title = "WaveGRU-TTS"
22
+ description = "WaveGRU text-to-speech demo."
23
+
24
+ gr.Interface(
25
+ fn=speak,
26
+ inputs="text",
27
+ outputs="audio",
28
+ title=title,
29
+ description=description,
30
+ theme="default",
31
+ allow_screenshot=False,
32
+ allow_flagging="never",
33
+ ).launch(debug=False)
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import librosa
4
+ import numpy as np
5
+ import pax
6
+
7
+ from text import english_cleaners
8
+ from utils import (
9
+ create_tacotron_model,
10
+ load_tacotron_ckpt,
11
+ load_tacotron_config,
12
+ load_wavegru_ckpt,
13
+ load_wavegru_config,
14
+ )
15
+ from wavegru import WaveGRU
16
+
17
+
18
+ def load_tacotron_model(alphabet_file, config_file, model_file):
19
+ """load tacotron model to memory"""
20
+ with open(alphabet_file, "r", encoding="utf-8") as f:
21
+ alphabet = f.read().split("\n")
22
+
23
+ config = load_tacotron_config(config_file)
24
+ net = create_tacotron_model(config)
25
+ _, net, _ = load_tacotron_ckpt(net, None, model_file)
26
+ net = net.eval()
27
+ net = jax.device_put(net)
28
+ return alphabet, net, config
29
+
30
+
31
+ tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=10000))
32
+
33
+
34
+ def text_to_mel(net, text, alphabet, config):
35
+ """convert text to mel spectrogram"""
36
+ text = english_cleaners(text)
37
+ text = text + config["PAD"] * (100 - (len(text) % 100))
38
+ tokens = [alphabet.index(c) for c in text]
39
+ tokens = jnp.array(tokens, dtype=jnp.int32)
40
+ mel = tacotron_inference_fn(net, tokens[None])
41
+ return mel
42
+
43
+
44
+ def load_wavegru_net(config_file, model_file):
45
+ """load wavegru to memory"""
46
+ config = load_wavegru_config(config_file)
47
+ net = WaveGRU(
48
+ mel_dim=config["mel_dim"],
49
+ embed_dim=config["embed_dim"],
50
+ rnn_dim=config["rnn_dim"],
51
+ upsample_factors=config["upsample_factors"],
52
+ )
53
+ _, net, _ = load_wavegru_ckpt(net, None, model_file)
54
+ net = net.eval()
55
+ net = jax.device_put(net)
56
+ return config, net
57
+
58
+
59
+ wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=False))
60
+
61
+
62
+ def mel_to_wav(net, mel, config):
63
+ """convert mel to wav"""
64
+ if len(mel.shape) == 2:
65
+ mel = mel[None]
66
+ pad = config["num_pad_frames"] // 2 + 4
67
+ mel = np.pad(
68
+ mel,
69
+ [(0, 0), (pad, pad), (0, 0)],
70
+ constant_values=np.log(config["mel_min"]),
71
+ )
72
+ x = wavegru_inference(net, mel)
73
+ x = jax.device_get(x)
74
+
75
+ wav = librosa.mu_expand(x - 127, mu=255)
76
+ wav = librosa.effects.deemphasis(wav, coef=0.86)
77
+ wav = wav * 2.0
78
+ wav = wav / max(1.0, np.max(np.abs(wav)))
79
+ wav = wav * 2**15
80
+ wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
81
+ wav = wav.astype(np.int16)
82
+ return wav
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1-dev
pooch.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ def os_cache(x):
2
+ return x
3
+
4
+
5
+ def create(*args, **kwargs):
6
+ class T:
7
+ def load_registry(self, *args, **kwargs):
8
+ return None
9
+
10
+ return T()
pretrained_model_ljs_500k.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4eabdcac35cd016469d17678f9549bd25d1c9bf66c9089ea9f0632619ba91194
3
+ size 53221435
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ jax==0.3.1
2
+ jaxlib==0.3.0
3
+ numpy==1.22.3
4
+ librosa==0.9.1
5
+ pax3==0.5.6
6
+ gradio
7
+ jinja2
8
+ toml==0.10.2
9
+ unidecode==1.3.4
10
+ pyyaml==6.0
tacotron.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tacotron + stepwise monotonic attention
3
+ """
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import pax
8
+
9
+
10
+ def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
11
+ """
12
+ Conv >> LayerNorm >> activation >> Dropout
13
+ """
14
+ f = pax.Sequential(
15
+ pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
16
+ pax.LayerNorm(out_ft, -1, True, True),
17
+ )
18
+ if activation_fn is not None:
19
+ f >>= activation_fn
20
+ if use_dropout:
21
+ f >>= pax.Dropout(0.5)
22
+ return f
23
+
24
+
25
+ class HighwayBlock(pax.Module):
26
+ """
27
+ Highway block
28
+ """
29
+
30
+ def __init__(self, dim: int) -> None:
31
+ super().__init__()
32
+ self.dim = dim
33
+ self.fc = pax.Linear(dim, 2 * dim)
34
+
35
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
36
+ t, h = jnp.split(self.fc(x), 2, axis=-1)
37
+ t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
38
+ h = jax.nn.relu(h)
39
+ x = x * (1.0 - t) + h * t
40
+ return x
41
+
42
+
43
+ class BiGRU(pax.Module):
44
+ """
45
+ Bidirectional GRU
46
+ """
47
+
48
+ def __init__(self, dim):
49
+ super().__init__()
50
+
51
+ self.rnn_fwd = pax.GRU(dim, dim)
52
+ self.rnn_bwd = pax.GRU(dim, dim)
53
+
54
+ def __call__(self, x, reset_masks):
55
+ N = x.shape[0]
56
+ x_fwd = x
57
+ x_bwd = jnp.flip(x, axis=1)
58
+ x_fwd_states = self.rnn_fwd.initial_state(N)
59
+ x_bwd_states = self.rnn_bwd.initial_state(N)
60
+ x_fwd_states, x_fwd = pax.scan(
61
+ self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
62
+ )
63
+
64
+ reset_masks = jnp.flip(reset_masks, axis=1)
65
+ x_bwd_states0 = x_bwd_states
66
+
67
+ def rnn_reset_core(prev, inputs):
68
+ x, reset_mask = inputs
69
+
70
+ def reset_state(x0, xt):
71
+ return jnp.where(reset_mask, x0, xt)
72
+
73
+ state, _ = self.rnn_bwd(prev, x)
74
+ state = jax.tree_map(reset_state, x_bwd_states0, state)
75
+ return state, state.hidden
76
+
77
+ x_bwd_states, x_bwd = pax.scan(
78
+ rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
79
+ )
80
+ x_bwd = jnp.flip(x_bwd, axis=1)
81
+ x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
82
+ return x
83
+
84
+
85
+ class CBHG(pax.Module):
86
+ """
87
+ Conv Bank >> Highway net >> GRU
88
+ """
89
+
90
+ def __init__(self, dim):
91
+ super().__init__()
92
+ self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
93
+ self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
94
+ self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
95
+
96
+ self.highway = pax.Sequential(
97
+ HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
98
+ )
99
+ self.rnn = BiGRU(dim)
100
+
101
+ def __call__(self, x, x_mask):
102
+ conv_input = x * x_mask
103
+ fts = [f(conv_input) for f in self.convs]
104
+ residual = jnp.concatenate(fts, axis=-1)
105
+ residual = pax.max_pool(residual, 2, 1, "SAME", -1)
106
+ residual = self.conv_projection_1(residual * x_mask)
107
+ residual = self.conv_projection_2(residual * x_mask)
108
+ x = x + residual
109
+ x = self.highway(x)
110
+ x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
111
+ return x * x_mask
112
+
113
+
114
+ class PreNet(pax.Module):
115
+ """
116
+ Linear >> relu >> dropout >> Linear >> relu >> dropout
117
+ """
118
+
119
+ def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
120
+ super().__init__()
121
+ self.fc1 = pax.Linear(input_dim, hidden_dim)
122
+ self.fc2 = pax.Linear(hidden_dim, output_dim)
123
+ self.rng_seq = pax.RngSeq()
124
+ self.always_dropout = always_dropout
125
+
126
+ def __call__(self, x, k1=None, k2=None):
127
+ x = self.fc1(x)
128
+ x = jax.nn.relu(x)
129
+ if self.always_dropout or self.training:
130
+ if k1 is None:
131
+ k1 = self.rng_seq.next_rng_key()
132
+ x = pax.dropout(k1, 0.5, x)
133
+ x = self.fc2(x)
134
+ x = jax.nn.relu(x)
135
+ if self.always_dropout or self.training:
136
+ if k2 is None:
137
+ k2 = self.rng_seq.next_rng_key()
138
+ x = pax.dropout(k2, 0.5, x)
139
+ return x
140
+
141
+
142
+ class Tacotron(pax.Module):
143
+ """
144
+ Tacotron TTS model.
145
+
146
+ It uses stepwise monotonic attention for robust attention.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ mel_dim: int,
152
+ attn_bias,
153
+ rr,
154
+ max_rr,
155
+ mel_min,
156
+ sigmoid_noise,
157
+ pad_token,
158
+ prenet_dim,
159
+ attn_hidden_dim,
160
+ attn_rnn_dim,
161
+ rnn_dim,
162
+ postnet_dim,
163
+ text_dim,
164
+ ):
165
+ """
166
+ New Tacotron model
167
+
168
+ Args:
169
+ mel_dim (int): dimension of log mel-spectrogram features.
170
+ attn_bias (float): control how "slow" the attention will
171
+ move forward at initialization.
172
+ rr (int): the reduction factor.
173
+ Number of predicted frame at each time step. Default is 2.
174
+ max_rr (int): max value of rr.
175
+ mel_min (float): the minimum value of mel features.
176
+ The <go> frame is filled by `log(mel_min)` values.
177
+ sigmoid_noise (float): the variance of gaussian noise added
178
+ to attention scores in training.
179
+ pad_token (int): the pad value at the end of text sequences.
180
+ prenet_dim (int): dimension of prenet output.
181
+ attn_hidden_dim (int): dimension of attention hidden vectors.
182
+ attn_rnn_dim (int): number of cells in the attention RNN.
183
+ rnn_dim (int): number of cells in the decoder RNNs.
184
+ postnet_dim (int): number of features in the postnet convolutions.
185
+ text_dim (int): dimension of text embedding vectors.
186
+ """
187
+ super().__init__()
188
+ self.text_dim = text_dim
189
+ assert rr <= max_rr
190
+ self.rr = rr
191
+ self.max_rr = max_rr
192
+ self.mel_dim = mel_dim
193
+ self.mel_min = mel_min
194
+ self.sigmoid_noise = sigmoid_noise
195
+ self.pad_token = pad_token
196
+ self.prenet_dim = prenet_dim
197
+
198
+ # encoder submodules
199
+ self.encoder_embed = pax.Embed(256, text_dim)
200
+ self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
201
+ self.encoder_cbhg = CBHG(prenet_dim)
202
+
203
+ # random key generator
204
+ self.rng_seq = pax.RngSeq()
205
+
206
+ # pre-net
207
+ self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
208
+
209
+ # decoder submodules
210
+ self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
211
+ self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
212
+ self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
213
+
214
+ self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
215
+ self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
216
+ self.attn_V_bias = jnp.array(attn_bias)
217
+ self.attn_log = jnp.zeros((1,))
218
+ self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
219
+ self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
220
+ self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
221
+ # mel + end-of-sequence token
222
+ self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
223
+
224
+ # post-net
225
+ self.post_net = pax.Sequential(
226
+ conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
227
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
228
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
229
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
230
+ conv_block(postnet_dim, mel_dim, 5, None, True),
231
+ )
232
+
233
+ parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
234
+
235
+ def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
236
+ """
237
+ Encode text to a sequence of real vectors
238
+ """
239
+ N, L = text.shape
240
+ text_mask = (text != self.pad_token)[..., None]
241
+ x = self.encoder_embed(text)
242
+ x = self.encoder_pre_net(x)
243
+ x = self.encoder_cbhg(x, text_mask)
244
+ return x
245
+
246
+ def go_frame(self, batch_size: int) -> jnp.ndarray:
247
+ """
248
+ return the go frame
249
+ """
250
+ return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
251
+
252
+ def decoder_initial_state(self, N: int, L: int):
253
+ """
254
+ setup decoder initial state
255
+ """
256
+ attn_context = jnp.zeros((N, self.prenet_dim * 2))
257
+ attn_pr = jax.nn.one_hot(
258
+ jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
259
+ )
260
+
261
+ attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
262
+ decoder_rnn_states = (
263
+ self.decoder_rnn1.initial_state(N),
264
+ self.decoder_rnn2.initial_state(N),
265
+ )
266
+ return attn_state, decoder_rnn_states
267
+
268
+ def monotonic_attention(self, prev_state, inputs, envs):
269
+ """
270
+ Stepwise monotonic attention
271
+ """
272
+ attn_rnn_state, attn_context, prev_attn_pr = prev_state
273
+ x, attn_rng_key = inputs
274
+ text, text_key = envs
275
+ attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
276
+ attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
277
+ attn_query_input = attn_rnn_output
278
+ attn_query = self.attn_query_fc(attn_query_input)
279
+ attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
280
+ score = self.attn_V(attn_hidden)
281
+ score = jnp.squeeze(score, axis=-1)
282
+ weight_norm = jnp.linalg.norm(self.attn_V.weight)
283
+ score = score * (self.attn_V_weight_norm / weight_norm)
284
+ score = score + self.attn_V_bias
285
+ noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
286
+ pr_stay = jax.nn.sigmoid(score + noise)
287
+ pr_move = 1.0 - pr_stay
288
+ pr_new_location = pr_move * prev_attn_pr
289
+ pr_new_location = jnp.pad(
290
+ pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
291
+ )
292
+ attn_pr = pr_stay * prev_attn_pr + pr_new_location
293
+ attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
294
+ new_state = (attn_rnn_state, attn_context, attn_pr)
295
+ return new_state, attn_rnn_output
296
+
297
+ def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
298
+ """
299
+ Return a zoneout lstm core.
300
+
301
+ It will zoneout the new hidden states and keep the new cell states unchanged.
302
+ """
303
+
304
+ def core(state, x):
305
+ new_state, _ = lstm_core(state, x)
306
+ h_old = state.hidden
307
+ h_new = new_state.hidden
308
+ mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
309
+ h_new = h_old * mask + h_new * (1.0 - mask)
310
+ return pax.LSTMState(h_new, new_state.cell), h_new
311
+
312
+ return core
313
+
314
+ def decoder_step(
315
+ self,
316
+ attn_state,
317
+ decoder_rnn_states,
318
+ rng_key,
319
+ mel,
320
+ text,
321
+ text_key,
322
+ call_pre_net=False,
323
+ ):
324
+ """
325
+ One decoder step
326
+ """
327
+ if call_pre_net:
328
+ k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
329
+ mel = self.decoder_pre_net(mel, k1, k2)
330
+ else:
331
+ zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
332
+ attn_inputs = (mel, rng_key)
333
+ attn_envs = (text, text_key)
334
+ attn_state, attn_rnn_output = self.monotonic_attention(
335
+ attn_state, attn_inputs, attn_envs
336
+ )
337
+ (_, attn_context, attn_pr) = attn_state
338
+ (decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
339
+ decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
340
+ decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
341
+ decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
342
+ decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
343
+ decoder_rnn_state1, decoder_rnn1_input
344
+ )
345
+ decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
346
+ decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
347
+ decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
348
+ decoder_rnn_state2, decoder_rnn2_input
349
+ )
350
+ x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
351
+ decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
352
+ return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
353
+
354
+ @jax.jit
355
+ def inference_step(
356
+ self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
357
+ ):
358
+ """one inference step"""
359
+ attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
360
+ attn_state,
361
+ decoder_rnn_states,
362
+ rng_key,
363
+ mel,
364
+ text,
365
+ text_key,
366
+ call_pre_net=True,
367
+ )
368
+ x = self.output_fc(x)
369
+ N, D2 = x.shape
370
+ x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
371
+ x = x[:, : self.rr, :]
372
+ x = jnp.reshape(x, (N, self.rr, -1))
373
+ mel = x[..., :-1]
374
+ eos = x[..., -1]
375
+ return attn_state, decoder_rnn_states, rng_key, (mel, eos)
376
+
377
+ def inference(self, text, seed=42, max_len=1000):
378
+ """
379
+ text to mel
380
+ """
381
+ text = self.encode_text(text)
382
+ text_key = self.text_key_fc(text)
383
+ N, L, D = text.shape
384
+ mel = self.go_frame(N)
385
+
386
+ attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
387
+ rng_key = jax.random.PRNGKey(seed)
388
+ mels = []
389
+ count = 0
390
+ while True:
391
+ count = count + 1
392
+ attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
393
+ attn_state, decoder_rnn_states, rng_key, mel, text, text_key
394
+ )
395
+ mels.append(mel)
396
+ if eos[0, -1].item() > 0 or count > max_len:
397
+ break
398
+
399
+ mel = mel[:, -1, :]
400
+
401
+ mels = jnp.concatenate(mels, axis=1)
402
+ mel = mel + self.post_net(mel)
403
+ return mels
404
+
405
+ def decode(self, mel, text):
406
+ """
407
+ Attention mechanism + Decoder
408
+ """
409
+ text_key = self.text_key_fc(text)
410
+
411
+ def scan_fn(prev_states, inputs):
412
+ attn_state, decoder_rnn_states = prev_states
413
+ x, rng_key = inputs
414
+ attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
415
+ attn_state, decoder_rnn_states, rng_key, x, text, text_key
416
+ )
417
+ states = (attn_state, decoder_rnn_states)
418
+ return states, (output, attn_pr)
419
+
420
+ N, L, D = text.shape
421
+ decoder_states = self.decoder_initial_state(N, L)
422
+ rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
423
+ rng_keys = jnp.stack(rng_keys, axis=1)
424
+ decoder_states, (x, attn_log) = pax.scan(
425
+ scan_fn,
426
+ decoder_states,
427
+ (mel, rng_keys),
428
+ time_major=False,
429
+ )
430
+ self.attn_log = attn_log
431
+ del decoder_states
432
+ x = self.output_fc(x)
433
+
434
+ N, T2, D2 = x.shape
435
+ x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
436
+ x = x[:, :, : self.rr, :]
437
+ x = jnp.reshape(x, (N, T2 * self.rr, -1))
438
+ mel = x[..., :-1]
439
+ eos = x[..., -1]
440
+ return mel, eos
441
+
442
+ def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
443
+ text = self.encode_text(text)
444
+ mel = self.decoder_pre_net(mel)
445
+ mel, eos = self.decode(mel, text)
446
+ return mel, mel + self.post_net(mel), eos
tacotron.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tacotron]
2
+
3
+ # training
4
+ BATCH_SIZE = 64
5
+ LR=1024e-6 # learning rate
6
+ MODEL_PREFIX = "mono_tts_cbhg_small"
7
+ LOG_DIR = "./logs"
8
+ CKPT_DIR = "./ckpts"
9
+ USE_MP = false # use mixed-precision training
10
+
11
+ # data
12
+ TF_DATA_DIR = "./tf_data" # tensorflow data directory
13
+ TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
14
+ SAMPLE_RATE = 24000 # convert to this sample rate if needed
15
+ MEL_DIM = 80 # the dimension of melspectrogram features
16
+ MEL_MIN = 1e-5
17
+ PAD = "_" # padding character
18
+ PAD_TOKEN = 0
19
+ TEST_DATA_SIZE = 1024
20
+
21
+ # model
22
+ RR = 2 # reduction factor
23
+ MAX_RR=2
24
+ ATTN_BIAS = 0.0 # control how slow the attention moves forward
25
+ SIGMOID_NOISE = 2.0
26
+ PRENET_DIM = 128
27
+ TEXT_DIM = 256
28
+ RNN_DIM = 512
29
+ ATTN_RNN_DIM = 256
30
+ ATTN_HIDDEN_DIM = 128
31
+ POSTNET_DIM = 512
text.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ """
14
+
15
+ import re
16
+
17
+ from unidecode import unidecode
18
+
19
+ # Regular expression matching whitespace:
20
+ _whitespace_re = re.compile(r"\s+")
21
+
22
+ # List of (regular expression, replacement) pairs for abbreviations:
23
+ _abbreviations = [
24
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
25
+ for x in [
26
+ ("mrs", "misess"),
27
+ ("mr", "mister"),
28
+ ("dr", "doctor"),
29
+ ("st", "saint"),
30
+ ("co", "company"),
31
+ ("jr", "junior"),
32
+ ("maj", "major"),
33
+ ("gen", "general"),
34
+ ("drs", "doctors"),
35
+ ("rev", "reverend"),
36
+ ("lt", "lieutenant"),
37
+ ("hon", "honorable"),
38
+ ("sgt", "sergeant"),
39
+ ("capt", "captain"),
40
+ ("esq", "esquire"),
41
+ ("ltd", "limited"),
42
+ ("col", "colonel"),
43
+ ("ft", "fort"),
44
+ ]
45
+ ]
46
+
47
+
48
+ def expand_abbreviations(text):
49
+ for regex, replacement in _abbreviations:
50
+ text = re.sub(regex, replacement, text)
51
+ return text
52
+
53
+
54
+ def lowercase(text):
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_abbreviations(text)
86
+ text = collapse_whitespace(text)
87
+ return text
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions
3
+ """
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ import pax
8
+ import toml
9
+ import yaml
10
+
11
+ from tacotron import Tacotron
12
+
13
+
14
+ def load_tacotron_config(config_file=Path("tacotron.toml")):
15
+ """
16
+ Load the project configurations
17
+ """
18
+ return toml.load(config_file)["tacotron"]
19
+
20
+
21
+ def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
22
+ """
23
+ load checkpoint from disk
24
+ """
25
+ with open(path, "rb") as f:
26
+ dic = pickle.load(f)
27
+ if net is not None:
28
+ net = net.load_state_dict(dic["model_state_dict"])
29
+ if optim is not None:
30
+ optim = optim.load_state_dict(dic["optim_state_dict"])
31
+ return dic["step"], net, optim
32
+
33
+
34
+ def create_tacotron_model(config):
35
+ """
36
+ return a random initialized Tacotron model
37
+ """
38
+ return Tacotron(
39
+ mel_dim=config["MEL_DIM"],
40
+ attn_bias=config["ATTN_BIAS"],
41
+ rr=config["RR"],
42
+ max_rr=config["MAX_RR"],
43
+ mel_min=config["MEL_MIN"],
44
+ sigmoid_noise=config["SIGMOID_NOISE"],
45
+ pad_token=config["PAD_TOKEN"],
46
+ prenet_dim=config["PRENET_DIM"],
47
+ attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
48
+ attn_rnn_dim=config["ATTN_RNN_DIM"],
49
+ rnn_dim=config["RNN_DIM"],
50
+ postnet_dim=config["POSTNET_DIM"],
51
+ text_dim=config["TEXT_DIM"],
52
+ )
53
+
54
+
55
+ def load_wavegru_config(config_file):
56
+ """
57
+ Load project configurations
58
+ """
59
+ with open(config_file, "r", encoding="utf-8") as f:
60
+ return yaml.safe_load(f)
61
+
62
+
63
+ def load_wavegru_ckpt(net, optim, ckpt_file):
64
+ """
65
+ load training checkpoint from file
66
+ """
67
+ with open(ckpt_file, "rb") as f:
68
+ dic = pickle.load(f)
69
+
70
+ if net is not None:
71
+ net = net.load_state_dict(dic["net_state_dict"])
72
+ if optim is not None:
73
+ optim = optim.load_state_dict(dic["optim_state_dict"])
74
+ return dic["step"], net, optim
wavegru.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WaveGRU model: melspectrogram => mu-law encoded waveform
3
+ """
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import pax
8
+
9
+
10
+ class ReLU(pax.Module):
11
+ def __call__(self, x):
12
+ return jax.nn.relu(x)
13
+
14
+
15
+ def dilated_residual_conv_block(dim, kernel, stride, dilation):
16
+ """
17
+ Use dilated convs to enlarge the receptive field
18
+ """
19
+ return pax.Sequential(
20
+ pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
21
+ pax.LayerNorm(dim, -1, True, True),
22
+ ReLU(),
23
+ pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
24
+ pax.LayerNorm(dim, -1, True, True),
25
+ ReLU(),
26
+ )
27
+
28
+
29
+ def tile_1d(x, factor):
30
+ """
31
+ Tile tensor of shape N, L, D into N, L*factor, D
32
+ """
33
+ N, L, D = x.shape
34
+ x = x[:, :, None, :]
35
+ x = jnp.tile(x, (1, 1, factor, 1))
36
+ x = jnp.reshape(x, (N, L * factor, D))
37
+ return x
38
+
39
+
40
+ def up_block(dim, factor):
41
+ """
42
+ Tile >> Conv >> BatchNorm >> ReLU
43
+ """
44
+ return pax.Sequential(
45
+ lambda x: tile_1d(x, factor),
46
+ pax.Conv1D(dim, dim, 2 * factor, stride=1, padding="VALID", with_bias=False),
47
+ pax.LayerNorm(dim, -1, True, True),
48
+ ReLU(),
49
+ )
50
+
51
+
52
+ class Upsample(pax.Module):
53
+ """
54
+ Upsample melspectrogram to match raw audio sample rate.
55
+ """
56
+
57
+ def __init__(self, input_dim, upsample_factors):
58
+ super().__init__()
59
+ self.input_conv = pax.Sequential(
60
+ pax.Conv1D(input_dim, 512, 1, with_bias=False),
61
+ pax.LayerNorm(512, -1, True, True),
62
+ )
63
+ self.upsample_factors = upsample_factors
64
+ self.dilated_convs = [
65
+ dilated_residual_conv_block(512, 3, 1, 2**i) for i in range(5)
66
+ ]
67
+ self.up_factors = upsample_factors[:-1]
68
+ self.up_blocks = [up_block(512, x) for x in self.up_factors]
69
+ self.final_tile = upsample_factors[-1]
70
+
71
+ def __call__(self, x):
72
+ x = self.input_conv(x)
73
+ for residual in self.dilated_convs:
74
+ y = residual(x)
75
+ pad = (x.shape[1] - y.shape[1]) // 2
76
+ x = x[:, pad:-pad, :] + y
77
+
78
+ for f in self.up_blocks:
79
+ x = f(x)
80
+
81
+ x = tile_1d(x, self.final_tile)
82
+ return x
83
+
84
+
85
+ class Pruner(pax.Module):
86
+ """
87
+ Base class for pruners
88
+ """
89
+
90
+ def __init__(self, update_freq=500):
91
+ super().__init__()
92
+ self.update_freq = update_freq
93
+
94
+ def compute_sparsity(self, step):
95
+ """
96
+ Two-stages pruning
97
+ """
98
+ t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
99
+ z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
100
+ for i in range(4):
101
+ t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
102
+ z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
103
+ return z
104
+
105
+ def prune(self, step, weights):
106
+ """
107
+ Return a mask
108
+ """
109
+ z = self.compute_sparsity(step)
110
+ x = weights
111
+ H, W = x.shape
112
+ x = x.reshape(H // 4, 4, W // 4, 4)
113
+ x = jnp.abs(x)
114
+ x = jnp.sum(x, axis=(1, 3), keepdims=True)
115
+ q = jnp.quantile(jnp.reshape(x, (-1,)), z)
116
+ x = x >= q
117
+ x = jnp.tile(x, (1, 4, 1, 4))
118
+ x = jnp.reshape(x, (H, W))
119
+ return x
120
+
121
+
122
+ class GRUPruner(Pruner):
123
+ def __init__(self, gru, update_freq=500):
124
+ super().__init__(update_freq=update_freq)
125
+ self.xh_zr_fc_mask = jnp.ones_like(gru.xh_zr_fc.weight) == 1
126
+ self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1
127
+
128
+ def __call__(self, gru: pax.GRU):
129
+ """
130
+ Apply mask after an optimization step
131
+ """
132
+ zr_masked_weights = jnp.where(self.xh_zr_fc_mask, gru.xh_zr_fc.weight, 0)
133
+ gru = gru.replace_node(gru.xh_zr_fc.weight, zr_masked_weights)
134
+ h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
135
+ gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
136
+ return gru
137
+
138
+ def update_mask(self, step, gru: pax.GRU):
139
+ """
140
+ Update internal masks
141
+ """
142
+ xh_z_weight, xh_r_weight = jnp.split(gru.xh_zr_fc.weight, 2, axis=1)
143
+ xh_z_weight = self.prune(step, xh_z_weight)
144
+ xh_r_weight = self.prune(step, xh_r_weight)
145
+ self.xh_zr_fc_mask *= jnp.concatenate((xh_z_weight, xh_r_weight), axis=1)
146
+ self.xh_h_fc_mask *= self.prune(step, gru.xh_h_fc.weight)
147
+
148
+
149
+ class LinearPruner(Pruner):
150
+ def __init__(self, linear, update_freq=500):
151
+ super().__init__(update_freq=update_freq)
152
+ self.mask = jnp.ones_like(linear.weight) == 1
153
+
154
+ def __call__(self, linear: pax.Linear):
155
+ """
156
+ Apply mask after an optimization step
157
+ """
158
+ return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
159
+
160
+ def update_mask(self, step, linear: pax.Linear):
161
+ """
162
+ Update internal masks
163
+ """
164
+ self.mask *= self.prune(step, linear.weight)
165
+
166
+
167
+ class WaveGRU(pax.Module):
168
+ """
169
+ WaveGRU vocoder model
170
+ """
171
+
172
+ def __init__(
173
+ self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
174
+ ):
175
+ super().__init__()
176
+ self.embed = pax.Embed(256, embed_dim)
177
+ self.upsample = Upsample(input_dim=mel_dim, upsample_factors=upsample_factors)
178
+ self.rnn = pax.GRU(embed_dim + rnn_dim, rnn_dim)
179
+ self.o1 = pax.Linear(rnn_dim, rnn_dim)
180
+ self.o2 = pax.Linear(rnn_dim, 256)
181
+ self.gru_pruner = GRUPruner(self.rnn)
182
+ self.o1_pruner = LinearPruner(self.o1)
183
+ self.o2_pruner = LinearPruner(self.o2)
184
+
185
+ def output(self, x):
186
+ x = self.o1(x)
187
+ x = jax.nn.relu(x)
188
+ x = self.o2(x)
189
+ return x
190
+
191
+ @jax.jit
192
+ def inference_step(self, rnn_state, mel, rng_key, x):
193
+ """one inference step"""
194
+ x = self.embed(x)
195
+ x = jnp.concatenate((x, mel), axis=-1)
196
+ rnn_state, x = self.rnn(rnn_state, x)
197
+ x = self.output(x)
198
+ rng_key, next_rng_key = jax.random.split(rng_key, 2)
199
+ x = jax.random.categorical(rng_key, x, axis=-1)
200
+ return rnn_state, next_rng_key, x
201
+
202
+ def inference(self, mel, no_gru=False, seed=42):
203
+ """
204
+ generate waveform form melspectrogram
205
+ """
206
+
207
+ y = self.upsample(mel)
208
+ if no_gru:
209
+ return y
210
+ x = jnp.array([127], dtype=jnp.int32)
211
+ rnn_state = self.rnn.initial_state(1)
212
+ output = []
213
+ rng_key = jax.random.PRNGKey(seed)
214
+ for i in range(y.shape[1]):
215
+ rnn_state, rng_key, x = self.inference_step(rnn_state, y[:, i], rng_key, x)
216
+ output.append(x)
217
+ x = jnp.concatenate(output, axis=0)
218
+ return x
219
+
220
+ def __call__(self, mel, x):
221
+ x = self.embed(x)
222
+ y = self.upsample(mel)
223
+ pad_left = (x.shape[1] - y.shape[1]) // 2
224
+ pad_right = x.shape[1] - y.shape[1] - pad_left
225
+ x = x[:, pad_left:-pad_right]
226
+ x = jnp.concatenate((x, y), axis=-1)
227
+ _, x = pax.scan(
228
+ self.rnn,
229
+ self.rnn.initial_state(x.shape[0]),
230
+ x,
231
+ time_major=False,
232
+ )
233
+ x = self.output(x)
234
+ return x
wavegru.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## dsp
2
+ sample_rate : 24000
3
+ window_length: 50.0 # ms
4
+ hop_length: 12.5 # ms
5
+ mel_min: 1.0e-5 ## need .0 to make it a float
6
+ mel_dim: 80
7
+ n_fft: 2048
8
+
9
+ ## wavegru
10
+ embed_dim: 32
11
+ rnn_dim: 512
12
+ frames_per_sequence: 67
13
+ num_pad_frames: 62
14
+ upsample_factors: [5, 4, 3, 5]
wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c09ed822c5daac0afbd19e8ba4e0ded26dd5732e0efd13ce193c3f54c4e63f54
3
+ size 56479599