add app
Browse files- BUILD +44 -0
- Dockerfile +32 -0
- WORKSPACE +154 -0
- alphabet.txt +97 -0
- app.py +148 -0
- build_ext.sh +4 -0
- extract_tacotrons_model.py +5 -0
- extract_wavegru_model.py +5 -0
- inference.py +90 -0
- mynumbers.py +73 -0
- packages.txt +7 -0
- pooch.py +10 -0
- requirements.txt +12 -0
- tacotron.py +451 -0
- tacotron.toml +32 -0
- text.py +92 -0
- utils.py +74 -0
- wavegru.py +300 -0
- wavegru.yaml +14 -0
- wavegru_cpp.py +42 -0
- wavegru_mod.cc +150 -0
BUILD
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [internal] load cc_fuzz_target.bzl
|
2 |
+
# [internal] load cc_proto_library.bzl
|
3 |
+
# [internal] load android_cc_test:def.bzl
|
4 |
+
|
5 |
+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
6 |
+
|
7 |
+
package(default_visibility = [":__subpackages__"])
|
8 |
+
|
9 |
+
licenses(["notice"])
|
10 |
+
|
11 |
+
# To run all cc_tests in this directory:
|
12 |
+
# bazel test //:all
|
13 |
+
|
14 |
+
# [internal] Command to run dsp_util_android_test.
|
15 |
+
|
16 |
+
# [internal] Command to run lyra_integration_android_test.
|
17 |
+
|
18 |
+
exports_files(
|
19 |
+
srcs = [
|
20 |
+
"wavegru_mod.cc",
|
21 |
+
],
|
22 |
+
)
|
23 |
+
|
24 |
+
pybind_extension(
|
25 |
+
name = "wavegru_mod", # This name is not actually created!
|
26 |
+
srcs = ["wavegru_mod.cc"],
|
27 |
+
deps = [
|
28 |
+
"//sparse_matmul",
|
29 |
+
],
|
30 |
+
)
|
31 |
+
|
32 |
+
py_library(
|
33 |
+
name = "wavegru_mod",
|
34 |
+
data = [":wavegru_mod.so"],
|
35 |
+
)
|
36 |
+
|
37 |
+
py_binary(
|
38 |
+
name = "wavegru",
|
39 |
+
srcs = ["wavegru.py"],
|
40 |
+
deps = [
|
41 |
+
":wavegru_mod"
|
42 |
+
],
|
43 |
+
)
|
44 |
+
|
Dockerfile
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.11
|
5 |
+
|
6 |
+
RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y
|
7 |
+
|
8 |
+
WORKDIR /code
|
9 |
+
|
10 |
+
COPY ./requirements.txt /code/requirements.txt
|
11 |
+
|
12 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
13 |
+
|
14 |
+
# Set up a new user named "user" with user ID 1000
|
15 |
+
RUN useradd -m -u 1000 user
|
16 |
+
|
17 |
+
# Switch to the "user" user
|
18 |
+
USER user
|
19 |
+
|
20 |
+
# Set home to the user's home directory
|
21 |
+
ENV HOME=/home/user \
|
22 |
+
PATH=/home/user/.local/bin:$PATH
|
23 |
+
|
24 |
+
# Set the working directory to the user's home directory
|
25 |
+
WORKDIR $HOME/app
|
26 |
+
|
27 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
28 |
+
COPY --chown=user . $HOME/app
|
29 |
+
|
30 |
+
RUN bash ./build_ext.sh
|
31 |
+
|
32 |
+
CMD ["python", "main.py"]
|
WORKSPACE
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################
|
2 |
+
# Platform Independent #
|
3 |
+
########################
|
4 |
+
|
5 |
+
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
|
6 |
+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
7 |
+
|
8 |
+
# GoogleTest/GoogleMock framework.
|
9 |
+
git_repository(
|
10 |
+
name = "com_google_googletest",
|
11 |
+
remote = "https://github.com/google/googletest.git",
|
12 |
+
tag = "release-1.10.0",
|
13 |
+
)
|
14 |
+
|
15 |
+
# Google benchmark.
|
16 |
+
http_archive(
|
17 |
+
name = "com_github_google_benchmark",
|
18 |
+
urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
|
19 |
+
strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
|
20 |
+
sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
|
21 |
+
)
|
22 |
+
|
23 |
+
# proto_library, cc_proto_library, and java_proto_library rules implicitly
|
24 |
+
# depend on @com_google_protobuf for protoc and proto runtimes.
|
25 |
+
# This statement defines the @com_google_protobuf repo.
|
26 |
+
git_repository(
|
27 |
+
name = "com_google_protobuf",
|
28 |
+
remote = "https://github.com/protocolbuffers/protobuf.git",
|
29 |
+
tag = "v3.15.4",
|
30 |
+
)
|
31 |
+
|
32 |
+
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
|
33 |
+
protobuf_deps()
|
34 |
+
|
35 |
+
# Google Abseil Libs
|
36 |
+
git_repository(
|
37 |
+
name = "com_google_absl",
|
38 |
+
remote = "https://github.com/abseil/abseil-cpp.git",
|
39 |
+
branch = "lts_2020_09_23",
|
40 |
+
)
|
41 |
+
|
42 |
+
# Filesystem
|
43 |
+
# The new_* prefix is used because it is not a bazel project and there is
|
44 |
+
# no BUILD file in that repo.
|
45 |
+
FILESYSTEM_BUILD = """
|
46 |
+
cc_library(
|
47 |
+
name = "filesystem",
|
48 |
+
hdrs = glob(["include/ghc/*"]),
|
49 |
+
visibility = ["//visibility:public"],
|
50 |
+
)
|
51 |
+
"""
|
52 |
+
|
53 |
+
new_git_repository(
|
54 |
+
name = "gulrak_filesystem",
|
55 |
+
remote = "https://github.com/gulrak/filesystem.git",
|
56 |
+
tag = "v1.3.6",
|
57 |
+
build_file_content = FILESYSTEM_BUILD
|
58 |
+
)
|
59 |
+
|
60 |
+
# Audio DSP
|
61 |
+
git_repository(
|
62 |
+
name = "com_google_audio_dsp",
|
63 |
+
remote = "https://github.com/google/multichannel-audio-tools.git",
|
64 |
+
# There are no tags for this repo, we are synced to bleeding edge.
|
65 |
+
branch = "master",
|
66 |
+
repo_mapping = {
|
67 |
+
"@com_github_glog_glog" : "@com_google_glog"
|
68 |
+
}
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
http_archive(
|
73 |
+
name = "pybind11_bazel",
|
74 |
+
strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
|
75 |
+
urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
|
76 |
+
)
|
77 |
+
# We still require the pybind library.
|
78 |
+
http_archive(
|
79 |
+
name = "pybind11",
|
80 |
+
build_file = "@pybind11_bazel//:pybind11.BUILD",
|
81 |
+
strip_prefix = "pybind11-2.9.0",
|
82 |
+
urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"],
|
83 |
+
)
|
84 |
+
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
|
85 |
+
python_configure(name = "local_config_python")
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
# Transitive dependencies of Audio DSP.
|
90 |
+
http_archive(
|
91 |
+
name = "eigen_archive",
|
92 |
+
build_file = "eigen.BUILD",
|
93 |
+
sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
|
94 |
+
strip_prefix = "eigen-eigen-049af2f56331",
|
95 |
+
urls = [
|
96 |
+
"http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
97 |
+
"https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
98 |
+
],
|
99 |
+
)
|
100 |
+
|
101 |
+
http_archive(
|
102 |
+
name = "fft2d",
|
103 |
+
build_file = "fft2d.BUILD",
|
104 |
+
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
|
105 |
+
urls = [
|
106 |
+
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
107 |
+
],
|
108 |
+
)
|
109 |
+
|
110 |
+
# Google logging
|
111 |
+
git_repository(
|
112 |
+
name = "com_google_glog",
|
113 |
+
remote = "https://github.com/google/glog.git",
|
114 |
+
branch = "master"
|
115 |
+
)
|
116 |
+
# Dependency for glog
|
117 |
+
git_repository(
|
118 |
+
name = "com_github_gflags_gflags",
|
119 |
+
remote = "https://github.com/mchinen/gflags.git",
|
120 |
+
branch = "android_linking_fix"
|
121 |
+
)
|
122 |
+
|
123 |
+
# Bazel/build rules
|
124 |
+
|
125 |
+
http_archive(
|
126 |
+
name = "bazel_skylib",
|
127 |
+
urls = [
|
128 |
+
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
129 |
+
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
130 |
+
],
|
131 |
+
sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
|
132 |
+
)
|
133 |
+
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
|
134 |
+
bazel_skylib_workspace()
|
135 |
+
|
136 |
+
http_archive(
|
137 |
+
name = "rules_android",
|
138 |
+
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
|
139 |
+
strip_prefix = "rules_android-0.1.1",
|
140 |
+
urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
|
141 |
+
)
|
142 |
+
|
143 |
+
# Google Maven Repository
|
144 |
+
GMAVEN_TAG = "20180625-1"
|
145 |
+
|
146 |
+
http_archive(
|
147 |
+
name = "gmaven_rules",
|
148 |
+
strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
|
149 |
+
url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
|
150 |
+
)
|
151 |
+
|
152 |
+
load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
|
153 |
+
|
154 |
+
gmaven_rules()
|
alphabet.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_
|
2 |
+
■
|
3 |
+
|
4 |
+
!
|
5 |
+
,
|
6 |
+
.
|
7 |
+
:
|
8 |
+
?
|
9 |
+
a
|
10 |
+
b
|
11 |
+
c
|
12 |
+
d
|
13 |
+
e
|
14 |
+
g
|
15 |
+
h
|
16 |
+
i
|
17 |
+
k
|
18 |
+
l
|
19 |
+
m
|
20 |
+
n
|
21 |
+
o
|
22 |
+
p
|
23 |
+
q
|
24 |
+
r
|
25 |
+
s
|
26 |
+
t
|
27 |
+
u
|
28 |
+
v
|
29 |
+
x
|
30 |
+
y
|
31 |
+
à
|
32 |
+
á
|
33 |
+
â
|
34 |
+
ã
|
35 |
+
è
|
36 |
+
é
|
37 |
+
ê
|
38 |
+
ì
|
39 |
+
í
|
40 |
+
ò
|
41 |
+
ó
|
42 |
+
ô
|
43 |
+
õ
|
44 |
+
ù
|
45 |
+
ú
|
46 |
+
ý
|
47 |
+
ă
|
48 |
+
đ
|
49 |
+
ĩ
|
50 |
+
ũ
|
51 |
+
ơ
|
52 |
+
ư
|
53 |
+
ạ
|
54 |
+
ả
|
55 |
+
ấ
|
56 |
+
ầ
|
57 |
+
ẩ
|
58 |
+
ẫ
|
59 |
+
ậ
|
60 |
+
ắ
|
61 |
+
ằ
|
62 |
+
ẳ
|
63 |
+
ẵ
|
64 |
+
ặ
|
65 |
+
ẹ
|
66 |
+
ẻ
|
67 |
+
ẽ
|
68 |
+
ế
|
69 |
+
ề
|
70 |
+
ể
|
71 |
+
ễ
|
72 |
+
ệ
|
73 |
+
ỉ
|
74 |
+
ị
|
75 |
+
ọ
|
76 |
+
ỏ
|
77 |
+
ố
|
78 |
+
ồ
|
79 |
+
ổ
|
80 |
+
ỗ
|
81 |
+
ộ
|
82 |
+
ớ
|
83 |
+
ờ
|
84 |
+
ở
|
85 |
+
ỡ
|
86 |
+
ợ
|
87 |
+
ụ
|
88 |
+
ủ
|
89 |
+
ứ
|
90 |
+
ừ
|
91 |
+
ử
|
92 |
+
ữ
|
93 |
+
ự
|
94 |
+
ỳ
|
95 |
+
ỵ
|
96 |
+
ỷ
|
97 |
+
ỹ
|
app.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## build wavegru-cpp
|
2 |
+
# import os
|
3 |
+
# os.system("./bazelisk-linux-amd64 clean --expunge")
|
4 |
+
# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
|
5 |
+
|
6 |
+
# install espeak
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import unicodedata
|
10 |
+
import regex
|
11 |
+
|
12 |
+
if not os.path.isfile("./wavegru_mod.so"):
|
13 |
+
os.system("bash ./build_ext.sh")
|
14 |
+
|
15 |
+
import gradio as gr
|
16 |
+
from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
|
17 |
+
from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
|
18 |
+
|
19 |
+
|
20 |
+
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
21 |
+
"./alphabet.txt", "./tacotron.toml", "./mono_tts_cbhg_small_0700000.ckpt"
|
22 |
+
)
|
23 |
+
|
24 |
+
wavegru_config, wavegru_net = load_wavegru_net(
|
25 |
+
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt"
|
26 |
+
)
|
27 |
+
|
28 |
+
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
29 |
+
wavecpp = load_wavegru_cpp(
|
30 |
+
wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1]
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
space_re = regex.compile(r"\s+")
|
35 |
+
number_re = regex.compile("([0-9]+)")
|
36 |
+
digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
|
37 |
+
num_re = regex.compile(r"([0-9.,]*[0-9])")
|
38 |
+
alphabet_ = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
|
39 |
+
keep_text_and_num_re = regex.compile(rf"[^\s{alphabet_}.,0-9]")
|
40 |
+
keep_text_re = regex.compile(rf"[^\s{alphabet_}]")
|
41 |
+
|
42 |
+
|
43 |
+
def read_number(num: str) -> str:
|
44 |
+
if len(num) == 1:
|
45 |
+
return digits[int(num)]
|
46 |
+
elif len(num) == 2 and num.isdigit():
|
47 |
+
n = int(num)
|
48 |
+
end = digits[n % 10]
|
49 |
+
if n == 10:
|
50 |
+
return "mười"
|
51 |
+
if n % 10 == 5:
|
52 |
+
end = "lăm"
|
53 |
+
if n % 10 == 0:
|
54 |
+
return digits[n // 10] + " mươi"
|
55 |
+
elif n < 20:
|
56 |
+
return "mười " + end
|
57 |
+
else:
|
58 |
+
if n % 10 == 1:
|
59 |
+
end = "mốt"
|
60 |
+
return digits[n // 10] + " mươi " + end
|
61 |
+
elif len(num) == 3 and num.isdigit():
|
62 |
+
n = int(num)
|
63 |
+
if n % 100 == 0:
|
64 |
+
return digits[n // 100] + " trăm"
|
65 |
+
elif num[1] == "0":
|
66 |
+
return digits[n // 100] + " trăm lẻ " + digits[n % 100]
|
67 |
+
else:
|
68 |
+
return digits[n // 100] + " trăm " + read_number(num[1:])
|
69 |
+
elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
|
70 |
+
n = int(num)
|
71 |
+
n1 = n // 1000
|
72 |
+
return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
|
73 |
+
elif "," in num:
|
74 |
+
n1, n2 = num.split(",")
|
75 |
+
return read_number(n1) + " phẩy " + read_number(n2)
|
76 |
+
elif "." in num:
|
77 |
+
parts = num.split(".")
|
78 |
+
if len(parts) == 2:
|
79 |
+
if parts[1] == "000":
|
80 |
+
return read_number(parts[0]) + " ngàn"
|
81 |
+
elif parts[1].startswith("00"):
|
82 |
+
end = digits[int(parts[1][2:])]
|
83 |
+
return read_number(parts[0]) + " ngàn lẻ " + end
|
84 |
+
else:
|
85 |
+
return read_number(parts[0]) + " ngàn " + read_number(parts[1])
|
86 |
+
elif len(parts) == 3:
|
87 |
+
return (
|
88 |
+
read_number(parts[0])
|
89 |
+
+ " triệu "
|
90 |
+
+ read_number(parts[1])
|
91 |
+
+ " ngàn "
|
92 |
+
+ read_number(parts[2])
|
93 |
+
)
|
94 |
+
return num
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_text(text):
|
98 |
+
# lowercase
|
99 |
+
text = text.lower()
|
100 |
+
# unicode normalize
|
101 |
+
text = unicodedata.normalize("NFKC", text)
|
102 |
+
text = text.replace(".", ". ")
|
103 |
+
text = text.replace(",", ", ")
|
104 |
+
text = text.replace(";", "; ")
|
105 |
+
text = text.replace(":", ": ")
|
106 |
+
text = text.replace("!", "! ")
|
107 |
+
text = text.replace("?", "? ")
|
108 |
+
text = text.replace("(", "( ")
|
109 |
+
|
110 |
+
text = num_re.sub(r" \1 ", text)
|
111 |
+
words = text.split()
|
112 |
+
words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
|
113 |
+
text = " ".join(words)
|
114 |
+
|
115 |
+
# remove redundant spaces
|
116 |
+
text = re.sub(r"\s+", " ", text)
|
117 |
+
# remove leading and trailing spaces
|
118 |
+
text = text.strip()
|
119 |
+
return text
|
120 |
+
|
121 |
+
|
122 |
+
def speak(text):
|
123 |
+
text = normalize_text(text)
|
124 |
+
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
125 |
+
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
126 |
+
return 24_000, y
|
127 |
+
|
128 |
+
|
129 |
+
title = "WaveGRU-TTS"
|
130 |
+
description = "WaveGRU text-to-speech demo."
|
131 |
+
|
132 |
+
gr.Interface(
|
133 |
+
fn=speak,
|
134 |
+
inputs="text",
|
135 |
+
examples=[
|
136 |
+
"Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
|
137 |
+
"Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
|
138 |
+
"Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
|
139 |
+
"Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học l��n của Việt Nam trong thời phong kiến",
|
140 |
+
"Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
|
141 |
+
],
|
142 |
+
outputs="audio",
|
143 |
+
title=title,
|
144 |
+
description=description,
|
145 |
+
theme="default",
|
146 |
+
allow_screenshot=False,
|
147 |
+
allow_flagging="never",
|
148 |
+
).launch(enable_queue=True)
|
build_ext.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pip install -U pip
|
2 |
+
pip install gradio==3.42.0
|
3 |
+
USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native
|
4 |
+
cp -f bazel-bin/wavegru_mod.so .
|
extract_tacotrons_model.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb"))
|
4 |
+
del dic["optim_state_dict"]
|
5 |
+
pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))
|
extract_wavegru_model.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
dic = pickle.load(open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb"))
|
4 |
+
del dic["optim_state_dict"]
|
5 |
+
pickle.dump(dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb"))
|
inference.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import pax
|
8 |
+
|
9 |
+
# from text import english_cleaners
|
10 |
+
from utils import (
|
11 |
+
create_tacotron_model,
|
12 |
+
load_tacotron_ckpt,
|
13 |
+
load_tacotron_config,
|
14 |
+
load_wavegru_ckpt,
|
15 |
+
load_wavegru_config,
|
16 |
+
)
|
17 |
+
from wavegru import WaveGRU
|
18 |
+
|
19 |
+
# os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51"
|
20 |
+
# from phonemizer.backend import EspeakBackend
|
21 |
+
# backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
|
22 |
+
|
23 |
+
|
24 |
+
def load_tacotron_model(alphabet_file, config_file, model_file):
|
25 |
+
"""load tacotron model to memory"""
|
26 |
+
with open(alphabet_file, "r", encoding="utf-8") as f:
|
27 |
+
alphabet = f.read().split("\n")
|
28 |
+
|
29 |
+
config = load_tacotron_config(config_file)
|
30 |
+
net = create_tacotron_model(config)
|
31 |
+
_, net, _ = load_tacotron_ckpt(net, None, model_file)
|
32 |
+
net = net.eval()
|
33 |
+
net = jax.device_put(net)
|
34 |
+
return alphabet, net, config
|
35 |
+
|
36 |
+
|
37 |
+
tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400))
|
38 |
+
|
39 |
+
|
40 |
+
def text_to_mel(net, text, alphabet, config):
|
41 |
+
"""convert text to mel spectrogram"""
|
42 |
+
# text = english_cleaners(text)
|
43 |
+
# text = backend.phonemize([text], strip=True)[0]
|
44 |
+
text = text + config["END_CHARACTER"]
|
45 |
+
text = text + config["PAD"] * (100 - (len(text) % 100))
|
46 |
+
tokens = []
|
47 |
+
for c in text:
|
48 |
+
if c in alphabet:
|
49 |
+
tokens.append(alphabet.index(c))
|
50 |
+
tokens = jnp.array(tokens, dtype=jnp.int32)
|
51 |
+
mel = tacotron_inference_fn(net, tokens[None])
|
52 |
+
return mel
|
53 |
+
|
54 |
+
|
55 |
+
def load_wavegru_net(config_file, model_file):
|
56 |
+
"""load wavegru to memory"""
|
57 |
+
config = load_wavegru_config(config_file)
|
58 |
+
net = WaveGRU(
|
59 |
+
mel_dim=config["mel_dim"],
|
60 |
+
rnn_dim=config["rnn_dim"],
|
61 |
+
upsample_factors=config["upsample_factors"],
|
62 |
+
has_linear_output=True,
|
63 |
+
)
|
64 |
+
_, net, _ = load_wavegru_ckpt(net, None, model_file)
|
65 |
+
net = net.eval()
|
66 |
+
net = jax.device_put(net)
|
67 |
+
return config, net
|
68 |
+
|
69 |
+
|
70 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
|
71 |
+
|
72 |
+
|
73 |
+
def mel_to_wav(net, netcpp, mel, config):
|
74 |
+
"""convert mel to wav"""
|
75 |
+
if len(mel.shape) == 2:
|
76 |
+
mel = mel[None]
|
77 |
+
pad = config["num_pad_frames"] // 2 + 2
|
78 |
+
mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
|
79 |
+
ft = wavegru_inference(net, mel)
|
80 |
+
ft = jax.device_get(ft[0])
|
81 |
+
wav = netcpp.inference(ft, 1.0)
|
82 |
+
wav = np.array(wav)
|
83 |
+
wav = librosa.mu_expand(wav - 127, mu=255)
|
84 |
+
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
85 |
+
wav = wav * 2.0
|
86 |
+
wav = wav / max(1.0, np.max(np.abs(wav)))
|
87 |
+
wav = wav * 2**15
|
88 |
+
wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
|
89 |
+
wav = wav.astype(np.int16)
|
90 |
+
return wav
|
mynumbers.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
import inflect
|
4 |
+
import re
|
5 |
+
|
6 |
+
|
7 |
+
_inflect = inflect.engine()
|
8 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
9 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
10 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
11 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
12 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
13 |
+
_number_re = re.compile(r"[0-9]+")
|
14 |
+
|
15 |
+
|
16 |
+
def _remove_commas(m):
|
17 |
+
return m.group(1).replace(",", "")
|
18 |
+
|
19 |
+
|
20 |
+
def _expand_decimal_point(m):
|
21 |
+
return m.group(1).replace(".", " point ")
|
22 |
+
|
23 |
+
|
24 |
+
def _expand_dollars(m):
|
25 |
+
match = m.group(1)
|
26 |
+
parts = match.split(".")
|
27 |
+
if len(parts) > 2:
|
28 |
+
return match + " dollars" # Unexpected format
|
29 |
+
dollars = int(parts[0]) if parts[0] else 0
|
30 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
31 |
+
if dollars and cents:
|
32 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
33 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
34 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
35 |
+
elif dollars:
|
36 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
37 |
+
return "%s %s" % (dollars, dollar_unit)
|
38 |
+
elif cents:
|
39 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
40 |
+
return "%s %s" % (cents, cent_unit)
|
41 |
+
else:
|
42 |
+
return "zero dollars"
|
43 |
+
|
44 |
+
|
45 |
+
def _expand_ordinal(m):
|
46 |
+
return _inflect.number_to_words(m.group(0))
|
47 |
+
|
48 |
+
|
49 |
+
def _expand_number(m):
|
50 |
+
num = int(m.group(0))
|
51 |
+
if num > 1000 and num < 3000:
|
52 |
+
if num == 2000:
|
53 |
+
return "two thousand"
|
54 |
+
elif num > 2000 and num < 2010:
|
55 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
56 |
+
elif num % 100 == 0:
|
57 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
58 |
+
else:
|
59 |
+
return _inflect.number_to_words(
|
60 |
+
num, andword="", zero="oh", group=2
|
61 |
+
).replace(", ", " ")
|
62 |
+
else:
|
63 |
+
return _inflect.number_to_words(num, andword="")
|
64 |
+
|
65 |
+
|
66 |
+
def normalize_numbers(text):
|
67 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
68 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
69 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
70 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
71 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
72 |
+
text = re.sub(_number_re, _expand_number, text)
|
73 |
+
return text
|
packages.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libsndfile1-dev
|
2 |
+
make
|
3 |
+
autoconf
|
4 |
+
automake
|
5 |
+
libtool
|
6 |
+
gcc
|
7 |
+
pkg-config
|
pooch.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def os_cache(x):
|
2 |
+
return x
|
3 |
+
|
4 |
+
|
5 |
+
def create(*args, **kwargs):
|
6 |
+
class T:
|
7 |
+
def load_registry(self, *args, **kwargs):
|
8 |
+
return None
|
9 |
+
|
10 |
+
return T()
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
inflect
|
2 |
+
jax
|
3 |
+
jaxlib
|
4 |
+
jinja2
|
5 |
+
librosa
|
6 |
+
numpy
|
7 |
+
pax3
|
8 |
+
pyyaml
|
9 |
+
toml
|
10 |
+
unidecode
|
11 |
+
phonemizer
|
12 |
+
gradio==3.42.0
|
tacotron.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Tacotron + stepwise monotonic attention
|
3 |
+
"""
|
4 |
+
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import pax
|
8 |
+
|
9 |
+
|
10 |
+
def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
|
11 |
+
"""
|
12 |
+
Conv >> LayerNorm >> activation >> Dropout
|
13 |
+
"""
|
14 |
+
f = pax.Sequential(
|
15 |
+
pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
|
16 |
+
pax.LayerNorm(out_ft, -1, True, True),
|
17 |
+
)
|
18 |
+
if activation_fn is not None:
|
19 |
+
f >>= activation_fn
|
20 |
+
if use_dropout:
|
21 |
+
f >>= pax.Dropout(0.5)
|
22 |
+
return f
|
23 |
+
|
24 |
+
|
25 |
+
class HighwayBlock(pax.Module):
|
26 |
+
"""
|
27 |
+
Highway block
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, dim: int) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.dim = dim
|
33 |
+
self.fc = pax.Linear(dim, 2 * dim)
|
34 |
+
|
35 |
+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
36 |
+
t, h = jnp.split(self.fc(x), 2, axis=-1)
|
37 |
+
t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
|
38 |
+
h = jax.nn.relu(h)
|
39 |
+
x = x * (1.0 - t) + h * t
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class BiGRU(pax.Module):
|
44 |
+
"""
|
45 |
+
Bidirectional GRU
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, dim):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
self.rnn_fwd = pax.GRU(dim, dim)
|
52 |
+
self.rnn_bwd = pax.GRU(dim, dim)
|
53 |
+
|
54 |
+
def __call__(self, x, reset_masks):
|
55 |
+
N = x.shape[0]
|
56 |
+
x_fwd = x
|
57 |
+
x_bwd = jnp.flip(x, axis=1)
|
58 |
+
x_fwd_states = self.rnn_fwd.initial_state(N)
|
59 |
+
x_bwd_states = self.rnn_bwd.initial_state(N)
|
60 |
+
x_fwd_states, x_fwd = pax.scan(
|
61 |
+
self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
|
62 |
+
)
|
63 |
+
|
64 |
+
reset_masks = jnp.flip(reset_masks, axis=1)
|
65 |
+
x_bwd_states0 = x_bwd_states
|
66 |
+
|
67 |
+
def rnn_reset_core(prev, inputs):
|
68 |
+
x, reset_mask = inputs
|
69 |
+
|
70 |
+
def reset_state(x0, xt):
|
71 |
+
return jnp.where(reset_mask, x0, xt)
|
72 |
+
|
73 |
+
state, _ = self.rnn_bwd(prev, x)
|
74 |
+
state = jax.tree_map(reset_state, x_bwd_states0, state)
|
75 |
+
return state, state.hidden
|
76 |
+
|
77 |
+
x_bwd_states, x_bwd = pax.scan(
|
78 |
+
rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
|
79 |
+
)
|
80 |
+
x_bwd = jnp.flip(x_bwd, axis=1)
|
81 |
+
x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
class CBHG(pax.Module):
|
86 |
+
"""
|
87 |
+
Conv Bank >> Highway net >> GRU
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, dim):
|
91 |
+
super().__init__()
|
92 |
+
self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
|
93 |
+
self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
|
94 |
+
self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
|
95 |
+
|
96 |
+
self.highway = pax.Sequential(
|
97 |
+
HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
|
98 |
+
)
|
99 |
+
self.rnn = BiGRU(dim)
|
100 |
+
|
101 |
+
def __call__(self, x, x_mask):
|
102 |
+
conv_input = x * x_mask
|
103 |
+
fts = [f(conv_input) for f in self.convs]
|
104 |
+
residual = jnp.concatenate(fts, axis=-1)
|
105 |
+
residual = pax.max_pool(residual, 2, 1, "SAME", -1)
|
106 |
+
residual = self.conv_projection_1(residual * x_mask)
|
107 |
+
residual = self.conv_projection_2(residual * x_mask)
|
108 |
+
x = x + residual
|
109 |
+
x = self.highway(x)
|
110 |
+
x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
|
111 |
+
return x * x_mask
|
112 |
+
|
113 |
+
|
114 |
+
class PreNet(pax.Module):
|
115 |
+
"""
|
116 |
+
Linear >> relu >> dropout >> Linear >> relu >> dropout
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
|
120 |
+
super().__init__()
|
121 |
+
self.fc1 = pax.Linear(input_dim, hidden_dim)
|
122 |
+
self.fc2 = pax.Linear(hidden_dim, output_dim)
|
123 |
+
self.rng_seq = pax.RngSeq()
|
124 |
+
self.always_dropout = always_dropout
|
125 |
+
|
126 |
+
def __call__(self, x, k1=None, k2=None):
|
127 |
+
x = self.fc1(x)
|
128 |
+
x = jax.nn.relu(x)
|
129 |
+
if self.always_dropout or self.training:
|
130 |
+
if k1 is None:
|
131 |
+
k1 = self.rng_seq.next_rng_key()
|
132 |
+
x = pax.dropout(k1, 0.5, x)
|
133 |
+
x = self.fc2(x)
|
134 |
+
x = jax.nn.relu(x)
|
135 |
+
if self.always_dropout or self.training:
|
136 |
+
if k2 is None:
|
137 |
+
k2 = self.rng_seq.next_rng_key()
|
138 |
+
x = pax.dropout(k2, 0.5, x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class Tacotron(pax.Module):
|
143 |
+
"""
|
144 |
+
Tacotron TTS model.
|
145 |
+
|
146 |
+
It uses stepwise monotonic attention for robust attention.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
mel_dim: int,
|
152 |
+
attn_bias,
|
153 |
+
rr,
|
154 |
+
max_rr,
|
155 |
+
mel_min,
|
156 |
+
sigmoid_noise,
|
157 |
+
pad_token,
|
158 |
+
prenet_dim,
|
159 |
+
attn_hidden_dim,
|
160 |
+
attn_rnn_dim,
|
161 |
+
rnn_dim,
|
162 |
+
postnet_dim,
|
163 |
+
text_dim,
|
164 |
+
):
|
165 |
+
"""
|
166 |
+
New Tacotron model
|
167 |
+
|
168 |
+
Args:
|
169 |
+
mel_dim (int): dimension of log mel-spectrogram features.
|
170 |
+
attn_bias (float): control how "slow" the attention will
|
171 |
+
move forward at initialization.
|
172 |
+
rr (int): the reduction factor.
|
173 |
+
Number of predicted frame at each time step. Default is 2.
|
174 |
+
max_rr (int): max value of rr.
|
175 |
+
mel_min (float): the minimum value of mel features.
|
176 |
+
The <go> frame is filled by `log(mel_min)` values.
|
177 |
+
sigmoid_noise (float): the variance of gaussian noise added
|
178 |
+
to attention scores in training.
|
179 |
+
pad_token (int): the pad value at the end of text sequences.
|
180 |
+
prenet_dim (int): dimension of prenet output.
|
181 |
+
attn_hidden_dim (int): dimension of attention hidden vectors.
|
182 |
+
attn_rnn_dim (int): number of cells in the attention RNN.
|
183 |
+
rnn_dim (int): number of cells in the decoder RNNs.
|
184 |
+
postnet_dim (int): number of features in the postnet convolutions.
|
185 |
+
text_dim (int): dimension of text embedding vectors.
|
186 |
+
"""
|
187 |
+
super().__init__()
|
188 |
+
self.text_dim = text_dim
|
189 |
+
assert rr <= max_rr
|
190 |
+
self.rr = rr
|
191 |
+
self.max_rr = max_rr
|
192 |
+
self.mel_dim = mel_dim
|
193 |
+
self.mel_min = mel_min
|
194 |
+
self.sigmoid_noise = sigmoid_noise
|
195 |
+
self.pad_token = pad_token
|
196 |
+
self.prenet_dim = prenet_dim
|
197 |
+
|
198 |
+
# encoder submodules
|
199 |
+
self.encoder_embed = pax.Embed(256, text_dim)
|
200 |
+
self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
|
201 |
+
self.encoder_cbhg = CBHG(prenet_dim)
|
202 |
+
|
203 |
+
# random key generator
|
204 |
+
self.rng_seq = pax.RngSeq()
|
205 |
+
|
206 |
+
# pre-net
|
207 |
+
self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
|
208 |
+
|
209 |
+
# decoder submodules
|
210 |
+
self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
|
211 |
+
self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
|
212 |
+
self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
|
213 |
+
|
214 |
+
self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
|
215 |
+
self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
|
216 |
+
self.attn_V_bias = jnp.array(attn_bias)
|
217 |
+
self.attn_log = jnp.zeros((1,))
|
218 |
+
self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
|
219 |
+
self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
|
220 |
+
self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
|
221 |
+
# mel + end-of-sequence token
|
222 |
+
self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
|
223 |
+
|
224 |
+
# post-net
|
225 |
+
self.post_net = pax.Sequential(
|
226 |
+
conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
|
227 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
228 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
229 |
+
conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
|
230 |
+
conv_block(postnet_dim, mel_dim, 5, None, True),
|
231 |
+
)
|
232 |
+
|
233 |
+
parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
|
234 |
+
|
235 |
+
def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
|
236 |
+
"""
|
237 |
+
Encode text to a sequence of real vectors
|
238 |
+
"""
|
239 |
+
N, L = text.shape
|
240 |
+
text_mask = (text != self.pad_token)[..., None]
|
241 |
+
x = self.encoder_embed(text)
|
242 |
+
x = self.encoder_pre_net(x)
|
243 |
+
x = self.encoder_cbhg(x, text_mask)
|
244 |
+
return x
|
245 |
+
|
246 |
+
def go_frame(self, batch_size: int) -> jnp.ndarray:
|
247 |
+
"""
|
248 |
+
return the go frame
|
249 |
+
"""
|
250 |
+
return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
|
251 |
+
|
252 |
+
def decoder_initial_state(self, N: int, L: int):
|
253 |
+
"""
|
254 |
+
setup decoder initial state
|
255 |
+
"""
|
256 |
+
attn_context = jnp.zeros((N, self.prenet_dim * 2))
|
257 |
+
attn_pr = jax.nn.one_hot(
|
258 |
+
jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
|
259 |
+
)
|
260 |
+
|
261 |
+
attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
|
262 |
+
decoder_rnn_states = (
|
263 |
+
self.decoder_rnn1.initial_state(N),
|
264 |
+
self.decoder_rnn2.initial_state(N),
|
265 |
+
)
|
266 |
+
return attn_state, decoder_rnn_states
|
267 |
+
|
268 |
+
def monotonic_attention(self, prev_state, inputs, envs):
|
269 |
+
"""
|
270 |
+
Stepwise monotonic attention
|
271 |
+
"""
|
272 |
+
attn_rnn_state, attn_context, prev_attn_pr = prev_state
|
273 |
+
x, attn_rng_key = inputs
|
274 |
+
text, text_key = envs
|
275 |
+
attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
|
276 |
+
attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
|
277 |
+
attn_query_input = attn_rnn_output
|
278 |
+
attn_query = self.attn_query_fc(attn_query_input)
|
279 |
+
attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
|
280 |
+
score = self.attn_V(attn_hidden)
|
281 |
+
score = jnp.squeeze(score, axis=-1)
|
282 |
+
weight_norm = jnp.linalg.norm(self.attn_V.weight)
|
283 |
+
score = score * (self.attn_V_weight_norm / weight_norm)
|
284 |
+
score = score + self.attn_V_bias
|
285 |
+
noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
|
286 |
+
pr_stay = jax.nn.sigmoid(score + noise)
|
287 |
+
pr_move = 1.0 - pr_stay
|
288 |
+
pr_new_location = pr_move * prev_attn_pr
|
289 |
+
pr_new_location = jnp.pad(
|
290 |
+
pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
|
291 |
+
)
|
292 |
+
attn_pr = pr_stay * prev_attn_pr + pr_new_location
|
293 |
+
attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
|
294 |
+
new_state = (attn_rnn_state, attn_context, attn_pr)
|
295 |
+
return new_state, attn_rnn_output
|
296 |
+
|
297 |
+
def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
|
298 |
+
"""
|
299 |
+
Return a zoneout lstm core.
|
300 |
+
|
301 |
+
It will zoneout the new hidden states and keep the new cell states unchanged.
|
302 |
+
"""
|
303 |
+
|
304 |
+
def core(state, x):
|
305 |
+
new_state, _ = lstm_core(state, x)
|
306 |
+
h_old = state.hidden
|
307 |
+
h_new = new_state.hidden
|
308 |
+
mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
|
309 |
+
h_new = h_old * mask + h_new * (1.0 - mask)
|
310 |
+
return pax.LSTMState(h_new, new_state.cell), h_new
|
311 |
+
|
312 |
+
return core
|
313 |
+
|
314 |
+
def decoder_step(
|
315 |
+
self,
|
316 |
+
attn_state,
|
317 |
+
decoder_rnn_states,
|
318 |
+
rng_key,
|
319 |
+
mel,
|
320 |
+
text,
|
321 |
+
text_key,
|
322 |
+
call_pre_net=False,
|
323 |
+
):
|
324 |
+
"""
|
325 |
+
One decoder step
|
326 |
+
"""
|
327 |
+
if call_pre_net:
|
328 |
+
k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
|
329 |
+
mel = self.decoder_pre_net(mel, k1, k2)
|
330 |
+
else:
|
331 |
+
zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
|
332 |
+
attn_inputs = (mel, rng_key)
|
333 |
+
attn_envs = (text, text_key)
|
334 |
+
attn_state, attn_rnn_output = self.monotonic_attention(
|
335 |
+
attn_state, attn_inputs, attn_envs
|
336 |
+
)
|
337 |
+
(_, attn_context, attn_pr) = attn_state
|
338 |
+
(decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
|
339 |
+
decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
|
340 |
+
decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
|
341 |
+
decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
|
342 |
+
decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
|
343 |
+
decoder_rnn_state1, decoder_rnn1_input
|
344 |
+
)
|
345 |
+
decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
|
346 |
+
decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
|
347 |
+
decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
|
348 |
+
decoder_rnn_state2, decoder_rnn2_input
|
349 |
+
)
|
350 |
+
x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
|
351 |
+
decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
|
352 |
+
return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
|
353 |
+
|
354 |
+
@jax.jit
|
355 |
+
def inference_step(
|
356 |
+
self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
357 |
+
):
|
358 |
+
"""one inference step"""
|
359 |
+
attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
|
360 |
+
attn_state,
|
361 |
+
decoder_rnn_states,
|
362 |
+
rng_key,
|
363 |
+
mel,
|
364 |
+
text,
|
365 |
+
text_key,
|
366 |
+
call_pre_net=True,
|
367 |
+
)
|
368 |
+
x = self.output_fc(x)
|
369 |
+
N, D2 = x.shape
|
370 |
+
x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
|
371 |
+
x = x[:, : self.rr, :]
|
372 |
+
x = jnp.reshape(x, (N, self.rr, -1))
|
373 |
+
mel = x[..., :-1]
|
374 |
+
eos_logit = x[..., -1]
|
375 |
+
eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
|
376 |
+
eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr)
|
377 |
+
rng_key, eos_rng_key = jax.random.split(rng_key)
|
378 |
+
eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
|
379 |
+
return attn_state, decoder_rnn_states, rng_key, (mel, eos)
|
380 |
+
|
381 |
+
def inference(self, text, seed=42, max_len=1000):
|
382 |
+
"""
|
383 |
+
text to mel
|
384 |
+
"""
|
385 |
+
text = self.encode_text(text)
|
386 |
+
text_key = self.text_key_fc(text)
|
387 |
+
N, L, D = text.shape
|
388 |
+
assert N == 1
|
389 |
+
mel = self.go_frame(N)
|
390 |
+
|
391 |
+
attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
|
392 |
+
rng_key = jax.random.PRNGKey(seed)
|
393 |
+
mels = []
|
394 |
+
count = 0
|
395 |
+
while True:
|
396 |
+
count = count + 1
|
397 |
+
attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
|
398 |
+
attn_state, decoder_rnn_states, rng_key, mel, text, text_key
|
399 |
+
)
|
400 |
+
mels.append(mel)
|
401 |
+
if eos.item() or count > max_len:
|
402 |
+
break
|
403 |
+
|
404 |
+
mel = mel[:, -1, :]
|
405 |
+
|
406 |
+
mels = jnp.concatenate(mels, axis=1)
|
407 |
+
mel = mel + self.post_net(mel)
|
408 |
+
return mels
|
409 |
+
|
410 |
+
def decode(self, mel, text):
|
411 |
+
"""
|
412 |
+
Attention mechanism + Decoder
|
413 |
+
"""
|
414 |
+
text_key = self.text_key_fc(text)
|
415 |
+
|
416 |
+
def scan_fn(prev_states, inputs):
|
417 |
+
attn_state, decoder_rnn_states = prev_states
|
418 |
+
x, rng_key = inputs
|
419 |
+
attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
|
420 |
+
attn_state, decoder_rnn_states, rng_key, x, text, text_key
|
421 |
+
)
|
422 |
+
states = (attn_state, decoder_rnn_states)
|
423 |
+
return states, (output, attn_pr)
|
424 |
+
|
425 |
+
N, L, D = text.shape
|
426 |
+
decoder_states = self.decoder_initial_state(N, L)
|
427 |
+
rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
|
428 |
+
rng_keys = jnp.stack(rng_keys, axis=1)
|
429 |
+
decoder_states, (x, attn_log) = pax.scan(
|
430 |
+
scan_fn,
|
431 |
+
decoder_states,
|
432 |
+
(mel, rng_keys),
|
433 |
+
time_major=False,
|
434 |
+
)
|
435 |
+
self.attn_log = attn_log
|
436 |
+
del decoder_states
|
437 |
+
x = self.output_fc(x)
|
438 |
+
|
439 |
+
N, T2, D2 = x.shape
|
440 |
+
x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
|
441 |
+
x = x[:, :, : self.rr, :]
|
442 |
+
x = jnp.reshape(x, (N, T2 * self.rr, -1))
|
443 |
+
mel = x[..., :-1]
|
444 |
+
eos = x[..., -1]
|
445 |
+
return mel, eos
|
446 |
+
|
447 |
+
def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
|
448 |
+
text = self.encode_text(text)
|
449 |
+
mel = self.decoder_pre_net(mel)
|
450 |
+
mel, eos = self.decode(mel, text)
|
451 |
+
return mel, mel + self.post_net(mel), eos
|
tacotron.toml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tacotron]
|
2 |
+
|
3 |
+
# training
|
4 |
+
BATCH_SIZE = 64
|
5 |
+
LR=1024e-6 # learning rate
|
6 |
+
MODEL_PREFIX = "mono_tts_cbhg_small"
|
7 |
+
LOG_DIR = "./logs"
|
8 |
+
CKPT_DIR = "./ckpts"
|
9 |
+
USE_MP = false # use mixed-precision training
|
10 |
+
|
11 |
+
# data
|
12 |
+
TF_DATA_DIR = "./tf_data" # tensorflow data directory
|
13 |
+
TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
|
14 |
+
SAMPLE_RATE = 24000 # convert to this sample rate if needed
|
15 |
+
MEL_DIM = 80 # the dimension of melspectrogram features
|
16 |
+
MEL_MIN = 1e-5
|
17 |
+
PAD = "_" # padding character
|
18 |
+
PAD_TOKEN = 0
|
19 |
+
END_CHARACTER = "■" # to signal the end of the transcript
|
20 |
+
TEST_DATA_SIZE = 1024
|
21 |
+
|
22 |
+
# model
|
23 |
+
RR = 1 # reduction factor
|
24 |
+
MAX_RR=2
|
25 |
+
ATTN_BIAS = 0.0 # control how slow the attention moves forward
|
26 |
+
SIGMOID_NOISE = 2.0
|
27 |
+
PRENET_DIM = 128
|
28 |
+
TEXT_DIM = 256
|
29 |
+
RNN_DIM = 512
|
30 |
+
ATTN_RNN_DIM = 256
|
31 |
+
ATTN_HIDDEN_DIM = 128
|
32 |
+
POSTNET_DIM = 512
|
text.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# """ from https://github.com/keithito/tacotron """
|
2 |
+
|
3 |
+
# """
|
4 |
+
# Cleaners are transformations that run over the input text at both training and eval time.
|
5 |
+
|
6 |
+
# Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
7 |
+
# hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
8 |
+
# 1. "english_cleaners" for English text
|
9 |
+
# 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
10 |
+
# the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
11 |
+
# 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
12 |
+
# the symbols in symbols.py to match your data).
|
13 |
+
# """
|
14 |
+
|
15 |
+
# import re
|
16 |
+
# from mynumbers import normalize_numbers
|
17 |
+
# from unidecode import unidecode
|
18 |
+
|
19 |
+
# # Regular expression matching whitespace:
|
20 |
+
# _whitespace_re = re.compile(r"\s+")
|
21 |
+
|
22 |
+
# # List of (regular expression, replacement) pairs for abbreviations:
|
23 |
+
# _abbreviations = [
|
24 |
+
# (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
25 |
+
# for x in [
|
26 |
+
# ("mrs", "misess"),
|
27 |
+
# ("mr", "mister"),
|
28 |
+
# ("dr", "doctor"),
|
29 |
+
# ("st", "saint"),
|
30 |
+
# ("co", "company"),
|
31 |
+
# ("jr", "junior"),
|
32 |
+
# ("maj", "major"),
|
33 |
+
# ("gen", "general"),
|
34 |
+
# ("drs", "doctors"),
|
35 |
+
# ("rev", "reverend"),
|
36 |
+
# ("lt", "lieutenant"),
|
37 |
+
# ("hon", "honorable"),
|
38 |
+
# ("sgt", "sergeant"),
|
39 |
+
# ("capt", "captain"),
|
40 |
+
# ("esq", "esquire"),
|
41 |
+
# ("ltd", "limited"),
|
42 |
+
# ("col", "colonel"),
|
43 |
+
# ("ft", "fort"),
|
44 |
+
# ]
|
45 |
+
# ]
|
46 |
+
|
47 |
+
|
48 |
+
# def expand_abbreviations(text):
|
49 |
+
# for regex, replacement in _abbreviations:
|
50 |
+
# text = re.sub(regex, replacement, text)
|
51 |
+
# return text
|
52 |
+
|
53 |
+
|
54 |
+
# def expand_numbers(text):
|
55 |
+
# return normalize_numbers(text)
|
56 |
+
|
57 |
+
|
58 |
+
# def lowercase(text):
|
59 |
+
# return text.lower()
|
60 |
+
|
61 |
+
|
62 |
+
# def collapse_whitespace(text):
|
63 |
+
# return re.sub(_whitespace_re, " ", text)
|
64 |
+
|
65 |
+
|
66 |
+
# def convert_to_ascii(text):
|
67 |
+
# return unidecode(text)
|
68 |
+
|
69 |
+
|
70 |
+
# def basic_cleaners(text):
|
71 |
+
# """Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
72 |
+
# text = lowercase(text)
|
73 |
+
# text = collapse_whitespace(text)
|
74 |
+
# return text
|
75 |
+
|
76 |
+
|
77 |
+
# def transliteration_cleaners(text):
|
78 |
+
# """Pipeline for non-English text that transliterates to ASCII."""
|
79 |
+
# text = convert_to_ascii(text)
|
80 |
+
# text = lowercase(text)
|
81 |
+
# text = collapse_whitespace(text)
|
82 |
+
# return text
|
83 |
+
|
84 |
+
|
85 |
+
# def english_cleaners(text):
|
86 |
+
# """Pipeline for English text, including number and abbreviation expansion."""
|
87 |
+
# text = convert_to_ascii(text)
|
88 |
+
# text = lowercase(text)
|
89 |
+
# text = expand_numbers(text)
|
90 |
+
# text = expand_abbreviations(text)
|
91 |
+
# text = collapse_whitespace(text)
|
92 |
+
# return text
|
utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions
|
3 |
+
"""
|
4 |
+
import pickle
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import pax
|
8 |
+
import toml
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
from tacotron import Tacotron
|
12 |
+
|
13 |
+
|
14 |
+
def load_tacotron_config(config_file=Path("tacotron.toml")):
|
15 |
+
"""
|
16 |
+
Load the project configurations
|
17 |
+
"""
|
18 |
+
return toml.load(config_file)["tacotron"]
|
19 |
+
|
20 |
+
|
21 |
+
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
|
22 |
+
"""
|
23 |
+
load checkpoint from disk
|
24 |
+
"""
|
25 |
+
with open(path, "rb") as f:
|
26 |
+
dic = pickle.load(f)
|
27 |
+
if net is not None:
|
28 |
+
net = net.load_state_dict(dic["model_state_dict"])
|
29 |
+
if optim is not None:
|
30 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
31 |
+
return dic["step"], net, optim
|
32 |
+
|
33 |
+
|
34 |
+
def create_tacotron_model(config):
|
35 |
+
"""
|
36 |
+
return a random initialized Tacotron model
|
37 |
+
"""
|
38 |
+
return Tacotron(
|
39 |
+
mel_dim=config["MEL_DIM"],
|
40 |
+
attn_bias=config["ATTN_BIAS"],
|
41 |
+
rr=config["RR"],
|
42 |
+
max_rr=config["MAX_RR"],
|
43 |
+
mel_min=config["MEL_MIN"],
|
44 |
+
sigmoid_noise=config["SIGMOID_NOISE"],
|
45 |
+
pad_token=config["PAD_TOKEN"],
|
46 |
+
prenet_dim=config["PRENET_DIM"],
|
47 |
+
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
|
48 |
+
attn_rnn_dim=config["ATTN_RNN_DIM"],
|
49 |
+
rnn_dim=config["RNN_DIM"],
|
50 |
+
postnet_dim=config["POSTNET_DIM"],
|
51 |
+
text_dim=config["TEXT_DIM"],
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def load_wavegru_config(config_file):
|
56 |
+
"""
|
57 |
+
Load project configurations
|
58 |
+
"""
|
59 |
+
with open(config_file, "r", encoding="utf-8") as f:
|
60 |
+
return yaml.safe_load(f)
|
61 |
+
|
62 |
+
|
63 |
+
def load_wavegru_ckpt(net, optim, ckpt_file):
|
64 |
+
"""
|
65 |
+
load training checkpoint from file
|
66 |
+
"""
|
67 |
+
with open(ckpt_file, "rb") as f:
|
68 |
+
dic = pickle.load(f)
|
69 |
+
|
70 |
+
if net is not None:
|
71 |
+
net = net.load_state_dict(dic["net_state_dict"])
|
72 |
+
if optim is not None:
|
73 |
+
optim = optim.load_state_dict(dic["optim_state_dict"])
|
74 |
+
return dic["step"], net, optim
|
wavegru.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
WaveGRU model: melspectrogram => mu-law encoded waveform
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
import jax
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import pax
|
10 |
+
from pax import GRUState
|
11 |
+
from tqdm.cli import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
class ReLU(pax.Module):
|
15 |
+
def __call__(self, x):
|
16 |
+
return jax.nn.relu(x)
|
17 |
+
|
18 |
+
|
19 |
+
def dilated_residual_conv_block(dim, kernel, stride, dilation):
|
20 |
+
"""
|
21 |
+
Use dilated convs to enlarge the receptive field
|
22 |
+
"""
|
23 |
+
return pax.Sequential(
|
24 |
+
pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
|
25 |
+
pax.LayerNorm(dim, -1, True, True),
|
26 |
+
ReLU(),
|
27 |
+
pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
|
28 |
+
pax.LayerNorm(dim, -1, True, True),
|
29 |
+
ReLU(),
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def tile_1d(x, factor):
|
34 |
+
"""
|
35 |
+
Tile tensor of shape N, L, D into N, L*factor, D
|
36 |
+
"""
|
37 |
+
N, L, D = x.shape
|
38 |
+
x = x[:, :, None, :]
|
39 |
+
x = jnp.tile(x, (1, 1, factor, 1))
|
40 |
+
x = jnp.reshape(x, (N, L * factor, D))
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def up_block(in_dim, out_dim, factor, relu=True):
|
45 |
+
"""
|
46 |
+
Tile >> Conv >> BatchNorm >> ReLU
|
47 |
+
"""
|
48 |
+
f = pax.Sequential(
|
49 |
+
lambda x: tile_1d(x, factor),
|
50 |
+
pax.Conv1D(
|
51 |
+
in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
|
52 |
+
),
|
53 |
+
pax.LayerNorm(out_dim, -1, True, True),
|
54 |
+
)
|
55 |
+
if relu:
|
56 |
+
f >>= ReLU()
|
57 |
+
return f
|
58 |
+
|
59 |
+
|
60 |
+
class Upsample(pax.Module):
|
61 |
+
"""
|
62 |
+
Upsample melspectrogram to match raw audio sample rate.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.input_conv = pax.Sequential(
|
70 |
+
pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False),
|
71 |
+
pax.LayerNorm(hidden_dim, -1, True, True),
|
72 |
+
)
|
73 |
+
self.upsample_factors = upsample_factors
|
74 |
+
self.dilated_convs = [
|
75 |
+
dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5)
|
76 |
+
]
|
77 |
+
self.up_factors = upsample_factors[:-1]
|
78 |
+
self.up_blocks = [
|
79 |
+
up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1]
|
80 |
+
]
|
81 |
+
self.up_blocks.append(
|
82 |
+
up_block(
|
83 |
+
hidden_dim,
|
84 |
+
hidden_dim if has_linear_output else 3 * rnn_dim,
|
85 |
+
self.up_factors[-1],
|
86 |
+
relu=False,
|
87 |
+
)
|
88 |
+
)
|
89 |
+
if has_linear_output:
|
90 |
+
self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3)
|
91 |
+
self.has_linear_output = has_linear_output
|
92 |
+
|
93 |
+
self.final_tile = upsample_factors[-1]
|
94 |
+
|
95 |
+
def __call__(self, x, no_repeat=False):
|
96 |
+
x = self.input_conv(x)
|
97 |
+
for residual in self.dilated_convs:
|
98 |
+
y = residual(x)
|
99 |
+
pad = (x.shape[1] - y.shape[1]) // 2
|
100 |
+
x = x[:, pad:-pad, :] + y
|
101 |
+
|
102 |
+
for f in self.up_blocks:
|
103 |
+
x = f(x)
|
104 |
+
|
105 |
+
if self.has_linear_output:
|
106 |
+
x = self.x2zrh_fc(x)
|
107 |
+
|
108 |
+
if no_repeat:
|
109 |
+
return x
|
110 |
+
x = tile_1d(x, self.final_tile)
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
class GRU(pax.Module):
|
115 |
+
"""
|
116 |
+
A customized GRU module.
|
117 |
+
"""
|
118 |
+
|
119 |
+
input_dim: int
|
120 |
+
hidden_dim: int
|
121 |
+
|
122 |
+
def __init__(self, hidden_dim: int):
|
123 |
+
super().__init__()
|
124 |
+
self.hidden_dim = hidden_dim
|
125 |
+
self.h_zrh_fc = pax.Linear(
|
126 |
+
hidden_dim,
|
127 |
+
hidden_dim * 3,
|
128 |
+
w_init=jax.nn.initializers.variance_scaling(
|
129 |
+
1, "fan_out", "truncated_normal"
|
130 |
+
),
|
131 |
+
)
|
132 |
+
|
133 |
+
def initial_state(self, batch_size: int) -> GRUState:
|
134 |
+
"""Create an all zeros initial state."""
|
135 |
+
return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
|
136 |
+
|
137 |
+
def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
|
138 |
+
hidden = state.hidden
|
139 |
+
x_zrh = x
|
140 |
+
h_zrh = self.h_zrh_fc(hidden)
|
141 |
+
x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
|
142 |
+
h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
|
143 |
+
|
144 |
+
zr = x_zr + h_zr
|
145 |
+
zr = jax.nn.sigmoid(zr)
|
146 |
+
z, r = jnp.split(zr, 2, axis=-1)
|
147 |
+
|
148 |
+
h_hat = x_h + r * h_h
|
149 |
+
h_hat = jnp.tanh(h_hat)
|
150 |
+
|
151 |
+
h = (1 - z) * hidden + z * h_hat
|
152 |
+
return GRUState(h), h
|
153 |
+
|
154 |
+
|
155 |
+
class Pruner(pax.Module):
|
156 |
+
"""
|
157 |
+
Base class for pruners
|
158 |
+
"""
|
159 |
+
|
160 |
+
def compute_sparsity(self, step):
|
161 |
+
t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
|
162 |
+
z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
163 |
+
return z
|
164 |
+
|
165 |
+
def prune(self, step, weights):
|
166 |
+
"""
|
167 |
+
Return a mask
|
168 |
+
"""
|
169 |
+
z = self.compute_sparsity(step)
|
170 |
+
x = weights
|
171 |
+
H, W = x.shape
|
172 |
+
x = x.reshape(H // 4, 4, W // 4, 4)
|
173 |
+
x = jnp.abs(x)
|
174 |
+
x = jnp.sum(x, axis=(1, 3), keepdims=True)
|
175 |
+
q = jnp.quantile(jnp.reshape(x, (-1,)), z)
|
176 |
+
x = x >= q
|
177 |
+
x = jnp.tile(x, (1, 4, 1, 4))
|
178 |
+
x = jnp.reshape(x, (H, W))
|
179 |
+
return x
|
180 |
+
|
181 |
+
|
182 |
+
class GRUPruner(Pruner):
|
183 |
+
def __init__(self, gru):
|
184 |
+
super().__init__()
|
185 |
+
self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
|
186 |
+
|
187 |
+
def __call__(self, gru: pax.GRU):
|
188 |
+
"""
|
189 |
+
Apply mask after an optimization step
|
190 |
+
"""
|
191 |
+
zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
|
192 |
+
gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
|
193 |
+
return gru
|
194 |
+
|
195 |
+
def update_mask(self, step, gru: pax.GRU):
|
196 |
+
"""
|
197 |
+
Update internal masks
|
198 |
+
"""
|
199 |
+
z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
|
200 |
+
z_mask = self.prune(step, z_weight)
|
201 |
+
r_mask = self.prune(step, r_weight)
|
202 |
+
h_mask = self.prune(step, h_weight)
|
203 |
+
self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
|
204 |
+
|
205 |
+
|
206 |
+
class LinearPruner(Pruner):
|
207 |
+
def __init__(self, linear):
|
208 |
+
super().__init__()
|
209 |
+
self.mask = jnp.ones_like(linear.weight) == 1
|
210 |
+
|
211 |
+
def __call__(self, linear: pax.Linear):
|
212 |
+
"""
|
213 |
+
Apply mask after an optimization step
|
214 |
+
"""
|
215 |
+
return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
|
216 |
+
|
217 |
+
def update_mask(self, step, linear: pax.Linear):
|
218 |
+
"""
|
219 |
+
Update internal masks
|
220 |
+
"""
|
221 |
+
self.mask *= self.prune(step, linear.weight)
|
222 |
+
|
223 |
+
|
224 |
+
class WaveGRU(pax.Module):
|
225 |
+
"""
|
226 |
+
WaveGRU vocoder model.
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
mel_dim=80,
|
232 |
+
rnn_dim=1024,
|
233 |
+
upsample_factors=(5, 3, 20),
|
234 |
+
has_linear_output=False,
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.embed = pax.Embed(256, 3 * rnn_dim)
|
238 |
+
self.upsample = Upsample(
|
239 |
+
input_dim=mel_dim,
|
240 |
+
hidden_dim=512,
|
241 |
+
rnn_dim=rnn_dim,
|
242 |
+
upsample_factors=upsample_factors,
|
243 |
+
has_linear_output=has_linear_output,
|
244 |
+
)
|
245 |
+
self.rnn = GRU(rnn_dim)
|
246 |
+
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
247 |
+
self.o2 = pax.Linear(rnn_dim, 256)
|
248 |
+
self.gru_pruner = GRUPruner(self.rnn)
|
249 |
+
self.o1_pruner = LinearPruner(self.o1)
|
250 |
+
self.o2_pruner = LinearPruner(self.o2)
|
251 |
+
|
252 |
+
def output(self, x):
|
253 |
+
x = self.o1(x)
|
254 |
+
x = jax.nn.relu(x)
|
255 |
+
x = self.o2(x)
|
256 |
+
return x
|
257 |
+
|
258 |
+
def inference(self, mel, no_gru=False, seed=42):
|
259 |
+
"""
|
260 |
+
generate waveform form melspectrogram
|
261 |
+
"""
|
262 |
+
|
263 |
+
@jax.jit
|
264 |
+
def step(rnn_state, mel, rng_key, x):
|
265 |
+
x = self.embed(x)
|
266 |
+
x = x + mel
|
267 |
+
rnn_state, x = self.rnn(rnn_state, x)
|
268 |
+
x = self.output(x)
|
269 |
+
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
270 |
+
x = jax.random.categorical(rng_key, x, axis=-1)
|
271 |
+
return rnn_state, next_rng_key, x
|
272 |
+
|
273 |
+
y = self.upsample(mel, no_repeat=no_gru)
|
274 |
+
if no_gru:
|
275 |
+
return y
|
276 |
+
x = jnp.array([127], dtype=jnp.int32)
|
277 |
+
rnn_state = self.rnn.initial_state(1)
|
278 |
+
output = []
|
279 |
+
rng_key = jax.random.PRNGKey(seed)
|
280 |
+
for i in tqdm(range(y.shape[1])):
|
281 |
+
rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
|
282 |
+
output.append(x)
|
283 |
+
x = jnp.concatenate(output, axis=0)
|
284 |
+
return x
|
285 |
+
|
286 |
+
def __call__(self, mel, x):
|
287 |
+
x = self.embed(x)
|
288 |
+
y = self.upsample(mel)
|
289 |
+
pad_left = (x.shape[1] - y.shape[1]) // 2
|
290 |
+
pad_right = x.shape[1] - y.shape[1] - pad_left
|
291 |
+
x = x[:, pad_left:-pad_right]
|
292 |
+
x = x + y
|
293 |
+
_, x = pax.scan(
|
294 |
+
self.rnn,
|
295 |
+
self.rnn.initial_state(x.shape[0]),
|
296 |
+
x,
|
297 |
+
time_major=False,
|
298 |
+
)
|
299 |
+
x = self.output(x)
|
300 |
+
return x
|
wavegru.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## dsp
|
2 |
+
sample_rate : 24000
|
3 |
+
window_length: 50.0 # ms
|
4 |
+
hop_length: 12.5 # ms
|
5 |
+
mel_min: 1.0e-5 ## need .0 to make it a float
|
6 |
+
mel_dim: 80
|
7 |
+
n_fft: 2048
|
8 |
+
|
9 |
+
## wavegru
|
10 |
+
embed_dim: 32
|
11 |
+
rnn_dim: 1024
|
12 |
+
frames_per_sequence: 67
|
13 |
+
num_pad_frames: 62
|
14 |
+
upsample_factors: [5, 3, 20]
|
wavegru_cpp.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from wavegru_mod import WaveGRU
|
3 |
+
|
4 |
+
|
5 |
+
def extract_weight_mask(net):
|
6 |
+
data = {}
|
7 |
+
data["embed_weight"] = net.embed.weight
|
8 |
+
data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight
|
9 |
+
data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask
|
10 |
+
data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias
|
11 |
+
|
12 |
+
data["o1_weight"] = net.o1.weight
|
13 |
+
data["o1_mask"] = net.o1_pruner.mask
|
14 |
+
data["o1_bias"] = net.o1.bias
|
15 |
+
data["o2_weight"] = net.o2.weight
|
16 |
+
data["o2_mask"] = net.o2_pruner.mask
|
17 |
+
data["o2_bias"] = net.o2.bias
|
18 |
+
return data
|
19 |
+
|
20 |
+
|
21 |
+
def load_wavegru_cpp(data, repeat_factor):
|
22 |
+
"""load wavegru weight to cpp object"""
|
23 |
+
embed = data["embed_weight"]
|
24 |
+
rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3
|
25 |
+
net = WaveGRU(rnn_dim, repeat_factor)
|
26 |
+
net.load_embed(embed)
|
27 |
+
|
28 |
+
m = np.ascontiguousarray(data["gru_h_zrh_weight"].T)
|
29 |
+
mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T)
|
30 |
+
b = data["gru_h_zrh_bias"]
|
31 |
+
|
32 |
+
o1 = np.ascontiguousarray(data["o1_weight"].T)
|
33 |
+
masko1 = np.ascontiguousarray(data["o1_mask"].T)
|
34 |
+
o1b = data["o1_bias"]
|
35 |
+
|
36 |
+
o2 = np.ascontiguousarray(data["o2_weight"].T)
|
37 |
+
masko2 = np.ascontiguousarray(data["o2_mask"].T)
|
38 |
+
o2b = data["o2_bias"]
|
39 |
+
|
40 |
+
net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b)
|
41 |
+
|
42 |
+
return net
|
wavegru_mod.cc
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
WaveGRU:
|
3 |
+
> Embed > GRU > O1 > O2 > Sampling > ...
|
4 |
+
*/
|
5 |
+
|
6 |
+
#include <pybind11/numpy.h>
|
7 |
+
#include <pybind11/pybind11.h>
|
8 |
+
#include <pybind11/stl.h>
|
9 |
+
|
10 |
+
#include <iostream>
|
11 |
+
#include <random>
|
12 |
+
#include <vector>
|
13 |
+
|
14 |
+
#include "sparse_matmul/sparse_matmul.h"
|
15 |
+
namespace py = pybind11;
|
16 |
+
using namespace std;
|
17 |
+
|
18 |
+
using fvec = std::vector<float>;
|
19 |
+
using ivec = std::vector<int>;
|
20 |
+
using fndarray = py::array_t<float>;
|
21 |
+
using indarray = py::array_t<int>;
|
22 |
+
using mat = csrblocksparse::CsrBlockSparseMatrix<float, float, int16_t>;
|
23 |
+
using vec = csrblocksparse::CacheAlignedVector<float>;
|
24 |
+
using masked_mat = csrblocksparse::MaskedSparseMatrix<float>;
|
25 |
+
|
26 |
+
mat create_mat(int h, int w) {
|
27 |
+
auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true);
|
28 |
+
auto a = mat(m);
|
29 |
+
return a;
|
30 |
+
}
|
31 |
+
|
32 |
+
struct WaveGRU {
|
33 |
+
int hidden_dim;
|
34 |
+
int repeat_factor;
|
35 |
+
mat m;
|
36 |
+
vec b;
|
37 |
+
vec z, r, hh, zrh;
|
38 |
+
vec fco1, fco2;
|
39 |
+
vec o1b, o2b;
|
40 |
+
vec t;
|
41 |
+
vec h;
|
42 |
+
vec logits;
|
43 |
+
mat o1, o2;
|
44 |
+
std::vector<vec> embed;
|
45 |
+
|
46 |
+
WaveGRU(int hidden_dim, int repeat_factor)
|
47 |
+
: hidden_dim(hidden_dim),
|
48 |
+
repeat_factor(repeat_factor),
|
49 |
+
b(3*hidden_dim),
|
50 |
+
t(3*hidden_dim),
|
51 |
+
zrh(3*hidden_dim),
|
52 |
+
z(hidden_dim),
|
53 |
+
r(hidden_dim),
|
54 |
+
hh(hidden_dim),
|
55 |
+
fco1(hidden_dim),
|
56 |
+
fco2(256),
|
57 |
+
h(hidden_dim),
|
58 |
+
o1b(hidden_dim),
|
59 |
+
o2b(256),
|
60 |
+
logits(256) {
|
61 |
+
m = create_mat(hidden_dim, 3*hidden_dim);
|
62 |
+
o1 = create_mat(hidden_dim, hidden_dim);
|
63 |
+
o2 = create_mat(hidden_dim, 256);
|
64 |
+
embed = std::vector<vec>();
|
65 |
+
for (int i = 0; i < 256; i++) {
|
66 |
+
embed.emplace_back(hidden_dim * 3);
|
67 |
+
embed[i].FillRandom();
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
void load_embed(fndarray embed_weights) {
|
72 |
+
auto a_embed = embed_weights.unchecked<2>();
|
73 |
+
for (int i = 0; i < 256; i++) {
|
74 |
+
for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j);
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) {
|
79 |
+
auto w_ptr = static_cast<float*>(w.request().ptr);
|
80 |
+
auto mask_ptr = static_cast<int*>(mask.request().ptr);
|
81 |
+
auto rb = b.unchecked<1>();
|
82 |
+
// load bias, scale by 1/4
|
83 |
+
for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4;
|
84 |
+
// load weights
|
85 |
+
masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr);
|
86 |
+
mat mmm(mm);
|
87 |
+
return mmm;
|
88 |
+
}
|
89 |
+
|
90 |
+
void load_weights(fndarray m, indarray m_mask, fndarray b,
|
91 |
+
fndarray o1, indarray o1_mask,
|
92 |
+
fndarray o1b, fndarray o2,
|
93 |
+
indarray o2_mask, fndarray o2b) {
|
94 |
+
this->m = load_linear(this->b, m, m_mask, b);
|
95 |
+
this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
|
96 |
+
this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
|
97 |
+
}
|
98 |
+
|
99 |
+
std::vector<int> inference(fndarray ft, float temperature) {
|
100 |
+
auto rft = ft.unchecked<2>();
|
101 |
+
int value = 127;
|
102 |
+
std::vector<int> signal(rft.shape(0) * repeat_factor);
|
103 |
+
h.FillZero();
|
104 |
+
for (int index = 0; index < signal.size(); index++) {
|
105 |
+
m.SpMM_bias(h, b, &zrh, false);
|
106 |
+
|
107 |
+
for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i);
|
108 |
+
for (int i = 0; i < hidden_dim; i++) {
|
109 |
+
z[i] = zrh[i] + t[i];
|
110 |
+
r[i] = zrh[hidden_dim + i] + t[hidden_dim + i];
|
111 |
+
}
|
112 |
+
|
113 |
+
z.Sigmoid();
|
114 |
+
r.Sigmoid();
|
115 |
+
|
116 |
+
for (int i = 0; i < hidden_dim; i++) {
|
117 |
+
hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i];
|
118 |
+
}
|
119 |
+
hh.Tanh();
|
120 |
+
for (int i = 0; i < hidden_dim; i++) {
|
121 |
+
h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
|
122 |
+
}
|
123 |
+
o1.SpMM_bias(h, o1b, &fco1, true);
|
124 |
+
o2.SpMM_bias(fco1, o2b, &fco2, false);
|
125 |
+
// auto max_logit = fco2[0];
|
126 |
+
// for (int i = 1; i <= 255; ++i) {
|
127 |
+
// max_logit = max(max_logit, fco2[i]);
|
128 |
+
// }
|
129 |
+
// float total = 0.0;
|
130 |
+
// for (int i = 0; i <= 255; ++i) {
|
131 |
+
// logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
|
132 |
+
// total += logits[i];
|
133 |
+
// }
|
134 |
+
// for (int i = 0; i <= 255; ++i) {
|
135 |
+
// if (logits[i] < total / 1024.0) fco2[i] = -1e9;
|
136 |
+
// }
|
137 |
+
value = fco2.Sample(temperature);
|
138 |
+
signal[index] = value;
|
139 |
+
}
|
140 |
+
return signal;
|
141 |
+
}
|
142 |
+
};
|
143 |
+
|
144 |
+
PYBIND11_MODULE(wavegru_mod, m) {
|
145 |
+
py::class_<WaveGRU>(m, "WaveGRU")
|
146 |
+
.def(py::init<int, int>())
|
147 |
+
.def("load_embed", &WaveGRU::load_embed)
|
148 |
+
.def("load_weights", &WaveGRU::load_weights)
|
149 |
+
.def("inference", &WaveGRU::inference);
|
150 |
+
}
|