ntt123 commited on
Commit
587b6c9
1 Parent(s): 6907c30
Files changed (21) hide show
  1. BUILD +44 -0
  2. Dockerfile +32 -0
  3. WORKSPACE +154 -0
  4. alphabet.txt +97 -0
  5. app.py +148 -0
  6. build_ext.sh +4 -0
  7. extract_tacotrons_model.py +5 -0
  8. extract_wavegru_model.py +5 -0
  9. inference.py +90 -0
  10. mynumbers.py +73 -0
  11. packages.txt +7 -0
  12. pooch.py +10 -0
  13. requirements.txt +12 -0
  14. tacotron.py +451 -0
  15. tacotron.toml +32 -0
  16. text.py +92 -0
  17. utils.py +74 -0
  18. wavegru.py +300 -0
  19. wavegru.yaml +14 -0
  20. wavegru_cpp.py +42 -0
  21. wavegru_mod.cc +150 -0
BUILD ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [internal] load cc_fuzz_target.bzl
2
+ # [internal] load cc_proto_library.bzl
3
+ # [internal] load android_cc_test:def.bzl
4
+
5
+ load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
6
+
7
+ package(default_visibility = [":__subpackages__"])
8
+
9
+ licenses(["notice"])
10
+
11
+ # To run all cc_tests in this directory:
12
+ # bazel test //:all
13
+
14
+ # [internal] Command to run dsp_util_android_test.
15
+
16
+ # [internal] Command to run lyra_integration_android_test.
17
+
18
+ exports_files(
19
+ srcs = [
20
+ "wavegru_mod.cc",
21
+ ],
22
+ )
23
+
24
+ pybind_extension(
25
+ name = "wavegru_mod", # This name is not actually created!
26
+ srcs = ["wavegru_mod.cc"],
27
+ deps = [
28
+ "//sparse_matmul",
29
+ ],
30
+ )
31
+
32
+ py_library(
33
+ name = "wavegru_mod",
34
+ data = [":wavegru_mod.so"],
35
+ )
36
+
37
+ py_binary(
38
+ name = "wavegru",
39
+ srcs = ["wavegru.py"],
40
+ deps = [
41
+ ":wavegru_mod"
42
+ ],
43
+ )
44
+
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11
5
+
6
+ RUN apt update; apt install libsndfile1-dev make autoconf automake libtool gcc pkg-config -y
7
+
8
+ WORKDIR /code
9
+
10
+ COPY ./requirements.txt /code/requirements.txt
11
+
12
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
13
+
14
+ # Set up a new user named "user" with user ID 1000
15
+ RUN useradd -m -u 1000 user
16
+
17
+ # Switch to the "user" user
18
+ USER user
19
+
20
+ # Set home to the user's home directory
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ # Set the working directory to the user's home directory
25
+ WORKDIR $HOME/app
26
+
27
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
28
+ COPY --chown=user . $HOME/app
29
+
30
+ RUN bash ./build_ext.sh
31
+
32
+ CMD ["python", "main.py"]
WORKSPACE ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################
2
+ # Platform Independent #
3
+ ########################
4
+
5
+ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
6
+ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
7
+
8
+ # GoogleTest/GoogleMock framework.
9
+ git_repository(
10
+ name = "com_google_googletest",
11
+ remote = "https://github.com/google/googletest.git",
12
+ tag = "release-1.10.0",
13
+ )
14
+
15
+ # Google benchmark.
16
+ http_archive(
17
+ name = "com_github_google_benchmark",
18
+ urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
19
+ strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
20
+ sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
21
+ )
22
+
23
+ # proto_library, cc_proto_library, and java_proto_library rules implicitly
24
+ # depend on @com_google_protobuf for protoc and proto runtimes.
25
+ # This statement defines the @com_google_protobuf repo.
26
+ git_repository(
27
+ name = "com_google_protobuf",
28
+ remote = "https://github.com/protocolbuffers/protobuf.git",
29
+ tag = "v3.15.4",
30
+ )
31
+
32
+ load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
33
+ protobuf_deps()
34
+
35
+ # Google Abseil Libs
36
+ git_repository(
37
+ name = "com_google_absl",
38
+ remote = "https://github.com/abseil/abseil-cpp.git",
39
+ branch = "lts_2020_09_23",
40
+ )
41
+
42
+ # Filesystem
43
+ # The new_* prefix is used because it is not a bazel project and there is
44
+ # no BUILD file in that repo.
45
+ FILESYSTEM_BUILD = """
46
+ cc_library(
47
+ name = "filesystem",
48
+ hdrs = glob(["include/ghc/*"]),
49
+ visibility = ["//visibility:public"],
50
+ )
51
+ """
52
+
53
+ new_git_repository(
54
+ name = "gulrak_filesystem",
55
+ remote = "https://github.com/gulrak/filesystem.git",
56
+ tag = "v1.3.6",
57
+ build_file_content = FILESYSTEM_BUILD
58
+ )
59
+
60
+ # Audio DSP
61
+ git_repository(
62
+ name = "com_google_audio_dsp",
63
+ remote = "https://github.com/google/multichannel-audio-tools.git",
64
+ # There are no tags for this repo, we are synced to bleeding edge.
65
+ branch = "master",
66
+ repo_mapping = {
67
+ "@com_github_glog_glog" : "@com_google_glog"
68
+ }
69
+ )
70
+
71
+
72
+ http_archive(
73
+ name = "pybind11_bazel",
74
+ strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
75
+ urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
76
+ )
77
+ # We still require the pybind library.
78
+ http_archive(
79
+ name = "pybind11",
80
+ build_file = "@pybind11_bazel//:pybind11.BUILD",
81
+ strip_prefix = "pybind11-2.9.0",
82
+ urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"],
83
+ )
84
+ load("@pybind11_bazel//:python_configure.bzl", "python_configure")
85
+ python_configure(name = "local_config_python")
86
+
87
+
88
+
89
+ # Transitive dependencies of Audio DSP.
90
+ http_archive(
91
+ name = "eigen_archive",
92
+ build_file = "eigen.BUILD",
93
+ sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
94
+ strip_prefix = "eigen-eigen-049af2f56331",
95
+ urls = [
96
+ "http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
97
+ "https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
98
+ ],
99
+ )
100
+
101
+ http_archive(
102
+ name = "fft2d",
103
+ build_file = "fft2d.BUILD",
104
+ sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
105
+ urls = [
106
+ "http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
107
+ ],
108
+ )
109
+
110
+ # Google logging
111
+ git_repository(
112
+ name = "com_google_glog",
113
+ remote = "https://github.com/google/glog.git",
114
+ branch = "master"
115
+ )
116
+ # Dependency for glog
117
+ git_repository(
118
+ name = "com_github_gflags_gflags",
119
+ remote = "https://github.com/mchinen/gflags.git",
120
+ branch = "android_linking_fix"
121
+ )
122
+
123
+ # Bazel/build rules
124
+
125
+ http_archive(
126
+ name = "bazel_skylib",
127
+ urls = [
128
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
129
+ "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
130
+ ],
131
+ sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
132
+ )
133
+ load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
134
+ bazel_skylib_workspace()
135
+
136
+ http_archive(
137
+ name = "rules_android",
138
+ sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
139
+ strip_prefix = "rules_android-0.1.1",
140
+ urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
141
+ )
142
+
143
+ # Google Maven Repository
144
+ GMAVEN_TAG = "20180625-1"
145
+
146
+ http_archive(
147
+ name = "gmaven_rules",
148
+ strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
149
+ url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
150
+ )
151
+
152
+ load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
153
+
154
+ gmaven_rules()
alphabet.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _
2
+
3
+
4
+ !
5
+ ,
6
+ .
7
+ :
8
+ ?
9
+ a
10
+ b
11
+ c
12
+ d
13
+ e
14
+ g
15
+ h
16
+ i
17
+ k
18
+ l
19
+ m
20
+ n
21
+ o
22
+ p
23
+ q
24
+ r
25
+ s
26
+ t
27
+ u
28
+ v
29
+ x
30
+ y
31
+ à
32
+ á
33
+ â
34
+ ã
35
+ è
36
+ é
37
+ ê
38
+ ì
39
+ í
40
+ ò
41
+ ó
42
+ ô
43
+ õ
44
+ ù
45
+ ú
46
+ ý
47
+ ă
48
+ đ
49
+ ĩ
50
+ ũ
51
+ ơ
52
+ ư
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+ ế
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## build wavegru-cpp
2
+ # import os
3
+ # os.system("./bazelisk-linux-amd64 clean --expunge")
4
+ # os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
5
+
6
+ # install espeak
7
+ import os
8
+ import re
9
+ import unicodedata
10
+ import regex
11
+
12
+ if not os.path.isfile("./wavegru_mod.so"):
13
+ os.system("bash ./build_ext.sh")
14
+
15
+ import gradio as gr
16
+ from inference import load_tacotron_model, load_wavegru_net, mel_to_wav, text_to_mel
17
+ from wavegru_cpp import extract_weight_mask, load_wavegru_cpp
18
+
19
+
20
+ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
21
+ "./alphabet.txt", "./tacotron.toml", "./mono_tts_cbhg_small_0700000.ckpt"
22
+ )
23
+
24
+ wavegru_config, wavegru_net = load_wavegru_net(
25
+ "./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt"
26
+ )
27
+
28
+ wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
29
+ wavecpp = load_wavegru_cpp(
30
+ wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1]
31
+ )
32
+
33
+
34
+ space_re = regex.compile(r"\s+")
35
+ number_re = regex.compile("([0-9]+)")
36
+ digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
37
+ num_re = regex.compile(r"([0-9.,]*[0-9])")
38
+ alphabet_ = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx"
39
+ keep_text_and_num_re = regex.compile(rf"[^\s{alphabet_}.,0-9]")
40
+ keep_text_re = regex.compile(rf"[^\s{alphabet_}]")
41
+
42
+
43
+ def read_number(num: str) -> str:
44
+ if len(num) == 1:
45
+ return digits[int(num)]
46
+ elif len(num) == 2 and num.isdigit():
47
+ n = int(num)
48
+ end = digits[n % 10]
49
+ if n == 10:
50
+ return "mười"
51
+ if n % 10 == 5:
52
+ end = "lăm"
53
+ if n % 10 == 0:
54
+ return digits[n // 10] + " mươi"
55
+ elif n < 20:
56
+ return "mười " + end
57
+ else:
58
+ if n % 10 == 1:
59
+ end = "mốt"
60
+ return digits[n // 10] + " mươi " + end
61
+ elif len(num) == 3 and num.isdigit():
62
+ n = int(num)
63
+ if n % 100 == 0:
64
+ return digits[n // 100] + " trăm"
65
+ elif num[1] == "0":
66
+ return digits[n // 100] + " trăm lẻ " + digits[n % 100]
67
+ else:
68
+ return digits[n // 100] + " trăm " + read_number(num[1:])
69
+ elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
70
+ n = int(num)
71
+ n1 = n // 1000
72
+ return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
73
+ elif "," in num:
74
+ n1, n2 = num.split(",")
75
+ return read_number(n1) + " phẩy " + read_number(n2)
76
+ elif "." in num:
77
+ parts = num.split(".")
78
+ if len(parts) == 2:
79
+ if parts[1] == "000":
80
+ return read_number(parts[0]) + " ngàn"
81
+ elif parts[1].startswith("00"):
82
+ end = digits[int(parts[1][2:])]
83
+ return read_number(parts[0]) + " ngàn lẻ " + end
84
+ else:
85
+ return read_number(parts[0]) + " ngàn " + read_number(parts[1])
86
+ elif len(parts) == 3:
87
+ return (
88
+ read_number(parts[0])
89
+ + " triệu "
90
+ + read_number(parts[1])
91
+ + " ngàn "
92
+ + read_number(parts[2])
93
+ )
94
+ return num
95
+
96
+
97
+ def normalize_text(text):
98
+ # lowercase
99
+ text = text.lower()
100
+ # unicode normalize
101
+ text = unicodedata.normalize("NFKC", text)
102
+ text = text.replace(".", ". ")
103
+ text = text.replace(",", ", ")
104
+ text = text.replace(";", "; ")
105
+ text = text.replace(":", ": ")
106
+ text = text.replace("!", "! ")
107
+ text = text.replace("?", "? ")
108
+ text = text.replace("(", "( ")
109
+
110
+ text = num_re.sub(r" \1 ", text)
111
+ words = text.split()
112
+ words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
113
+ text = " ".join(words)
114
+
115
+ # remove redundant spaces
116
+ text = re.sub(r"\s+", " ", text)
117
+ # remove leading and trailing spaces
118
+ text = text.strip()
119
+ return text
120
+
121
+
122
+ def speak(text):
123
+ text = normalize_text(text)
124
+ mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
125
+ y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
126
+ return 24_000, y
127
+
128
+
129
+ title = "WaveGRU-TTS"
130
+ description = "WaveGRU text-to-speech demo."
131
+
132
+ gr.Interface(
133
+ fn=speak,
134
+ inputs="text",
135
+ examples=[
136
+ "Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
137
+ "Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
138
+ "Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
139
+ "Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học l��n của Việt Nam trong thời phong kiến",
140
+ "Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
141
+ ],
142
+ outputs="audio",
143
+ title=title,
144
+ description=description,
145
+ theme="default",
146
+ allow_screenshot=False,
147
+ allow_flagging="never",
148
+ ).launch(enable_queue=True)
build_ext.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pip install -U pip
2
+ pip install gradio==3.42.0
3
+ USE_BAZEL_VERSION=5.0.0 ./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native
4
+ cp -f bazel-bin/wavegru_mod.so .
extract_tacotrons_model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb"))
4
+ del dic["optim_state_dict"]
5
+ pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))
extract_wavegru_model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ dic = pickle.load(open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb"))
4
+ del dic["optim_state_dict"]
5
+ pickle.dump(dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb"))
inference.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import librosa
6
+ import numpy as np
7
+ import pax
8
+
9
+ # from text import english_cleaners
10
+ from utils import (
11
+ create_tacotron_model,
12
+ load_tacotron_ckpt,
13
+ load_tacotron_config,
14
+ load_wavegru_ckpt,
15
+ load_wavegru_config,
16
+ )
17
+ from wavegru import WaveGRU
18
+
19
+ # os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = "./espeak/usr/lib/libespeak-ng.so.1.1.51"
20
+ # from phonemizer.backend import EspeakBackend
21
+ # backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
22
+
23
+
24
+ def load_tacotron_model(alphabet_file, config_file, model_file):
25
+ """load tacotron model to memory"""
26
+ with open(alphabet_file, "r", encoding="utf-8") as f:
27
+ alphabet = f.read().split("\n")
28
+
29
+ config = load_tacotron_config(config_file)
30
+ net = create_tacotron_model(config)
31
+ _, net, _ = load_tacotron_ckpt(net, None, model_file)
32
+ net = net.eval()
33
+ net = jax.device_put(net)
34
+ return alphabet, net, config
35
+
36
+
37
+ tacotron_inference_fn = pax.pure(lambda net, text: net.inference(text, max_len=2400))
38
+
39
+
40
+ def text_to_mel(net, text, alphabet, config):
41
+ """convert text to mel spectrogram"""
42
+ # text = english_cleaners(text)
43
+ # text = backend.phonemize([text], strip=True)[0]
44
+ text = text + config["END_CHARACTER"]
45
+ text = text + config["PAD"] * (100 - (len(text) % 100))
46
+ tokens = []
47
+ for c in text:
48
+ if c in alphabet:
49
+ tokens.append(alphabet.index(c))
50
+ tokens = jnp.array(tokens, dtype=jnp.int32)
51
+ mel = tacotron_inference_fn(net, tokens[None])
52
+ return mel
53
+
54
+
55
+ def load_wavegru_net(config_file, model_file):
56
+ """load wavegru to memory"""
57
+ config = load_wavegru_config(config_file)
58
+ net = WaveGRU(
59
+ mel_dim=config["mel_dim"],
60
+ rnn_dim=config["rnn_dim"],
61
+ upsample_factors=config["upsample_factors"],
62
+ has_linear_output=True,
63
+ )
64
+ _, net, _ = load_wavegru_ckpt(net, None, model_file)
65
+ net = net.eval()
66
+ net = jax.device_put(net)
67
+ return config, net
68
+
69
+
70
+ wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
71
+
72
+
73
+ def mel_to_wav(net, netcpp, mel, config):
74
+ """convert mel to wav"""
75
+ if len(mel.shape) == 2:
76
+ mel = mel[None]
77
+ pad = config["num_pad_frames"] // 2 + 2
78
+ mel = np.pad(mel, [(0, 0), (pad, pad), (0, 0)], mode="edge")
79
+ ft = wavegru_inference(net, mel)
80
+ ft = jax.device_get(ft[0])
81
+ wav = netcpp.inference(ft, 1.0)
82
+ wav = np.array(wav)
83
+ wav = librosa.mu_expand(wav - 127, mu=255)
84
+ wav = librosa.effects.deemphasis(wav, coef=0.86)
85
+ wav = wav * 2.0
86
+ wav = wav / max(1.0, np.max(np.abs(wav)))
87
+ wav = wav * 2**15
88
+ wav = np.clip(wav, a_min=-(2**15), a_max=(2**15) - 1)
89
+ wav = wav.astype(np.int16)
90
+ return wav
mynumbers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import inflect
4
+ import re
5
+
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split(".")
27
+ if len(parts) > 2:
28
+ return match + " dollars" # Unexpected format
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
33
+ cent_unit = "cent" if cents == 1 else "cents"
34
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35
+ elif dollars:
36
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
37
+ return "%s %s" % (dollars, dollar_unit)
38
+ elif cents:
39
+ cent_unit = "cent" if cents == 1 else "cents"
40
+ return "%s %s" % (cents, cent_unit)
41
+ else:
42
+ return "zero dollars"
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return "two thousand"
54
+ elif num > 2000 and num < 2010:
55
+ return "two thousand " + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + " hundred"
58
+ else:
59
+ return _inflect.number_to_words(
60
+ num, andword="", zero="oh", group=2
61
+ ).replace(", ", " ")
62
+ else:
63
+ return _inflect.number_to_words(num, andword="")
64
+
65
+
66
+ def normalize_numbers(text):
67
+ text = re.sub(_comma_number_re, _remove_commas, text)
68
+ text = re.sub(_pounds_re, r"\1 pounds", text)
69
+ text = re.sub(_dollars_re, _expand_dollars, text)
70
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
72
+ text = re.sub(_number_re, _expand_number, text)
73
+ return text
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ libsndfile1-dev
2
+ make
3
+ autoconf
4
+ automake
5
+ libtool
6
+ gcc
7
+ pkg-config
pooch.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ def os_cache(x):
2
+ return x
3
+
4
+
5
+ def create(*args, **kwargs):
6
+ class T:
7
+ def load_registry(self, *args, **kwargs):
8
+ return None
9
+
10
+ return T()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ inflect
2
+ jax
3
+ jaxlib
4
+ jinja2
5
+ librosa
6
+ numpy
7
+ pax3
8
+ pyyaml
9
+ toml
10
+ unidecode
11
+ phonemizer
12
+ gradio==3.42.0
tacotron.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tacotron + stepwise monotonic attention
3
+ """
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import pax
8
+
9
+
10
+ def conv_block(in_ft, out_ft, kernel_size, activation_fn, use_dropout):
11
+ """
12
+ Conv >> LayerNorm >> activation >> Dropout
13
+ """
14
+ f = pax.Sequential(
15
+ pax.Conv1D(in_ft, out_ft, kernel_size, with_bias=False),
16
+ pax.LayerNorm(out_ft, -1, True, True),
17
+ )
18
+ if activation_fn is not None:
19
+ f >>= activation_fn
20
+ if use_dropout:
21
+ f >>= pax.Dropout(0.5)
22
+ return f
23
+
24
+
25
+ class HighwayBlock(pax.Module):
26
+ """
27
+ Highway block
28
+ """
29
+
30
+ def __init__(self, dim: int) -> None:
31
+ super().__init__()
32
+ self.dim = dim
33
+ self.fc = pax.Linear(dim, 2 * dim)
34
+
35
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
36
+ t, h = jnp.split(self.fc(x), 2, axis=-1)
37
+ t = jax.nn.sigmoid(t - 1.0) # bias toward keeping x
38
+ h = jax.nn.relu(h)
39
+ x = x * (1.0 - t) + h * t
40
+ return x
41
+
42
+
43
+ class BiGRU(pax.Module):
44
+ """
45
+ Bidirectional GRU
46
+ """
47
+
48
+ def __init__(self, dim):
49
+ super().__init__()
50
+
51
+ self.rnn_fwd = pax.GRU(dim, dim)
52
+ self.rnn_bwd = pax.GRU(dim, dim)
53
+
54
+ def __call__(self, x, reset_masks):
55
+ N = x.shape[0]
56
+ x_fwd = x
57
+ x_bwd = jnp.flip(x, axis=1)
58
+ x_fwd_states = self.rnn_fwd.initial_state(N)
59
+ x_bwd_states = self.rnn_bwd.initial_state(N)
60
+ x_fwd_states, x_fwd = pax.scan(
61
+ self.rnn_fwd, x_fwd_states, x_fwd, time_major=False
62
+ )
63
+
64
+ reset_masks = jnp.flip(reset_masks, axis=1)
65
+ x_bwd_states0 = x_bwd_states
66
+
67
+ def rnn_reset_core(prev, inputs):
68
+ x, reset_mask = inputs
69
+
70
+ def reset_state(x0, xt):
71
+ return jnp.where(reset_mask, x0, xt)
72
+
73
+ state, _ = self.rnn_bwd(prev, x)
74
+ state = jax.tree_map(reset_state, x_bwd_states0, state)
75
+ return state, state.hidden
76
+
77
+ x_bwd_states, x_bwd = pax.scan(
78
+ rnn_reset_core, x_bwd_states, (x_bwd, reset_masks), time_major=False
79
+ )
80
+ x_bwd = jnp.flip(x_bwd, axis=1)
81
+ x = jnp.concatenate((x_fwd, x_bwd), axis=-1)
82
+ return x
83
+
84
+
85
+ class CBHG(pax.Module):
86
+ """
87
+ Conv Bank >> Highway net >> GRU
88
+ """
89
+
90
+ def __init__(self, dim):
91
+ super().__init__()
92
+ self.convs = [conv_block(dim, dim, i, jax.nn.relu, False) for i in range(1, 17)]
93
+ self.conv_projection_1 = conv_block(16 * dim, dim, 3, jax.nn.relu, False)
94
+ self.conv_projection_2 = conv_block(dim, dim, 3, None, False)
95
+
96
+ self.highway = pax.Sequential(
97
+ HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim), HighwayBlock(dim)
98
+ )
99
+ self.rnn = BiGRU(dim)
100
+
101
+ def __call__(self, x, x_mask):
102
+ conv_input = x * x_mask
103
+ fts = [f(conv_input) for f in self.convs]
104
+ residual = jnp.concatenate(fts, axis=-1)
105
+ residual = pax.max_pool(residual, 2, 1, "SAME", -1)
106
+ residual = self.conv_projection_1(residual * x_mask)
107
+ residual = self.conv_projection_2(residual * x_mask)
108
+ x = x + residual
109
+ x = self.highway(x)
110
+ x = self.rnn(x * x_mask, reset_masks=1 - x_mask)
111
+ return x * x_mask
112
+
113
+
114
+ class PreNet(pax.Module):
115
+ """
116
+ Linear >> relu >> dropout >> Linear >> relu >> dropout
117
+ """
118
+
119
+ def __init__(self, input_dim, hidden_dim, output_dim, always_dropout=True):
120
+ super().__init__()
121
+ self.fc1 = pax.Linear(input_dim, hidden_dim)
122
+ self.fc2 = pax.Linear(hidden_dim, output_dim)
123
+ self.rng_seq = pax.RngSeq()
124
+ self.always_dropout = always_dropout
125
+
126
+ def __call__(self, x, k1=None, k2=None):
127
+ x = self.fc1(x)
128
+ x = jax.nn.relu(x)
129
+ if self.always_dropout or self.training:
130
+ if k1 is None:
131
+ k1 = self.rng_seq.next_rng_key()
132
+ x = pax.dropout(k1, 0.5, x)
133
+ x = self.fc2(x)
134
+ x = jax.nn.relu(x)
135
+ if self.always_dropout or self.training:
136
+ if k2 is None:
137
+ k2 = self.rng_seq.next_rng_key()
138
+ x = pax.dropout(k2, 0.5, x)
139
+ return x
140
+
141
+
142
+ class Tacotron(pax.Module):
143
+ """
144
+ Tacotron TTS model.
145
+
146
+ It uses stepwise monotonic attention for robust attention.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ mel_dim: int,
152
+ attn_bias,
153
+ rr,
154
+ max_rr,
155
+ mel_min,
156
+ sigmoid_noise,
157
+ pad_token,
158
+ prenet_dim,
159
+ attn_hidden_dim,
160
+ attn_rnn_dim,
161
+ rnn_dim,
162
+ postnet_dim,
163
+ text_dim,
164
+ ):
165
+ """
166
+ New Tacotron model
167
+
168
+ Args:
169
+ mel_dim (int): dimension of log mel-spectrogram features.
170
+ attn_bias (float): control how "slow" the attention will
171
+ move forward at initialization.
172
+ rr (int): the reduction factor.
173
+ Number of predicted frame at each time step. Default is 2.
174
+ max_rr (int): max value of rr.
175
+ mel_min (float): the minimum value of mel features.
176
+ The <go> frame is filled by `log(mel_min)` values.
177
+ sigmoid_noise (float): the variance of gaussian noise added
178
+ to attention scores in training.
179
+ pad_token (int): the pad value at the end of text sequences.
180
+ prenet_dim (int): dimension of prenet output.
181
+ attn_hidden_dim (int): dimension of attention hidden vectors.
182
+ attn_rnn_dim (int): number of cells in the attention RNN.
183
+ rnn_dim (int): number of cells in the decoder RNNs.
184
+ postnet_dim (int): number of features in the postnet convolutions.
185
+ text_dim (int): dimension of text embedding vectors.
186
+ """
187
+ super().__init__()
188
+ self.text_dim = text_dim
189
+ assert rr <= max_rr
190
+ self.rr = rr
191
+ self.max_rr = max_rr
192
+ self.mel_dim = mel_dim
193
+ self.mel_min = mel_min
194
+ self.sigmoid_noise = sigmoid_noise
195
+ self.pad_token = pad_token
196
+ self.prenet_dim = prenet_dim
197
+
198
+ # encoder submodules
199
+ self.encoder_embed = pax.Embed(256, text_dim)
200
+ self.encoder_pre_net = PreNet(text_dim, 256, prenet_dim, always_dropout=True)
201
+ self.encoder_cbhg = CBHG(prenet_dim)
202
+
203
+ # random key generator
204
+ self.rng_seq = pax.RngSeq()
205
+
206
+ # pre-net
207
+ self.decoder_pre_net = PreNet(mel_dim, 256, prenet_dim, always_dropout=True)
208
+
209
+ # decoder submodules
210
+ self.attn_rnn = pax.LSTM(prenet_dim + prenet_dim * 2, attn_rnn_dim)
211
+ self.text_key_fc = pax.Linear(prenet_dim * 2, attn_hidden_dim, with_bias=True)
212
+ self.attn_query_fc = pax.Linear(attn_rnn_dim, attn_hidden_dim, with_bias=False)
213
+
214
+ self.attn_V = pax.Linear(attn_hidden_dim, 1, with_bias=False)
215
+ self.attn_V_weight_norm = jnp.array(1.0 / jnp.sqrt(attn_hidden_dim))
216
+ self.attn_V_bias = jnp.array(attn_bias)
217
+ self.attn_log = jnp.zeros((1,))
218
+ self.decoder_input = pax.Linear(attn_rnn_dim + 2 * prenet_dim, rnn_dim)
219
+ self.decoder_rnn1 = pax.LSTM(rnn_dim, rnn_dim)
220
+ self.decoder_rnn2 = pax.LSTM(rnn_dim, rnn_dim)
221
+ # mel + end-of-sequence token
222
+ self.output_fc = pax.Linear(rnn_dim, (mel_dim + 1) * max_rr, with_bias=True)
223
+
224
+ # post-net
225
+ self.post_net = pax.Sequential(
226
+ conv_block(mel_dim, postnet_dim, 5, jax.nn.tanh, True),
227
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
228
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
229
+ conv_block(postnet_dim, postnet_dim, 5, jax.nn.tanh, True),
230
+ conv_block(postnet_dim, mel_dim, 5, None, True),
231
+ )
232
+
233
+ parameters = pax.parameters_method("attn_V_weight_norm", "attn_V_bias")
234
+
235
+ def encode_text(self, text: jnp.ndarray) -> jnp.ndarray:
236
+ """
237
+ Encode text to a sequence of real vectors
238
+ """
239
+ N, L = text.shape
240
+ text_mask = (text != self.pad_token)[..., None]
241
+ x = self.encoder_embed(text)
242
+ x = self.encoder_pre_net(x)
243
+ x = self.encoder_cbhg(x, text_mask)
244
+ return x
245
+
246
+ def go_frame(self, batch_size: int) -> jnp.ndarray:
247
+ """
248
+ return the go frame
249
+ """
250
+ return jnp.ones((batch_size, self.mel_dim)) * jnp.log(self.mel_min)
251
+
252
+ def decoder_initial_state(self, N: int, L: int):
253
+ """
254
+ setup decoder initial state
255
+ """
256
+ attn_context = jnp.zeros((N, self.prenet_dim * 2))
257
+ attn_pr = jax.nn.one_hot(
258
+ jnp.zeros((N,), dtype=jnp.int32), num_classes=L, axis=-1
259
+ )
260
+
261
+ attn_state = (self.attn_rnn.initial_state(N), attn_context, attn_pr)
262
+ decoder_rnn_states = (
263
+ self.decoder_rnn1.initial_state(N),
264
+ self.decoder_rnn2.initial_state(N),
265
+ )
266
+ return attn_state, decoder_rnn_states
267
+
268
+ def monotonic_attention(self, prev_state, inputs, envs):
269
+ """
270
+ Stepwise monotonic attention
271
+ """
272
+ attn_rnn_state, attn_context, prev_attn_pr = prev_state
273
+ x, attn_rng_key = inputs
274
+ text, text_key = envs
275
+ attn_rnn_input = jnp.concatenate((x, attn_context), axis=-1)
276
+ attn_rnn_state, attn_rnn_output = self.attn_rnn(attn_rnn_state, attn_rnn_input)
277
+ attn_query_input = attn_rnn_output
278
+ attn_query = self.attn_query_fc(attn_query_input)
279
+ attn_hidden = jnp.tanh(attn_query[:, None, :] + text_key)
280
+ score = self.attn_V(attn_hidden)
281
+ score = jnp.squeeze(score, axis=-1)
282
+ weight_norm = jnp.linalg.norm(self.attn_V.weight)
283
+ score = score * (self.attn_V_weight_norm / weight_norm)
284
+ score = score + self.attn_V_bias
285
+ noise = jax.random.normal(attn_rng_key, score.shape) * self.sigmoid_noise
286
+ pr_stay = jax.nn.sigmoid(score + noise)
287
+ pr_move = 1.0 - pr_stay
288
+ pr_new_location = pr_move * prev_attn_pr
289
+ pr_new_location = jnp.pad(
290
+ pr_new_location[:, :-1], ((0, 0), (1, 0)), constant_values=0
291
+ )
292
+ attn_pr = pr_stay * prev_attn_pr + pr_new_location
293
+ attn_context = jnp.einsum("NL,NLD->ND", attn_pr, text)
294
+ new_state = (attn_rnn_state, attn_context, attn_pr)
295
+ return new_state, attn_rnn_output
296
+
297
+ def zoneout_lstm(self, lstm_core, rng_key, zoneout_pr=0.1):
298
+ """
299
+ Return a zoneout lstm core.
300
+
301
+ It will zoneout the new hidden states and keep the new cell states unchanged.
302
+ """
303
+
304
+ def core(state, x):
305
+ new_state, _ = lstm_core(state, x)
306
+ h_old = state.hidden
307
+ h_new = new_state.hidden
308
+ mask = jax.random.bernoulli(rng_key, zoneout_pr, h_old.shape)
309
+ h_new = h_old * mask + h_new * (1.0 - mask)
310
+ return pax.LSTMState(h_new, new_state.cell), h_new
311
+
312
+ return core
313
+
314
+ def decoder_step(
315
+ self,
316
+ attn_state,
317
+ decoder_rnn_states,
318
+ rng_key,
319
+ mel,
320
+ text,
321
+ text_key,
322
+ call_pre_net=False,
323
+ ):
324
+ """
325
+ One decoder step
326
+ """
327
+ if call_pre_net:
328
+ k1, k2, zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 6)
329
+ mel = self.decoder_pre_net(mel, k1, k2)
330
+ else:
331
+ zk1, zk2, rng_key, rng_key_next = jax.random.split(rng_key, 4)
332
+ attn_inputs = (mel, rng_key)
333
+ attn_envs = (text, text_key)
334
+ attn_state, attn_rnn_output = self.monotonic_attention(
335
+ attn_state, attn_inputs, attn_envs
336
+ )
337
+ (_, attn_context, attn_pr) = attn_state
338
+ (decoder_rnn_state1, decoder_rnn_state2) = decoder_rnn_states
339
+ decoder_rnn1_input = jnp.concatenate((attn_rnn_output, attn_context), axis=-1)
340
+ decoder_rnn1_input = self.decoder_input(decoder_rnn1_input)
341
+ decoder_rnn1 = self.zoneout_lstm(self.decoder_rnn1, zk1)
342
+ decoder_rnn_state1, decoder_rnn_output1 = decoder_rnn1(
343
+ decoder_rnn_state1, decoder_rnn1_input
344
+ )
345
+ decoder_rnn2_input = decoder_rnn1_input + decoder_rnn_output1
346
+ decoder_rnn2 = self.zoneout_lstm(self.decoder_rnn2, zk2)
347
+ decoder_rnn_state2, decoder_rnn_output2 = decoder_rnn2(
348
+ decoder_rnn_state2, decoder_rnn2_input
349
+ )
350
+ x = decoder_rnn1_input + decoder_rnn_output1 + decoder_rnn_output2
351
+ decoder_rnn_states = (decoder_rnn_state1, decoder_rnn_state2)
352
+ return attn_state, decoder_rnn_states, rng_key_next, x, attn_pr[0]
353
+
354
+ @jax.jit
355
+ def inference_step(
356
+ self, attn_state, decoder_rnn_states, rng_key, mel, text, text_key
357
+ ):
358
+ """one inference step"""
359
+ attn_state, decoder_rnn_states, rng_key, x, _ = self.decoder_step(
360
+ attn_state,
361
+ decoder_rnn_states,
362
+ rng_key,
363
+ mel,
364
+ text,
365
+ text_key,
366
+ call_pre_net=True,
367
+ )
368
+ x = self.output_fc(x)
369
+ N, D2 = x.shape
370
+ x = jnp.reshape(x, (N, self.max_rr, D2 // self.max_rr))
371
+ x = x[:, : self.rr, :]
372
+ x = jnp.reshape(x, (N, self.rr, -1))
373
+ mel = x[..., :-1]
374
+ eos_logit = x[..., -1]
375
+ eos_pr = jax.nn.sigmoid(eos_logit[0, -1])
376
+ eos_pr = jnp.where(eos_pr < 0.1, 0.0, eos_pr)
377
+ rng_key, eos_rng_key = jax.random.split(rng_key)
378
+ eos = jax.random.bernoulli(eos_rng_key, p=eos_pr)
379
+ return attn_state, decoder_rnn_states, rng_key, (mel, eos)
380
+
381
+ def inference(self, text, seed=42, max_len=1000):
382
+ """
383
+ text to mel
384
+ """
385
+ text = self.encode_text(text)
386
+ text_key = self.text_key_fc(text)
387
+ N, L, D = text.shape
388
+ assert N == 1
389
+ mel = self.go_frame(N)
390
+
391
+ attn_state, decoder_rnn_states = self.decoder_initial_state(N, L)
392
+ rng_key = jax.random.PRNGKey(seed)
393
+ mels = []
394
+ count = 0
395
+ while True:
396
+ count = count + 1
397
+ attn_state, decoder_rnn_states, rng_key, (mel, eos) = self.inference_step(
398
+ attn_state, decoder_rnn_states, rng_key, mel, text, text_key
399
+ )
400
+ mels.append(mel)
401
+ if eos.item() or count > max_len:
402
+ break
403
+
404
+ mel = mel[:, -1, :]
405
+
406
+ mels = jnp.concatenate(mels, axis=1)
407
+ mel = mel + self.post_net(mel)
408
+ return mels
409
+
410
+ def decode(self, mel, text):
411
+ """
412
+ Attention mechanism + Decoder
413
+ """
414
+ text_key = self.text_key_fc(text)
415
+
416
+ def scan_fn(prev_states, inputs):
417
+ attn_state, decoder_rnn_states = prev_states
418
+ x, rng_key = inputs
419
+ attn_state, decoder_rnn_states, _, output, attn_pr = self.decoder_step(
420
+ attn_state, decoder_rnn_states, rng_key, x, text, text_key
421
+ )
422
+ states = (attn_state, decoder_rnn_states)
423
+ return states, (output, attn_pr)
424
+
425
+ N, L, D = text.shape
426
+ decoder_states = self.decoder_initial_state(N, L)
427
+ rng_keys = self.rng_seq.next_rng_key(mel.shape[1])
428
+ rng_keys = jnp.stack(rng_keys, axis=1)
429
+ decoder_states, (x, attn_log) = pax.scan(
430
+ scan_fn,
431
+ decoder_states,
432
+ (mel, rng_keys),
433
+ time_major=False,
434
+ )
435
+ self.attn_log = attn_log
436
+ del decoder_states
437
+ x = self.output_fc(x)
438
+
439
+ N, T2, D2 = x.shape
440
+ x = jnp.reshape(x, (N, T2, self.max_rr, D2 // self.max_rr))
441
+ x = x[:, :, : self.rr, :]
442
+ x = jnp.reshape(x, (N, T2 * self.rr, -1))
443
+ mel = x[..., :-1]
444
+ eos = x[..., -1]
445
+ return mel, eos
446
+
447
+ def __call__(self, mel: jnp.ndarray, text: jnp.ndarray):
448
+ text = self.encode_text(text)
449
+ mel = self.decoder_pre_net(mel)
450
+ mel, eos = self.decode(mel, text)
451
+ return mel, mel + self.post_net(mel), eos
tacotron.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tacotron]
2
+
3
+ # training
4
+ BATCH_SIZE = 64
5
+ LR=1024e-6 # learning rate
6
+ MODEL_PREFIX = "mono_tts_cbhg_small"
7
+ LOG_DIR = "./logs"
8
+ CKPT_DIR = "./ckpts"
9
+ USE_MP = false # use mixed-precision training
10
+
11
+ # data
12
+ TF_DATA_DIR = "./tf_data" # tensorflow data directory
13
+ TF_GTA_DATA_DIR = "./tf_gta_data" # tf gta data directory
14
+ SAMPLE_RATE = 24000 # convert to this sample rate if needed
15
+ MEL_DIM = 80 # the dimension of melspectrogram features
16
+ MEL_MIN = 1e-5
17
+ PAD = "_" # padding character
18
+ PAD_TOKEN = 0
19
+ END_CHARACTER = "■" # to signal the end of the transcript
20
+ TEST_DATA_SIZE = 1024
21
+
22
+ # model
23
+ RR = 1 # reduction factor
24
+ MAX_RR=2
25
+ ATTN_BIAS = 0.0 # control how slow the attention moves forward
26
+ SIGMOID_NOISE = 2.0
27
+ PRENET_DIM = 128
28
+ TEXT_DIM = 256
29
+ RNN_DIM = 512
30
+ ATTN_RNN_DIM = 256
31
+ ATTN_HIDDEN_DIM = 128
32
+ POSTNET_DIM = 512
text.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """ from https://github.com/keithito/tacotron """
2
+
3
+ # """
4
+ # Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ # Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ # hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ # 1. "english_cleaners" for English text
9
+ # 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ # the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ # 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ # the symbols in symbols.py to match your data).
13
+ # """
14
+
15
+ # import re
16
+ # from mynumbers import normalize_numbers
17
+ # from unidecode import unidecode
18
+
19
+ # # Regular expression matching whitespace:
20
+ # _whitespace_re = re.compile(r"\s+")
21
+
22
+ # # List of (regular expression, replacement) pairs for abbreviations:
23
+ # _abbreviations = [
24
+ # (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
25
+ # for x in [
26
+ # ("mrs", "misess"),
27
+ # ("mr", "mister"),
28
+ # ("dr", "doctor"),
29
+ # ("st", "saint"),
30
+ # ("co", "company"),
31
+ # ("jr", "junior"),
32
+ # ("maj", "major"),
33
+ # ("gen", "general"),
34
+ # ("drs", "doctors"),
35
+ # ("rev", "reverend"),
36
+ # ("lt", "lieutenant"),
37
+ # ("hon", "honorable"),
38
+ # ("sgt", "sergeant"),
39
+ # ("capt", "captain"),
40
+ # ("esq", "esquire"),
41
+ # ("ltd", "limited"),
42
+ # ("col", "colonel"),
43
+ # ("ft", "fort"),
44
+ # ]
45
+ # ]
46
+
47
+
48
+ # def expand_abbreviations(text):
49
+ # for regex, replacement in _abbreviations:
50
+ # text = re.sub(regex, replacement, text)
51
+ # return text
52
+
53
+
54
+ # def expand_numbers(text):
55
+ # return normalize_numbers(text)
56
+
57
+
58
+ # def lowercase(text):
59
+ # return text.lower()
60
+
61
+
62
+ # def collapse_whitespace(text):
63
+ # return re.sub(_whitespace_re, " ", text)
64
+
65
+
66
+ # def convert_to_ascii(text):
67
+ # return unidecode(text)
68
+
69
+
70
+ # def basic_cleaners(text):
71
+ # """Basic pipeline that lowercases and collapses whitespace without transliteration."""
72
+ # text = lowercase(text)
73
+ # text = collapse_whitespace(text)
74
+ # return text
75
+
76
+
77
+ # def transliteration_cleaners(text):
78
+ # """Pipeline for non-English text that transliterates to ASCII."""
79
+ # text = convert_to_ascii(text)
80
+ # text = lowercase(text)
81
+ # text = collapse_whitespace(text)
82
+ # return text
83
+
84
+
85
+ # def english_cleaners(text):
86
+ # """Pipeline for English text, including number and abbreviation expansion."""
87
+ # text = convert_to_ascii(text)
88
+ # text = lowercase(text)
89
+ # text = expand_numbers(text)
90
+ # text = expand_abbreviations(text)
91
+ # text = collapse_whitespace(text)
92
+ # return text
utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions
3
+ """
4
+ import pickle
5
+ from pathlib import Path
6
+
7
+ import pax
8
+ import toml
9
+ import yaml
10
+
11
+ from tacotron import Tacotron
12
+
13
+
14
+ def load_tacotron_config(config_file=Path("tacotron.toml")):
15
+ """
16
+ Load the project configurations
17
+ """
18
+ return toml.load(config_file)["tacotron"]
19
+
20
+
21
+ def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
22
+ """
23
+ load checkpoint from disk
24
+ """
25
+ with open(path, "rb") as f:
26
+ dic = pickle.load(f)
27
+ if net is not None:
28
+ net = net.load_state_dict(dic["model_state_dict"])
29
+ if optim is not None:
30
+ optim = optim.load_state_dict(dic["optim_state_dict"])
31
+ return dic["step"], net, optim
32
+
33
+
34
+ def create_tacotron_model(config):
35
+ """
36
+ return a random initialized Tacotron model
37
+ """
38
+ return Tacotron(
39
+ mel_dim=config["MEL_DIM"],
40
+ attn_bias=config["ATTN_BIAS"],
41
+ rr=config["RR"],
42
+ max_rr=config["MAX_RR"],
43
+ mel_min=config["MEL_MIN"],
44
+ sigmoid_noise=config["SIGMOID_NOISE"],
45
+ pad_token=config["PAD_TOKEN"],
46
+ prenet_dim=config["PRENET_DIM"],
47
+ attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
48
+ attn_rnn_dim=config["ATTN_RNN_DIM"],
49
+ rnn_dim=config["RNN_DIM"],
50
+ postnet_dim=config["POSTNET_DIM"],
51
+ text_dim=config["TEXT_DIM"],
52
+ )
53
+
54
+
55
+ def load_wavegru_config(config_file):
56
+ """
57
+ Load project configurations
58
+ """
59
+ with open(config_file, "r", encoding="utf-8") as f:
60
+ return yaml.safe_load(f)
61
+
62
+
63
+ def load_wavegru_ckpt(net, optim, ckpt_file):
64
+ """
65
+ load training checkpoint from file
66
+ """
67
+ with open(ckpt_file, "rb") as f:
68
+ dic = pickle.load(f)
69
+
70
+ if net is not None:
71
+ net = net.load_state_dict(dic["net_state_dict"])
72
+ if optim is not None:
73
+ optim = optim.load_state_dict(dic["optim_state_dict"])
74
+ return dic["step"], net, optim
wavegru.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WaveGRU model: melspectrogram => mu-law encoded waveform
3
+ """
4
+
5
+ from typing import Tuple
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import pax
10
+ from pax import GRUState
11
+ from tqdm.cli import tqdm
12
+
13
+
14
+ class ReLU(pax.Module):
15
+ def __call__(self, x):
16
+ return jax.nn.relu(x)
17
+
18
+
19
+ def dilated_residual_conv_block(dim, kernel, stride, dilation):
20
+ """
21
+ Use dilated convs to enlarge the receptive field
22
+ """
23
+ return pax.Sequential(
24
+ pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
25
+ pax.LayerNorm(dim, -1, True, True),
26
+ ReLU(),
27
+ pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
28
+ pax.LayerNorm(dim, -1, True, True),
29
+ ReLU(),
30
+ )
31
+
32
+
33
+ def tile_1d(x, factor):
34
+ """
35
+ Tile tensor of shape N, L, D into N, L*factor, D
36
+ """
37
+ N, L, D = x.shape
38
+ x = x[:, :, None, :]
39
+ x = jnp.tile(x, (1, 1, factor, 1))
40
+ x = jnp.reshape(x, (N, L * factor, D))
41
+ return x
42
+
43
+
44
+ def up_block(in_dim, out_dim, factor, relu=True):
45
+ """
46
+ Tile >> Conv >> BatchNorm >> ReLU
47
+ """
48
+ f = pax.Sequential(
49
+ lambda x: tile_1d(x, factor),
50
+ pax.Conv1D(
51
+ in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
52
+ ),
53
+ pax.LayerNorm(out_dim, -1, True, True),
54
+ )
55
+ if relu:
56
+ f >>= ReLU()
57
+ return f
58
+
59
+
60
+ class Upsample(pax.Module):
61
+ """
62
+ Upsample melspectrogram to match raw audio sample rate.
63
+ """
64
+
65
+ def __init__(
66
+ self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False
67
+ ):
68
+ super().__init__()
69
+ self.input_conv = pax.Sequential(
70
+ pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False),
71
+ pax.LayerNorm(hidden_dim, -1, True, True),
72
+ )
73
+ self.upsample_factors = upsample_factors
74
+ self.dilated_convs = [
75
+ dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5)
76
+ ]
77
+ self.up_factors = upsample_factors[:-1]
78
+ self.up_blocks = [
79
+ up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1]
80
+ ]
81
+ self.up_blocks.append(
82
+ up_block(
83
+ hidden_dim,
84
+ hidden_dim if has_linear_output else 3 * rnn_dim,
85
+ self.up_factors[-1],
86
+ relu=False,
87
+ )
88
+ )
89
+ if has_linear_output:
90
+ self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3)
91
+ self.has_linear_output = has_linear_output
92
+
93
+ self.final_tile = upsample_factors[-1]
94
+
95
+ def __call__(self, x, no_repeat=False):
96
+ x = self.input_conv(x)
97
+ for residual in self.dilated_convs:
98
+ y = residual(x)
99
+ pad = (x.shape[1] - y.shape[1]) // 2
100
+ x = x[:, pad:-pad, :] + y
101
+
102
+ for f in self.up_blocks:
103
+ x = f(x)
104
+
105
+ if self.has_linear_output:
106
+ x = self.x2zrh_fc(x)
107
+
108
+ if no_repeat:
109
+ return x
110
+ x = tile_1d(x, self.final_tile)
111
+ return x
112
+
113
+
114
+ class GRU(pax.Module):
115
+ """
116
+ A customized GRU module.
117
+ """
118
+
119
+ input_dim: int
120
+ hidden_dim: int
121
+
122
+ def __init__(self, hidden_dim: int):
123
+ super().__init__()
124
+ self.hidden_dim = hidden_dim
125
+ self.h_zrh_fc = pax.Linear(
126
+ hidden_dim,
127
+ hidden_dim * 3,
128
+ w_init=jax.nn.initializers.variance_scaling(
129
+ 1, "fan_out", "truncated_normal"
130
+ ),
131
+ )
132
+
133
+ def initial_state(self, batch_size: int) -> GRUState:
134
+ """Create an all zeros initial state."""
135
+ return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
136
+
137
+ def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
138
+ hidden = state.hidden
139
+ x_zrh = x
140
+ h_zrh = self.h_zrh_fc(hidden)
141
+ x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
142
+ h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
143
+
144
+ zr = x_zr + h_zr
145
+ zr = jax.nn.sigmoid(zr)
146
+ z, r = jnp.split(zr, 2, axis=-1)
147
+
148
+ h_hat = x_h + r * h_h
149
+ h_hat = jnp.tanh(h_hat)
150
+
151
+ h = (1 - z) * hidden + z * h_hat
152
+ return GRUState(h), h
153
+
154
+
155
+ class Pruner(pax.Module):
156
+ """
157
+ Base class for pruners
158
+ """
159
+
160
+ def compute_sparsity(self, step):
161
+ t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
162
+ z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1)
163
+ return z
164
+
165
+ def prune(self, step, weights):
166
+ """
167
+ Return a mask
168
+ """
169
+ z = self.compute_sparsity(step)
170
+ x = weights
171
+ H, W = x.shape
172
+ x = x.reshape(H // 4, 4, W // 4, 4)
173
+ x = jnp.abs(x)
174
+ x = jnp.sum(x, axis=(1, 3), keepdims=True)
175
+ q = jnp.quantile(jnp.reshape(x, (-1,)), z)
176
+ x = x >= q
177
+ x = jnp.tile(x, (1, 4, 1, 4))
178
+ x = jnp.reshape(x, (H, W))
179
+ return x
180
+
181
+
182
+ class GRUPruner(Pruner):
183
+ def __init__(self, gru):
184
+ super().__init__()
185
+ self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
186
+
187
+ def __call__(self, gru: pax.GRU):
188
+ """
189
+ Apply mask after an optimization step
190
+ """
191
+ zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
192
+ gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
193
+ return gru
194
+
195
+ def update_mask(self, step, gru: pax.GRU):
196
+ """
197
+ Update internal masks
198
+ """
199
+ z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
200
+ z_mask = self.prune(step, z_weight)
201
+ r_mask = self.prune(step, r_weight)
202
+ h_mask = self.prune(step, h_weight)
203
+ self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
204
+
205
+
206
+ class LinearPruner(Pruner):
207
+ def __init__(self, linear):
208
+ super().__init__()
209
+ self.mask = jnp.ones_like(linear.weight) == 1
210
+
211
+ def __call__(self, linear: pax.Linear):
212
+ """
213
+ Apply mask after an optimization step
214
+ """
215
+ return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))
216
+
217
+ def update_mask(self, step, linear: pax.Linear):
218
+ """
219
+ Update internal masks
220
+ """
221
+ self.mask *= self.prune(step, linear.weight)
222
+
223
+
224
+ class WaveGRU(pax.Module):
225
+ """
226
+ WaveGRU vocoder model.
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ mel_dim=80,
232
+ rnn_dim=1024,
233
+ upsample_factors=(5, 3, 20),
234
+ has_linear_output=False,
235
+ ):
236
+ super().__init__()
237
+ self.embed = pax.Embed(256, 3 * rnn_dim)
238
+ self.upsample = Upsample(
239
+ input_dim=mel_dim,
240
+ hidden_dim=512,
241
+ rnn_dim=rnn_dim,
242
+ upsample_factors=upsample_factors,
243
+ has_linear_output=has_linear_output,
244
+ )
245
+ self.rnn = GRU(rnn_dim)
246
+ self.o1 = pax.Linear(rnn_dim, rnn_dim)
247
+ self.o2 = pax.Linear(rnn_dim, 256)
248
+ self.gru_pruner = GRUPruner(self.rnn)
249
+ self.o1_pruner = LinearPruner(self.o1)
250
+ self.o2_pruner = LinearPruner(self.o2)
251
+
252
+ def output(self, x):
253
+ x = self.o1(x)
254
+ x = jax.nn.relu(x)
255
+ x = self.o2(x)
256
+ return x
257
+
258
+ def inference(self, mel, no_gru=False, seed=42):
259
+ """
260
+ generate waveform form melspectrogram
261
+ """
262
+
263
+ @jax.jit
264
+ def step(rnn_state, mel, rng_key, x):
265
+ x = self.embed(x)
266
+ x = x + mel
267
+ rnn_state, x = self.rnn(rnn_state, x)
268
+ x = self.output(x)
269
+ rng_key, next_rng_key = jax.random.split(rng_key, 2)
270
+ x = jax.random.categorical(rng_key, x, axis=-1)
271
+ return rnn_state, next_rng_key, x
272
+
273
+ y = self.upsample(mel, no_repeat=no_gru)
274
+ if no_gru:
275
+ return y
276
+ x = jnp.array([127], dtype=jnp.int32)
277
+ rnn_state = self.rnn.initial_state(1)
278
+ output = []
279
+ rng_key = jax.random.PRNGKey(seed)
280
+ for i in tqdm(range(y.shape[1])):
281
+ rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
282
+ output.append(x)
283
+ x = jnp.concatenate(output, axis=0)
284
+ return x
285
+
286
+ def __call__(self, mel, x):
287
+ x = self.embed(x)
288
+ y = self.upsample(mel)
289
+ pad_left = (x.shape[1] - y.shape[1]) // 2
290
+ pad_right = x.shape[1] - y.shape[1] - pad_left
291
+ x = x[:, pad_left:-pad_right]
292
+ x = x + y
293
+ _, x = pax.scan(
294
+ self.rnn,
295
+ self.rnn.initial_state(x.shape[0]),
296
+ x,
297
+ time_major=False,
298
+ )
299
+ x = self.output(x)
300
+ return x
wavegru.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## dsp
2
+ sample_rate : 24000
3
+ window_length: 50.0 # ms
4
+ hop_length: 12.5 # ms
5
+ mel_min: 1.0e-5 ## need .0 to make it a float
6
+ mel_dim: 80
7
+ n_fft: 2048
8
+
9
+ ## wavegru
10
+ embed_dim: 32
11
+ rnn_dim: 1024
12
+ frames_per_sequence: 67
13
+ num_pad_frames: 62
14
+ upsample_factors: [5, 3, 20]
wavegru_cpp.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from wavegru_mod import WaveGRU
3
+
4
+
5
+ def extract_weight_mask(net):
6
+ data = {}
7
+ data["embed_weight"] = net.embed.weight
8
+ data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight
9
+ data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask
10
+ data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias
11
+
12
+ data["o1_weight"] = net.o1.weight
13
+ data["o1_mask"] = net.o1_pruner.mask
14
+ data["o1_bias"] = net.o1.bias
15
+ data["o2_weight"] = net.o2.weight
16
+ data["o2_mask"] = net.o2_pruner.mask
17
+ data["o2_bias"] = net.o2.bias
18
+ return data
19
+
20
+
21
+ def load_wavegru_cpp(data, repeat_factor):
22
+ """load wavegru weight to cpp object"""
23
+ embed = data["embed_weight"]
24
+ rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3
25
+ net = WaveGRU(rnn_dim, repeat_factor)
26
+ net.load_embed(embed)
27
+
28
+ m = np.ascontiguousarray(data["gru_h_zrh_weight"].T)
29
+ mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T)
30
+ b = data["gru_h_zrh_bias"]
31
+
32
+ o1 = np.ascontiguousarray(data["o1_weight"].T)
33
+ masko1 = np.ascontiguousarray(data["o1_mask"].T)
34
+ o1b = data["o1_bias"]
35
+
36
+ o2 = np.ascontiguousarray(data["o2_weight"].T)
37
+ masko2 = np.ascontiguousarray(data["o2_mask"].T)
38
+ o2b = data["o2_bias"]
39
+
40
+ net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b)
41
+
42
+ return net
wavegru_mod.cc ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ WaveGRU:
3
+ > Embed > GRU > O1 > O2 > Sampling > ...
4
+ */
5
+
6
+ #include <pybind11/numpy.h>
7
+ #include <pybind11/pybind11.h>
8
+ #include <pybind11/stl.h>
9
+
10
+ #include <iostream>
11
+ #include <random>
12
+ #include <vector>
13
+
14
+ #include "sparse_matmul/sparse_matmul.h"
15
+ namespace py = pybind11;
16
+ using namespace std;
17
+
18
+ using fvec = std::vector<float>;
19
+ using ivec = std::vector<int>;
20
+ using fndarray = py::array_t<float>;
21
+ using indarray = py::array_t<int>;
22
+ using mat = csrblocksparse::CsrBlockSparseMatrix<float, float, int16_t>;
23
+ using vec = csrblocksparse::CacheAlignedVector<float>;
24
+ using masked_mat = csrblocksparse::MaskedSparseMatrix<float>;
25
+
26
+ mat create_mat(int h, int w) {
27
+ auto m = masked_mat(w, h, 0.90, 4, 4, 0.0, true);
28
+ auto a = mat(m);
29
+ return a;
30
+ }
31
+
32
+ struct WaveGRU {
33
+ int hidden_dim;
34
+ int repeat_factor;
35
+ mat m;
36
+ vec b;
37
+ vec z, r, hh, zrh;
38
+ vec fco1, fco2;
39
+ vec o1b, o2b;
40
+ vec t;
41
+ vec h;
42
+ vec logits;
43
+ mat o1, o2;
44
+ std::vector<vec> embed;
45
+
46
+ WaveGRU(int hidden_dim, int repeat_factor)
47
+ : hidden_dim(hidden_dim),
48
+ repeat_factor(repeat_factor),
49
+ b(3*hidden_dim),
50
+ t(3*hidden_dim),
51
+ zrh(3*hidden_dim),
52
+ z(hidden_dim),
53
+ r(hidden_dim),
54
+ hh(hidden_dim),
55
+ fco1(hidden_dim),
56
+ fco2(256),
57
+ h(hidden_dim),
58
+ o1b(hidden_dim),
59
+ o2b(256),
60
+ logits(256) {
61
+ m = create_mat(hidden_dim, 3*hidden_dim);
62
+ o1 = create_mat(hidden_dim, hidden_dim);
63
+ o2 = create_mat(hidden_dim, 256);
64
+ embed = std::vector<vec>();
65
+ for (int i = 0; i < 256; i++) {
66
+ embed.emplace_back(hidden_dim * 3);
67
+ embed[i].FillRandom();
68
+ }
69
+ }
70
+
71
+ void load_embed(fndarray embed_weights) {
72
+ auto a_embed = embed_weights.unchecked<2>();
73
+ for (int i = 0; i < 256; i++) {
74
+ for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j);
75
+ }
76
+ }
77
+
78
+ mat load_linear(vec& bias, fndarray w, indarray mask, fndarray b) {
79
+ auto w_ptr = static_cast<float*>(w.request().ptr);
80
+ auto mask_ptr = static_cast<int*>(mask.request().ptr);
81
+ auto rb = b.unchecked<1>();
82
+ // load bias, scale by 1/4
83
+ for (int i = 0; i < rb.shape(0); i++) bias[i] = rb(i) / 4;
84
+ // load weights
85
+ masked_mat mm(w.shape(0), w.shape(1), mask_ptr, w_ptr);
86
+ mat mmm(mm);
87
+ return mmm;
88
+ }
89
+
90
+ void load_weights(fndarray m, indarray m_mask, fndarray b,
91
+ fndarray o1, indarray o1_mask,
92
+ fndarray o1b, fndarray o2,
93
+ indarray o2_mask, fndarray o2b) {
94
+ this->m = load_linear(this->b, m, m_mask, b);
95
+ this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
96
+ this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
97
+ }
98
+
99
+ std::vector<int> inference(fndarray ft, float temperature) {
100
+ auto rft = ft.unchecked<2>();
101
+ int value = 127;
102
+ std::vector<int> signal(rft.shape(0) * repeat_factor);
103
+ h.FillZero();
104
+ for (int index = 0; index < signal.size(); index++) {
105
+ m.SpMM_bias(h, b, &zrh, false);
106
+
107
+ for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i);
108
+ for (int i = 0; i < hidden_dim; i++) {
109
+ z[i] = zrh[i] + t[i];
110
+ r[i] = zrh[hidden_dim + i] + t[hidden_dim + i];
111
+ }
112
+
113
+ z.Sigmoid();
114
+ r.Sigmoid();
115
+
116
+ for (int i = 0; i < hidden_dim; i++) {
117
+ hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i];
118
+ }
119
+ hh.Tanh();
120
+ for (int i = 0; i < hidden_dim; i++) {
121
+ h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
122
+ }
123
+ o1.SpMM_bias(h, o1b, &fco1, true);
124
+ o2.SpMM_bias(fco1, o2b, &fco2, false);
125
+ // auto max_logit = fco2[0];
126
+ // for (int i = 1; i <= 255; ++i) {
127
+ // max_logit = max(max_logit, fco2[i]);
128
+ // }
129
+ // float total = 0.0;
130
+ // for (int i = 0; i <= 255; ++i) {
131
+ // logits[i] = csrblocksparse::fast_exp(fco2[i] - max_logit);
132
+ // total += logits[i];
133
+ // }
134
+ // for (int i = 0; i <= 255; ++i) {
135
+ // if (logits[i] < total / 1024.0) fco2[i] = -1e9;
136
+ // }
137
+ value = fco2.Sample(temperature);
138
+ signal[index] = value;
139
+ }
140
+ return signal;
141
+ }
142
+ };
143
+
144
+ PYBIND11_MODULE(wavegru_mod, m) {
145
+ py::class_<WaveGRU>(m, "WaveGRU")
146
+ .def(py::init<int, int>())
147
+ .def("load_embed", &WaveGRU::load_embed)
148
+ .def("load_weights", &WaveGRU::load_weights)
149
+ .def("inference", &WaveGRU::inference);
150
+ }