Spaces:
Runtime error
Runtime error
NTT123
commited on
Commit
·
df1ad02
1
Parent(s):
73eaac3
a slow but working model
Browse files- .gitattributes +2 -0
- alphabet.txt +41 -0
- app.py +30 -4
- inference.py +82 -0
- packages.txt +1 -0
- pooch.py +10 -0
- pretrained_model_ljs_500k.ckpt +3 -0
- requirements.txt +10 -0
- tacotron.py +446 -0
- tacotron.toml +31 -0
- text.py +87 -0
- utils.py +74 -0
- wavegru.py +234 -0
- wavegru.yaml +14 -0
- wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt +3 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
pretrained_model_ljs_500k.ckpt filter=lfs diff=lfs merge=lfs -text
|
29 |
+
wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt filter=lfs diff=lfs merge=lfs -text
|
alphabet.txt
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_
|
2 |
+
|
3 |
+
!
|
4 |
+
"
|
5 |
+
'
|
6 |
+
(
|
7 |
+
)
|
8 |
+
,
|
9 |
+
-
|
10 |
+
.
|
11 |
+
:
|
12 |
+
;
|
13 |
+
?
|
14 |
+
[
|
15 |
+
]
|
16 |
+
a
|
17 |
+
b
|
18 |
+
c
|
19 |
+
d
|
20 |
+
e
|
21 |
+
f
|
22 |
+
g
|
23 |
+
h
|
24 |
+
i
|
25 |
+
j
|
26 |
+
k
|
27 |
+
l
|
28 |
+
m
|
29 |
+
n
|
30 |
+
o
|
31 |
+
p
|
32 |
+
q
|
33 |
+
r
|
34 |
+
s
|
35 |
+
t
|
36 |
+
u
|
37 |
+
v
|
38 |
+
w
|
39 |
+
x
|
40 |
+
y
|
41 |
+
z
|
app.py
CHANGED
@@ -1,7 +1,33 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
|
|
|
4 |
|
5 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
6 |
+
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
11 |
+
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def speak(text):
|
16 |
+
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
17 |
+
y = mel_to_wav(wavegru_net, mel, wavegru_config)
|
18 |
+
return 24_000, y
|
19 |
+
|
20 |
+
|
21 |
+
title = "WaveGRU-TTS"
|
22 |
+
description = "WaveGRU text-to-speech demo."
|
23 |
+
|
24 |
+
gr.Interface(
|
25 |
+
fn=speak,
|
26 |
+
inputs="text",
|
27 |
+
outputs="audio",
|
28 |
+
title=title,
|
29 |
+
description=description,
|
30 |
+
theme="default",
|
31 |
+
allow_screenshot=False,
|
32 |
+
allow_flagging="never",
|
33 |
+
).launch(debug=False)
|
inference.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax
|
2 |
+
import jax.numpy as jnp
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import pax
|
6 |
+
|
7 |
+
from text import english_cleaners
|
8 |
+
from utils import (
|
9 |
+
create_tacotron_model,
|
10 |
+
load_tacotron_ckpt,
|
11 |
+
load_tacotron_config,
|
12 |
+
load_wavegru_ckpt,
|
13 |
+
load_wavegru_config,
|
14 |
+
)
|
15 |
+
from wavegru import WaveGRU
|
16 |
+
|
17 |
+
|
18 |
+
def load_tacotron_model(alphabet_file, config_file, model_file):
|
19 |
+
"""load tacotron model to memory"""
|
20 |
+
with open(alphabet_file, "r", encoding="utf-8") as f:
|
21 |
+
alphabet = f.read().split("\n")
|
22 |
+
|
23 |
+
config = load_tacotron_config(config_file)
|
24 |
+
net = create_tacotron_model(config)
|
25 |
+
_, net, _ = load_tacotron_ckpt(net, None, model_file)
|
26 |
+
net = net.eval()
|
27 |
+
net = jax.device_put(net)
|
28 |
+
return alphabet, net, config
|
29 |
+
|
30 |
+
|
31 |
+
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=10000))
|
32 |
+
|
33 |
+
|
34 |
+
def text_to_mel(net, text, alphabet, config):
|
35 |
+
"""convert text to mel spectrogram"""
|
36 |
+
text = english_cleaners(text)
|
37 |
+
text = text + config["PAD"] * (100 - (len(text) % 100))
|
38 |
+
tokens = [alphabet.index(c) for c in text]
|
39 |
+
tokens = jnp.array(tokens, dtype=jnp.int32)
|
40 |
+
mel = tacotron_inference_fn(net, tokens[None])
|
41 |
+
return mel
|
42 |
+
|
43 |
+
|
44 |
+
def load_wavegru_net(config_file, model_file):
|
45 |
+
"""load wavegru to memory"""
|
46 |
+
config = load_wavegru_config(config_file)
|
47 |
+
net = WaveGRU(
|
48 |
+
mel_dim=config["mel_dim"],
|
49 |
+
embed_dim=config["embed_dim"],
|
50 |
+
rnn_dim=config["rnn_dim"],
|
51 |
+
upsample_factors=config["upsample_factors"],
|
52 |
+
)
|
53 |
+
_, net, _ = load_wavegru_ckpt(net, None, model_file)
|
54 |
+
net = net.eval()
|
55 |
+
net = jax.device_put(net)
|
56 |
+
return config, net
|
57 |
+
|
58 |
+
|
59 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=False))
|
60 |
+
|
61 |
+
|
62 |
+
def mel_to_wav(net, mel, config):
|
63 |
+
"""convert mel to wav"""
|
64 |
+
if len(mel.shape) == 2:
|
65 |
+
mel = mel[None]
|
66 |
+
pad = config["num_pad_frames"] // 2 + 4
|
67 |
+
mel = np.pad(
|
68 |
+
mel,
|
69 |
+
[(0, 0), (pad, pad), (0, 0)],
|
70 |
+
constant_values=np.log(config["mel_min"]),
|
71 |
+
)
|
72 |
+
x = wavegru_inference(net, mel)
|
73 |
+
x = jax.device_get(x)
|
74 |
+
|
75 |
+
wav = librosa.mu_expand(x - 127, mu=255)
|
76 |
+
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
77 |
+
wav = wav * 2.0
|
78 |
+
wav = wav / max(1.0, np.max(np.abs(wav)))
|
79 |
+
wav = wav * 2**15
|
80 |
+
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
|
81 |
+
wav = wav.astype(np.int16)
|
82 |
+
return wav
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
libsndfile1-dev
|
pooch.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def os_cache(x):
|
2 |
+
return x
|
3 |
+
|
4 |
+
|
5 |
+
def create(*args, **kwargs):
|
6 |
+
class T:
|
7 |
+
def load_registry(self, *args, **kwargs):
|
8 |
+
return None
|
9 |
+
|
10 |
+
return T()
|
pretrained_model_ljs_500k.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4eabdcac35cd016469d17678f9549bd25d1c9bf66c9089ea9f0632619ba91194
|
3 |
+
size 53221435
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
jax==0.3.1
|
2 |
+
jaxlib==0.3.0
|
3 |
+
numpy==1.22.3
|
4 |
+
librosa==0.9.1
|
5 |
+
pax3==0.5.6
|
6 |
+
gradio
|
7 |
+
jinja2
|
8 |
+
toml==0.10.2
|
9 |
+
unidecode==1.3.4
|
10 |
+
pyyaml==6.0
|
tacotron.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tacotron + stepwise monotonic attention
|
3 |
+
"""
|
4 |
+
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import pax
|
8 |
+
|
9 |
+
|
10 |
+
def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
|
11 |
+
"""
|
12 |
+
Conv >> LayerNorm >> activation >> Dropout
|
13 |
+
"""
|
14 |
+
f = pax.Sequential(
|
15 |
+
pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
|
16 |
+
pax.LayerNorm(out_ft, -1, True, True),
|
17 |
+
)
|
18 |
+
if activation_fn is not None:
|
19 |
+
f >>= activation_fn
|
20 |
+
if use_dropout:
|
21 |
+
f >>= pax.Dropout(0.5)
|
22 |
+
return f
|
23 |
+
|
24 |
+
|
25 |
+
class HighwayBlock(pax.Module):
|
26 |
+
"""
|
27 |
+
Highway block
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, dim: int) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.dim = dim
|
33 |
+
self.fc = pax.Linear(dim, 2 * dim)
|
34 |
+
|
35 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
36 |
+
t, h = jnp.split(self.fc(x), 2, axis=-1)
|
37 |
+
t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
|
38 |
+
h = jax.nn.relu(h)
|
39 |
+
x = x * (1.0 - t) + h * t
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class BiGRU(pax.Module):
|
44 |
+
"""
|
45 |
+
Bidirectional GRU
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, dim):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.rnn_fwd = pax.GRU(dim, dim)
|
52 |
+
self.rnn_bwd = pax.GRU(dim, dim)
|
53 |
+
|
54 |
+
def __call__(self, x, reset_masks):
|
55 |
+
N = x.shape[0]
|
56 |
+
x_fwd = x
|
57 |
+
x_bwd = jnp.flip(x, axis=1)
|
58 |
+
x_fwd_states = self.rnn_fwd.initial_state(N)
|
59 |
+
x_bwd_states = self.rnn_bwd.initial_state(N)
|
60 |
+
x_fwd_states, x_fwd = pax.scan(
|
61 |
+
self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
|
62 |
+
)
|
63 |
+
|
64 |
+
reset_masks = jnp.flip(reset_masks, axis=1)
|
65 |
+
x_bwd_states0 = x_bwd_states
|
66 |
+
|
67 |
+
def rnn_reset_core(prev, inputs):
|
68 |
+
x, reset_mask = inputs
|
69 |
+
|
70 |
+
def reset_state(x0, xt):
|
71 |
+
return jnp.where(reset_mask, x0, xt)
|
72 |
+
|
73 |
+
state, _ = self.rnn_bwd(prev, x)
|
74 |
+
state = jax.tree_map(reset_state, x_bwd_states0, state)
|
75 |
+
return state, state.hidden
|
76 |
+
|
77 |
+
x_bwd_states, x_bwd = pax.scan(
|
78 |
+
rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
|
79 |
+
)
|
80 |
+
x_bwd = jnp.flip(x_bwd, axis=1)
|
81 |
+
x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class CBHG(pax.Module):
|
86 |
+
"""
|
87 |
+
Conv Bank >> Highway net >> GRU
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, dim):
|
91 |
+
super().__init__()
|
92 |
+
self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
|
93 |
+
self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
|
94 |
+
self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
|
95 |
+
|
96 |
+
self.highway = pax.Sequential(
|
97 |
+
HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
|
98 |
+
)
|
99 |
+
self.rnn = BiGRU(dim)
|
100 |
+
|
101 |
+
def __call__(self, x, x_mask):
|
102 |
+
conv_input = x * x_mask
|
103 |
+
fts = [f(conv_input) for f in self.convs]
|
104 |
+
residual = jnp.concatenate(fts, axis=-1)
|
105 |
+
residual = pax.max_pool(residual, 2, 1, "SAME", -1)
|
106 |
+
residual = self.conv_projection_1(residual * x_mask)
|
107 |
+
residual = self.conv_projection_2(residual * x_mask)
|
108 |
+
x = x + residual
|
109 |
+
x = self.highway(x)
|
110 |
+
x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
|
111 |
+
return x * x_mask
|
112 |
+
|
113 |
+
|
114 |
+
class PreNet(pax.Module):
|
115 |
+
"""
|
116 |
+
Linear >> relu >> dropout >> Linear >> relu >> dropout
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
|
120 |
+
super().__init__()
|
121 |
+
self.fc1 = pax.Linear(input_dim, hidden_dim)
|
122 |
+
self.fc2 = pax.Linear(hidden_dim, output_dim)
|
123 |
+
self.rng_seq = pax.RngSeq()
|
124 |
+
self.always_dropout = always_dropout
|
125 |
+
|
126 |
+
def __call__(self, x, k1=None, k2=None):
|
127 |
+
x = self.fc1(x)
|
128 |
+
x = jax.nn.relu(x)
|
129 |
+
if self.always_dropout or self.training:
|
130 |
+
if k1 is None:
|
131 |
+
k1 = self.rng_seq.next_rng_key()
|
132 |
+
x = pax.dropout(k1, 0.5, x)
|
133 |
+
x = self.fc2(x)
|
134 |
+
x = jax.nn.relu(x)
|
135 |
+
if self.always_dropout or self.training:
|
136 |
+
if k2 is None:
|
137 |
+
k2 = self.rng_seq.next_rng_key()
|
138 |
+
x = pax.dropout(k2, 0.5, x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class Tacotron(pax.Module):
|
143 |
+
"""
|
144 |
+
Tacotron TTS model.
|
145 |
+
|
146 |
+
It uses stepwise monotonic attention for robust attention.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
mel_dim: int,
|
152 |
+
attn_bias,
|
153 |
+
rr,
|
154 |
+
max_rr,
|
155 |
+
mel_min,
|
156 |
+
sigmoid_noise,
|
157 |
+
pad_token,
|
158 |
+
prenet_dim,
|
159 |
+
attn_hidden_dim,
|
160 |
+
attn_rnn_dim,
|
161 |
+
rnn_dim,
|
162 |
+
postnet_dim,
|
163 |
+
text_dim,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
New Tacotron model
|
167 |
+
|
168 |
+
Args:
|
169 |
+
mel_dim (int): dimension of log mel-spectrogram features.
|
170 |
+
attn_bias (float): control how "slow" the attention will
|
171 |
+
move forward at initialization.
|
172 |
+
rr (int): the reduction factor.
|
173 |
+
Number of predicted frame at each time step. Default is 2.
|
174 |
+
max_rr (int): max value of rr.
|
175 |
+
mel_min (float): the minimum value of mel features.
|
176 |
+
The <go> frame is filled by `log(mel_min)` values.
|
177 |
+
sigmoid_noise (float): the variance of gaussian noise added
|
178 |
+
to attention scores in training.
|
179 |
+
pad_token (int): the pad value at the end of text sequences.
|
180 |
+
prenet_dim (int): dimension of prenet output.
|
181 |
+
attn_hidden_dim (int): dimension of attention hidden vectors.
|
182 |
+
attn_rnn_dim (int): number of cells in the attention RNN.
|
183 |
+
rnn_dim (int): number of cells in the decoder RNNs.
|
184 |
+
postnet_dim (int): number of features in the postnet convolutions.
|
185 |
+
text_dim (int): dimension of text embedding vectors.
|
186 |
+
"""
|
187 |
+
super().__init__()
|
188 |
+
self.text_dim = text_dim
|
189 |
+
assert rr <= max_rr
|
190 |
+
self.rr = rr
|
191 |
+
self.max_rr = max_rr
|
192 |
+
self.mel_dim = mel_dim
|
193 |
+
self.mel_min = mel_min
|
194 |
+
self.sigmoid_noise = sigmoid_noise
|
195 |
+
self.pad_token = pad_token
|
196 |
+
self.prenet_dim = prenet_dim
|
197 |
+
|
198 |
+
# encoder submodules
|
199 |
+
self.encoder_embed = pax.Embed(256, text_dim)
|
200 |
+
self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
|
201 |
+
self.encoder_cbhg = CBHG(prenet_dim)
|
202 |
+
|
203 |
+
# random key generator
|
204 |
+
self.rng_seq = pax.RngSeq()
|
205 |
+
|
206 |
+
# pre-net
|
207 |
+
self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
|
208 |
+
|
209 |
+
# decoder submodules
|
210 |
+
self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
|
211 |
+
self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
|
212 |
+
self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
|
213 |
+
|
214 |
+
self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
|
215 |
+
self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
|
216 |
+
self.attn_V_bias = jnp.array(attn_bias)
|
217 |
+
self.attn_log = jnp.zeros((1,))
|
218 |
+
self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
|
219 |
+
self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
|
220 |
+
self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
|
221 |
+
# mel + end-of-sequence token
|
222 |
+
self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
|
223 |
+
|
224 |
+
# post-net
|
225 |
+
self.post_net = pax.Sequential(
|
226 |
+
conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
|
227 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
228 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
229 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
230 |
+
conv_block(postnet_dim, mel_dim, 5, None, True),
|
231 |
+
)
|
232 |
+
|
233 |
+
parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
|
234 |
+
|
235 |
+
def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
|
236 |
+
"""
|
237 |
+
Encode text to a sequence of real vectors
|
238 |
+
"""
|
239 |
+
N, L = text.shape
|
240 |
+
text_mask = (text != self.pad_token)[..., None]
|
241 |
+
x = self.encoder_embed(text)
|
242 |
+
x = self.encoder_pre_net(x)
|
243 |
+
x = self.encoder_cbhg(x, text_mask)
|
244 |
+
return x
|
245 |
+
|
246 |
+
def go_frame(self, batch_size: int) -> jnp.ndarray:
|
247 |
+
"""
|
248 |
+
return the go frame
|
249 |
+
"""
|
250 |
+
return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
|
251 |
+
|
252 |
+
def decoder_initial_state(self, N: int, L: int):
|
253 |
+
"""
|
254 |
+
setup decoder initial state
|
255 |
+
"""
|
256 |
+
attn_context = jnp.zeros((N, self.prenet_dim * 2))
|
257 |
+
attn_pr = jax.nn.one_hot(
|
258 |
+
jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
|
259 |
+
)
|
260 |
+
|
261 |
+
attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
|
262 |
+
decoder_rnn_states = (
|
263 |
+
self.decoder_rnn1.initial_state(N),
|
264 |
+
self.decoder_rnn2.initial_state(N),
|
265 |
+
)
|
266 |
+
return attn_state, decoder_rnn_states
|
267 |
+
|
268 |
+
def monotonic_attention(self, prev_state, inputs, envs):
|
269 |
+
"""
|
270 |
+
Stepwise monotonic attention
|
271 |
+
"""
|
272 |
+
attn_rnn_state, attn_context, prev_attn_pr = prev_state
|
273 |
+
x, attn_rng_key = inputs
|
274 |
+
text, text_key = envs
|
275 |
+
attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
|
276 |
+
attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
|
277 |
+
attn_query_input = attn_rnn_output
|
278 |
+
attn_query = self.attn_query_fc(attn_query_input)
|
279 |
+
attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
|
280 |
+
score = self.attn_V(attn_hidden)
|
281 |
+
score = jnp.squeeze(score, axis=-1)
|
282 |
+
weight_norm = jnp.linalg.norm(self.attn_V.weight)
|
283 |
+
score = score * (self.attn_V_weight_norm / weight_norm)
|
284 |
+
score = score + self.attn_V_bias
|
285 |
+
noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
|
286 |
+
pr_stay = jax.nn.sigmoid(score + noise)
|
287 |
+
pr_move = 1.0 - pr_stay
|
288 |
+
pr_new_location = pr_move * prev_attn_pr
|
289 |
+
pr_new_location = jnp.pad(
|
290 |
+
pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
|
291 |
+
)
|
292 |
+
attn_pr = pr_stay * prev_attn_pr + pr_new_location
|
293 |
+
attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
|
294 |
+
new_state = (attn_rnn_state, attn_context, attn_pr)
|
295 |
+
return new_state, attn_rnn_output
|
296 |
+
|
297 |
+
def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
|
298 |
+
"""
|
299 |
+
Return a zoneout lstm core.
|
300 |
+
|
301 |
+
It will zoneout the new hidden states and keep the new cell states unchanged.
|
302 |
+
"""
|
303 |
+
|
304 |
+
def core(state, x):
|
305 |
+
new_state, _ = lstm_core(state, x)
|
306 |
+
h_old = state.hidden
|
307 |
+
h_new = new_state.hidden
|
308 |
+
mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
|
309 |
+
h_new = h_old * mask + h_new * (1.0 - mask)
|
310 |
+
return pax.LSTMState(h_new, new_state.cell), h_new
|
311 |
+
|
312 |
+
return core
|
313 |
+
|
314 |
+
def decoder_step(
|
315 |
+
self,
|
316 |
+
attn_state,
|
317 |
+
decoder_rnn_states,
|
318 |
+
rng_key,
|
319 |
+
mel,
|
320 |
+
text,
|
321 |
+
text_key,
|
322 |
+
call_pre_net=False,
|
323 |
+
):
|
324 |
+
"""
|
325 |
+
One decoder step
|
326 |
+
"""
|
327 |
+
if call_pre_net:
|
328 |
+
k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
|
329 |
+
mel = self.decoder_pre_net(mel, k1, k2)
|
330 |
+
else:
|
331 |
+
zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
|
332 |
+
attn_inputs = (mel, rng_key)
|
333 |
+
attn_envs = (text, text_key)
|
334 |
+
attn_state, attn_rnn_output = self.monotonic_attention(
|
335 |
+
attn_state, attn_inputs, attn_envs
|
336 |
+
)
|
337 |
+
(_, attn_context, attn_pr) = attn_state
|
338 |
+
(decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
|
339 |
+
decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
|
340 |
+
decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
|
341 |
+
decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
|
342 |
+
decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
|
343 |
+
decoder_rnn_state1, decoder_rnn1_input
|
344 |
+
)
|
345 |
+
decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
|
346 |
+
decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
|
347 |
+
decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
|
348 |
+
decoder_rnn_state2, decoder_rnn2_input
|
349 |
+
)
|
350 |
+
x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
|
351 |
+
decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
|
352 |
+
return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
|
353 |
+
|
354 |
+
@jax.jit
|
355 |
+
def inference_step(
|
356 |
+
self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
357 |
+
):
|
358 |
+
"""one inference step"""
|
359 |
+
attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
|
360 |
+
attn_state,
|
361 |
+
decoder_rnn_states,
|
362 |
+
rng_key,
|
363 |
+
mel,
|
364 |
+
text,
|
365 |
+
text_key,
|
366 |
+
call_pre_net=True,
|
367 |
+
)
|
368 |
+
x = self.output_fc(x)
|
369 |
+
N, D2 = x.shape
|
370 |
+
x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
|
371 |
+
x = x[:, : self.rr, :]
|
372 |
+
x = jnp.reshape(x, (N, self.rr, -1))
|
373 |
+
mel = x[..., :-1]
|
374 |
+
eos = x[..., -1]
|
375 |
+
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|
376 |
+
|
377 |
+
def inference(self, text, seed=42, max_len=1000):
|
378 |
+
"""
|
379 |
+
text to mel
|
380 |
+
"""
|
381 |
+
text = self.encode_text(text)
|
382 |
+
text_key = self.text_key_fc(text)
|
383 |
+
N, L, D = text.shape
|
384 |
+
mel = self.go_frame(N)
|
385 |
+
|
386 |
+
attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
|
387 |
+
rng_key = jax.random.PRNGKey(seed)
|
388 |
+
mels = []
|
389 |
+
count = 0
|
390 |
+
while True:
|
391 |
+
count = count + 1
|
392 |
+
attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
|
393 |
+
attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
394 |
+
)
|
395 |
+
mels.append(mel)
|
396 |
+
if eos[0, -1].item() > 0 or count > max_len:
|
397 |
+
break
|
398 |
+
|
399 |
+
mel = mel[:, -1, :]
|
400 |
+
|
401 |
+
mels = jnp.concatenate(mels, axis=1)
|
402 |
+
mel = mel + self.post_net(mel)
|
403 |
+
return mels
|
404 |
+
|
405 |
+
def decode(self, mel, text):
|
406 |
+
"""
|
407 |
+
Attention mechanism + Decoder
|
408 |
+
"""
|
409 |
+
text_key = self.text_key_fc(text)
|
410 |
+
|
411 |
+
def scan_fn(prev_states, inputs):
|
412 |
+
attn_state, decoder_rnn_states = prev_states
|
413 |
+
x, rng_key = inputs
|
414 |
+
attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
|
415 |
+
attn_state, decoder_rnn_states, rng_key, x, text, text_key
|
416 |
+
)
|
417 |
+
states = (attn_state, decoder_rnn_states)
|
418 |
+
return states, (output, attn_pr)
|
419 |
+
|
420 |
+
N, L, D = text.shape
|
421 |
+
decoder_states = self.decoder_initial_state(N, L)
|
422 |
+
rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
|
423 |
+
rng_keys = jnp.stack(rng_keys, axis=1)
|
424 |
+
decoder_states, (x, attn_log) = pax.scan(
|
425 |
+
scan_fn,
|
426 |
+
decoder_states,
|
427 |
+
(mel, rng_keys),
|
428 |
+
time_major=False,
|
429 |
+
)
|
430 |
+
self.attn_log = attn_log
|
431 |
+
del decoder_states
|
432 |
+
x = self.output_fc(x)
|
433 |
+
|
434 |
+
N, T2, D2 = x.shape
|
435 |
+
x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
|
436 |
+
x = x[:, :, : self.rr, :]
|
437 |
+
x = jnp.reshape(x, (N, T2 * self.rr, -1))
|
438 |
+
mel = x[..., :-1]
|
439 |
+
eos = x[..., -1]
|
440 |
+
return mel, eos
|
441 |
+
|
442 |
+
def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
|
443 |
+
text = self.encode_text(text)
|
444 |
+
mel = self.decoder_pre_net(mel)
|
445 |
+
mel, eos = self.decode(mel, text)
|
446 |
+
return mel, mel + self.post_net(mel), eos
|
tacotron.toml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tacotron]
|
2 |
+
|
3 |
+
# training
|
4 |
+
BATCH_SIZE = 64
|
5 |
+
LR=1024e-6 # learning rate
|
6 |
+
MODEL_PREFIX = "mono_tts_cbhg_small"
|
7 |
+
LOG_DIR = "./logs"
|
8 |
+
CKPT_DIR = "./ckpts"
|
9 |
+
USE_MP = false # use mixed-precision training
|
10 |
+
|
11 |
+
# data
|
12 |
+
TF_DATA_DIR = "./tf_data" # tensorflow data directory
|
13 |
+
TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
|
14 |
+
SAMPLE_RATE = 24000 # convert to this sample rate if needed
|
15 |
+
MEL_DIM = 80 # the dimension of melspectrogram features
|
16 |
+
MEL_MIN = 1e-5
|
17 |
+
PAD = "_" # padding character
|
18 |
+
PAD_TOKEN = 0
|
19 |
+
TEST_DATA_SIZE = 1024
|
20 |
+
|
21 |
+
# model
|
22 |
+
RR = 2 # reduction factor
|
23 |
+
MAX_RR=2
|
24 |
+
ATTN_BIAS = 0.0 # control how slow the attention moves forward
|
25 |
+
SIGMOID_NOISE = 2.0
|
26 |
+
PRENET_DIM = 128
|
27 |
+
TEXT_DIM = 256
|
28 |
+
RNN_DIM = 512
|
29 |
+
ATTN_RNN_DIM = 256
|
30 |
+
ATTN_HIDDEN_DIM = 128
|
31 |
+
POSTNET_DIM = 512
|
text.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
"""
|
4 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
1. "english_cleaners" for English text
|
9 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
the symbols in symbols.py to match your data).
|
13 |
+
"""
|
14 |
+
|
15 |
+
import re
|
16 |
+
|
17 |
+
from unidecode import unidecode
|
18 |
+
|
19 |
+
# Regular expression matching whitespace:
|
20 |
+
_whitespace_re = re.compile(r"\s+")
|
21 |
+
|
22 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
23 |
+
_abbreviations = [
|
24 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
25 |
+
for x in [
|
26 |
+
("mrs", "misess"),
|
27 |
+
("mr", "mister"),
|
28 |
+
("dr", "doctor"),
|
29 |
+
("st", "saint"),
|
30 |
+
("co", "company"),
|
31 |
+
("jr", "junior"),
|
32 |
+
("maj", "major"),
|
33 |
+
("gen", "general"),
|
34 |
+
("drs", "doctors"),
|
35 |
+
("rev", "reverend"),
|
36 |
+
("lt", "lieutenant"),
|
37 |
+
("hon", "honorable"),
|
38 |
+
("sgt", "sergeant"),
|
39 |
+
("capt", "captain"),
|
40 |
+
("esq", "esquire"),
|
41 |
+
("ltd", "limited"),
|
42 |
+
("col", "colonel"),
|
43 |
+
("ft", "fort"),
|
44 |
+
]
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
def expand_abbreviations(text):
|
49 |
+
for regex, replacement in _abbreviations:
|
50 |
+
text = re.sub(regex, replacement, text)
|
51 |
+
return text
|
52 |
+
|
53 |
+
|
54 |
+
def lowercase(text):
|
55 |
+
return text.lower()
|
56 |
+
|
57 |
+
|
58 |
+
def collapse_whitespace(text):
|
59 |
+
return re.sub(_whitespace_re, " ", text)
|
60 |
+
|
61 |
+
|
62 |
+
def convert_to_ascii(text):
|
63 |
+
return unidecode(text)
|
64 |
+
|
65 |
+
|
66 |
+
def basic_cleaners(text):
|
67 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
68 |
+
text = lowercase(text)
|
69 |
+
text = collapse_whitespace(text)
|
70 |
+
return text
|
71 |
+
|
72 |
+
|
73 |
+
def transliteration_cleaners(text):
|
74 |
+
"""Pipeline for non-English text that transliterates to ASCII."""
|
75 |
+
text = convert_to_ascii(text)
|
76 |
+
text = lowercase(text)
|
77 |
+
text = collapse_whitespace(text)
|
78 |
+
return text
|
79 |
+
|
80 |
+
|
81 |
+
def english_cleaners(text):
|
82 |
+
"""Pipeline for English text, including number and abbreviation expansion."""
|
83 |
+
text = convert_to_ascii(text)
|
84 |
+
text = lowercase(text)
|
85 |
+
text = expand_abbreviations(text)
|
86 |
+
text = collapse_whitespace(text)
|
87 |
+
return text
|
utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions
|
3 |
+
"""
|
4 |
+
import pickle
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import pax
|
8 |
+
import toml
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
from tacotron import Tacotron
|
12 |
+
|
13 |
+
|
14 |
+
def load_tacotron_config(config_file=Path("tacotron.toml")):
|
15 |
+
"""
|
16 |
+
Load the project configurations
|
17 |
+
"""
|
18 |
+
return toml.load(config_file)["tacotron"]
|
19 |
+
|
20 |
+
|
21 |
+
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
|
22 |
+
"""
|
23 |
+
load checkpoint from disk
|
24 |
+
"""
|
25 |
+
with open(path, "rb") as f:
|
26 |
+
dic = pickle.load(f)
|
27 |
+
if net is not None:
|
28 |
+
net = net.load_state_dict(dic["model_state_dict"])
|
29 |
+
if optim is not None:
|
30 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
31 |
+
return dic["step"], net, optim
|
32 |
+
|
33 |
+
|
34 |
+
def create_tacotron_model(config):
|
35 |
+
"""
|
36 |
+
return a random initialized Tacotron model
|
37 |
+
"""
|
38 |
+
return Tacotron(
|
39 |
+
mel_dim=config["MEL_DIM"],
|
40 |
+
attn_bias=config["ATTN_BIAS"],
|
41 |
+
rr=config["RR"],
|
42 |
+
max_rr=config["MAX_RR"],
|
43 |
+
mel_min=config["MEL_MIN"],
|
44 |
+
sigmoid_noise=config["SIGMOID_NOISE"],
|
45 |
+
pad_token=config["PAD_TOKEN"],
|
46 |
+
prenet_dim=config["PRENET_DIM"],
|
47 |
+
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
|
48 |
+
attn_rnn_dim=config["ATTN_RNN_DIM"],
|
49 |
+
rnn_dim=config["RNN_DIM"],
|
50 |
+
postnet_dim=config["POSTNET_DIM"],
|
51 |
+
text_dim=config["TEXT_DIM"],
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def load_wavegru_config(config_file):
|
56 |
+
"""
|
57 |
+
Load project configurations
|
58 |
+
"""
|
59 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
60 |
+
return yaml.safe_load(f)
|
61 |
+
|
62 |
+
|
63 |
+
def load_wavegru_ckpt(net, optim, ckpt_file):
|
64 |
+
"""
|
65 |
+
load training checkpoint from file
|
66 |
+
"""
|
67 |
+
with open(ckpt_file, "rb") as f:
|
68 |
+
dic = pickle.load(f)
|
69 |
+
|
70 |
+
if net is not None:
|
71 |
+
net = net.load_state_dict(dic["net_state_dict"])
|
72 |
+
if optim is not None:
|
73 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
74 |
+
return dic["step"], net, optim
|
wavegru.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
WaveGRU model: melspectrogram => mu-law encoded waveform
|
3 |
+
"""
|
4 |
+
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import pax
|
8 |
+
|
9 |
+
|
10 |
+
class ReLU(pax.Module):
|
11 |
+
def __call__(self, x):
|
12 |
+
return jax.nn.relu(x)
|
13 |
+
|
14 |
+
|
15 |
+
def dilated_residual_conv_block(dim, kernel, stride, dilation):
|
16 |
+
"""
|
17 |
+
Use dilated convs to enlarge the receptive field
|
18 |
+
"""
|
19 |
+
return pax.Sequential(
|
20 |
+
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
|
21 |
+
pax.LayerNorm(dim, -1, True, True),
|
22 |
+
ReLU(),
|
23 |
+
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
|
24 |
+
pax.LayerNorm(dim, -1, True, True),
|
25 |
+
ReLU(),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def tile_1d(x, factor):
|
30 |
+
"""
|
31 |
+
Tile tensor of shape N, L, D into N, L*factor, D
|
32 |
+
"""
|
33 |
+
N, L, D = x.shape
|
34 |
+
x = x[:, :, None, :]
|
35 |
+
x = jnp.tile(x, (1, 1, factor, 1))
|
36 |
+
x = jnp.reshape(x, (N, L * factor, D))
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
def up_block(dim, factor):
|
41 |
+
"""
|
42 |
+
Tile >> Conv >> BatchNorm >> ReLU
|
43 |
+
"""
|
44 |
+
return pax.Sequential(
|
45 |
+
lambda x: tile_1d(x, factor),
|
46 |
+
pax.Conv1D(dim, dim, 2 * factor, stride=1, padding="VALID", with_bias=False),
|
47 |
+
pax.LayerNorm(dim, -1, True, True),
|
48 |
+
ReLU(),
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class Upsample(pax.Module):
|
53 |
+
"""
|
54 |
+
Upsample melspectrogram to match raw audio sample rate.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, input_dim, upsample_factors):
|
58 |
+
super().__init__()
|
59 |
+
self.input_conv = pax.Sequential(
|
60 |
+
pax.Conv1D(input_dim, 512, 1, with_bias=False),
|
61 |
+
pax.LayerNorm(512, -1, True, True),
|
62 |
+
)
|
63 |
+
self.upsample_factors = upsample_factors
|
64 |
+
self.dilated_convs = [
|
65 |
+
dilated_residual_conv_block(512, 3, 1, 2**i) for i in range(5)
|
66 |
+
]
|
67 |
+
self.up_factors = upsample_factors[:-1]
|
68 |
+
self.up_blocks = [up_block(512, x) for x in self.up_factors]
|
69 |
+
self.final_tile = upsample_factors[-1]
|
70 |
+
|
71 |
+
def __call__(self, x):
|
72 |
+
x = self.input_conv(x)
|
73 |
+
for residual in self.dilated_convs:
|
74 |
+
y = residual(x)
|
75 |
+
pad = (x.shape[1] - y.shape[1]) // 2
|
76 |
+
x = x[:, pad:-pad, :] + y
|
77 |
+
|
78 |
+
for f in self.up_blocks:
|
79 |
+
x = f(x)
|
80 |
+
|
81 |
+
x = tile_1d(x, self.final_tile)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class Pruner(pax.Module):
|
86 |
+
"""
|
87 |
+
Base class for pruners
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, update_freq=500):
|
91 |
+
super().__init__()
|
92 |
+
self.update_freq = update_freq
|
93 |
+
|
94 |
+
def compute_sparsity(self, step):
|
95 |
+
"""
|
96 |
+
Two-stages pruning
|
97 |
+
"""
|
98 |
+
t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
|
99 |
+
z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
100 |
+
for i in range(4):
|
101 |
+
t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
|
102 |
+
z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
|
103 |
+
return z
|
104 |
+
|
105 |
+
def prune(self, step, weights):
|
106 |
+
"""
|
107 |
+
Return a mask
|
108 |
+
"""
|
109 |
+
z = self.compute_sparsity(step)
|
110 |
+
x = weights
|
111 |
+
H, W = x.shape
|
112 |
+
x = x.reshape(H // 4, 4, W // 4, 4)
|
113 |
+
x = jnp.abs(x)
|
114 |
+
x = jnp.sum(x, axis=(1, 3), keepdims=True)
|
115 |
+
q = jnp.quantile(jnp.reshape(x, (-1,)), z)
|
116 |
+
x = x >= q
|
117 |
+
x = jnp.tile(x, (1, 4, 1, 4))
|
118 |
+
x = jnp.reshape(x, (H, W))
|
119 |
+
return x
|
120 |
+
|
121 |
+
|
122 |
+
class GRUPruner(Pruner):
|
123 |
+
def __init__(self, gru, update_freq=500):
|
124 |
+
super().__init__(update_freq=update_freq)
|
125 |
+
self.xh_zr_fc_mask = jnp.ones_like(gru.xh_zr_fc.weight) == 1
|
126 |
+
self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1
|
127 |
+
|
128 |
+
def __call__(self, gru: pax.GRU):
|
129 |
+
"""
|
130 |
+
Apply mask after an optimization step
|
131 |
+
"""
|
132 |
+
zr_masked_weights = jnp.where(self.xh_zr_fc_mask, gru.xh_zr_fc.weight, 0)
|
133 |
+
gru = gru.replace_node(gru.xh_zr_fc.weight, zr_masked_weights)
|
134 |
+
h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
|
135 |
+
gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
|
136 |
+
return gru
|
137 |
+
|
138 |
+
def update_mask(self, step, gru: pax.GRU):
|
139 |
+
"""
|
140 |
+
Update internal masks
|
141 |
+
"""
|
142 |
+
xh_z_weight, xh_r_weight = jnp.split(gru.xh_zr_fc.weight, 2, axis=1)
|
143 |
+
xh_z_weight = self.prune(step, xh_z_weight)
|
144 |
+
xh_r_weight = self.prune(step, xh_r_weight)
|
145 |
+
self.xh_zr_fc_mask *= jnp.concatenate((xh_z_weight, xh_r_weight), axis=1)
|
146 |
+
self.xh_h_fc_mask *= self.prune(step, gru.xh_h_fc.weight)
|
147 |
+
|
148 |
+
|
149 |
+
class LinearPruner(Pruner):
|
150 |
+
def __init__(self, linear, update_freq=500):
|
151 |
+
super().__init__(update_freq=update_freq)
|
152 |
+
self.mask = jnp.ones_like(linear.weight) == 1
|
153 |
+
|
154 |
+
def __call__(self, linear: pax.Linear):
|
155 |
+
"""
|
156 |
+
Apply mask after an optimization step
|
157 |
+
"""
|
158 |
+
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
|
159 |
+
|
160 |
+
def update_mask(self, step, linear: pax.Linear):
|
161 |
+
"""
|
162 |
+
Update internal masks
|
163 |
+
"""
|
164 |
+
self.mask *= self.prune(step, linear.weight)
|
165 |
+
|
166 |
+
|
167 |
+
class WaveGRU(pax.Module):
|
168 |
+
"""
|
169 |
+
WaveGRU vocoder model
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self.embed = pax.Embed(256, embed_dim)
|
177 |
+
self.upsample = Upsample(input_dim=mel_dim, upsample_factors=upsample_factors)
|
178 |
+
self.rnn = pax.GRU(embed_dim + rnn_dim, rnn_dim)
|
179 |
+
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
180 |
+
self.o2 = pax.Linear(rnn_dim, 256)
|
181 |
+
self.gru_pruner = GRUPruner(self.rnn)
|
182 |
+
self.o1_pruner = LinearPruner(self.o1)
|
183 |
+
self.o2_pruner = LinearPruner(self.o2)
|
184 |
+
|
185 |
+
def output(self, x):
|
186 |
+
x = self.o1(x)
|
187 |
+
x = jax.nn.relu(x)
|
188 |
+
x = self.o2(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
@jax.jit
|
192 |
+
def inference_step(self, rnn_state, mel, rng_key, x):
|
193 |
+
"""one inference step"""
|
194 |
+
x = self.embed(x)
|
195 |
+
x = jnp.concatenate((x, mel), axis=-1)
|
196 |
+
rnn_state, x = self.rnn(rnn_state, x)
|
197 |
+
x = self.output(x)
|
198 |
+
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
199 |
+
x = jax.random.categorical(rng_key, x, axis=-1)
|
200 |
+
return rnn_state, next_rng_key, x
|
201 |
+
|
202 |
+
def inference(self, mel, no_gru=False, seed=42):
|
203 |
+
"""
|
204 |
+
generate waveform form melspectrogram
|
205 |
+
"""
|
206 |
+
|
207 |
+
y = self.upsample(mel)
|
208 |
+
if no_gru:
|
209 |
+
return y
|
210 |
+
x = jnp.array([127], dtype=jnp.int32)
|
211 |
+
rnn_state = self.rnn.initial_state(1)
|
212 |
+
output = []
|
213 |
+
rng_key = jax.random.PRNGKey(seed)
|
214 |
+
for i in range(y.shape[1]):
|
215 |
+
rnn_state, rng_key, x = self.inference_step(rnn_state, y[:, i], rng_key, x)
|
216 |
+
output.append(x)
|
217 |
+
x = jnp.concatenate(output, axis=0)
|
218 |
+
return x
|
219 |
+
|
220 |
+
def __call__(self, mel, x):
|
221 |
+
x = self.embed(x)
|
222 |
+
y = self.upsample(mel)
|
223 |
+
pad_left = (x.shape[1] - y.shape[1]) // 2
|
224 |
+
pad_right = x.shape[1] - y.shape[1] - pad_left
|
225 |
+
x = x[:, pad_left:-pad_right]
|
226 |
+
x = jnp.concatenate((x, y), axis=-1)
|
227 |
+
_, x = pax.scan(
|
228 |
+
self.rnn,
|
229 |
+
self.rnn.initial_state(x.shape[0]),
|
230 |
+
x,
|
231 |
+
time_major=False,
|
232 |
+
)
|
233 |
+
x = self.output(x)
|
234 |
+
return x
|
wavegru.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## dsp
|
2 |
+
sample_rate : 24000
|
3 |
+
window_length: 50.0 # ms
|
4 |
+
hop_length: 12.5 # ms
|
5 |
+
mel_min: 1.0e-5 ## need .0 to make it a float
|
6 |
+
mel_dim: 80
|
7 |
+
n_fft: 2048
|
8 |
+
|
9 |
+
## wavegru
|
10 |
+
embed_dim: 32
|
11 |
+
rnn_dim: 512
|
12 |
+
frames_per_sequence: 67
|
13 |
+
num_pad_frames: 62
|
14 |
+
upsample_factors: [5, 4, 3, 5]
|
wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c09ed822c5daac0afbd19e8ba4e0ded26dd5732e0efd13ce193c3f54c4e63f54
|
3 |
+
size 56479599
|